Remove ad-hoc logic to make the dest of `tensor.insert_slice` op a tied argument in dispatch region formation. (#8607)
With the tile and distribute moved out of flow this is not needed
anymore. The normal tied operands computation works fine.
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 2476b40..e073f17 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -214,29 +214,6 @@
SmallVector<Value> operands, operandDims;
SmallVector<int64_t> tiedOperands;
- // TODO(#...) This special handling of `tensor.insert_slice` op does need to
- // be here anymore. It can be moved to the same place as other ops where
- // readwrite operands are computed.
-
- if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
- // Handle tensor.insert_slice in a special manner. This op is actually two
- // steps:
- // 1) Copy over the dest tensor to the result,
- // 2) Update the overwritten part of the result with the destination.
- // To actually make this work, the dispatch region needs the `dest` and
- // result to be tied operands. This is somehow special. It might fall out
- // naturally, but not sure how. For now, just do it by construction.
- operands.push_back(insertSliceOp.dest());
- ReifiedRankedShapedTypeDims resultShapes;
- (void)insertSliceOp.reifyResultShapes(rewriter, resultShapes);
- auto destType = insertSliceOp.dest().getType().cast<ShapedType>();
- for (auto shape : enumerate(destType.getShape())) {
- if (shape.value() != ShapedType::kDynamicSize) continue;
- operandDims.push_back(resultShapes[0][shape.index()]);
- }
- tiedOperands.push_back(0);
- }
-
auto dispatchOp = rewriter.create<IREE::Flow::DispatchWorkgroupsOp>(
loc, count, op->getResultTypes(), resultDynamicDims, operands,
operandDims, tiedOperands);
@@ -262,30 +239,6 @@
rewriter.create<IREE::Flow::ReturnOp>(loc);
}
- // Handle read-write arguments. Need to insert a load of these as well to get
- // the tensor type from the !flow.dispatch.tensor type.
- {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPointToStart(block);
- unsigned dynamicDimIdx = 0;
- auto readWriteArgs = llvm::make_filter_range(
- dispatchOp.body().getArguments(), [](BlockArgument arg) {
- auto flowTensorType =
- arg.getType().dyn_cast<IREE::Flow::DispatchTensorType>();
- return flowTensorType && flowTensorType.getAccess() ==
- IREE::Flow::TensorAccess::ReadWrite;
- });
- for (auto it : llvm::enumerate(readWriteArgs)) {
- Value operand = dispatchOp.operands()[it.index()];
- auto operandType = operand.getType().cast<RankedTensorType>();
- auto dynamicDims = resultDynamicDims.slice(
- dynamicDimIdx, operandType.getNumDynamicDims());
- Value loadOp = rewriter.create<IREE::Flow::DispatchTensorLoadOp>(
- loc, operandType, it.value(), dynamicDims);
- clonedOp->replaceUsesOfWith(operand, loadOp);
- }
- }
-
LLVM_DEBUG(llvm::dbgs() << "Created dispatchOp shell \n"
<< *dispatchOp << "\n");
return {dispatchOp, clonedOp};
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 59553df..cff6bea 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -426,17 +426,17 @@
// CHECK-DAG: %[[ARG1_D0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[ARG1_D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups[%[[ARG0_D1]], %[[ARG0_D0]], %[[C1]]]
-// CHECK-SAME: (%[[ARG1]], %[[ARG1_D0]], %[[ARG1_D1]], %[[ARG0]], %[[ARG0_D0]], %[[ARG0_D1]],
+// CHECK-SAME: (%[[ARG0]], %[[ARG0_D0]], %[[ARG0_D1]], %[[ARG1]], %[[ARG1_D0]], %[[ARG1_D1]],
// CHECK-SAME: %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]])
-// CHECK-SAME: tensor<?x?xf32>{%[[ARG1_D0]], %[[ARG1_D1]]}
// CHECK-SAME: tensor<?x?xf32>{%[[ARG0_D0]], %[[ARG0_D1]]}
+// CHECK-SAME: tensor<?x?xf32>{%[[ARG1_D0]], %[[ARG1_D1]]}
// CHECK-SAME: -> %[[ARG1]]{%[[ARG1_D0]], %[[ARG1_D1]]}
-// CHECK-NEXT: %[[ARG1_CAPTURE:.+]]: !flow.dispatch.tensor<readwrite:?x?xf32>
-// CHECK-SAME: %[[ARG1_D0_CAPTURE:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG1_D1_CAPTURE:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG0_CAPTURE:.+]]: !flow.dispatch.tensor<readonly:?x?xf32>
+// CHECK-NEXT: %[[ARG0_CAPTURE:.+]]: !flow.dispatch.tensor<readonly:?x?xf32>
// CHECK-SAME: %[[ARG0_D0_CAPTURE:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG0_D1_CAPTURE:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG1_CAPTURE:.+]]: !flow.dispatch.tensor<readwrite:?x?xf32>
+// CHECK-SAME: %[[ARG1_D0_CAPTURE:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG1_D1_CAPTURE:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG2_CAPTURE:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3_CAPTURE:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4_CAPTURE:[a-zA-Z0-9]+]]: index
@@ -998,10 +998,10 @@
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG3]], %[[C0]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG3]], %[[C1]]
// CHECK: flow.dispatch.workgroups[%[[D0]], %[[C1]], %[[C1]]]
-// CHECK-SAME: tensor<?x?xi32>{%[[D1]], %[[D2]]}
// CHECK-SAME: tensor<?xi32>{%[[D0]]}
-// CHECK-NEXT: %[[ARG4:.+]]: !flow.dispatch.tensor<readwrite:?x?xi32>
-// CHECK-SAME: %[[ARG5:.+]]: !flow.dispatch.tensor<readonly:?xi32>
+// CHECK-SAME: tensor<?x?xi32>{%[[D1]], %[[D2]]}
+// CHECK-NEXT: !flow.dispatch.tensor<readonly:?xi32>
+// CHECK-SAME: !flow.dispatch.tensor<readwrite:?x?xi32>
// -----