[NFC] Centralizing dispatch region constant inlining. (#12235)

This disables tensor constant inlining during closure optimization and
makes it so that dispatch region formation only inlines tensor values.
This avoids some issues with constant tensor access in #12233.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 69e39a2..b7d26bd 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -196,8 +196,12 @@
 
 void DispatchWorkgroupsOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
+  // Disable constant inlining as we have done it during dispatch region
+  // formation.
+  IREE::Util::ClosureOptimizationOptions closureOptions;
+  closureOptions.maxInlinedConstantBytes = 0;
   results.insert<IREE::Util::ClosureOptimizationPattern<DispatchWorkgroupsOp>>(
-      context);
+      context, closureOptions);
   results.insert<ReplaceDispatchResultIfZeroElements>(context);
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 46596ff..318521a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -960,9 +960,14 @@
   return getResults();
 }
 
-// Inline operations that the dispatch region can handle natively.
-static bool canDispatchRegionContainOp(Operation *op) {
-  // Inline constant operations that are splat or small constants.
+bool DispatchWorkgroupsOp::canClosureContainOp(Operation *op) {
+  // For now we only allow constants; we could bring other ops across the
+  // boundary though if we want (particularly metadata ops).
+  // Note that the closure optimization may still filter out the constant op if
+  // it's not configured to inline constants of certain types/sizes.
+  // TODO(#12233): this should just be isa<ConstantOp> but today we need to
+  // ensure that we don't mess with tensors after dispatch region formation due
+  // to requirements around tensor access checking.
   if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
     auto constantType = constantOp.getType();
     if (constantType.isIntOrIndexOrFloat()) {
@@ -972,10 +977,6 @@
   return false;
 }
 
-bool DispatchWorkgroupsOp::canClosureContainOp(Operation *op) {
-  return canDispatchRegionContainOp(op);
-}
-
 // Refines the tensor access from what is declared on |type| based on actual
 // usage. We expect that the access was set correctly to begin with but today
 // we sometimes specify things too wide.
@@ -985,15 +986,20 @@
     // If the argument is a result with `readwrite` access, return false if the
     // value is only written to. Check this by looking at the uses of the
     // argument being only the `target` of `flow.dispatch.tensor.store` ops.
-    bool onlyWrites = true;
+    bool hasReads = false;
+    bool hasWrites = false;
     for (OpOperand &uses : value.getUses()) {
-      auto storeOp = dyn_cast<DispatchTensorStoreOp>(uses.getOwner());
-      if (!(storeOp && storeOp.getTarget() == uses.get())) {
-        onlyWrites = false;
-        break;
-      }
+      TypeSwitch<Operation *>(uses.getOwner())
+          .Case<DispatchTensorLoadOp>([&](auto loadOp) { hasReads = true; })
+          .Case<DispatchTensorStoreOp>([&](auto storeOp) { hasWrites = true; })
+          .Default([&](auto op) {
+            // Treat unknown ops conservatively as read/write.
+            hasReads = true;
+            hasWrites = true;
+          });
     }
-    if (onlyWrites) tensorAccess = TensorAccess::WriteOnly;
+    if (hasReads && !hasWrites) tensorAccess = TensorAccess::ReadOnly;
+    if (!hasReads && hasWrites) tensorAccess = TensorAccess::WriteOnly;
   }
   return tensorAccess;
 }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir
index 648bd88..8e4f5e4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir
@@ -1,59 +1,5 @@
 // RUN: iree-opt --allow-unregistered-dialect --split-input-file --canonicalize %s | iree-opt --allow-unregistered-dialect --split-input-file | FileCheck %s
 
