Split clone_preceding_op_into_dispatch transform op (#10396)

This op is split into two parts: One op that clones and one op that
moves.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
index c1c0bfc..96a208c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
@@ -603,11 +603,7 @@
 // All uses of the target inside of the dispatch region are replaced with the
 // results of the cloned op.
 //
-// If `updateUsesOutsideOfRegion` is set, all uses of the target op after the
-// dispatch region, are also updated: The target op's results are returned from
-// the dispatch region an used in those places.
-//
-// Example when `updateUsesOutsideOfRegion` is unset:
+// Example:
 //
 // %0 = "some_op"() : () -> (tensor<?xf32>)
 // %r = flow.dispatch.region -> (tensor<?xf32>{%d0}) {
@@ -615,9 +611,41 @@
 //   flow.return %1 : tensor<?xf32>
 // }
 // %2 = "yet_another_use"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
+static LogicalResult clonePrecedingOpIntoDispatchRegion(
+    RewriterBase &rewriter, Operation *target,
+    Flow::DispatchRegionOp regionOp) {
+  Block &body = regionOp.getBody().front();
+
+  // Gather all uses of `target`.
+  SmallVector<OpOperand *> usesInsideOfRegion;
+  for (OpOperand &use : target->getUses()) {
+    if (regionOp->isProperAncestor(use.getOwner()))
+      usesInsideOfRegion.push_back(&use);
+  }
+
+  // Clone op into dispatch region.
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(&body);
+  Operation *newTargetOp = rewriter.clone(*target);
+
+  // Replace all uses in the dispatch region.
+  for (OpOperand *use : usesInsideOfRegion) {
+    rewriter.updateRootInPlace(use->getOwner(), [&]() {
+      use->set(newTargetOp->getResult(
+          use->get().cast<OpResult>().getResultNumber()));
+    });
+  }
+
+  return success();
+}
+
+// Move a `target` op that is preceding the given dispatch region op into the
+// dispatch region.
 //
-// In this example, "some_op" will be cloned into the dispatch region and the
-// OpOperand of "another_op" will be replaced:
+// All uses of the target outside of the dispatch region are replaced with the
+// results of the cloned op.
+//
+// Example:
 //
 // %0 = "some_op"() : () -> (tensor<?xf32>)
 // %r = flow.dispatch.region -> (tensor<?xf32>{%d0}) {
@@ -626,72 +654,44 @@
 //   flow.return %1 : tensor<?xf32>
 // }
 // %2 = "yet_another_use"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
-static FailureOr<Flow::DispatchRegionOp> clonePrecedingOpIntoDispatchRegion(
-    RewriterBase &rewriter, Operation *target, Flow::DispatchRegionOp regionOp,
-    bool updateUsesOutsideOfRegion) {
-  assert(target->isBeforeInBlock(regionOp) &&
-         "expected that target comes first");
+static FailureOr<Flow::DispatchRegionOp> movePrecedingOpIntoDispatchRegion(
+    RewriterBase &rewriter, Operation *target,
+    Flow::DispatchRegionOp regionOp) {
+#ifndef NDEBUG
+  PostDominanceInfo domInfo;
+  for (OpOperand &use : target->getUses()) {
+    if (regionOp->isProperAncestor(use.getOwner())) continue;
+    assert(domInfo.postDominates(use.getOwner(), regionOp) &&
+           "found use that does not post-dominate target");
+  }
+#endif  // NDEBUG
+
   Block &body = regionOp.getBody().front();
 
   // Gather all uses of `target`.
-  SmallVector<OpOperand *> usesInsideOfRegion, usesAfterRegion;
-  bool hasUsesBeforeRegion = false;
-  for (OpOperand &use : target->getUses()) {
-    if (regionOp->isProperAncestor(use.getOwner())) {
-      usesInsideOfRegion.push_back(&use);
-    } else {
-      // Collect only uses that post-dominate the region.
-      if (isAfterRegion(use.getOwner(), regionOp)) {
-        usesAfterRegion.push_back(&use);
-      } else {
-        hasUsesBeforeRegion = true;
-      }
-    }
-  }
+  SmallVector<OpOperand *> usesOutsideOfRegion;
+  for (OpOperand &use : target->getUses())
+    if (!regionOp->isProperAncestor(use.getOwner()))
+      usesOutsideOfRegion.push_back(&use);
 
-  // Clone op into dispatch region.
-  Operation *newTargetOp;
-  if (usesAfterRegion.empty() && !hasUsesBeforeRegion) {
-    // Optimization: If there are not uses outside of the region, we can simply
-    // move the target instead of cloning it.
-    target->moveBefore(&body.front());
-    newTargetOp = target;
-  } else {
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(&body);
-    newTargetOp = rewriter.clone(*target);
-
-    // Replace all uses in the dispatch region.
-    for (OpOperand *use : usesInsideOfRegion) {
-      rewriter.updateRootInPlace(use->getOwner(), [&]() {
-        use->set(newTargetOp->getResult(
-            use->get().cast<OpResult>().getResultNumber()));
-      });
-    }
-  }
+  // Move op into dispatch region.
+  target->moveBefore(&body.front());
 
   // Replace all uses outside of the dispatch region.
-  if (updateUsesOutsideOfRegion && !usesAfterRegion.empty()) {
-    // Fail if there are uses before the dispatch region. In that case it does
-    // usually not make sense to update uses after the region; we can just keep
-    // using the original op result.
-    if (hasUsesBeforeRegion) return failure();
-
+  if (!usesOutsideOfRegion.empty()) {
     unsigned previousNumResults = regionOp->getNumResults();
 
     // Note: Appending results one-by-one here so that this can be extended to
     // specific results in the future. Many ops have just one result, so this
     // should not be a large overhead.
-    for (Value v : newTargetOp->getResults()) {
+    for (Value v : target->getResults()) {
       auto newRegionOp = appendDispatchRegionResult(rewriter, regionOp, v);
       if (failed(newRegionOp)) return failure();
       regionOp = *newRegionOp;
     }
 
     // Replace uses of `target` after the dispatch region.
-    for (OpOperand *use : usesAfterRegion) {
-      assert(DominanceInfo().properlyDominates(regionOp, use->getOwner()) &&
-             "all target uses must be inside or after regionOp");
+    for (OpOperand *use : usesOutsideOfRegion) {
       rewriter.updateRootInPlace(use->getOwner(), [&]() {
         use->set(
             regionOp->getResult(previousNumResults +
@@ -700,26 +700,9 @@
     }
   }
 
-  // Remove the original target if it no longer has any uses.
-  if (target->use_empty()) rewriter.eraseOp(target);
-
   return regionOp;
 }
 
-// Move a `target` op that is preceding the given dispatch region op into the
-// dispatch region. All uses of the target must be inside the region.
-static FailureOr<Flow::DispatchRegionOp> movePrecedingOpIntoDispatchRegion(
-    RewriterBase &rewriter, Operation *target,
-    Flow::DispatchRegionOp regionOp) {
-  assert(llvm::all_of(target->getUses(),
-                      [&](OpOperand &use) {
-                        return regionOp->isProperAncestor(use.getOwner());
-                      }) &&
-         "cannot move target into region");
-  return clonePrecedingOpIntoDispatchRegion(
-      rewriter, target, regionOp, /*updateUsesOutsideOfRegion=*/false);
-}
-
 DiagnosedSilenceableFailure
 transform_dialect::ClonePrecedingOpIntoDispatchRegionOp::apply(
     transform::TransformResults &transformResults,
@@ -753,9 +736,51 @@
   SmallVector<Operation *> orderedTargets =
       llvm::to_vector(llvm::reverse(targetOps));
   IRRewriter rewriter(regionOp->getContext());
+  for (Operation *target : orderedTargets)
+    if (failed(clonePrecedingOpIntoDispatchRegion(rewriter, target, regionOp)))
+      return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+
+  transformResults.set(getTransformed().cast<OpResult>(),
+                       regionOp.getOperation());
+  return DiagnosedSilenceableFailure(success());
+}
+
+DiagnosedSilenceableFailure
+transform_dialect::MovePrecedingOpIntoDispatchRegionOp::apply(
+    transform::TransformResults &transformResults,
+    transform::TransformState &state) {
+  ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+  ArrayRef<Operation *> dispatchRegion =
+      state.getPayloadOps(getDispatchRegion());
+
+  if (targetOps.empty() && dispatchRegion.empty()) {
+    transformResults.set(getResult().cast<OpResult>(),
+                         SmallVector<mlir::Operation *>{});
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  if (dispatchRegion.size() != 1)
+    return DiagnosedSilenceableFailure(this->emitOpError(
+        "requires exactly one target/dispatch region handle"));
+
+  auto regionOp = dyn_cast<Flow::DispatchRegionOp>(dispatchRegion.front());
+  if (!regionOp)
+    return DiagnosedSilenceableFailure(
+        this->emitOpError("expected 'dispatch.region' operand"));
+
+  // We are cloning ops one-by-one, so the order must be inversed (as opposed
+  // to cloning all ops in one go).
+  SmallVector<Operation *> targetOpsList(targetOps.begin(), targetOps.end());
+  bool sortResult = computeTopologicalSorting(
+      dispatchRegion.front()->getBlock(), targetOpsList);
+  (void)sortResult;
+  assert(sortResult && "unable to sort topologically");
+  SmallVector<Operation *> orderedTargets =
+      llvm::to_vector(llvm::reverse(targetOps));
+  IRRewriter rewriter(regionOp->getContext());
   for (Operation *target : orderedTargets) {
-    auto newRegionOp = clonePrecedingOpIntoDispatchRegion(
-        rewriter, target, regionOp, getUpdateUsesOutsideOfRegion());
+    auto newRegionOp =
+        movePrecedingOpIntoDispatchRegion(rewriter, target, regionOp);
     if (failed(newRegionOp))
       return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
     regionOp = *newRegionOp;
@@ -915,8 +940,8 @@
       makeEmptyDispatchRegion(rewriter, target->getLoc());
 
   // Move the target into the dispatch region.
-  auto newRegionOp = clonePrecedingOpIntoDispatchRegion(
-      rewriter, target, regionOp, /*updateUsesOutsideOfRegion=*/true);
+  auto newRegionOp =
+      movePrecedingOpIntoDispatchRegion(rewriter, target, regionOp);
   if (failed(newRegionOp))
     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
index 20b955b..2ec4c0d 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
@@ -111,27 +111,57 @@
     handle must be mapped to exactly one payload op.
 
     All uses of the target inside of the dispatch region are replaced with the
-    results of the cloned op.
-
-    If `update_uses_outside_of_region` is set (default value: `false`), all
-    uses outside of the dispatch region are also replaced: The results of the
-    cloned target op are yielded from the dispatch region and used in all uses
-    outside of the dispatch region. The transform fails if there are uses that
-    appear before the dispatch region.
+    results of the cloned op. Uses of the target outside of the dispatch region
+    remain unchanged.
 
     #### Return modes
 
-    This transform consumes both the `target` handle and the `dispatch_region`
+    This transform reads the `target` handle and consumes the `dispatch_region`
     handle. It produces a new handle to the extended dispatch region.
   }];
 
   let arguments = (ins Arg<PDL_Operation, "",
-                           [TransformMappingRead,
-                            TransformMappingFree]>:$target,
+                           [TransformMappingRead]>:$target,
                        Arg<PDL_Operation, "",
                            [TransformMappingRead,
-                            TransformMappingFree]>:$dispatch_region,
-                       DefaultValuedAttr<BoolAttr, "false">:$update_uses_outside_of_region);
+                            TransformMappingFree]>:$dispatch_region);
+  let results = (outs Res<PDL_Operation, "",
+                           [TransformMappingAlloc,
+                            TransformMappingWrite]>:$transformed);
+  let assemblyFormat = "$target `into` $dispatch_region attr-dict";
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def MovePrecedingOpIntoDispatchRegionOp : Op<
+    Transform_Dialect, "iree.move_preceding_op_into_dispatch_region",
+    [TransformOpInterface]> {
+  let description = [{
+    Move the `target` op into the given dispatch region op. The dispatch region
+    handle must be mapped to exactly one payload op.
+
+    An extra result is added to the dispatch region for every result of the
+    target op. All uses of the target op are replaced with the newly added
+    results of the dispatch region.
+
+    Note: This transform generates invalid IR if there are uses of the target op
+    that appear before (i.e., dominate) the dispatch region.
+
+    #### Return modes
+
+    This transform reads the `target` handle and consumes the `dispatch_region`
+    handle. It produces a new handle to the extended dispatch region.
+  }];
+
+  let arguments = (ins Arg<PDL_Operation, "",
+                           [TransformMappingRead]>:$target,
+                       Arg<PDL_Operation, "",
+                           [TransformMappingRead,
+                            TransformMappingFree]>:$dispatch_region);
   let results = (outs Res<PDL_Operation, "",
                            [TransformMappingAlloc,
                             TransformMappingWrite]>:$transformed);
@@ -152,7 +182,7 @@
     handle must be mapped to exactly one payload op.
 
     All operands of the target are replaced with values that are defined inside
-    of the dispatch region when possible. 
+    of the dispatch region when possible.
 
     If `update_uses_outside_of_region` is set (default value: `true`), all uses
     of the original target op are replaced: The results of the cloned target op
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
index d8a1061..687eeb5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
@@ -84,7 +84,7 @@
     %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
     %dispatch_op = transform.iree.wrap_in_dispatch_region %0
     %1 = transform.structured.match ops{["tensor.extract_slice"]} in %arg1
-    transform.iree.clone_preceding_op_into_dispatch_region %1 into %dispatch_op {update_uses_outside_of_region = true}
+    transform.iree.move_preceding_op_into_dispatch_region %1 into %dispatch_op
   }
 }
 
@@ -116,20 +116,18 @@
 
 // -----
 
-// CHECK-LABEL: func @move_multiple_preceding
+// CHECK-LABEL: func @clone_multiple_preceding
 //   CHECK-DAG:   arith.constant
 //   CHECK-DAG:   arith.constant
 //   CHECK-DAG:   tensor.dim
 //   CHECK-DAG:   tensor.dim
-//  CHECK-NEXT:   "test.dummy_op"
-//  CHECK-NEXT:   "test.third_user"
-//  CHECK-NEXT:   flow.dispatch.region
+//       CHECK:   flow.dispatch.region
 //  CHECK-NEXT:     "test.dummy_op"
 //  CHECK-NEXT:     "test.first_user"
 //  CHECK-NEXT:     "test.second_user"
 //  CHECK-NEXT:     "test.merge1"
 //  CHECK-NEXT:     "test.merge2"
-func.func @move_multiple_preceding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1: index, %s2: index) -> (tensor<?x?xf32>) {
+func.func @clone_multiple_preceding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1: index, %s2: index) -> (tensor<?x?xf32>) {
   %0 = "test.dummy_op"(%arg0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
   %1 = "test.first_user"(%0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
   %2 = "test.second_user"(%0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
diff --git a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
index 8986ab1..7ddaeb7 100644
--- a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
@@ -14,8 +14,8 @@
     // TODO: this could be replaced by a C++ only version.
     // Atm the IR produced is not the same so all pieces do not connect.
     %region_op = transform.iree.wrap_in_dispatch_region %root
-    %region_op_2 = transform.iree.clone_preceding_op_into_dispatch_region %red into %region_op
-    %region_op_3 = transform.iree.clone_preceding_op_into_dispatch_region %fill into %region_op_2
+    %region_op_2 = transform.iree.move_preceding_op_into_dispatch_region %red into %region_op
+    %region_op_3 = transform.iree.move_preceding_op_into_dispatch_region %fill into %region_op_2
     transform.iree.region_to_workgroups %region_op_3
   }
 }