Reapply "Propagate reshapes through generics with reduction… (#18968)
Reland after fixing sdxl int8 regressions via
https://github.com/iree-org/iree/pull/19012.
Running CI revealed further performance regressions that have pending
patches: https://github.com/iree-org/iree/pull/19325 and
https://github.com/iree-org/iree/pull/19326.
This reverts commit 8d3faf8e0f739838a2c06adbeffae258a43d56a7.
---------
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml
index 448c556..d194bba 100644
--- a/.github/workflows/pkgci_regression_test.yml
+++ b/.github/workflows/pkgci_regression_test.yml
@@ -125,7 +125,7 @@
--goldentime-rocm-vae-ms 310.0 \
--goldendispatch-rocm-unet 1602 \
--goldendispatch-rocm-clip 1139 \
- --goldendispatch-rocm-vae 246 \
+ --goldendispatch-rocm-vae 245 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
@@ -150,7 +150,7 @@
--goldentime-rocm-vae-ms 75.0 \
--goldendispatch-rocm-unet 1602 \
--goldendispatch-rocm-clip 1139 \
- --goldendispatch-rocm-vae 246 \
+ --goldendispatch-rocm-vae 245 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
index 71fe957..1f8e010 100644
--- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
@@ -134,12 +134,8 @@
return false;
}
- // Do not fuse producer generic op if it has more than one user
- // or any reduction iterators.
if (auto producerGenericOp = dyn_cast<linalg::GenericOp>(producer)) {
- return producerGenericOp->hasOneUse() &&
- llvm::all_of(producerGenericOp.getIteratorTypesArray(),
- linalg::isParallelIterator);
+ return true;
}
// Do not fuse with any producer linalg named ops for now.
@@ -147,11 +143,9 @@
return false;
}
- // Do not fuse with consumer linalg named ops or reductions.
+ // Do not fuse with consumer linalg named ops.
if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
- return isa<linalg::GenericOp>(consumerLinalgOp) &&
- llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),
- linalg::isParallelIterator);
+ return isa<linalg::GenericOp>(consumerLinalgOp);
}
// Fuse in all other cases.
return true;
diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
index 47d1699..672ef9a 100644
--- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
@@ -149,6 +149,7 @@
patterns.insert<BubbleUpExtract>(context);
patterns.insert<SwapExtractSliceOfFill>(context);
tensor::populateFoldTensorEmptyPatterns(patterns, false);
+ linalg::FillOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir
index b582b56..c5311c2 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir
@@ -141,3 +141,23 @@
// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @fold_extract_of_expand_of_fill(%arg0 : index, %arg1 : index, %arg2 : index) -> tensor<?xf16> {
+ %cst0 = arith.constant 0.0 : f16
+ %0 = tensor.empty(%arg0) : tensor<?xf16>
+ %2 = linalg.fill ins(%cst0 : f16) outs(%0 : tensor<?xf16>) -> tensor<?xf16>
+ %3 = tensor.expand_shape %2 [[0, 1]] output_shape[1, %arg1] : tensor<?xf16> into tensor<1x?xf16>
+ %4 = tensor.extract_slice %3 [0, 0] [1, %arg2] [1, 1] : tensor<1x?xf16> to tensor<?xf16>
+ func.return %4 : tensor<?xf16>
+}
+
+// CHECK-LABEL: func @fold_extract_of_expand_of_fill
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[ARG2]])
+// CHECK-DAG: %[[CST0:.+]] = arith.constant 0.0
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST0]] : f16) outs(%[[EMPTY]]
+// CHECK: return %[[FILL]]