[NFC] Centralizing dispatch region constant inlining. (#12235)
This disables tensor constant inlining during closure optimization and
makes it so that dispatch region formation only inlines tensor values.
This avoids some issues with constant tensor access in #12233.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 69e39a2..b7d26bd 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -196,8 +196,12 @@
void DispatchWorkgroupsOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
+ // Disable constant inlining as we have done it during dispatch region
+ // formation.
+ IREE::Util::ClosureOptimizationOptions closureOptions;
+ closureOptions.maxInlinedConstantBytes = 0;
results.insert<IREE::Util::ClosureOptimizationPattern<DispatchWorkgroupsOp>>(
- context);
+ context, closureOptions);
results.insert<ReplaceDispatchResultIfZeroElements>(context);
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 46596ff..318521a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -960,9 +960,14 @@
return getResults();
}
-// Inline operations that the dispatch region can handle natively.
-static bool canDispatchRegionContainOp(Operation *op) {
- // Inline constant operations that are splat or small constants.
+bool DispatchWorkgroupsOp::canClosureContainOp(Operation *op) {
+ // For now we only allow constants; we could bring other ops across the
+ // boundary though if we want (particularly metadata ops).
+ // Note that the closure optimization may still filter out the constant op if
+ // it's not configured to inline constants of certain types/sizes.
+ // TODO(#12233): this should just be isa<ConstantOp> but today we need to
+ // ensure that we don't mess with tensors after dispatch region formation due
+ // to requirements around tensor access checking.
if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
auto constantType = constantOp.getType();
if (constantType.isIntOrIndexOrFloat()) {
@@ -972,10 +977,6 @@
return false;
}
-bool DispatchWorkgroupsOp::canClosureContainOp(Operation *op) {
- return canDispatchRegionContainOp(op);
-}
-
// Refines the tensor access from what is declared on |type| based on actual
// usage. We expect that the access was set correctly to begin with but today
// we sometimes specify things too wide.
@@ -985,15 +986,20 @@
// If the argument is a result with `readwrite` access, return false if the
// value is only written to. Check this by looking at the uses of the
// argument being only the `target` of `flow.dispatch.tensor.store` ops.
- bool onlyWrites = true;
+ bool hasReads = false;
+ bool hasWrites = false;
for (OpOperand &uses : value.getUses()) {
- auto storeOp = dyn_cast<DispatchTensorStoreOp>(uses.getOwner());
- if (!(storeOp && storeOp.getTarget() == uses.get())) {
- onlyWrites = false;
- break;
- }
+ TypeSwitch<Operation *>(uses.getOwner())
+ .Case<DispatchTensorLoadOp>([&](auto loadOp) { hasReads = true; })
+ .Case<DispatchTensorStoreOp>([&](auto storeOp) { hasWrites = true; })
+ .Default([&](auto op) {
+ // Treat unknown ops conservatively as read/write.
+ hasReads = true;
+ hasWrites = true;
+ });
}
- if (onlyWrites) tensorAccess = TensorAccess::WriteOnly;
+ if (hasReads && !hasWrites) tensorAccess = TensorAccess::ReadOnly;
+ if (!hasReads && hasWrites) tensorAccess = TensorAccess::WriteOnly;
}
return tensorAccess;
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir
index 648bd88..8e4f5e4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir
@@ -1,59 +1,5 @@
// RUN: iree-opt --allow-unregistered-dialect --split-input-file --canonicalize %s | iree-opt --allow-unregistered-dialect --split-input-file | FileCheck %s
-// CHECK-LABEL: @inlineWithTiedResults1
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x4xf32>)
-func.func @inlineWithTiedResults1(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
- // CHECK-NOT: constant 128
- %cst = arith.constant 128 : index
- // CHECK-DAG: %[[X:.+]] = arith.constant 100
- %x = arith.constant 100 : index
- // CHECK-DAG: %[[Y:.+]] = arith.constant 50
- %y = arith.constant 50 : index
- // CHECK: flow.dispatch.workgroups[%[[X]], %[[Y]]](%[[ARG0]]) : (tensor<1x4xf32>) -> %[[ARG0]] =
- // CHECK-NEXT: (%[[ARG0_INNER:.+]]: !flow.dispatch.tensor<readwrite:tensor<1x4xf32>>)
- %0 = flow.dispatch.workgroups[%x, %y](%cst, %arg0) : (index, tensor<1x4xf32>) -> %arg0 = (
- %cst_capture: index,
- %arg0_capture: !flow.dispatch.tensor<readwrite:tensor<1x4xf32>>
- ) {
- // CHECK: %[[INLINED_CST:.+]] = arith.constant 128 : index
- // CHECK-NEXT: "test.sink"(%[[INLINED_CST]])
- "test.sink"(%cst_capture) : (index) -> ()
- // CHECK-NEXT: "test.sink"(%[[ARG0_INNER]])
- "test.sink"(%arg0_capture) : (!flow.dispatch.tensor<readwrite:tensor<1x4xf32>>) -> ()
- flow.return
- }
- return %0 : tensor<1x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @inlineWithTiedResults2
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x4xf32>)
-func.func @inlineWithTiedResults2(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
- // CHECK-NOT: constant 128
- %cst = arith.constant 128 : index
- // CHECK-DAG: %[[X:.+]] = arith.constant 100
- %x = arith.constant 100 : index
- // CHECK-DAG: %[[Y:.+]] = arith.constant 50
- %y = arith.constant 50 : index
- // CHECK: flow.dispatch.workgroups[%[[X]], %[[Y]]](%[[ARG0]]) : (tensor<1x4xf32>) -> %[[ARG0]] =
- // CHECK-NEXT: (%[[ARG0_INNER:.+]]: !flow.dispatch.tensor<readwrite:tensor<1x4xf32>>)
- %0 = flow.dispatch.workgroups[%x, %y](%arg0, %cst) : (tensor<1x4xf32>, index) -> %arg0 = (
- %arg0_capture: !flow.dispatch.tensor<readwrite:tensor<1x4xf32>>,
- %cst_capture: index
- ) {
- // CHECK: %[[INLINED_CST:.+]] = arith.constant 128 : index
- // CHECK-NEXT: "test.sink"(%[[INLINED_CST]])
- "test.sink"(%cst_capture) : (index) -> ()
- // CHECK-NEXT: "test.sink"(%[[ARG0_INNER]])
- "test.sink"(%arg0_capture) : (!flow.dispatch.tensor<readwrite:tensor<1x4xf32>>) -> ()
- flow.return
- }
- return %0 : tensor<1x4xf32>
-}
-
-// -----
-
// CHECK-LABEL: func.func @dontInlineReadWrite
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x4xf32>)
func.func @dontInlineReadWrite(%arg0: tensor<1x4xf32>) -> tensor<4x8xf32> {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index a5af4f6..354b7b6 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -48,8 +48,8 @@
// compiler implementation details.
static llvm::cl::opt<int> clInlineConstantByteLength(
"iree-flow-inline-constants-max-byte-length",
- llvm::cl::desc("Maximum byte-length of constant that can be inlined into a "
- "dispatch region"),
+ llvm::cl::desc("Maximum byte-length of tensor constant that can be inlined "
+ "into a dispatch region or 0 to disable inlining."),
llvm::cl::init(256));
static const char kRootOpAttr[] = "__root_op__";
@@ -214,6 +214,7 @@
return true;
}
if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
+ if (clInlineConstantByteLength == 0) return false;
auto constantValueAttr = constantOp.getValue();
auto constantType = constantOp.getType();
if (constantValueAttr.isa<SplatElementsAttr>()) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
index 1185bdd..3849874 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
@@ -93,6 +93,9 @@
}
result.push_back(definingOp);
worklist.append(definingOp->operand_begin(), definingOp->operand_end());
+ llvm::SetVector<Value> nestedValues;
+ mlir::getUsedValuesDefinedAbove(definingOp->getRegions(), nestedValues);
+ worklist.append(nestedValues.begin(), nestedValues.end());
}
return result;
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 9fcf0c6..418a965 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -437,7 +437,7 @@
// CHECK-DAG: %[[ARG1_D0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[ARG1_D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups[%[[ARG0_D0]], %[[ARG0_D1]]]
-// CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]],
+// CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]],
// CHECK-SAME: %[[ARG0_D0]], %[[ARG0_D1]], %[[ARG1_D0]], %[[ARG1_D1]])
// CHECK-SAME: tensor<?x?xf32>{%[[ARG0_D0]], %[[ARG0_D1]]}
// CHECK-SAME: tensor<?x?xf32>{%[[ARG1_D0]], %[[ARG1_D1]]}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
index 534c659..cdf9507 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
@@ -107,7 +107,8 @@
}
}
-// Returns true if |constantOp| represents a (logically) small constant value.
+// Returns true if |constantOp| represents a (logically) small constant value
+// that can be inlined into a closure.
//
// "Small" is relative and there's a risk that we'll bloat the closures by
// duplicating a bunch of constants however what we are able to save by not
@@ -119,13 +120,14 @@
// This is also still at a fairly high level (flow dialect): once the closures
// are expanded out in lower dialects things like CSE have a chance to once
// again get at the constants and dedupe them if they survive.
-static bool isConstantSmall(arith::ConstantOp constantOp) {
- // We could tune this/take it as a configuration setting.
- // The current value is chosen based on what is known to be reasonable to
- // inline into command buffers way down in the HAL, which is not great but at
- // least better than either allocating independent buffers for 4 byte
- // constants or inlining megabytes.
- static constexpr int kMaxInlinedConstantBytes = 256;
+static bool isConstantInlinable(const ClosureOptimizationOptions &options,
+ arith::ConstantOp constantOp) {
+ int64_t maxInlinedConstantBytes =
+ options.maxInlinedConstantBytes.value_or(INT64_MAX);
+ if (maxInlinedConstantBytes == 0) {
+ // Inlining of constants disabled.
+ return false;
+ }
auto constantValueAttr = constantOp.getValue();
auto constantType = constantOp.getType();
@@ -141,7 +143,7 @@
shapedType.getNumElements() *
getRoundedElementByteWidth(shapedType.getElementType());
return denseAttr.isSplat() ||
- estimatedByteLength <= kMaxInlinedConstantBytes;
+ estimatedByteLength <= maxInlinedConstantBytes;
} else if (constantType.isIntOrIndexOrFloat()) {
// Primitives can always go in.
return true;
@@ -155,11 +157,12 @@
// trees is hard and it'd be better to model that differently such as by having
// a wrapper region for immutable blobs that can be inlined that this then
// returns true for.
-static bool shouldInlineIntoClosure(Value value) {
+static bool shouldInlineIntoClosure(const ClosureOptimizationOptions &options,
+ Value value) {
auto definingOp = value.getDefiningOp();
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
// Constants are perfect!
- return isConstantSmall(constantOp);
+ return isConstantInlinable(options, constantOp);
}
return false;
}
@@ -171,7 +174,8 @@
// Note that if multiple operands reference the same value it will get cloned
// multiple times. That's fine, as anything we can inline here is something we
// should also be able to CSE and that happens later on anyway.
-static void inlineClosureOperands(ClosureOpInterface &closureOp,
+static void inlineClosureOperands(const ClosureOptimizationOptions &options,
+ ClosureOpInterface &closureOp,
Block &entryBlock,
PatternRewriter &rewriter) {
OpBuilder::InsertionGuard g(rewriter);
@@ -188,7 +192,7 @@
if (!closureOp.getOperandAccess(opArg.index()).isReadOnly()) continue;
if (closureOp.canClosureContainOp(sourceOp) &&
- shouldInlineIntoClosure(outerValue)) {
+ shouldInlineIntoClosure(options, outerValue)) {
// Clone the op (with regions).
auto *clonedOp = rewriter.clone(*sourceOp);
@@ -208,7 +212,8 @@
}
}
-LogicalResult optimizeClosureLikeOp(ClosureOpInterface closureOp,
+LogicalResult optimizeClosureLikeOp(const ClosureOptimizationOptions &options,
+ ClosureOpInterface closureOp,
PatternRewriter &rewriter) {
// NOTE: the block is transferred to the new op; we can update it in place.
Block &entryBlock = closureOp.getClosureBodyRegion().front();
@@ -218,7 +223,7 @@
// then elide below. When we do inline things the operands will be changed
// such that the following work is guaranteed to happen and thus our op will
// be rebuilt.
- inlineClosureOperands(closureOp, entryBlock, rewriter);
+ inlineClosureOperands(options, closureOp, entryBlock, rewriter);
// Build data structure for unused operand elision.
SmallVector<unsigned, 4> elidedOperands;
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
index c3a900f..d70c3dd 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
@@ -43,11 +43,19 @@
void eraseRegionResults(Region ®ion,
ArrayRef<unsigned> excludedResultIndices);
+struct ClosureOptimizationOptions {
+ // Maximum size in bytes of constant values to inline into the closure.
+ // When 0 no constants will be inlined; when None all constants will be
+ // inlined.
+ Optional<int64_t> maxInlinedConstantBytes = {256};
+};
+
// Optimizes closure |closureOp| to remove duplicate operands and unused
// results. The op may be mutated, destroyed, or replaced with a new one. If an
// optional |rewriter| is provided then it will be notified of the operations
// performed on the op. Returns true if the op was optimized.
-LogicalResult optimizeClosureLikeOp(ClosureOpInterface closureOp,
+LogicalResult optimizeClosureLikeOp(const ClosureOptimizationOptions &options,
+ ClosureOpInterface closureOp,
PatternRewriter &rewriter);
// A pattern that optimizes the given region-containing op T (CSE, DCE, etc).
@@ -56,14 +64,21 @@
//
// T must implement the IREE::Util::ClosureOpInterface.
template <typename T>
-struct ClosureOptimizationPattern : public OpRewritePattern<T> {
- using OpRewritePattern<T>::OpRewritePattern;
+class ClosureOptimizationPattern : public OpRewritePattern<T> {
+ public:
+ ClosureOptimizationPattern(MLIRContext *context,
+ ClosureOptimizationOptions options = {},
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<T>(context, benefit), options(options) {}
LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter) const override {
auto closureOp = cast<ClosureOpInterface>(op.getOperation());
- return optimizeClosureLikeOp(closureOp, rewriter);
+ return optimizeClosureLikeOp(options, closureOp, rewriter);
}
+
+ private:
+ const ClosureOptimizationOptions options;
};
} // namespace Util