Revert "Propagate reshapes through generics with reduction… (#18968)
…(#18857)"
This regresses sdxl int8 perf by increasing the dimensionality of
`attention` ops which messes with the attention spec. Revert this for
now and reland once `CollapseDimensionsPass` can handle attention.
This reverts commit 78481a6ed98c9be1dd9c33eda0572e391a4d8d89.
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml
index c6b7805..a111077 100644
--- a/.github/workflows/pkgci_regression_test.yml
+++ b/.github/workflows/pkgci_regression_test.yml
@@ -222,7 +222,7 @@
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1139 \
- --goldendispatch-rocm-vae 245 \
+ --goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
@@ -243,7 +243,7 @@
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1139 \
- --goldendispatch-rocm-vae 245 \
+ --goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
index 0c4430f..8973ba5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
@@ -80,13 +80,13 @@
// CHECK: flow.executable private @[[EXECUTABLE0:[a-zA-Z0-9_]+]]
// CHECK: func.func @[[FUNC0:[a-zA-Z0-9_x]+]]
// CHECK: %[[GEN0:.+]] = linalg.generic
-// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ["parallel", "parallel", "parallel"]
// CHECK: arith.extui
// CHECK: arith.uitofp
// CHECK: arith.subf
// CHECK: arith.mulf
// CHECK: %[[GEN1:.+]] = linalg.generic
-// CHECK-SAME: ["parallel", "parallel", "parallel", "reduction", "reduction"]
+// CHECK-SAME: ["parallel", "reduction", "reduction"]
// CHECK-SAME: ins(
// CHECK-SAME: %[[GEN0]]
// CHECK-SAME: outs(
@@ -95,4 +95,5 @@
// CHECK: flow.dispatch.tensor.store %[[GEN1]]
// CHECK: util.func public @grouped_quantized_matmul(
// CHECK: %[[T0:.+]] = flow.dispatch @[[EXECUTABLE0]]::@[[FUNC0]]
-// CHECK: util.return %[[T0]]
+// CHECK: %[[RS:.+]] = flow.tensor.reshape %[[T0]] : tensor<4096xf32> -> tensor<1x1x4096xf32>
+// CHECK: util.return %[[RS]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
index 9ee67d6..79ae8d3 100644
--- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
@@ -57,8 +57,12 @@
return false;
}
+ // Do not fuse producer generic op if it has more than one user
+ // or any reduction iterators.
if (auto producerGenericOp = dyn_cast<linalg::GenericOp>(producer)) {
- return true;
+ return producerGenericOp->hasOneUse() &&
+ llvm::all_of(producerGenericOp.getIteratorTypesArray(),
+ linalg::isParallelIterator);
}
// Do not fuse with any producer linalg named ops for now.
@@ -66,9 +70,11 @@
return false;
}
- // Do not fuse with consumer linalg named ops.
+ // Do not fuse with consumer linalg named ops or reductions.
if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
- return isa<linalg::GenericOp>(consumerLinalgOp);
+ return isa<linalg::GenericOp>(consumerLinalgOp) &&
+ llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),
+ linalg::isParallelIterator);
}
// Fuse in all other cases.
return true;