-// CHECK-LABEL: @inlineWithTiedResults1
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x4xf32>)
-func.func @inlineWithTiedResults1(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
-  // CHECK-NOT: constant 128
-  %cst = arith.constant 128 : index
-  // CHECK-DAG: %[[X:.+]] = arith.constant 100
-  %x = arith.constant 100 : index
-  // CHECK-DAG: %[[Y:.+]] = arith.constant 50
-  %y = arith.constant 50 : index
-  //      CHECK: flow.dispatch.workgroups[%[[X]], %[[Y]]](%[[ARG0]]) : (tensor<1x4xf32>) -> %[[ARG0]] =
-  // CHECK-NEXT:   (%[[ARG0_INNER:.+]]: !flow.dispatch.tensor<readwrite:tensor<1x4xf32>>)
-  %0 = flow.dispatch.workgroups[%x, %y](%cst, %arg0) : (index, tensor<1x4xf32>) -> %arg0 = (
-    %cst_capture: index,
-    %arg0_capture: !flow.dispatch.tensor<readwrite:tensor<1x4xf32>>
-  ) {
-    //      CHECK: %[[INLINED_CST:.+]] = arith.constant 128 : index
-    // CHECK-NEXT: "test.sink"(%[[INLINED_CST]])
-    "test.sink"(%cst_capture) : (index) -> ()
-    // CHECK-NEXT: "test.sink"(%[[ARG0_INNER]])
-    "test.sink"(%arg0_capture) : (!flow.dispatch.tensor<readwrite:tensor<1x4xf32>>) -> ()
-    flow.return
-  }
-  return %0 : tensor<1x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @inlineWithTiedResults2
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x4xf32>)
-func.func @inlineWithTiedResults2(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
-  // CHECK-NOT: constant 128
-  %cst = arith.constant 128 : index
-  // CHECK-DAG: %[[X:.+]] = arith.constant 100
-  %x = arith.constant 100 : index
-  // CHECK-DAG: %[[Y:.+]] = arith.constant 50
-  %y = arith.constant 50 : index
-  //      CHECK: flow.dispatch.workgroups[%[[X]], %[[Y]]](%[[ARG0]]) : (tensor<1x4xf32>) -> %[[ARG0]] =
-  // CHECK-NEXT:   (%[[ARG0_INNER:.+]]: !flow.dispatch.tensor<readwrite:tensor<1x4xf32>>)
-  %0 = flow.dispatch.workgroups[%x, %y](%arg0, %cst) : (tensor<1x4xf32>, index) -> %arg0 = (
-    %arg0_capture: !flow.dispatch.tensor<readwrite:tensor<1x4xf32>>,
-    %cst_capture: index
-  ) {
-    //      CHECK: %[[INLINED_CST:.+]] = arith.constant 128 : index
-    // CHECK-NEXT: "test.sink"(%[[INLINED_CST]])
-    "test.sink"(%cst_capture) : (index) -> ()
-    // CHECK-NEXT: "test.sink"(%[[ARG0_INNER]])
-    "test.sink"(%arg0_capture) : (!flow.dispatch.tensor<readwrite:tensor<1x4xf32>>) -> ()
-    flow.return
-  }
-  return %0 : tensor<1x4xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func.func @dontInlineReadWrite
 // CHECK-SAME: (%[[ARG0:.+]]: tensor<1x4xf32>)
 func.func @dontInlineReadWrite(%arg0: tensor<1x4xf32>) -> tensor<4x8xf32> {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index a5af4f6..354b7b6 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -48,8 +48,8 @@
 // compiler implementation details.
 static llvm::cl::opt<int> clInlineConstantByteLength(
     "iree-flow-inline-constants-max-byte-length",
-    llvm::cl::desc("Maximum byte-length of constant that can be inlined into a "
-                   "dispatch region"),
+    llvm::cl::desc("Maximum byte-length of tensor constant that can be inlined "
+                   "into a dispatch region or 0 to disable inlining."),
     llvm::cl::init(256));
 
 static const char kRootOpAttr[] = "__root_op__";
@@ -214,6 +214,7 @@
     return true;
   }
   if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
+    if (clInlineConstantByteLength == 0) return false;
     auto constantValueAttr = constantOp.getValue();
     auto constantType = constantOp.getType();
     if (constantValueAttr.isa<SplatElementsAttr>()) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
index 1185bdd..3849874 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchWorkgroups.cpp
@@ -93,6 +93,9 @@
     }
     result.push_back(definingOp);
     worklist.append(definingOp->operand_begin(), definingOp->operand_end());
