Treat nested implicit captures as producers for const-expr analysis. (#7938)

* Also reworks the hoisting transformation to consider the producer tree from the analysis authoritative vs relying on creating a backward slice.
* This is needed to handle implicit captures correctly and has the byproduct of making a more precise clone (the prior version would sometimes materialize extra dead ops for odd const-expr tree shapes).
diff --git a/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp b/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
index 672a3e1..d681ef6 100644
--- a/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
+++ b/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
@@ -127,37 +127,37 @@
 }
 
 void ConstExprAnalysis::expandToOp(Operation *op) {
-  bool eligible = isEligibleConstExprOp(op);
+  ConstExprOpInfo opInfo = ConstExprOpInfo::getForOp(op);
   for (auto result : op->getResults()) {
     auto foundIt = constInfoMap.find(result);
     if (foundIt != constInfoMap.end()) continue;
 
     // Generate new info record.
-    auto *info = addInfo(result);
-    if (!eligible) {
+    auto *valueInfo = addInfo(result);
+    if (!opInfo.isEligible) {
       // Put it in a NON_CONSTANT state and bail. This is terminal.
-      info->state = ConstValueInfo::NON_CONSTANT;
+      valueInfo->state = ConstValueInfo::NON_CONSTANT;
       LLVM_DEBUG(dbgs() << "  EXPAND TO INELIGIBLE: " << result << "\n");
       continue;
     }
 
     // If here, then an unknown state.
     LLVM_DEBUG(dbgs() << "  EXPAND TO UNKNOWN: " << result << "\n");
-    worklist.push_back(info);
+    worklist.push_back(valueInfo);
 
-    // Process operands.
-    for (auto operand : op->getOperands()) {
-      Operation *definingOp = operand.getDefiningOp();
+    // Process producers.
+    for (auto producer : opInfo.producers) {
+      Operation *definingOp = producer.getDefiningOp();
       if (!definingOp) {
         // Consider crossing out of block to be non-const.
-        info->state = ConstValueInfo::NON_CONSTANT;
+        valueInfo->state = ConstValueInfo::NON_CONSTANT;
         break;
       }
       expandToOp(definingOp);
 
-      ConstValueInfo *producerInfo = constInfoMap.lookup(operand);
+      ConstValueInfo *producerInfo = constInfoMap.lookup(producer);
       assert(producerInfo && "should have producer info in map");
-      info->producers.push_back(producerInfo);
+      valueInfo->producers.insert(producerInfo);
     }
   }
 }
@@ -172,6 +172,10 @@
       for (Value root : info->roots) {
         os << "      " << root << "\n";
       }
+      os << "    WITH PRODUCERS:\n";
+      for (ConstValueInfo *producerInfo : info->producers) {
+        os << "      " << producerInfo->constValue << "\n";
+      }
     }
   }
 }
diff --git a/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h b/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h
index 9237998..c59bed0 100644
--- a/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h
+++ b/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h
@@ -99,7 +99,7 @@
     SmallPtrSet<Value, 4> roots;
 
     // Direct producers that feed into this constant value.
-    SmallVector<ConstValueInfo *> producers;
+    SmallPtrSet<ConstValueInfo *, 8> producers;
 
     // Whether this is a root.
     bool isRoot = false;
diff --git a/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp b/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp
index d3f7a09..15bbc9f 100644
--- a/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp
+++ b/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp
@@ -7,6 +7,7 @@
 #include "iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h"
 
 #include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
