[spirv] Invoke TransposeOp canonicalization after vectorization (#8726)

This is needed to fold `transpose(broadcast(<scalar>))` patterns
before further vector-level transformations.

Fixes https://github.com/google/iree/issues/8689
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index 6b433a2..60283e0 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -116,6 +116,7 @@
       RewritePatternSet foldPatterns(context);
       // Fold consumer add ops into the contraction op itself.
       vector::ContractionOp::getCanonicalizationPatterns(foldPatterns, context);
+      vector::TransposeOp::getCanonicalizationPatterns(foldPatterns, context);
       if (failed(
               applyPatternsAndFoldGreedily(funcOp, std::move(foldPatterns)))) {
         return signalPassFailure();
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index f28616c..dbc3f85 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -111,3 +111,46 @@
 // CHECK-COUNT-16: vector.transfer_read {{.*}} : tensor<8x8xf32>, vector<4xf32>
 // CHECK-COUNT-16: vector.fma
 // CHECK-COUNT-16: vector.transfer_write {{.*}} : vector<4xf32>, tensor<8x8xf32>
+
+
+// -----
+
+func.func @matmul_broadcast_add(%init: tensor<1x8xf32>, %a: tensor<1x8xf32>, %b: tensor<8x8xf32>, %c: tensor<1x8xf32>, %bias: tensor<1xf32>) -> tensor<1x8xf32> {
+  %c16 = arith.constant 16 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %c8 = arith.constant 8 : index
+  %c1 = arith.constant 1 : index
+
+  %matmul = linalg.matmul ins(%a, %b : tensor<1x8xf32>, tensor<8x8xf32>) outs(%c : tensor<1x8xf32>) -> tensor<1x8xf32>
+  %bcast_add = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%matmul, %bias : tensor<1x8xf32>, tensor<1xf32>) outs(%init : tensor<1x8xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+    %add = arith.addf %arg0, %arg1 : f32
+    linalg.yield %add : f32
+  } -> tensor<1x8xf32>
+  return %bcast_add: tensor<1x8xf32>
+}
+
+//    CHECK-LABEL: func @matmul_broadcast_add
+//     CHECK-SAME: (%[[INIT:[a-z0-9]+]]: tensor<1x8xf32>
+//     CHECK-SAME:  %[[BIAS:[a-z0-9]+]]: tensor<1xf32>)
+
+//      CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+
+// CHECK-COUNT-16:   vector.fma
+//      CHECK-NOT:   vector.transpose
+
+//          CHECK:   %[[READ:.+]] = vector.transfer_read %[[BIAS]]
+//          CHECK:   %[[EXT0:.+]] = vector.extract %[[READ]][0] : vector<1xf32>
+//          CHECK:   %[[BCST0:.+]] = vector.splat %[[EXT0]] : vector<4xf32>
+//          CHECK:   %[[ADD0:.+]] = arith.addf %{{.+}}, %[[BCST0]] : vector<4xf32>
+//          CHECK:   %[[EXT1:.+]] = vector.extract %[[READ]][0] : vector<1xf32>
+//          CHECK:   %[[BCST1:.+]] = vector.splat %[[EXT1]] : vector<4xf32>
+//          CHECK:   %[[ADD1:.+]] = arith.addf %{{.+}}, %[[BCST1]] : vector<4xf32>
+//          CHECK:   %[[WRITE0:.+]] = vector.transfer_write %[[ADD0]], %[[INIT]][%[[C0]], %[[C0]]]
+//          CHECK:   %[[WRITE1:.+]] = vector.transfer_write %[[ADD1]], %[[WRITE0]][%[[C0]], %[[C4]]]
+//          CHECK:   return %[[WRITE1]]