+    llvm::SetVector<Value> nestedValues;
+    mlir::getUsedValuesDefinedAbove(definingOp->getRegions(), nestedValues);
+    worklist.append(nestedValues.begin(), nestedValues.end());
   }
 
   return result;
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 9fcf0c6..418a965 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -437,7 +437,7 @@
 //  CHECK-DAG:   %[[ARG1_D0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
 //  CHECK-DAG:   %[[ARG1_D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
 //      CHECK:   %[[RESULT:.+]] = flow.dispatch.workgroups[%[[ARG0_D0]], %[[ARG0_D1]]]
-// CHECK-SAME:       (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], 
+// CHECK-SAME:       (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]],
 // CHECK-SAME:        %[[ARG0_D0]], %[[ARG0_D1]], %[[ARG1_D0]], %[[ARG1_D1]])
 // CHECK-SAME:       tensor<?x?xf32>{%[[ARG0_D0]], %[[ARG0_D1]]}
 // CHECK-SAME:       tensor<?x?xf32>{%[[ARG1_D0]], %[[ARG1_D1]]}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
index 534c659..cdf9507 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
@@ -107,7 +107,8 @@
   }
 }
 
-// Returns true if |constantOp| represents a (logically) small constant value.
+// Returns true if |constantOp| represents a (logically) small constant value
+// that can be inlined into a closure.
 //
 // "Small" is relative and there's a risk that we'll bloat the closures by
 // duplicating a bunch of constants however what we are able to save by not
@@ -119,13 +120,14 @@
 // This is also still at a fairly high level (flow dialect): once the closures
 // are expanded out in lower dialects things like CSE have a chance to once
 // again get at the constants and dedupe them if they survive.
