Propagate reshapes through generics with reduction iterators (#18857)
Removes the constraint in `BubbleUpExpandShapes` that prevents moving
tensor reshape ops through reduction `linalg.generic` ops. This has the
benefit of increasing the dimensionality of reduction ops (more fusion
opportunities) as well as increasing the chance these ops will be moved
to the edge of the program.
Closes https://github.com/iree-org/iree/issues/18854
---------
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml
index 9849c57..a111077 100644
--- a/.github/workflows/pkgci_regression_test.yml
+++ b/.github/workflows/pkgci_regression_test.yml
@@ -222,7 +222,7 @@
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1139 \
- --goldendispatch-rocm-vae 247 \
+ --goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
@@ -243,7 +243,7 @@
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1139 \
- --goldendispatch-rocm-vae 247 \
+ --goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
index 8973ba5..0c4430f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
@@ -80,13 +80,13 @@
// CHECK: flow.executable private @[[EXECUTABLE0:[a-zA-Z0-9_]+]]
// CHECK: func.func @[[FUNC0:[a-zA-Z0-9_x]+]]
// CHECK: %[[GEN0:.+]] = linalg.generic
-// CHECK-SAME: ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel"]
// CHECK: arith.extui
// CHECK: arith.uitofp
// CHECK: arith.subf
// CHECK: arith.mulf
// CHECK: %[[GEN1:.+]] = linalg.generic
-// CHECK-SAME: ["parallel", "reduction", "reduction"]
+// CHECK-SAME: ["parallel", "parallel", "parallel", "reduction", "reduction"]
// CHECK-SAME: ins(
// CHECK-SAME: %[[GEN0]]
// CHECK-SAME: outs(
@@ -95,5 +95,4 @@
// CHECK: flow.dispatch.tensor.store %[[GEN1]]
// CHECK: util.func public @grouped_quantized_matmul(
// CHECK: %[[T0:.+]] = flow.dispatch @[[EXECUTABLE0]]::@[[FUNC0]]
-// CHECK: %[[RS:.+]] = flow.tensor.reshape %[[T0]] : tensor<4096xf32> -> tensor<1x1x4096xf32>
-// CHECK: util.return %[[RS]]
+// CHECK: util.return %[[T0]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
index 79ae8d3..9ee67d6 100644
--- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
@@ -57,12 +57,8 @@
return false;
}
- // Do not fuse producer generic op if it has more than one user
- // or any reduction iterators.
if (auto producerGenericOp = dyn_cast<linalg::GenericOp>(producer)) {
- return producerGenericOp->hasOneUse() &&
- llvm::all_of(producerGenericOp.getIteratorTypesArray(),
- linalg::isParallelIterator);
+ return true;
}
// Do not fuse with any producer linalg named ops for now.
@@ -70,11 +66,9 @@
return false;
}
- // Do not fuse with consumer linalg named ops or reductions.
+ // Do not fuse with consumer linalg named ops.
if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
- return isa<linalg::GenericOp>(consumerLinalgOp) &&
- llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),
- linalg::isParallelIterator);
+ return isa<linalg::GenericOp>(consumerLinalgOp);
}
// Fuse in all other cases.
return true;