[spirv] Invoke TransposeOp canonicalization after vectorization (#8726)
This is needed to fold `transpose(broadcast(<scalar>))` patterns
before further vector-level transformations.
Fixes https://github.com/google/iree/issues/8689
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index 6b433a2..60283e0 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -116,6 +116,7 @@
RewritePatternSet foldPatterns(context);
// Fold consumer add ops into the contraction op itself.
vector::ContractionOp::getCanonicalizationPatterns(foldPatterns, context);
+ vector::TransposeOp::getCanonicalizationPatterns(foldPatterns, context);
if (failed(
applyPatternsAndFoldGreedily(funcOp, std::move(foldPatterns)))) {
return signalPassFailure();
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index f28616c..dbc3f85 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -111,3 +111,46 @@
// CHECK-COUNT-16: vector.transfer_read {{.*}} : tensor<8x8xf32>, vector<4xf32>
// CHECK-COUNT-16: vector.fma
// CHECK-COUNT-16: vector.transfer_write {{.*}} : vector<4xf32>, tensor<8x8xf32>
+
+
+// -----
+
+func.func @matmul_broadcast_add(%init: tensor<1x8xf32>, %a: tensor<1x8xf32>, %b: tensor<8x8xf32>, %c: tensor<1x8xf32>, %bias: tensor<1xf32>) -> tensor<1x8xf32> {
+ %c16 = arith.constant 16 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %c8 = arith.constant 8 : index
+ %c1 = arith.constant 1 : index
+
+ %matmul = linalg.matmul ins(%a, %b : tensor<1x8xf32>, tensor<8x8xf32>) outs(%c : tensor<1x8xf32>) -> tensor<1x8xf32>
+ %bcast_add = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%matmul, %bias : tensor<1x8xf32>, tensor<1xf32>) outs(%init : tensor<1x8xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+ %add = arith.addf %arg0, %arg1 : f32
+ linalg.yield %add : f32
+ } -> tensor<1x8xf32>
+ return %bcast_add: tensor<1x8xf32>
+}
+
+// CHECK-LABEL: func @matmul_broadcast_add
+// CHECK-SAME: (%[[INIT:[a-z0-9]+]]: tensor<1x8xf32>
+// CHECK-SAME: %[[BIAS:[a-z0-9]+]]: tensor<1xf32>)
+
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+
+// CHECK-COUNT-16: vector.fma
+// CHECK-NOT: vector.transpose
+
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[BIAS]]
+// CHECK: %[[EXT0:.+]] = vector.extract %[[READ]][0] : vector<1xf32>
+// CHECK: %[[BCST0:.+]] = vector.splat %[[EXT0]] : vector<4xf32>
+// CHECK: %[[ADD0:.+]] = arith.addf %{{.+}}, %[[BCST0]] : vector<4xf32>
+// CHECK: %[[EXT1:.+]] = vector.extract %[[READ]][0] : vector<1xf32>
+// CHECK: %[[BCST1:.+]] = vector.splat %[[EXT1]] : vector<4xf32>
+// CHECK: %[[ADD1:.+]] = arith.addf %{{.+}}, %[[BCST1]] : vector<4xf32>
+// CHECK: %[[WRITE0:.+]] = vector.transfer_write %[[ADD0]], %[[INIT]][%[[C0]], %[[C0]]]
+// CHECK: %[[WRITE1:.+]] = vector.transfer_write %[[ADD1]], %[[WRITE0]][%[[C0]], %[[C4]]]
+// CHECK: return %[[WRITE1]]