Treat nested implicit captures as producers for const-expr analysis. (#7938)
* Also reworks the hoisting transformation to consider the producer tree from the analysis authoritative vs relying on creating a backward slice.
* This is needed to handle implicit captures correctly and has the byproduct of making a more precise clone (the prior version would sometimes materialize extra dead ops for odd const-expr tree shapes).
diff --git a/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp b/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
index 672a3e1..d681ef6 100644
--- a/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
+++ b/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
@@ -127,37 +127,37 @@
}
void ConstExprAnalysis::expandToOp(Operation *op) {
- bool eligible = isEligibleConstExprOp(op);
+ ConstExprOpInfo opInfo = ConstExprOpInfo::getForOp(op);
for (auto result : op->getResults()) {
auto foundIt = constInfoMap.find(result);
if (foundIt != constInfoMap.end()) continue;
// Generate new info record.
- auto *info = addInfo(result);
- if (!eligible) {
+ auto *valueInfo = addInfo(result);
+ if (!opInfo.isEligible) {
// Put it in a NON_CONSTANT state and bail. This is terminal.
- info->state = ConstValueInfo::NON_CONSTANT;
+ valueInfo->state = ConstValueInfo::NON_CONSTANT;
LLVM_DEBUG(dbgs() << " EXPAND TO INELIGIBLE: " << result << "\n");
continue;
}
// If here, then an unknown state.
LLVM_DEBUG(dbgs() << " EXPAND TO UNKNOWN: " << result << "\n");
- worklist.push_back(info);
+ worklist.push_back(valueInfo);
- // Process operands.
- for (auto operand : op->getOperands()) {
- Operation *definingOp = operand.getDefiningOp();
+ // Process producers.
+ for (auto producer : opInfo.producers) {
+ Operation *definingOp = producer.getDefiningOp();
if (!definingOp) {
// Consider crossing out of block to be non-const.
- info->state = ConstValueInfo::NON_CONSTANT;
+ valueInfo->state = ConstValueInfo::NON_CONSTANT;
break;
}
expandToOp(definingOp);
- ConstValueInfo *producerInfo = constInfoMap.lookup(operand);
+ ConstValueInfo *producerInfo = constInfoMap.lookup(producer);
assert(producerInfo && "should have producer info in map");
- info->producers.push_back(producerInfo);
+ valueInfo->producers.insert(producerInfo);
}
}
}
@@ -172,6 +172,10 @@
for (Value root : info->roots) {
os << " " << root << "\n";
}
+ os << " WITH PRODUCERS:\n";
+ for (ConstValueInfo *producerInfo : info->producers) {
+ os << " " << producerInfo->constValue << "\n";
+ }
}
}
}
diff --git a/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h b/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h
index 9237998..c59bed0 100644
--- a/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h
+++ b/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h
@@ -99,7 +99,7 @@
SmallPtrSet<Value, 4> roots;
// Direct producers that feed into this constant value.
- SmallVector<ConstValueInfo *> producers;
+ SmallPtrSet<ConstValueInfo *, 8> producers;
// Whether this is a root.
bool isRoot = false;
diff --git a/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp b/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp
index d3f7a09..15bbc9f 100644
--- a/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp
+++ b/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -15,44 +16,86 @@
namespace IREE {
namespace Util {
+namespace {
+
+void populateEscapingProducers(Operation *parentOp, ConstExprOpInfo &info) {
+ SmallPtrSet<Operation *, 8> containedOps;
+ parentOp->walk<WalkOrder::PreOrder>([&](Operation *itOp) {
+ containedOps.insert(parentOp);
+ // For the outer-most op, consider that all operands escape.
+ if (itOp == parentOp) {
+ info.producers.insert(itOp->getOperands().begin(),
+ itOp->getOperands().end());
+ return;
+ }
+
+ // For nested operations, only consider that they escape if they are
+ // defined outside of the parent.
+ for (Value operand : itOp->getOperands()) {
+ Block *block = operand.getParentBlock();
+ if (!containedOps.contains(block->getParentOp())) {
+ info.producers.insert(operand);
+ }
+ }
+ });
+}
+
+ConstExprOpInfo getInfoForDefaultConstExprOp(Operation *op) {
+ ConstExprOpInfo info;
+ info.isEligible = true;
+ populateEscapingProducers(op, info);
+ return info;
+}
+
+} // namespace
+
void registerConstExprDependentDialects(DialectRegistry ®istry) {
registry.insert<IREE::Util::UtilDialect>();
registry.insert<linalg::LinalgDialect>();
}
-bool isEligibleConstExprOp(Operation *op) {
+ConstExprOpInfo ConstExprOpInfo::getForOp(Operation *op) {
// Special carve-out for unregistered testing ops.
if (!op->isRegistered()) {
- if (op->getName().getStringRef() ==
- "iree_unregistered.non_leaf_const_expr") {
- return true;
- }
- if (op->getName().getStringRef() == "iree_unregistered.const_expr") {
- return true;
- }
+ // Reject.
if (op->getName().getStringRef() == "iree_unregistered.var_expr") {
- return false;
+ return {};
}
- return false;
+ // Accept.
+ if (op->getName().getStringRef() ==
+ "iree_unregistered.non_leaf_const_expr" ||
+ op->getName().getStringRef() == "iree_unregistered.const_expr") {
+ return getInfoForDefaultConstExprOp(op);
+ }
+ return {};
}
- // Allow linalg ops, even though they are not effect annotated.
+ // We have a specific allow-list for Linalg ops because we want to consider
+ // new additions carefully.
if (op->getDialect() ==
op->getContext()->getOrLoadDialect<linalg::LinalgDialect>()) {
- return true;
+ // Structured op implementations and a handful of pure ops are included.
+ // Notably: IndexOp is not included because it establishes a hidden
+ // dependency to the iterator and is non-const.
+ if (llvm::isa<linalg::LinalgOp>(op) || llvm::isa<linalg::PadTensorOp>(op) ||
+ llvm::isa<linalg::InitTensorOp>(op)) {
+ return getInfoForDefaultConstExprOp(op);
+ }
+
+ return {};
}
// By default any effects make it non const-expr.
if (!MemoryEffectOpInterface::hasNoEffect(op)) {
- return false;
+ return {};
}
// By default, ops without results are not const-expr.
if (op->getNumResults() == 0) {
- return false;
+ return {};
}
- return true;
+ return getInfoForDefaultConstExprOp(op);
}
bool isHoistableConstExprLeaf(const ConstExprAnalysis::ConstValueInfo *info) {
diff --git a/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h b/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h
index f620115..343d04e 100644
--- a/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h
+++ b/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h
@@ -8,6 +8,7 @@
#define IREE_COMPILER_DIALECT_IREE_UTIL_ANALYSIS_CONSTANT_OP_ORACLE_H_
#include "iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "mlir/IR/Operation.h"
namespace mlir {
@@ -18,15 +19,28 @@
// Registers dialects needed to query or construct const-expr information.
void registerConstExprDependentDialects(DialectRegistry ®istry);
-// Whether an op can be considered a pure expression, producing a constant if
-// provided constants and having no side effects beyond that.
-//
-// In order to enable testing, some unregistered ops are also recognized:
-// - iree_unregistered.non_leaf_const_expr : Will be treated as const-expr.
-// - iree_unregistered.const_expr : Will be treated as const-expr
-// - iree_unregistered.var_expr : Will be treated as not const-expr
-// Any other unregistered ops are treated as not const-expr.
-bool isEligibleConstExprOp(Operation *op);
+// Information about a possible const-expr op.
+struct ConstExprOpInfo {
+ // Whether the op is eligible to be considered const-expr, assuming that
+ // all of its producers are eligible.
+ bool isEligible = false;
+
+ // Producer values that must be const-expr for this op to be considered
+ // const-expr. This minimally includes operands, and for region-based ops
+ // may include implicit captures.
+ llvm::SmallPtrSet<Value, 8> producers;
+
+ // Gets information for an op.
+ // Whether an op can be considered a pure expression, producing a constant if
+ // provided constants and having no side effects beyond that.
+ //
+ // In order to enable testing, some unregistered ops are also recognized:
+ // - iree_unregistered.non_leaf_const_expr : Will be treated as const-expr.
+ // - iree_unregistered.const_expr : Will be treated as const-expr
+ // - iree_unregistered.var_expr : Will be treated as not const-expr
+ // Any other unregistered ops are treated as not const-expr.
+ static ConstExprOpInfo getForOp(Operation *op);
+};
// Whether a const-expr op is eligible to be hoistable. This enforces
// policies for excluding certain, otherwise eligible, const-expr ops from
diff --git a/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
index a3fdaf9..5b1ac52 100644
--- a/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
+++ b/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
@@ -131,8 +131,11 @@
Location loc = originalValue.getLoc();
OpBuilder builder = getModuleEndBuilder();
auto initializerOp = builder.create<InitializerOp>(loc);
- cloneConstExprInto(initializerOp, originalValue, hoistedMap,
- moduleSymbols);
+ Block *entryBlock = initializerOp.addEntryBlock();
+ OpBuilder initBuilder = OpBuilder::atBlockEnd(entryBlock);
+ BlockAndValueMapping valueMapping;
+ cloneConstExprInto(initializerOp.getLoc(), initBuilder, originalValue,
+ hoistedMap, moduleSymbols, valueMapping, constExprs);
existingGlobal = hoistedMap.lookup(originalValue);
}
@@ -147,54 +150,59 @@
operand->set(load);
}
+ void cloneProducerTreeInto(
+ OpBuilder &builder, const ConstExprAnalysis::ConstValueInfo *producerInfo,
+ HoistedValueMap &hoistedMap, BlockAndValueMapping &cloneMapping,
+ const ConstExprAnalysis &constExprs) {
+ if (cloneMapping.contains(producerInfo->constValue)) return;
+
+ // We either have a global associated already or we need to traverse
+ // down and materialize producers.
+ GlobalOp existingGlobal = hoistedMap.lookup(producerInfo->constValue);
+ if (existingGlobal) {
+ cloneMapping.map(producerInfo->constValue,
+ builder.create<GlobalLoadOp>(existingGlobal.getLoc(),
+ existingGlobal));
+ return;
+ }
+
+ // Materialize all producers recursively.
+ for (auto *producerInfo : producerInfo->producers) {
+ cloneProducerTreeInto(builder, producerInfo, hoistedMap, cloneMapping,
+ constExprs);
+ }
+
+ // And clone the requested op.
+ Operation *sourceOp = producerInfo->constValue.getDefiningOp();
+ assert(sourceOp && "must have defining op for const-expr values");
+ LLVM_DEBUG(dbgs() << " CLONE OP: " << *sourceOp << "\n");
+ Operation *clonedOp = sourceOp->clone(cloneMapping);
+ builder.insert(clonedOp);
+ }
+
// Clones the const expr tree rooted at `constExprValue` into the given
// initializer, noting any new hoisted value mappings that result. At
// a minimum, a mapping will be created for the requested value.
- void cloneConstExprInto(InitializerOp initializerOp, Value constExprValue,
- HoistedValueMap &hoistedMap,
- SymbolTable &moduleSymbols) {
- Block *entryBlock = initializerOp.addEntryBlock();
- OpBuilder initBuilder = OpBuilder::atBlockEnd(entryBlock);
-
- // Clone all dependents of the defining op.
+ void cloneConstExprInto(Location loc, OpBuilder &builder,
+ Value constExprValue, HoistedValueMap &hoistedMap,
+ SymbolTable &moduleSymbols,
+ BlockAndValueMapping &cloneMapping,
+ const ConstExprAnalysis &constExprs) {
+ // Do a depth first traversal of the producers, emitting them in a valid
+ // def-use order.
Operation *rootOp = constExprValue.getDefiningOp();
assert(rootOp && "const-expr value should have a defining op");
- SetVector<Operation *> slice;
- getBackwardSlice(rootOp, &slice);
- BlockAndValueMapping cloneMap;
+ auto *rootInfo = constExprs.lookup(rootOp);
+ assert(rootInfo && "must have const-value-info for const-expr root op");
- for (Operation *sourceOp : slice) {
- // Iterate over the source results and see if we have already hoisted.
- // Note that because we hoist all results of an op below, we can count
- // on all or none of them having hoisted. Initialization order is
- // correct because we greedily hoist in topological order of const-expr
- // ops above.
- bool needsClone = true;
- for (Value origResult : sourceOp->getResults()) {
- GlobalOp existingGlobal = hoistedMap.lookup(origResult);
- if (!existingGlobal) break;
- needsClone = false;
- cloneMap.map(origResult, initBuilder.create<GlobalLoadOp>(
- existingGlobal.getLoc(), existingGlobal));
- }
+ // Clone the whole tree as needed.
+ cloneProducerTreeInto(builder, rootInfo, hoistedMap, cloneMapping,
+ constExprs);
- if (needsClone) {
- LLVM_DEBUG(dbgs() << " CLONE OP: " << *sourceOp << "\n");
- Operation *cloneOp = sourceOp->clone(cloneMap);
- initBuilder.insert(cloneOp);
- }
- }
-
- // Now, for the defining op itself, create a global for each result and
- // store into it.
- // Note that we create globals at the beginning of the module because
- // they must precede accesses and this is guaranteed here.
+ // And for each result, create a global and store into it.
OpBuilder globalBuilder = getModuleBeginBuilder();
- Operation *clonedRootOp = rootOp->clone(cloneMap);
- initBuilder.insert(clonedRootOp);
for (Value origResult : rootOp->getResults()) {
- Value clonedResult = cloneMap.lookup(origResult);
- Location loc = clonedRootOp->getLoc();
+ Value clonedResult = cloneMapping.lookup(origResult);
GlobalOp globalOp = globalBuilder.create<GlobalOp>(loc, "hoisted", false,
origResult.getType());
StringAttr globalSymbol = moduleSymbols.insert(globalOp);
@@ -205,10 +213,12 @@
hoistedMap[origResult] = globalOp;
// And store into it.
- initBuilder.create<GlobalStoreOp>(loc, clonedResult, globalSymbol);
+ LLVM_DEBUG(dbgs() << " CREATE GLOBAL " << globalSymbol << " = "
+ << clonedResult << "\n");
+ builder.create<GlobalStoreOp>(loc, clonedResult, globalSymbol);
}
- initBuilder.create<InitializerReturnOp>(initializerOp.getLoc());
+ builder.create<InitializerReturnOp>(loc);
}
void cleanupDeadOps(const ConstExprAnalysis &constExprs) {
diff --git a/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir b/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
index 7c99dc5..df59eb2 100644
--- a/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
+++ b/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
@@ -121,3 +121,35 @@
// CHECK: util.initializer.return
// CHECK: }
}
+
+// -----
+// CHECK-LABEL: @hoist_implicit_capture
+module @hoist_implicit_capture {
+ // CHECK: util.global private @[[HOISTED_SYM:.*]] : i32
+ // CHECK: func @main
+ builtin.func @main() -> (i32) {
+ %0 = arith.constant 0 : i32
+ %1 = arith.constant 1 : i32
+ // CHECK-NOT: arith.constant
+ // CHECK-NOT: iree_unregistered.const_expr
+ // CHECK: %[[VAL:.*]] = util.global.load @[[HOISTED_SYM]] : i32
+ // CHECK: return %[[VAL]]
+ %2 = "iree_unregistered.const_expr"(%0) ({
+ ^bb0(%inner0 : i32):
+ %3 = arith.addi %inner0, %1 : i32
+ "iree_unregistered.yield"(%3) : (i32) -> i32
+ }) : (i32) -> i32
+ return %2 : i32
+ }
+ // Key checks: arith.constant 1 gets pulled in to the initializer
+ // and the reference is updated correctly in the custom op region.
+ // CHECK: util.initializer {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+ // CHECK: %[[CE0:.*]] = "iree_unregistered.const_expr"(%[[C0]])
+ // CHECK: ^bb0(%[[B0:.*]]: i32):
+ // CHECK: arith.addi %[[B0]], %[[C1]]
+ // CHECK: util.global.store %[[CE0]], @[[HOISTED_SYM]] : i32
+ // CHECK: util.initializer.return
+ // CHECK: }
+}