[flow] NFC: Move code for cloning producers into RegionOpUtils (#12349)

This commit prepares those functions to be shared among multiple source
files. It's a preparation step to move CloneProducersIntoDispatchRegions
into its own pass.

Progress towards https://github.com/openxla/iree/issues/12230
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index 79a62b2..bea2ed2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -16,13 +16,11 @@
 #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -43,15 +41,6 @@
 
 #define DEBUG_TYPE "iree-flow-form-dispatch-regions"
 
-// NOTE: These flags are added for experimental purposes only
-// for developer control. These should be treated as internal
-// compiler implementation details.
-static llvm::cl::opt<int> clInlineConstantByteLength(
-    "iree-flow-inline-constants-max-byte-length",
-    llvm::cl::desc("Maximum byte-length of tensor constant that can be inlined "
-                   "into a dispatch region or 0 to disable inlining."),
-    llvm::cl::init(256));
-
 static const char kRootOpAttr[] = "__root_op__";
 static const char kFusionGroupsAttr[] = "__fused_op__";
 
@@ -203,44 +192,6 @@
          isa<LinalgExt::SetEncodingOp, LinalgExt::UnsetEncodingOp>(op);
 }
 
-/// Operations that are cloned into dispatch regions formed with other
-/// operations as roots.
-bool isClonableIntoDispatchOp(Operation *op) {
-  // TODO(#8637): `tensor.collapse_shape` and `tensor.expand_shape` are
-  // trivially clonable too, but they cause problems
-  // with bufferization. Make them clonable when fixed.
-  if (isa<AffineApplyOp, arith::IndexCastOp, linalg::FillOp, tensor::EmptyOp,
-          tensor::CastOp, tensor::ExtractOp, tensor::ExtractSliceOp,
-          tensor::PadOp>(op)) {
-    return true;
-  }
-  if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
-    if (clInlineConstantByteLength == 0) return false;
-    auto constantValueAttr = constantOp.getValue();
-    auto constantType = constantOp.getType();
-    if (constantValueAttr.isa<SplatElementsAttr>()) {
-      return true;
-    } else if (auto denseAttr =
-                   constantValueAttr.dyn_cast<DenseElementsAttr>()) {
-      auto shapedType = constantOp.getType().cast<ShapedType>();
-      uint64_t estimatedByteLength =
-          (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) /
-          8;
-      return denseAttr.isSplat() ||
-             estimatedByteLength <= clInlineConstantByteLength;
-    } else if (constantType.isIntOrIndexOrFloat()) {
-      return true;
-    }
-  }
-  if (llvm::all_of(op->getOperands(),
-                   [&](Value v) { return v.getType().isIntOrFloat(); }) &&
-      llvm::all_of(op->getResults(),
-                   [&](Value v) { return v.getType().isIntOrFloat(); })) {
-    return true;
-  }
-  return false;
-}
-
 //===----------------------------------------------------------------------===//
 // Methods for getting the workload information for dispatch region creation.
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h
index f39035f..b8d917e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h
@@ -33,13 +33,6 @@
 namespace IREE {
 namespace Flow {
 
-/// A heuristic that decides which ops should be cloned and fused into a
-/// dispatch region.
-///
-/// Note: This function returns `false` for ops that should be tiled and fused
-/// into a dispatch region.
-bool isClonableIntoDispatchOp(Operation *op);
-
 /// Computes the workload and provides a workload region builder for the given
 /// root op.
 FailureOr<Flow::WorkloadBuilder> getWorkloadBuilder(OpBuilder &builder,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
index 3849874..3e621b2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
@@ -29,8 +29,6 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/RegionUtils.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
 
 #define DEBUG_TYPE "iree-flow-form-dispatch-workgroups"
 
@@ -39,86 +37,6 @@
 namespace IREE {
 namespace Flow {
 
-//===---------------------------------------------------------------------===//
-// Methods to legalize a dispatch region op, i.e. make it isolated from above.
-//===---------------------------------------------------------------------===//
-
-/// Checks if the `Value` has a use within the dispatch that is unfusable.
-static bool hasUnfusableUseInDispatch(Value v, Operation *dispatchOp) {
-  for (OpOperand &use : v.getUses()) {
-    Operation *user = use.getOwner();
-    Operation *ownerWorkgroupsOp =
-        user->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>();
-    Operation *ownerRegionOp =
-        user->getParentOfType<IREE::Flow::DispatchRegionOp>();
-    Operation *owner = ownerWorkgroupsOp ? ownerWorkgroupsOp : ownerRegionOp;
-
-    // Ignore uses outside of dispatch workgroups op.
-    if (owner != dispatchOp) continue;
-
-    // Cannot fuse producer of `dest` with `tensor.insert_slice`.
-    if (auto insertSliceUser = dyn_cast<tensor::InsertSliceOp>(user)) {
-      if (insertSliceUser.getDest() == v) return true;
-    }
-  }
-  return false;
-}
-
-/// Collect all ops that should be cloned into the given dispatch region op.
-static SmallVector<Operation *> getCloneableOps(
-    Flow::DispatchRegionOp regionOp) {
-  // Find values that are used inside of the dispatch region but defined outside
-  // of the dispatch region.
-  llvm::SetVector<Value> valuesDefinedAbove;
-  mlir::getUsedValuesDefinedAbove(regionOp.getBody(), valuesDefinedAbove);
-  if (valuesDefinedAbove.empty()) return {};
-
-  // Traverse the defining ops of these values (and the ops on their reverse
-  // SSA use-def chain).
-  SmallVector<Operation *> result;
-  llvm::SetVector<Value> visited;
-  SmallVector<Value, 4> worklist;
-  worklist.assign(valuesDefinedAbove.begin(), valuesDefinedAbove.end());
-  while (!worklist.empty()) {
-    Value outsideValue = worklist.pop_back_val();
-    // Skip values that were already visited.
-    if (visited.count(outsideValue)) continue;
-    visited.insert(outsideValue);
-
-    Operation *definingOp = outsideValue.getDefiningOp();
-    if (!definingOp || !isClonableIntoDispatchOp(definingOp) ||
-        hasUnfusableUseInDispatch(outsideValue, regionOp)) {
-      valuesDefinedAbove.insert(outsideValue);
-      continue;
-    }
-    result.push_back(definingOp);
-    worklist.append(definingOp->operand_begin(), definingOp->operand_end());
-    llvm::SetVector<Value> nestedValues;
-    mlir::getUsedValuesDefinedAbove(definingOp->getRegions(), nestedValues);
-    worklist.append(nestedValues.begin(), nestedValues.end());
-  }
-
-  return result;
-}
-
-/// Clone producers into the dispatch region.
-static LogicalResult cloneProducers(RewriterBase &rewriter,
-                                    Flow::DispatchRegionOp regionOp) {
-  SmallVector<Operation *> cloneableOps = getCloneableOps(regionOp);
-  bool sortResult = mlir::computeTopologicalSorting(cloneableOps);
-  (void)sortResult;
-  assert(sortResult && "could not compute topological sorting");
-
-  for (Operation *producer : llvm::reverse(cloneableOps)) {
-    if (failed(
-            clonePrecedingOpIntoDispatchRegion(rewriter, producer, regionOp))) {
-      return failure();
-    }
-  }
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Dispatch workgroups formation
 //===----------------------------------------------------------------------===//
@@ -136,7 +54,7 @@
   // Clone additional producers and rewrite to DispatchWorkgroupsOp.
   SmallVector<Flow::DispatchWorkgroupsOp> result;
   for (auto regionOp : regionOps) {
-    if (failed(cloneProducers(rewriter, regionOp))) return failure();
+    if (failed(cloneProducersToRegion(rewriter, regionOp))) return failure();
     auto maybeWorkgroupOp =
         rewriteFlowDispatchRegionToFlowDispatchWorkgroups(regionOp, rewriter);
     if (failed(maybeWorkgroupOp)) return failure();
@@ -175,7 +93,7 @@
   // Wrap operation.
   auto regionOp = Flow::wrapOpInDispatchRegion(rewriter, op, workloadBuilder);
   if (failed(regionOp)) return failure();
-  if (failed(cloneProducers(rewriter, *regionOp))) return failure();
+  if (failed(cloneProducersToRegion(rewriter, *regionOp))) return failure();
   auto workgroupsOp = Flow::rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
       *regionOp, rewriter);
   if (failed(workgroupsOp)) return failure();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 5d67ae1..2b8ba21 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -6,14 +6,31 @@
 
 #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
 
+#include <iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h>
+
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "mlir/Transforms/TopologicalSortUtils.h"
+
+// NOTE: These flags are added for experimental purposes only
+// for developer control. These should be treated as internal
+// compiler implementation details.
+static llvm::cl::opt<int> clInlineConstantByteLength(
+    "iree-flow-inline-constants-max-byte-length",
+    llvm::cl::desc("Maximum byte-length of tensor constant that can be inlined "
+                   "into a dispatch region or 0 to disable inlining."),
+    llvm::cl::init(256));
 
 using namespace mlir;
 using namespace mlir::iree_compiler;
@@ -328,3 +345,120 @@
 
   return newRegionOp;
 }
+//===---------------------------------------------------------------------===//
+// Utilities to make a dispatch region isolated from above
+//===---------------------------------------------------------------------===//
+
+/// Operations that are cloned into dispatch regions formed with other
+/// operations as roots.
+bool Flow::isClonableIntoDispatchOp(Operation *op) {
+  // TODO(#8637): `tensor.collapse_shape` and `tensor.expand_shape` are
+  // trivially clonable too, but they cause problems
+  // with bufferization. Make them clonable when fixed.
+  if (isa<AffineApplyOp, arith::IndexCastOp, linalg::FillOp, tensor::EmptyOp,
+          tensor::CastOp, tensor::ExtractOp, tensor::ExtractSliceOp,
+          tensor::PadOp>(op)) {
+    return true;
+  }
+  if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
+    if (clInlineConstantByteLength == 0) return false;
+    auto constantValueAttr = constantOp.getValue();
+    auto constantType = constantOp.getType();
+    if (constantValueAttr.isa<SplatElementsAttr>()) {
+      return true;
+    } else if (auto denseAttr =
+                   constantValueAttr.dyn_cast<DenseElementsAttr>()) {
+      auto shapedType = constantOp.getType().cast<ShapedType>();
+      uint64_t estimatedByteLength =
+          (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) /
+          8;
+      return denseAttr.isSplat() ||
+             estimatedByteLength <= clInlineConstantByteLength;
+    } else if (constantType.isIntOrIndexOrFloat()) {
+      return true;
+    }
+  }
+  if (llvm::all_of(op->getOperands(),
+                   [&](Value v) { return v.getType().isIntOrFloat(); }) &&
+      llvm::all_of(op->getResults(),
+                   [&](Value v) { return v.getType().isIntOrFloat(); })) {
+    return true;
+  }
+  return false;
+}
+
+/// Checks if the `Value` has a use within the dispatch that is unfusable.
+static bool hasUnfusableUseInDispatch(Value v, Operation *dispatchOp) {
+  for (OpOperand &use : v.getUses()) {
+    Operation *user = use.getOwner();
+    Operation *ownerWorkgroupsOp =
+        user->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>();
+    Operation *ownerRegionOp =
+        user->getParentOfType<IREE::Flow::DispatchRegionOp>();
+    Operation *owner = ownerWorkgroupsOp ? ownerWorkgroupsOp : ownerRegionOp;
+
+    // Ignore uses outside of dispatch workgroups op.
+    if (owner != dispatchOp) continue;
+
+    // Cannot fuse producer of `dest` with `tensor.insert_slice`.
+    if (auto insertSliceUser = dyn_cast<tensor::InsertSliceOp>(user)) {
+      if (insertSliceUser.getDest() == v) return true;
+    }
+  }
+  return false;
+}
+
+/// Collect all ops that should be cloned into the given dispatch region op.
+static SmallVector<Operation *> getCloneableOps(
+    Flow::DispatchRegionOp regionOp) {
+  // Find values that are used inside of the dispatch region but defined outside
+  // of the dispatch region.
+  llvm::SetVector<Value> valuesDefinedAbove;
+  mlir::getUsedValuesDefinedAbove(regionOp.getBody(), valuesDefinedAbove);
+  if (valuesDefinedAbove.empty()) return {};
+
+  // Traverse the defining ops of these values (and the ops on their reverse
+  // SSA use-def chain).
+  SmallVector<Operation *> result;
+  llvm::SetVector<Value> visited;
+  SmallVector<Value, 4> worklist;
+  worklist.assign(valuesDefinedAbove.begin(), valuesDefinedAbove.end());
+  while (!worklist.empty()) {
+    Value outsideValue = worklist.pop_back_val();
+    // Skip values that were already visited.
+    if (visited.count(outsideValue)) continue;
+    visited.insert(outsideValue);
+
+    Operation *definingOp = outsideValue.getDefiningOp();
+    if (!definingOp || !Flow::isClonableIntoDispatchOp(definingOp) ||
+        hasUnfusableUseInDispatch(outsideValue, regionOp)) {
+      valuesDefinedAbove.insert(outsideValue);
+      continue;
+    }
+    result.push_back(definingOp);
+    worklist.append(definingOp->operand_begin(), definingOp->operand_end());
+    llvm::SetVector<Value> nestedValues;
+    mlir::getUsedValuesDefinedAbove(definingOp->getRegions(), nestedValues);
+    worklist.append(nestedValues.begin(), nestedValues.end());
+  }
+
+  return result;
+}
+
+/// Clone producers into the dispatch region.
+LogicalResult Flow::cloneProducersToRegion(RewriterBase &rewriter,
+                                           Flow::DispatchRegionOp regionOp) {
+  SmallVector<Operation *> cloneableOps = getCloneableOps(regionOp);
+  bool sortResult = mlir::computeTopologicalSorting(cloneableOps);
+  (void)sortResult;
+  assert(sortResult && "could not compute topological sorting");
+
+  for (Operation *producer : llvm::reverse(cloneableOps)) {
+    if (failed(
+            clonePrecedingOpIntoDispatchRegion(rewriter, producer, regionOp))) {
+      return failure();
+    }
+  }
+
+  return success();
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
index a7cd075..1758724 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
@@ -7,7 +7,6 @@
 #define IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_REGIONOPUTILS_H_
 
 #include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
-#include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Support/LogicalResult.h"
 
@@ -84,6 +83,18 @@
     RewriterBase &rewriter, Operation *op,
     Optional<Flow::WorkloadBuilder> workloadBuilder = std::nullopt);
 
+/// Decide whether the given op should be cloned and fused into a dispatch
+/// region using heuristics.
+///
+/// Note: This function returns `false` for ops that should be tiled and fused
+/// into a dispatch region.
+bool isClonableIntoDispatchOp(Operation *op);
+
+/// Clone into the region producers of those value used in the region but
+/// defined above, to prepare the dispatch region isolated from above.
+LogicalResult cloneProducersToRegion(RewriterBase &rewriter,
+                                     Flow::DispatchRegionOp regionOp);
+
 }  // namespace Flow
 }  // namespace IREE
 }  // namespace iree_compiler