@@ -15,44 +16,86 @@
 namespace IREE {
 namespace Util {
 
+namespace {
+
+void populateEscapingProducers(Operation *parentOp, ConstExprOpInfo &info) {
+  SmallPtrSet<Operation *, 8> containedOps;
+  parentOp->walk<WalkOrder::PreOrder>([&](Operation *itOp) {
+    containedOps.insert(parentOp);
+    // For the outer-most op, consider that all operands escape.
+    if (itOp == parentOp) {
+      info.producers.insert(itOp->getOperands().begin(),
+                            itOp->getOperands().end());
+      return;
+    }
+
+    // For nested operations, only consider that they escape if they are
+    // defined outside of the parent.
+    for (Value operand : itOp->getOperands()) {
+      Block *block = operand.getParentBlock();
+      if (!containedOps.contains(block->getParentOp())) {
+        info.producers.insert(operand);
+      }
+    }
+  });
+}
+
+ConstExprOpInfo getInfoForDefaultConstExprOp(Operation *op) {
+  ConstExprOpInfo info;
+  info.isEligible = true;
+  populateEscapingProducers(op, info);
+  return info;
+}
+
+}  // namespace
+
 void registerConstExprDependentDialects(DialectRegistry &registry) {
   registry.insert<IREE::Util::UtilDialect>();
   registry.insert<linalg::LinalgDialect>();
 }
 
-bool isEligibleConstExprOp(Operation *op) {
+ConstExprOpInfo ConstExprOpInfo::getForOp(Operation *op) {
   // Special carve-out for unregistered testing ops.
   if (!op->isRegistered()) {
-    if (op->getName().getStringRef() ==
-        "iree_unregistered.non_leaf_const_expr") {
-      return true;
-    }
-    if (op->getName().getStringRef() == "iree_unregistered.const_expr") {
-      return true;
-    }
+    // Reject.
     if (op->getName().getStringRef() == "iree_unregistered.var_expr") {
-      return false;
+      return {};
     }
-    return false;
+    // Accept.
+    if (op->getName().getStringRef() ==
+            "iree_unregistered.non_leaf_const_expr" ||
+        op->getName().getStringRef() == "iree_unregistered.const_expr") {
+      return getInfoForDefaultConstExprOp(op);
+    }
+    return {};
   }
 
-  // Allow linalg ops, even though they are not effect annotated.
+  // We have a specific allow-list for Linalg ops because we want to consider
+  // new additions carefully.
   if (op->getDialect() ==
       op->getContext()->getOrLoadDialect<linalg::LinalgDialect>()) {
-    return true;
+    // Structured op implementations and a handful of pure ops are included.
+    // Notably: IndexOp is not included because it establishes a hidden
+    // dependency to the iterator and is non-const.
+    if (llvm::isa<linalg::LinalgOp>(op) || llvm::isa<linalg::PadTensorOp>(op) ||
+        llvm::isa<linalg::InitTensorOp>(op)) {
+      return getInfoForDefaultConstExprOp(op);
+    }
+
+    return {};
   }
 
   // By default any effects make it non const-expr.
   if (!MemoryEffectOpInterface::hasNoEffect(op)) {
-    return false;
+    return {};
   }
 
   // By default, ops without results are not const-expr.
   if (op->getNumResults() == 0) {
-    return false;
+    return {};
   }
 
-  return true;
+  return getInfoForDefaultConstExprOp(op);
 }
 
 bool isHoistableConstExprLeaf(const ConstExprAnalysis::ConstValueInfo *info) {
diff --git a/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h b/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h
index f620115..343d04e 100644
--- a/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h
+++ b/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h
@@ -8,6 +8,7 @@
 #define IREE_COMPILER_DIALECT_IREE_UTIL_ANALYSIS_CONSTANT_OP_ORACLE_H_
 
 #include "iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "mlir/IR/Operation.h"
 
 namespace mlir {
@@ -18,15 +19,28 @@
 // Registers dialects needed to query or construct const-expr information.
 void registerConstExprDependentDialects(DialectRegistry &registry);
 
-// Whether an op can be considered a pure expression, producing a constant if
-// provided constants and having no side effects beyond that.
-//
-// In order to enable testing, some unregistered ops are also recognized:
-//   - iree_unregistered.non_leaf_const_expr : Will be treated as const-expr.
-//   - iree_unregistered.const_expr : Will be treated as const-expr
-//   - iree_unregistered.var_expr : Will be treated as not const-expr
-// Any other unregistered ops are treated as not const-expr.
-bool isEligibleConstExprOp(Operation *op);
+// Information about a possible const-expr op.
+struct ConstExprOpInfo {
+  // Whether the op is eligible to be considered const-expr, assuming that
+  // all of its producers are eligible.
+  bool isEligible = false;
+
+  // Producer values that must be const-expr for this op to be considered
+  // const-expr. This minimally includes operands, and for region-based ops
+  // may include implicit captures.
+  llvm::SmallPtrSet<Value, 8> producers;
+
+  // Gets information for an op.
+  // Whether an op can be considered a pure expression, producing a constant if
+  // provided constants and having no side effects beyond that.
+  //
+  // In order to enable testing, some unregistered ops are also recognized:
+  //   - iree_unregistered.non_leaf_const_expr : Will be treated as const-expr.
+  //   - iree_unregistered.const_expr : Will be treated as const-expr
+  //   - iree_unregistered.var_expr : Will be treated as not const-expr
+  // Any other unregistered ops are treated as not const-expr.
+  static ConstExprOpInfo getForOp(Operation *op);
+};
 
 // Whether a const-expr op is eligible to be hoistable. This enforces
 // policies for excluding certain, otherwise eligible, const-expr ops from
diff --git a/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
index a3fdaf9..5b1ac52 100644
--- a/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
+++ b/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
@@ -131,8 +131,11 @@
       Location loc = originalValue.getLoc();
       OpBuilder builder = getModuleEndBuilder();
       auto initializerOp = builder.create<InitializerOp>(loc);
-      cloneConstExprInto(initializerOp, originalValue, hoistedMap,
-                         moduleSymbols);
+      Block *entryBlock = initializerOp.addEntryBlock();
+      OpBuilder initBuilder = OpBuilder::atBlockEnd(entryBlock);
+      BlockAndValueMapping valueMapping;
+      cloneConstExprInto(initializerOp.getLoc(), initBuilder, originalValue,
+                         hoistedMap, moduleSymbols, valueMapping, constExprs);
 
       existingGlobal = hoistedMap.lookup(originalValue);
     }
@@ -147,54 +150,59 @@
     operand->set(load);
   }
 
+  void cloneProducerTreeInto(
+      OpBuilder &builder, const ConstExprAnalysis::ConstValueInfo *producerInfo,
+      HoistedValueMap &hoistedMap, BlockAndValueMapping &cloneMapping,
+      const ConstExprAnalysis &constExprs) {
+    if (cloneMapping.contains(producerInfo->constValue)) return;
+
+    // We either have a global associated already or we need to traverse
+    // down and materialize producers.
+    GlobalOp existingGlobal = hoistedMap.lookup(producerInfo->constValue);
+    if (existingGlobal) {
+      cloneMapping.map(producerInfo->constValue,
+                       builder.create<GlobalLoadOp>(existingGlobal.getLoc(),
+                                                    existingGlobal));
+      return;
+    }
+
+    // Materialize all producers recursively.
+    for (auto *producerInfo : producerInfo->producers) {
+      cloneProducerTreeInto(builder, producerInfo, hoistedMap, cloneMapping,
+                            constExprs);
+    }
+
+    // And clone the requested op.
+    Operation *sourceOp = producerInfo->constValue.getDefiningOp();
+    assert(sourceOp && "must have defining op for const-expr values");
+    LLVM_DEBUG(dbgs() << "    CLONE OP: " << *sourceOp << "\n");
+    Operation *clonedOp = sourceOp->clone(cloneMapping);
+    builder.insert(clonedOp);
+  }
+
   // Clones the const expr tree rooted at `constExprValue` into the given
   // initializer, noting any new hoisted value mappings that result. At
   // a minimum, a mapping will be created for the requested value.
-  void cloneConstExprInto(InitializerOp initializerOp, Value constExprValue,
-                          HoistedValueMap &hoistedMap,
-                          SymbolTable &moduleSymbols) {
-    Block *entryBlock = initializerOp.addEntryBlock();
-    OpBuilder initBuilder = OpBuilder::atBlockEnd(entryBlock);
-
-    // Clone all dependents of the defining op.
+  void cloneConstExprInto(Location loc, OpBuilder &builder,
+                          Value constExprValue, HoistedValueMap &hoistedMap,
+                          SymbolTable &moduleSymbols,
+                          BlockAndValueMapping &cloneMapping,
+                          const ConstExprAnalysis &constExprs) {
+    // Do a depth first traversal of the producers, emitting them in a valid
+    // def-use order.
     Operation *rootOp = constExprValue.getDefiningOp();
     assert(rootOp && "const-expr value should have a defining op");
-    SetVector<Operation *> slice;
-    getBackwardSlice(rootOp, &slice);
-    BlockAndValueMapping cloneMap;
+    auto *rootInfo = constExprs.lookup(rootOp);
+    assert(rootInfo && "must have const-value-info for const-expr root op");
 
-    for (Operation *sourceOp : slice) {
-      // Iterate over the source results and see if we have already hoisted.
-      // Note that because we hoist all results of an op below, we can count
-      // on all or none of them having hoisted. Initialization order is
-      // correct because we greedily hoist in topological order of const-expr
-      // ops above.
-      bool needsClone = true;
-      for (Value origResult : sourceOp->getResults()) {
-        GlobalOp existingGlobal = hoistedMap.lookup(origResult);
-        if (!existingGlobal) break;
-        needsClone = false;
-        cloneMap.map(origResult, initBuilder.create<GlobalLoadOp>(
-                                     existingGlobal.getLoc(), existingGlobal));
-      }
+    // Clone the whole tree as needed.
+    cloneProducerTreeInto(builder, rootInfo, hoistedMap, cloneMapping,
+                          constExprs);
 
-      if (needsClone) {
-        LLVM_DEBUG(dbgs() << "    CLONE OP: " << *sourceOp << "\n");
-        Operation *cloneOp = sourceOp->clone(cloneMap);
-        initBuilder.insert(cloneOp);
-      }
-    }
-
-    // Now, for the defining op itself, create a global for each result and
-    // store into it.
-    // Note that we create globals at the beginning of the module because
-    // they must precede accesses and this is guaranteed here.
+    // And for each result, create a global and store into it.
     OpBuilder globalBuilder = getModuleBeginBuilder();
-    Operation *clonedRootOp = rootOp->clone(cloneMap);
-    initBuilder.insert(clonedRootOp);
     for (Value origResult : rootOp->getResults()) {
-      Value clonedResult = cloneMap.lookup(origResult);
-      Location loc = clonedRootOp->getLoc();
+      Value clonedResult = cloneMapping.lookup(origResult);
       GlobalOp globalOp = globalBuilder.create<GlobalOp>(loc, "hoisted", false,
                                                          origResult.getType());
       StringAttr globalSymbol = moduleSymbols.insert(globalOp);
@@ -205,10 +213,12 @@
       hoistedMap[origResult] = globalOp;
 
       // And store into it.
-      initBuilder.create<GlobalStoreOp>(loc, clonedResult, globalSymbol);
+      LLVM_DEBUG(dbgs() << "    CREATE GLOBAL " << globalSymbol << " = "
+                        << clonedResult << "\n");
+      builder.create<GlobalStoreOp>(loc, clonedResult, globalSymbol);
     }
 
-    initBuilder.create<InitializerReturnOp>(initializerOp.getLoc());
+    builder.create<InitializerReturnOp>(loc);
   }
 
   void cleanupDeadOps(const ConstExprAnalysis &constExprs) {
diff --git a/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir b/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
index 7c99dc5..df59eb2 100644
--- a/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
+++ b/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
@@ -121,3 +121,35 @@
   // CHECK:   util.initializer.return
   // CHECK: }
 }
+
+// -----
+// CHECK-LABEL: @hoist_implicit_capture
+module @hoist_implicit_capture {
+  // CHECK: util.global private @[[HOISTED_SYM:.*]] : i32
+  // CHECK: func @main
+  builtin.func @main() -> (i32) {
+    %0 = arith.constant 0 : i32
+    %1 = arith.constant 1 : i32
+    // CHECK-NOT: arith.constant
+    // CHECK-NOT: iree_unregistered.const_expr
+    // CHECK: %[[VAL:.*]] = util.global.load @[[HOISTED_SYM]] : i32
+    // CHECK: return %[[VAL]]
+    %2 = "iree_unregistered.const_expr"(%0) ({
+    ^bb0(%inner0 : i32):
+      %3 = arith.addi %inner0, %1 : i32
+      "iree_unregistered.yield"(%3) : (i32) -> i32
+    }) : (i32) -> i32
+    return %2 : i32
+  }
+  // Key checks: arith.constant 1 gets pulled in to the initializer
+  // and the reference is updated correctly in the custom op region.
+  // CHECK: util.initializer {
+  // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : i32
+  // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : i32
+  // CHECK:       %[[CE0:.*]] = "iree_unregistered.const_expr"(%[[C0]])
+  // CHECK:         ^bb0(%[[B0:.*]]: i32):
+  // CHECK:         arith.addi %[[B0]], %[[C1]]
+  // CHECK:       util.global.store %[[CE0]], @[[HOISTED_SYM]] : i32
+  // CHECK:       util.initializer.return
+  // CHECK: }
+}