[spirv] Generalize transposed batch matmul op (#15197)
This allows it to go down generic op kernel configuration deduction so
that we can use subgroup reduction pipeline for batch matvec cases.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVGeneralizeNamedOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVGeneralizeNamedOps.cpp
index ed687f8..1c6973d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVGeneralizeNamedOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVGeneralizeNamedOps.cpp
@@ -32,8 +32,8 @@
auto funcOp = getOperation();
SmallVector<linalg::LinalgOp> namedOpCandidates;
funcOp.walk([&](linalg::LinalgOp linalgOp) {
- if (isa<linalg::MatmulTransposeBOp, linalg::VecmatOp, linalg::MatvecOp>(
- linalgOp))
+ if (isa<linalg::BatchMatmulTransposeBOp, linalg::MatmulTransposeBOp,
+ linalg::VecmatOp, linalg::MatvecOp>(linalgOp))
namedOpCandidates.push_back(linalgOp);
});
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/generalize_named_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/generalize_named_ops.mlir
index 650dd9c..eb54c00 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/generalize_named_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/generalize_named_ops.mlir
@@ -75,3 +75,28 @@
// CHECK: %[[A0:.+]] = arith.mulf %[[IN]], %[[IN0]] : f32
// CHECK: %[[A1:.+]] = arith.addf %[[OUT]], %[[A0]] : f32
// CHECK: linalg.yield %[[A1]] : f32
+
+// -----
+
+module {
+ func.func @transpose_batch_matmul(%arg0: tensor<32x1x128xf16>, %arg1: tensor<32x?x128xf16>, %dim: index) -> tensor<32x1x?xf16> {
+ %f0 = arith.constant 0.0 : f16
+ %empty = tensor.empty(%dim) : tensor<32x1x?xf16>
+ %fill = linalg.fill ins(%f0 : f16) outs(%empty : tensor<32x1x?xf16>) -> tensor<32x1x?xf16>
+ %2 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : tensor<32x1x128xf16>, tensor<32x?x128xf16>) outs(%fill : tensor<32x1x?xf16>) -> tensor<32x1x?xf16>
+ return %2 : tensor<32x1x?xf16>
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @transpose_batch_matmul
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK: ^bb0(%[[A:.+]]: f16, %[[B:.+]]: f16, %[[OUT:.+]]: f16):
+// CHECK: %[[MUL:.+]] = arith.mulf %[[A]], %[[B]] : f16
+// CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[MUL]] : f16
+// CHECK: linalg.yield %[[ADD]] : f16