[DispatchCreation] Add constant expression hoisting (#19750)
This adds constexpr hoisting to dispatch creation preprocessing to catch
any potential new hoisting opportunities introduced by earlier passes.
In the future we will likely want to run this pass in more places (e.g.
near the end of DispatchCreation and Flow when we've set encodings).
diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
index 14c18ca..76f7704 100644
--- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
@@ -194,6 +194,19 @@
// - help with dispatch region formation.
// - move reduction iterators to be innermost.
.addPass(DispatchCreation::createTransposeGenericOpsPass);
+
+ // Run constant expression hoisting just before dispatch creation in case
+ // there are any new hoisting opportunities (e.g. transpose generics or
+ // horizontal fusion).
+ IREE::Util::ExprHoistingOptions options;
+ options.maxSizeIncreaseThreshold = 0;
+ options.registerDependentDialectsFn = [](DialectRegistry ®istry) {
+ registry.insert<IREE::Flow::FlowDialect>();
+ };
+ passManager.addPass(IREE::Util::createHoistIntoGlobalsPass(options));
+ FunctionLikeNest(passManager)
+ .addPass(mlir::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass);
}
// Pipeline to first create `flow.dispatch.region` ops and then lower to
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_region_formation_preprocessing.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_region_formation_preprocessing.mlir
index 30e78fc..33ef6f9 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_region_formation_preprocessing.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_region_formation_preprocessing.mlir
@@ -542,3 +542,28 @@
// CHECK: tensor.unpack
// CHECK: linalg.generic
// CHECK: tensor.expand_shape
+
+// -----
+
+util.func public @hoist_constant() -> tensor<2xf32> {
+ %cst = arith.constant dense<[1.0, 2.0]> : tensor<2xf32>
+ %empty = tensor.empty() : tensor<2xf32>
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>
+ ], iterator_types = ["parallel"]}
+ ins(%cst : tensor<2xf32>) outs(%empty : tensor<2xf32>) {
+ ^bb0(%b0: f32, %b1: f32):
+ %1 = math.exp %b0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<2xf32>
+ util.return %0 : tensor<2xf32>
+}
+// CHECK-LABEL: util.global private @__hoisted_tensor_2xf32
+// CHECK: util.initializer
+// CHECK: %[[EXP:.+]] = linalg.generic
+// CHECK: math.exp
+// CHECK: util.global.store %[[EXP]], @__hoisted_tensor_2xf32
+// CHECK: util.func public @hoist_constant
+// CHECK: util.global.load immutable @__hoisted_tensor_2xf32