Adjust `isFusableUsingTileAndFuse` in `SinkReshapes` (#18921)
Adjust `isFusableUsingTileAndFuse` to return true if the producer
implements `LinalgExt::LinalgFusionOpInterface`. This is motivated by a
`tensor.expand_shape` getting stuck between `linalg.softmax` and a
'bit-truncate` op, preventing fusion and leading to materialization of
the higher bit width tensor.
Closes https://github.com/iree-org/iree/issues/18893
---------
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml
index a111077..c6b7805 100644
--- a/.github/workflows/pkgci_regression_test.yml
+++ b/.github/workflows/pkgci_regression_test.yml
@@ -222,7 +222,7 @@
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1139 \
- --goldendispatch-rocm-vae 246 \
+ --goldendispatch-rocm-vae 245 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
@@ -243,7 +243,7 @@
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1139 \
- --goldendispatch-rocm-vae 246 \
+ --goldendispatch-rocm-vae 245 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.td b/compiler/src/iree/compiler/DispatchCreation/Passes.td
index 1f1132e..f6d2751 100644
--- a/compiler/src/iree/compiler/DispatchCreation/Passes.td
+++ b/compiler/src/iree/compiler/DispatchCreation/Passes.td
@@ -113,6 +113,7 @@
let dependentDialects = [
"mlir::affine::AffineDialect",
"mlir::arith::ArithDialect",
+ "IREE::LinalgExt::IREELinalgExtDialect",
];
}
diff --git a/compiler/src/iree/compiler/DispatchCreation/SinkReshapes.cpp b/compiler/src/iree/compiler/DispatchCreation/SinkReshapes.cpp
index 6e7c707..07abf08 100644
--- a/compiler/src/iree/compiler/DispatchCreation/SinkReshapes.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/SinkReshapes.cpp
@@ -15,6 +15,8 @@
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
@@ -47,7 +49,8 @@
/// we just approximate it (and try to be optimistic)
static bool isFusableUsingTileAndFuse(Operation *producer,
Operation *consumer) {
- return llvm::isa_and_nonnull<linalg::LinalgOp, tensor::UnPackOp,
+ return llvm::isa_and_nonnull<IREE::LinalgExt::LinalgFusionOpInterface,
+ linalg::LinalgOp, tensor::UnPackOp,
IREE::Encoding::UnsetEncodingOp>(producer);
}
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir b/compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir
index 15a7e39..7cab209 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir
@@ -211,3 +211,28 @@
// CHECK: %[[GENERIC2:.+]] = linalg.generic
// CHECK-SAME: ins(%{{.+}}, %[[GENERIC1]] :
// CHECK: tensor.expand_shape %[[GENERIC2]]
+
+// -----
+
+func.func @fuse_softmax_with_truncate(%arg0 : tensor<4x64x?xf32>) -> tensor<4x64x1x?xf16> {
+ %cst = arith.constant 0xFC00 : f16
+ %cst_0 = arith.constant 0.000000e+00 : f16
+ %cst_1 = arith.constant 11.3137083 : f32
+ %c2 = arith.constant 2 : index
+ %dim = tensor.dim %arg0, %c2 : tensor<4x64x?xf32>
+ %0 = tensor.empty(%dim) : tensor<4x64x?xf32>
+ %2 = linalg.softmax dimension(2) ins(%arg0 : tensor<4x64x?xf32>) outs(%0 : tensor<4x64x?xf32>) -> tensor<4x64x?xf32>
+ %expanded = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [4, 64, 1, %dim] : tensor<4x64x?xf32> into tensor<4x64x1x?xf32>
+ %3 = tensor.empty(%dim) : tensor<4x64x1x?xf16>
+ %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded : tensor<4x64x1x?xf32>) outs(%3 : tensor<4x64x1x?xf16>) {
+ ^bb0(%in: f32, %out: f16):
+ %5 = arith.truncf %in : f32 to f16
+ linalg.yield %5 : f16
+ } -> tensor<4x64x1x?xf16>
+ func.return %4 : tensor<4x64x1x?xf16>
+}
+// CHECK-LABEL: func @fuse_softmax_with_truncate
+// CHECK: %[[SOFTMAX:.+]] = linalg.softmax
+// CHECK: %[[TRUNC:.+]] = linalg.generic {{.*}} ins(%[[SOFTMAX]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[TRUNC]]
+// CHECK: return %[[EXPAND]]