Split clone_preceding_op_into_dispatch transform op (#10396)
This op is split into two parts: One op that clones and one op that
moves.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
index c1c0bfc..96a208c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
@@ -603,11 +603,7 @@
// All uses of the target inside of the dispatch region are replaced with the
// results of the cloned op.
//
-// If `updateUsesOutsideOfRegion` is set, all uses of the target op after the
-// dispatch region, are also updated: The target op's results are returned from
-// the dispatch region an used in those places.
-//
-// Example when `updateUsesOutsideOfRegion` is unset:
+// Example:
//
// %0 = "some_op"() : () -> (tensor<?xf32>)
// %r = flow.dispatch.region -> (tensor<?xf32>{%d0}) {
@@ -615,9 +611,41 @@
// flow.return %1 : tensor<?xf32>
// }
// %2 = "yet_another_use"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
+static LogicalResult clonePrecedingOpIntoDispatchRegion(
+ RewriterBase &rewriter, Operation *target,
+ Flow::DispatchRegionOp regionOp) {
+ Block &body = regionOp.getBody().front();
+
+ // Gather all uses of `target`.
+ SmallVector<OpOperand *> usesInsideOfRegion;
+ for (OpOperand &use : target->getUses()) {
+ if (regionOp->isProperAncestor(use.getOwner()))
+ usesInsideOfRegion.push_back(&use);
+ }
+
+ // Clone op into dispatch region.
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&body);
+ Operation *newTargetOp = rewriter.clone(*target);
+
+ // Replace all uses in the dispatch region.
+ for (OpOperand *use : usesInsideOfRegion) {
+ rewriter.updateRootInPlace(use->getOwner(), [&]() {
+ use->set(newTargetOp->getResult(
+ use->get().cast<OpResult>().getResultNumber()));
+ });
+ }
+
+ return success();
+}
+
+// Move a `target` op that is preceding the given dispatch region op into the
+// dispatch region.
//
-// In this example, "some_op" will be cloned into the dispatch region and the
-// OpOperand of "another_op" will be replaced:
+// All uses of the target outside of the dispatch region are replaced with the
+// results of the cloned op.
+//
+// Example:
//
// %0 = "some_op"() : () -> (tensor<?xf32>)
// %r = flow.dispatch.region -> (tensor<?xf32>{%d0}) {
@@ -626,72 +654,44 @@
// flow.return %1 : tensor<?xf32>
// }
// %2 = "yet_another_use"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
-static FailureOr<Flow::DispatchRegionOp> clonePrecedingOpIntoDispatchRegion(
- RewriterBase &rewriter, Operation *target, Flow::DispatchRegionOp regionOp,
- bool updateUsesOutsideOfRegion) {
- assert(target->isBeforeInBlock(regionOp) &&
- "expected that target comes first");
+static FailureOr<Flow::DispatchRegionOp> movePrecedingOpIntoDispatchRegion(
+ RewriterBase &rewriter, Operation *target,
+ Flow::DispatchRegionOp regionOp) {
+#ifndef NDEBUG
+ PostDominanceInfo domInfo;
+ for (OpOperand &use : target->getUses()) {
+ if (regionOp->isProperAncestor(use.getOwner())) continue;
+ assert(domInfo.postDominates(use.getOwner(), regionOp) &&
+ "found use that does not post-dominate target");
+ }
+#endif // NDEBUG
+
Block &body = regionOp.getBody().front();
// Gather all uses of `target`.
- SmallVector<OpOperand *> usesInsideOfRegion, usesAfterRegion;
- bool hasUsesBeforeRegion = false;
- for (OpOperand &use : target->getUses()) {
- if (regionOp->isProperAncestor(use.getOwner())) {
- usesInsideOfRegion.push_back(&use);
- } else {
- // Collect only uses that post-dominate the region.
- if (isAfterRegion(use.getOwner(), regionOp)) {
- usesAfterRegion.push_back(&use);
- } else {
- hasUsesBeforeRegion = true;
- }
- }
- }
+ SmallVector<OpOperand *> usesOutsideOfRegion;
+ for (OpOperand &use : target->getUses())
+ if (!regionOp->isProperAncestor(use.getOwner()))
+ usesOutsideOfRegion.push_back(&use);
- // Clone op into dispatch region.
- Operation *newTargetOp;
- if (usesAfterRegion.empty() && !hasUsesBeforeRegion) {
- // Optimization: If there are not uses outside of the region, we can simply
- // move the target instead of cloning it.
- target->moveBefore(&body.front());
- newTargetOp = target;
- } else {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(&body);
- newTargetOp = rewriter.clone(*target);
-
- // Replace all uses in the dispatch region.
- for (OpOperand *use : usesInsideOfRegion) {
- rewriter.updateRootInPlace(use->getOwner(), [&]() {
- use->set(newTargetOp->getResult(
- use->get().cast<OpResult>().getResultNumber()));
- });
- }
- }
+ // Move op into dispatch region.
+ target->moveBefore(&body.front());
// Replace all uses outside of the dispatch region.
- if (updateUsesOutsideOfRegion && !usesAfterRegion.empty()) {
- // Fail if there are uses before the dispatch region. In that case it does
- // usually not make sense to update uses after the region; we can just keep
- // using the original op result.
- if (hasUsesBeforeRegion) return failure();
-
+ if (!usesOutsideOfRegion.empty()) {
unsigned previousNumResults = regionOp->getNumResults();
// Note: Appending results one-by-one here so that this can be extended to
// specific results in the future. Many ops have just one result, so this
// should not be a large overhead.
- for (Value v : newTargetOp->getResults()) {
+ for (Value v : target->getResults()) {
auto newRegionOp = appendDispatchRegionResult(rewriter, regionOp, v);
if (failed(newRegionOp)) return failure();
regionOp = *newRegionOp;
}
// Replace uses of `target` after the dispatch region.
- for (OpOperand *use : usesAfterRegion) {
- assert(DominanceInfo().properlyDominates(regionOp, use->getOwner()) &&
- "all target uses must be inside or after regionOp");
+ for (OpOperand *use : usesOutsideOfRegion) {
rewriter.updateRootInPlace(use->getOwner(), [&]() {
use->set(
regionOp->getResult(previousNumResults +
@@ -700,26 +700,9 @@
}
}
- // Remove the original target if it no longer has any uses.
- if (target->use_empty()) rewriter.eraseOp(target);
-
return regionOp;
}
-// Move a `target` op that is preceding the given dispatch region op into the
-// dispatch region. All uses of the target must be inside the region.
-static FailureOr<Flow::DispatchRegionOp> movePrecedingOpIntoDispatchRegion(
- RewriterBase &rewriter, Operation *target,
- Flow::DispatchRegionOp regionOp) {
- assert(llvm::all_of(target->getUses(),
- [&](OpOperand &use) {
- return regionOp->isProperAncestor(use.getOwner());
- }) &&
- "cannot move target into region");
- return clonePrecedingOpIntoDispatchRegion(
- rewriter, target, regionOp, /*updateUsesOutsideOfRegion=*/false);
-}
-
DiagnosedSilenceableFailure
transform_dialect::ClonePrecedingOpIntoDispatchRegionOp::apply(
transform::TransformResults &transformResults,
@@ -753,9 +736,51 @@
SmallVector<Operation *> orderedTargets =
llvm::to_vector(llvm::reverse(targetOps));
IRRewriter rewriter(regionOp->getContext());
+ for (Operation *target : orderedTargets)
+ if (failed(clonePrecedingOpIntoDispatchRegion(rewriter, target, regionOp)))
+ return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+
+ transformResults.set(getTransformed().cast<OpResult>(),
+ regionOp.getOperation());
+ return DiagnosedSilenceableFailure(success());
+}
+
+DiagnosedSilenceableFailure
+transform_dialect::MovePrecedingOpIntoDispatchRegionOp::apply(
+ transform::TransformResults &transformResults,
+ transform::TransformState &state) {
+ ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+ ArrayRef<Operation *> dispatchRegion =
+ state.getPayloadOps(getDispatchRegion());
+
+ if (targetOps.empty() && dispatchRegion.empty()) {
+ transformResults.set(getResult().cast<OpResult>(),
+ SmallVector<mlir::Operation *>{});
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ if (dispatchRegion.size() != 1)
+ return DiagnosedSilenceableFailure(this->emitOpError(
+ "requires exactly one target/dispatch region handle"));
+
+ auto regionOp = dyn_cast<Flow::DispatchRegionOp>(dispatchRegion.front());
+ if (!regionOp)
+ return DiagnosedSilenceableFailure(
+ this->emitOpError("expected 'dispatch.region' operand"));
+
+ // We are cloning ops one-by-one, so the order must be inversed (as opposed
+ // to cloning all ops in one go).
+ SmallVector<Operation *> targetOpsList(targetOps.begin(), targetOps.end());
+ bool sortResult = computeTopologicalSorting(
+ dispatchRegion.front()->getBlock(), targetOpsList);
+ (void)sortResult;
+ assert(sortResult && "unable to sort topologically");
+ SmallVector<Operation *> orderedTargets =
+ llvm::to_vector(llvm::reverse(targetOps));
+ IRRewriter rewriter(regionOp->getContext());
for (Operation *target : orderedTargets) {
- auto newRegionOp = clonePrecedingOpIntoDispatchRegion(
- rewriter, target, regionOp, getUpdateUsesOutsideOfRegion());
+ auto newRegionOp =
+ movePrecedingOpIntoDispatchRegion(rewriter, target, regionOp);
if (failed(newRegionOp))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
regionOp = *newRegionOp;
@@ -915,8 +940,8 @@
makeEmptyDispatchRegion(rewriter, target->getLoc());
// Move the target into the dispatch region.
- auto newRegionOp = clonePrecedingOpIntoDispatchRegion(
- rewriter, target, regionOp, /*updateUsesOutsideOfRegion=*/true);
+ auto newRegionOp =
+ movePrecedingOpIntoDispatchRegion(rewriter, target, regionOp);
if (failed(newRegionOp))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
index 20b955b..2ec4c0d 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
@@ -111,27 +111,57 @@
handle must be mapped to exactly one payload op.
All uses of the target inside of the dispatch region are replaced with the
- results of the cloned op.
-
- If `update_uses_outside_of_region` is set (default value: `false`), all
- uses outside of the dispatch region are also replaced: The results of the
- cloned target op are yielded from the dispatch region and used in all uses
- outside of the dispatch region. The transform fails if there are uses that
- appear before the dispatch region.
+ results of the cloned op. Uses of the target outside of the dispatch region
+ remain unchanged.
#### Return modes
- This transform consumes both the `target` handle and the `dispatch_region`
+ This transform reads the `target` handle and consumes the `dispatch_region`
handle. It produces a new handle to the extended dispatch region.
}];
let arguments = (ins Arg<PDL_Operation, "",
- [TransformMappingRead,
- TransformMappingFree]>:$target,
+ [TransformMappingRead]>:$target,
Arg<PDL_Operation, "",
[TransformMappingRead,
- TransformMappingFree]>:$dispatch_region,
- DefaultValuedAttr<BoolAttr, "false">:$update_uses_outside_of_region);
+ TransformMappingFree]>:$dispatch_region);
+ let results = (outs Res<PDL_Operation, "",
+ [TransformMappingAlloc,
+ TransformMappingWrite]>:$transformed);
+ let assemblyFormat = "$target `into` $dispatch_region attr-dict";
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+def MovePrecedingOpIntoDispatchRegionOp : Op<
+ Transform_Dialect, "iree.move_preceding_op_into_dispatch_region",
+ [TransformOpInterface]> {
+ let description = [{
+ Move the `target` op into the given dispatch region op. The dispatch region
+ handle must be mapped to exactly one payload op.
+
+ An extra result is added to the dispatch region for every result of the
+ target op. All uses of the target op are replaced with the newly added
+ results of the dispatch region.
+
+ Note: This transform generates invalid IR if there are uses of the target op
+ that appear before (i.e., dominate) the dispatch region.
+
+ #### Return modes
+
+ This transform reads the `target` handle and consumes the `dispatch_region`
+ handle. It produces a new handle to the extended dispatch region.
+ }];
+
+ let arguments = (ins Arg<PDL_Operation, "",
+ [TransformMappingRead]>:$target,
+ Arg<PDL_Operation, "",
+ [TransformMappingRead,
+ TransformMappingFree]>:$dispatch_region);
let results = (outs Res<PDL_Operation, "",
[TransformMappingAlloc,
TransformMappingWrite]>:$transformed);
@@ -152,7 +182,7 @@
handle must be mapped to exactly one payload op.
All operands of the target are replaced with values that are defined inside
- of the dispatch region when possible.
+ of the dispatch region when possible.
If `update_uses_outside_of_region` is set (default value: `true`), all uses
of the original target op are replaced: The results of the cloned target op
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
index d8a1061..687eeb5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
@@ -84,7 +84,7 @@
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
%dispatch_op = transform.iree.wrap_in_dispatch_region %0
%1 = transform.structured.match ops{["tensor.extract_slice"]} in %arg1
- transform.iree.clone_preceding_op_into_dispatch_region %1 into %dispatch_op {update_uses_outside_of_region = true}
+ transform.iree.move_preceding_op_into_dispatch_region %1 into %dispatch_op
}
}
@@ -116,20 +116,18 @@
// -----
-// CHECK-LABEL: func @move_multiple_preceding
+// CHECK-LABEL: func @clone_multiple_preceding
// CHECK-DAG: arith.constant
// CHECK-DAG: arith.constant
// CHECK-DAG: tensor.dim
// CHECK-DAG: tensor.dim
-// CHECK-NEXT: "test.dummy_op"
-// CHECK-NEXT: "test.third_user"
-// CHECK-NEXT: flow.dispatch.region
+// CHECK: flow.dispatch.region
// CHECK-NEXT: "test.dummy_op"
// CHECK-NEXT: "test.first_user"
// CHECK-NEXT: "test.second_user"
// CHECK-NEXT: "test.merge1"
// CHECK-NEXT: "test.merge2"
-func.func @move_multiple_preceding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1: index, %s2: index) -> (tensor<?x?xf32>) {
+func.func @clone_multiple_preceding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1: index, %s2: index) -> (tensor<?x?xf32>) {
%0 = "test.dummy_op"(%arg0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%1 = "test.first_user"(%0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%2 = "test.second_user"(%0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
diff --git a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
index 8986ab1..7ddaeb7 100644
--- a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
@@ -14,8 +14,8 @@
// TODO: this could be replaced by a C++ only version.
// Atm the IR produced is not the same so all pieces do not connect.
%region_op = transform.iree.wrap_in_dispatch_region %root
- %region_op_2 = transform.iree.clone_preceding_op_into_dispatch_region %red into %region_op
- %region_op_3 = transform.iree.clone_preceding_op_into_dispatch_region %fill into %region_op_2
+ %region_op_2 = transform.iree.move_preceding_op_into_dispatch_region %red into %region_op
+ %region_op_3 = transform.iree.move_preceding_op_into_dispatch_region %fill into %region_op_2
transform.iree.region_to_workgroups %region_op_3
}
}