[flow] NFC: Move code for cloning producers into RegionOpUtils (#12349)
This commit prepares those functions to be shared among multiple source
files. It's a preparation step to move CloneProducersIntoDispatchRegions
into its own pass.
Progress towards https://github.com/openxla/iree/issues/12230
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index 79a62b2..bea2ed2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -16,13 +16,11 @@
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -43,15 +41,6 @@
#define DEBUG_TYPE "iree-flow-form-dispatch-regions"
-// NOTE: These flags are added for experimental purposes only
-// for developer control. These should be treated as internal
-// compiler implementation details.
-static llvm::cl::opt<int> clInlineConstantByteLength(
- "iree-flow-inline-constants-max-byte-length",
- llvm::cl::desc("Maximum byte-length of tensor constant that can be inlined "
- "into a dispatch region or 0 to disable inlining."),
- llvm::cl::init(256));
-
static const char kRootOpAttr[] = "__root_op__";
static const char kFusionGroupsAttr[] = "__fused_op__";
@@ -203,44 +192,6 @@
isa<LinalgExt::SetEncodingOp, LinalgExt::UnsetEncodingOp>(op);
}
-/// Operations that are cloned into dispatch regions formed with other
-/// operations as roots.
-bool isClonableIntoDispatchOp(Operation *op) {
- // TODO(#8637): `tensor.collapse_shape` and `tensor.expand_shape` are
- // trivially clonable too, but they cause problems
- // with bufferization. Make them clonable when fixed.
- if (isa<AffineApplyOp, arith::IndexCastOp, linalg::FillOp, tensor::EmptyOp,
- tensor::CastOp, tensor::ExtractOp, tensor::ExtractSliceOp,
- tensor::PadOp>(op)) {
- return true;
- }
- if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
- if (clInlineConstantByteLength == 0) return false;
- auto constantValueAttr = constantOp.getValue();
- auto constantType = constantOp.getType();
- if (constantValueAttr.isa<SplatElementsAttr>()) {
- return true;
- } else if (auto denseAttr =
- constantValueAttr.dyn_cast<DenseElementsAttr>()) {
- auto shapedType = constantOp.getType().cast<ShapedType>();
- uint64_t estimatedByteLength =
- (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) /
- 8;
- return denseAttr.isSplat() ||
- estimatedByteLength <= clInlineConstantByteLength;
- } else if (constantType.isIntOrIndexOrFloat()) {
- return true;
- }
- }
- if (llvm::all_of(op->getOperands(),
- [&](Value v) { return v.getType().isIntOrFloat(); }) &&
- llvm::all_of(op->getResults(),
- [&](Value v) { return v.getType().isIntOrFloat(); })) {
- return true;
- }
- return false;
-}
-
//===----------------------------------------------------------------------===//
// Methods for getting the workload information for dispatch region creation.
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h
index f39035f..b8d917e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h
@@ -33,13 +33,6 @@
namespace IREE {
namespace Flow {
-/// A heuristic that decides which ops should be cloned and fused into a
-/// dispatch region.
-///
-/// Note: This function returns `false` for ops that should be tiled and fused
-/// into a dispatch region.
-bool isClonableIntoDispatchOp(Operation *op);
-
/// Computes the workload and provides a workload region builder for the given
/// root op.
FailureOr<Flow::WorkloadBuilder> getWorkloadBuilder(OpBuilder &builder,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
index 3849874..3e621b2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
@@ -29,8 +29,6 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/RegionUtils.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
#define DEBUG_TYPE "iree-flow-form-dispatch-workgroups"
@@ -39,86 +37,6 @@
namespace IREE {
namespace Flow {
-//===---------------------------------------------------------------------===//
-// Methods to legalize a dispatch region op, i.e. make it isolated from above.
-//===---------------------------------------------------------------------===//
-
-/// Checks if the `Value` has a use within the dispatch that is unfusable.
-static bool hasUnfusableUseInDispatch(Value v, Operation *dispatchOp) {
- for (OpOperand &use : v.getUses()) {
- Operation *user = use.getOwner();
- Operation *ownerWorkgroupsOp =
- user->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>();
- Operation *ownerRegionOp =
- user->getParentOfType<IREE::Flow::DispatchRegionOp>();
- Operation *owner = ownerWorkgroupsOp ? ownerWorkgroupsOp : ownerRegionOp;
-
- // Ignore uses outside of dispatch workgroups op.
- if (owner != dispatchOp) continue;
-
- // Cannot fuse producer of `dest` with `tensor.insert_slice`.
- if (auto insertSliceUser = dyn_cast<tensor::InsertSliceOp>(user)) {
- if (insertSliceUser.getDest() == v) return true;
- }
- }
- return false;
-}
-
-/// Collect all ops that should be cloned into the given dispatch region op.
-static SmallVector<Operation *> getCloneableOps(
- Flow::DispatchRegionOp regionOp) {
- // Find values that are used inside of the dispatch region but defined outside
- // of the dispatch region.
- llvm::SetVector<Value> valuesDefinedAbove;
- mlir::getUsedValuesDefinedAbove(regionOp.getBody(), valuesDefinedAbove);
- if (valuesDefinedAbove.empty()) return {};
-
- // Traverse the defining ops of these values (and the ops on their reverse
- // SSA use-def chain).
- SmallVector<Operation *> result;
- llvm::SetVector<Value> visited;
- SmallVector<Value, 4> worklist;
- worklist.assign(valuesDefinedAbove.begin(), valuesDefinedAbove.end());
- while (!worklist.empty()) {
- Value outsideValue = worklist.pop_back_val();
- // Skip values that were already visited.
- if (visited.count(outsideValue)) continue;
- visited.insert(outsideValue);
-
- Operation *definingOp = outsideValue.getDefiningOp();
- if (!definingOp || !isClonableIntoDispatchOp(definingOp) ||
- hasUnfusableUseInDispatch(outsideValue, regionOp)) {
- valuesDefinedAbove.insert(outsideValue);
- continue;
- }
- result.push_back(definingOp);
- worklist.append(definingOp->operand_begin(), definingOp->operand_end());
- llvm::SetVector<Value> nestedValues;
- mlir::getUsedValuesDefinedAbove(definingOp->getRegions(), nestedValues);
- worklist.append(nestedValues.begin(), nestedValues.end());
- }
-
- return result;
-}
-
-/// Clone producers into the dispatch region.
-static LogicalResult cloneProducers(RewriterBase &rewriter,
- Flow::DispatchRegionOp regionOp) {
- SmallVector<Operation *> cloneableOps = getCloneableOps(regionOp);
- bool sortResult = mlir::computeTopologicalSorting(cloneableOps);
- (void)sortResult;
- assert(sortResult && "could not compute topological sorting");
-
- for (Operation *producer : llvm::reverse(cloneableOps)) {
- if (failed(
- clonePrecedingOpIntoDispatchRegion(rewriter, producer, regionOp))) {
- return failure();
- }
- }
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Dispatch workgroups formation
//===----------------------------------------------------------------------===//
@@ -136,7 +54,7 @@
// Clone additional producers and rewrite to DispatchWorkgroupsOp.
SmallVector<Flow::DispatchWorkgroupsOp> result;
for (auto regionOp : regionOps) {
- if (failed(cloneProducers(rewriter, regionOp))) return failure();
+ if (failed(cloneProducersToRegion(rewriter, regionOp))) return failure();
auto maybeWorkgroupOp =
rewriteFlowDispatchRegionToFlowDispatchWorkgroups(regionOp, rewriter);
if (failed(maybeWorkgroupOp)) return failure();
@@ -175,7 +93,7 @@
// Wrap operation.
auto regionOp = Flow::wrapOpInDispatchRegion(rewriter, op, workloadBuilder);
if (failed(regionOp)) return failure();
- if (failed(cloneProducers(rewriter, *regionOp))) return failure();
+ if (failed(cloneProducersToRegion(rewriter, *regionOp))) return failure();
auto workgroupsOp = Flow::rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
*regionOp, rewriter);
if (failed(workgroupsOp)) return failure();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 5d67ae1..2b8ba21 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -6,14 +6,31 @@
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include <iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h>
+
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "mlir/Transforms/TopologicalSortUtils.h"
+
+// NOTE: These flags are added for experimental purposes only
+// for developer control. These should be treated as internal
+// compiler implementation details.
+static llvm::cl::opt<int> clInlineConstantByteLength(
+ "iree-flow-inline-constants-max-byte-length",
+ llvm::cl::desc("Maximum byte-length of tensor constant that can be inlined "
+ "into a dispatch region or 0 to disable inlining."),
+ llvm::cl::init(256));
using namespace mlir;
using namespace mlir::iree_compiler;
@@ -328,3 +345,120 @@
return newRegionOp;
}
+//===---------------------------------------------------------------------===//
+// Utilities to make a dispatch region isolated from above
+//===---------------------------------------------------------------------===//
+
+/// Operations that are cloned into dispatch regions formed with other
+/// operations as roots.
+bool Flow::isClonableIntoDispatchOp(Operation *op) {
+ // TODO(#8637): `tensor.collapse_shape` and `tensor.expand_shape` are
+ // trivially clonable too, but they cause problems
+ // with bufferization. Make them clonable when fixed.
+ if (isa<AffineApplyOp, arith::IndexCastOp, linalg::FillOp, tensor::EmptyOp,
+ tensor::CastOp, tensor::ExtractOp, tensor::ExtractSliceOp,
+ tensor::PadOp>(op)) {
+ return true;
+ }
+ if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
+ if (clInlineConstantByteLength == 0) return false;
+ auto constantValueAttr = constantOp.getValue();
+ auto constantType = constantOp.getType();
+ if (constantValueAttr.isa<SplatElementsAttr>()) {
+ return true;
+ } else if (auto denseAttr =
+ constantValueAttr.dyn_cast<DenseElementsAttr>()) {
+ auto shapedType = constantOp.getType().cast<ShapedType>();
+ uint64_t estimatedByteLength =
+ (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) /
+ 8;
+ return denseAttr.isSplat() ||
+ estimatedByteLength <= clInlineConstantByteLength;
+ } else if (constantType.isIntOrIndexOrFloat()) {
+ return true;
+ }
+ }
+ if (llvm::all_of(op->getOperands(),
+ [&](Value v) { return v.getType().isIntOrFloat(); }) &&
+ llvm::all_of(op->getResults(),
+ [&](Value v) { return v.getType().isIntOrFloat(); })) {
+ return true;
+ }
+ return false;
+}
+
+/// Checks if the `Value` has a use within the dispatch that is unfusable.
+static bool hasUnfusableUseInDispatch(Value v, Operation *dispatchOp) {
+ for (OpOperand &use : v.getUses()) {
+ Operation *user = use.getOwner();
+ Operation *ownerWorkgroupsOp =
+ user->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>();
+ Operation *ownerRegionOp =
+ user->getParentOfType<IREE::Flow::DispatchRegionOp>();
+ Operation *owner = ownerWorkgroupsOp ? ownerWorkgroupsOp : ownerRegionOp;
+
+ // Ignore uses outside of dispatch workgroups op.
+ if (owner != dispatchOp) continue;
+
+ // Cannot fuse producer of `dest` with `tensor.insert_slice`.
+ if (auto insertSliceUser = dyn_cast<tensor::InsertSliceOp>(user)) {
+ if (insertSliceUser.getDest() == v) return true;
+ }
+ }
+ return false;
+}
+
+/// Collect all ops that should be cloned into the given dispatch region op.
+static SmallVector<Operation *> getCloneableOps(
+ Flow::DispatchRegionOp regionOp) {
+ // Find values that are used inside of the dispatch region but defined outside
+ // of the dispatch region.
+ llvm::SetVector<Value> valuesDefinedAbove;
+ mlir::getUsedValuesDefinedAbove(regionOp.getBody(), valuesDefinedAbove);
+ if (valuesDefinedAbove.empty()) return {};
+
+ // Traverse the defining ops of these values (and the ops on their reverse
+ // SSA use-def chain).
+ SmallVector<Operation *> result;
+ llvm::SetVector<Value> visited;
+ SmallVector<Value, 4> worklist;
+ worklist.assign(valuesDefinedAbove.begin(), valuesDefinedAbove.end());
+ while (!worklist.empty()) {
+ Value outsideValue = worklist.pop_back_val();
+ // Skip values that were already visited.
+ if (visited.count(outsideValue)) continue;
+ visited.insert(outsideValue);
+
+ Operation *definingOp = outsideValue.getDefiningOp();
+ if (!definingOp || !Flow::isClonableIntoDispatchOp(definingOp) ||
+ hasUnfusableUseInDispatch(outsideValue, regionOp)) {
+ valuesDefinedAbove.insert(outsideValue);
+ continue;
+ }
+ result.push_back(definingOp);
+ worklist.append(definingOp->operand_begin(), definingOp->operand_end());
+ llvm::SetVector<Value> nestedValues;
+ mlir::getUsedValuesDefinedAbove(definingOp->getRegions(), nestedValues);
+ worklist.append(nestedValues.begin(), nestedValues.end());
+ }
+
+ return result;
+}
+
+/// Clone producers into the dispatch region.
+LogicalResult Flow::cloneProducersToRegion(RewriterBase &rewriter,
+ Flow::DispatchRegionOp regionOp) {
+ SmallVector<Operation *> cloneableOps = getCloneableOps(regionOp);
+ bool sortResult = mlir::computeTopologicalSorting(cloneableOps);
+ (void)sortResult;
+ assert(sortResult && "could not compute topological sorting");
+
+ for (Operation *producer : llvm::reverse(cloneableOps)) {
+ if (failed(
+ clonePrecedingOpIntoDispatchRegion(rewriter, producer, regionOp))) {
+ return failure();
+ }
+ }
+
+ return success();
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
index a7cd075..1758724 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
@@ -7,7 +7,6 @@
#define IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_REGIONOPUTILS_H_
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
-#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Support/LogicalResult.h"
@@ -84,6 +83,18 @@
RewriterBase &rewriter, Operation *op,
Optional<Flow::WorkloadBuilder> workloadBuilder = std::nullopt);
+/// Decide whether the given op should be cloned and fused into a dispatch
+/// region using heuristics.
+///
+/// Note: This function returns `false` for ops that should be tiled and fused
+/// into a dispatch region.
+bool isClonableIntoDispatchOp(Operation *op);
+
+/// Clone into the region producers of those value used in the region but
+/// defined above, to prepare the dispatch region isolated from above.
+LogicalResult cloneProducersToRegion(RewriterBase &rewriter,
+ Flow::DispatchRegionOp regionOp);
+
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler