[GlobalOpt] Prevent fusing transposed extend in RaiseSpecialOps (#18901)
NamedImplicitCastOpConversion pattern is incorrectly fusing transposed
element-wise extend into Linalg op.
---------
Signed-off-by: Cullen Rhodes <cullen.rhodes@arm.com>
diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp
index b700869..a1b579e 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp
@@ -304,6 +304,10 @@
return false;
}
+ if (!llvm::all_of(producer.getIndexingMapsArray(),
+ [](AffineMap map) { return map.isIdentity(); }))
+ return false;
+
std::optional<CastOpInterface> castOp =
getDefiningNonI1ExtendingCastOp(operand.get());
if (!castOp) {
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir
index a1cd2d6..c84f128 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir
@@ -566,6 +566,33 @@
// CHECK: util.return %[[RESULT]]
// -----
+// Regression test. extsi is transposed, dont't fuse into matmul.
+util.func public @matmul_extsi_transposed(%arg0 : tensor<10x20xi32>,
+ %arg1 : tensor<40x20xi16>) -> tensor<10x40xi32> {
+ %0 = tensor.empty() : tensor<20x40xi32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg1 : tensor<40x20xi16>) outs(%0 : tensor<20x40xi32>) {
+ ^bb0(%b0 : i16, %b1 : i32):
+ %e = arith.extsi %b0 : i16 to i32
+ linalg.yield %e : i32
+ } -> tensor<20x40xi32>
+ %2 = tensor.empty() : tensor<10x40xi32>
+ %3 = arith.constant 0 : i32
+ %4 = linalg.fill ins(%3 : i32) outs(%2 : tensor<10x40xi32>) -> tensor<10x40xi32>
+ %5 = linalg.matmul ins(%arg0, %1 : tensor<10x20xi32>, tensor<20x40xi32>)
+ outs(%4 : tensor<10x40xi32>) -> tensor<10x40xi32>
+ util.return %5 : tensor<10x40xi32>
+}
+// CHECK-LABEL: util.func public @matmul_extsi_transposed
+// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<40x20xi16>
+// CHECK: %[[GEN:.+]] = linalg.generic
+// CHECK: %[[RESULT:.+]] = linalg.matmul ins(%[[ARG0]], %[[GEN]]
+// CHECK: util.return %[[RESULT]]
+// -----
+
util.func public @matmul_extsi_a(%arg0 : tensor<10x20xi16>,
%arg1 : tensor<20x40xi32>) -> tensor<10x40xi32> {
%0 = tensor.empty() : tensor<10x20xi32>