-static bool isConstantSmall(arith::ConstantOp constantOp) {
-  // We could tune this/take it as a configuration setting.
-  // The current value is chosen based on what is known to be reasonable to
-  // inline into command buffers way down in the HAL, which is not great but at
-  // least better than either allocating independent buffers for 4 byte
-  // constants or inlining megabytes.
-  static constexpr int kMaxInlinedConstantBytes = 256;
+static bool isConstantInlinable(const ClosureOptimizationOptions &options,
+                                arith::ConstantOp constantOp) {
+  int64_t maxInlinedConstantBytes =
+      options.maxInlinedConstantBytes.value_or(INT64_MAX);
+  if (maxInlinedConstantBytes == 0) {
+    // Inlining of constants disabled.
+    return false;
+  }
 
   auto constantValueAttr = constantOp.getValue();
   auto constantType = constantOp.getType();
@@ -141,7 +143,7 @@
         shapedType.getNumElements() *
         getRoundedElementByteWidth(shapedType.getElementType());
     return denseAttr.isSplat() ||
-           estimatedByteLength <= kMaxInlinedConstantBytes;
+           estimatedByteLength <= maxInlinedConstantBytes;
   } else if (constantType.isIntOrIndexOrFloat()) {
     // Primitives can always go in.
     return true;
@@ -155,11 +157,12 @@
 // trees is hard and it'd be better to model that differently such as by having
 // a wrapper region for immutable blobs that can be inlined that this then
 // returns true for.
-static bool shouldInlineIntoClosure(Value value) {
+static bool shouldInlineIntoClosure(const ClosureOptimizationOptions &options,
+                                    Value value) {
   auto definingOp = value.getDefiningOp();
   if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
     // Constants are perfect!
-    return isConstantSmall(constantOp);
+    return isConstantInlinable(options, constantOp);
   }
   return false;
 }
@@ -171,7 +174,8 @@
 // Note that if multiple operands reference the same value it will get cloned
 // multiple times. That's fine, as anything we can inline here is something we
 // should also be able to CSE and that happens later on anyway.
-static void inlineClosureOperands(ClosureOpInterface &closureOp,
+static void inlineClosureOperands(const ClosureOptimizationOptions &options,
+                                  ClosureOpInterface &closureOp,
                                   Block &entryBlock,
                                   PatternRewriter &rewriter) {
   OpBuilder::InsertionGuard g(rewriter);
@@ -188,7 +192,7 @@
     if (!closureOp.getOperandAccess(opArg.index()).isReadOnly()) continue;
 
     if (closureOp.canClosureContainOp(sourceOp) &&
-        shouldInlineIntoClosure(outerValue)) {
+        shouldInlineIntoClosure(options, outerValue)) {
       // Clone the op (with regions).
       auto *clonedOp = rewriter.clone(*sourceOp);
 
@@ -208,7 +212,8 @@
   }
 }
 
-LogicalResult optimizeClosureLikeOp(ClosureOpInterface closureOp,
+LogicalResult optimizeClosureLikeOp(const ClosureOptimizationOptions &options,
+                                    ClosureOpInterface closureOp,
                                     PatternRewriter &rewriter) {
   // NOTE: the block is transferred to the new op; we can update it in place.
   Block &entryBlock = closureOp.getClosureBodyRegion().front();
@@ -218,7 +223,7 @@
   // then elide below. When we do inline things the operands will be changed
   // such that the following work is guaranteed to happen and thus our op will
   // be rebuilt.
-  inlineClosureOperands(closureOp, entryBlock, rewriter);
+  inlineClosureOperands(options, closureOp, entryBlock, rewriter);
 
   // Build data structure for unused operand elision.
   SmallVector<unsigned, 4> elidedOperands;
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
index c3a900f..d70c3dd 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
@@ -43,11 +43,19 @@
 void eraseRegionResults(Region &region,
                         ArrayRef<unsigned> excludedResultIndices);
 
+struct ClosureOptimizationOptions {
+  // Maximum size in bytes of constant values to inline into the closure.
+  // When 0 no constants will be inlined; when None all constants will be
+  // inlined.
+  Optional<int64_t> maxInlinedConstantBytes = {256};
+};
+
 // Optimizes closure |closureOp| to remove duplicate operands and unused
 // results. The op may be mutated, destroyed, or replaced with a new one. If an
 // optional |rewriter| is provided then it will be notified of the operations
 // performed on the op. Returns true if the op was optimized.
-LogicalResult optimizeClosureLikeOp(ClosureOpInterface closureOp,
+LogicalResult optimizeClosureLikeOp(const ClosureOptimizationOptions &options,
+                                    ClosureOpInterface closureOp,
                                     PatternRewriter &rewriter);
 
 // A pattern that optimizes the given region-containing op T (CSE, DCE, etc).
@@ -56,14 +64,21 @@
 //
 // T must implement the IREE::Util::ClosureOpInterface.
 template <typename T>
-struct ClosureOptimizationPattern : public OpRewritePattern<T> {
-  using OpRewritePattern<T>::OpRewritePattern;
+class ClosureOptimizationPattern : public OpRewritePattern<T> {
+ public:
+  ClosureOptimizationPattern(MLIRContext *context,
+                             ClosureOptimizationOptions options = {},
+                             PatternBenefit benefit = 1)
+      : OpRewritePattern<T>(context, benefit), options(options) {}
 
   LogicalResult matchAndRewrite(T op,
                                 PatternRewriter &rewriter) const override {
     auto closureOp = cast<ClosureOpInterface>(op.getOperation());
-    return optimizeClosureLikeOp(closureOp, rewriter);
+    return optimizeClosureLikeOp(options, closureOp, rewriter);
   }
+
+ private:
+  const ClosureOptimizationOptions options;
 };
 
 }  // namespace Util