[ConstEval] Make ConstExprMaxSizeIncreaseThreshold be controlled by API. (#15183)
Fixes https://github.com/openxla/iree/issues/15073
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
index 6ef7295..bf82da3 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
@@ -25,12 +25,6 @@
namespace IREE {
namespace Util {
-static llvm::cl::opt<int64_t> clConstExprMaxSizeIncreaseThreshold(
- "iree-util-const-expr-max-size-increase-threshold",
- llvm::cl::desc("Maximum byte size increase allowed for constant expr "
- "hoisting policy to allow hoisting."),
- llvm::cl::init(1024 * 1024));
-
//===----------------------------------------------------------------------===//
// ConstExprAnalysis
//===----------------------------------------------------------------------===//
@@ -248,8 +242,9 @@
//===----------------------------------------------------------------------===//
ConstExprHoistingPolicy::ConstExprHoistingPolicy(
- const ConstExprAnalysis &analysis)
- : analysis(analysis), decisions(analysis.allocedConstInfos.size()) {
+ const ConstExprAnalysis &analysis, int64_t threshold)
+ : analysis(analysis), constExprMaxSizeIncreaseThreshold(threshold),
+ decisions(analysis.allocedConstInfos.size()) {
for (auto &it : analysis.allocedConstInfos) {
decisions[it.get()] = {};
}
@@ -320,7 +315,7 @@
}
static bool doesHoistingIncreaseSizeSignificantly(
- const ConstExprAnalysis::ConstValueInfo *info) {
+ const ConstExprAnalysis::ConstValueInfo *info, int64_t threshold) {
int64_t inSize = 0;
for (Value root : info->roots) {
@@ -354,7 +349,7 @@
getRoundedPhysicalStorageSize(elementCount, type.getElementType());
}
- return outSize > inSize + clConstExprMaxSizeIncreaseThreshold.getValue();
+ return outSize > inSize + threshold;
}
void ConstExprHoistingPolicy::makeInvariantDecision(
@@ -376,7 +371,8 @@
// Check 4: Does hoisting this value significantly increase the size of the
// module?
- if (doesHoistingIncreaseSizeSignificantly(info)) {
+ if (doesHoistingIncreaseSizeSignificantly(
+ info, constExprMaxSizeIncreaseThreshold)) {
return decision->disableHoist();
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h
index a47c56f..a99d0de 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h
@@ -219,7 +219,7 @@
const ConstExprAnalysis &getAnalysis() const { return analysis; }
- ConstExprHoistingPolicy(const ConstExprAnalysis &analysis);
+ ConstExprHoistingPolicy(const ConstExprAnalysis &analysis, int64_t threshold);
void initialize();
Decision *getDecision(const ConstExprAnalysis::ConstValueInfo *info) {
return &decisions[info];
@@ -244,6 +244,8 @@
const ConstExprAnalysis &analysis;
+ int64_t constExprMaxSizeIncreaseThreshold;
+
// Map of ConstValueInfo * to decision structs. All are allocated at
// initialization and then the structure is not changed.
llvm::DenseMap<const ConstExprAnalysis::ConstValueInfo *, Decision> decisions;
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
index 8abb3d2..9d1d10b 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
@@ -47,12 +47,16 @@
registerConstExprDependentDialects(registry);
}
+ HoistIntoGlobalsPass(int64_t threshold) {
+ this->maxSizeIncreaseThreshold.setValue(threshold);
+ }
+
void runOnOperation() override {
SymbolTable moduleSymbols(getOperation());
const auto &constExprs = getAnalysis<ConstExprAnalysis>();
LLVM_DEBUG(dbgs() << constExprs);
LLVM_DEBUG(dbgs() << "\n\n");
- ConstExprHoistingPolicy policy(constExprs);
+ ConstExprHoistingPolicy policy(constExprs, this->maxSizeIncreaseThreshold);
policy.initialize();
// Print analysis dot graph if requested.
@@ -252,8 +256,9 @@
} // namespace
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createHoistIntoGlobalsPass() {
- return std::make_unique<HoistIntoGlobalsPass>();
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createHoistIntoGlobalsPass(int64_t maxSizeIncreaseThreshold) {
+ return std::make_unique<HoistIntoGlobalsPass>(maxSizeIncreaseThreshold);
}
} // namespace Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
index 51f6877..73d8a83 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
@@ -23,7 +23,8 @@
createFixedPointIteratorPass(OpPassManager pipeline);
std::unique_ptr<OperationPass<mlir::ModuleOp>> createFoldGlobalsPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createFuseGlobalsPass();
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createHoistIntoGlobalsPass();
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createHoistIntoGlobalsPass(int64_t maxSizeIncreaseThreshold = 2147483647);
std::unique_ptr<OperationPass<mlir::ModuleOp>> createIPOPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createOutlineConstantsPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createPropagateSubrangesPass();
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
index 8794c56..bada10a 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
@@ -105,6 +105,12 @@
let constructor = [{
mlir::iree_compiler::IREE::Util::createHoistIntoGlobalsPass()
}];
+ let options = [
+ Option<"maxSizeIncreaseThreshold", "max-size-increase-threshold", "int64_t",
+ /*default=*/"1048576",
+ "Maximum byte size increase allowed for constant expr hoisting policy to"
+ "allow hoisting. The threshold is 1MB by default.">
+ ];
}
def SimplifyGlobalAccesses :
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
index 6e038c4..0ad053c 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --iree-util-hoist-into-globals --allow-unregistered-dialect --iree-util-const-expr-max-size-increase-threshold=64 %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-util-hoist-into-globals="max-size-increase-threshold=64" --allow-unregistered-dialect %s | FileCheck %s
// CHECK-LABEL: @hoist_simple_const_expr
module @hoist_simple_const_expr {
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index 16853dd..d2de795 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -88,7 +88,8 @@
pipeline.addPass(IREE::Util::createIPOPass());
if (transformOptions.options.constExprHoisting) {
- pipeline.addPass(IREE::Util::createHoistIntoGlobalsPass());
+ pipeline.addPass(IREE::Util::createHoistIntoGlobalsPass(
+ transformOptions.options.constExprMaxSizeIncreaseThreshold));
}
if (transformOptions.buildConstEvalPassPipeline) {
diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h
index f1c4830..0909ab8 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.h
+++ b/compiler/src/iree/compiler/Pipelines/Options.h
@@ -95,6 +95,10 @@
// Strips debug assertions after any useful information has been extracted.
bool stripAssertions = false;
+ // Maximum byte size increase allowed for constant expr hoisting policy to
+ // allow hoisting. The threshold is 1MB by default.
+ int64_t constExprMaxSizeIncreaseThreshold = 1024 * 1024;
+
void bindOptions(OptionsBinder &binder);
using FromFlags = OptionsFromFlags<GlobalOptimizationOptions>;
};