[spirv] Tile additional reduction dims for matmul-like generic (#13603)
We were only handle the case where we only have one reduction dimension.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 7cd949d..67ec2c1 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -708,6 +708,12 @@
workgroupSize, pipelineDepth, storeStage, subgroupSize,
maxBytes, elementBits);
+ // Tile all additional reduction dimensions with size 1 to materialize loops.
+ for (auto [i, it] : llvm::enumerate(op.getIteratorTypesArray())) {
+ if (linalg::isReductionIterator(it) && reductionTileSizes[i] == 0)
+ reductionTileSizes[i] = 1;
+ }
+
SmallVector<int64_t> threadTileSizes(numLoops, 0);
if (isBM) {
threadTileSizes[bIndex] = workgroupTileSizes[bIndex] / workgroupSize[2];
@@ -903,8 +909,7 @@
dimN, dimK);
if (!coopMatSize) return failure();
- auto pipeline = IREE::Codegen::DispatchLoweringPassPipeline::
- SPIRVCooperativeMatrixVectorize;
+ auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize;
std::optional<int64_t> subgroupSize = limits.getSubgroupSize();
// AMD RDNA architectures supports both wave32 and wave64 modes. Prefer to use
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir
index a573a53..eaddaaa 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul.mlir
@@ -46,3 +46,59 @@
// CHECK: func.func @matmul_1x4096x9216()
// CHECK: linalg.matmul
// CHECK-SAME: lowering_config = #[[CONFIG]]
+
+// -----
+
+// Multi-reduction-dimension transposed-B matmul.
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
+hal.executable private @multi_reduction_transposed_b_matmul {
+ hal.executable.variant public @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, NVIDIA:DiscreteGPU, #spirv.resource_limits<
+ max_compute_shared_memory_size = 49152,
+ max_compute_workgroup_invocations = 1024,
+ max_compute_workgroup_size = [1024, 1024, 64],
+ subgroup_size = 32>>
+ }> {
+ hal.executable.export public @multi_reduction_transposed_b_matmul layout(#pipeline_layout)
+ builtin.module {
+ func.func @multi_reduction_transposed_b_matmul() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2048x86x128xf32>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4096, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>> -> tensor<4096x86x128xf32>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2048, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x86x128xf32>> -> tensor<2048x86x128xf32>
+ %5 = tensor.empty() : tensor<4096x2048xf32>
+ %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<4096x2048xf32>) -> tensor<4096x2048xf32>
+ %7 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+ } ins(%3, %4 : tensor<4096x86x128xf32>, tensor<2048x86x128xf32>) outs(%6 : tensor<4096x2048xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %8 = arith.mulf %in, %in_0 : f32
+ %9 = arith.addf %out, %8 : f32
+ linalg.yield %9 : f32
+ } -> tensor<4096x2048xf32>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [4096, 2048], strides = [1, 1] : tensor<4096x2048xf32> -> !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>>
+ return
+ }
+ }
+ }
+}
+
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[32, 128], [4, 4], [0, 0, 1, 32]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize pipeline_depth = 1>
+// CHECK: hal.executable.export public @multi_reduction_transposed_b_matmul
+// CHECK-SAME: translation_info = #[[TRANSLATION]]
+// CHECK-SAME: workgroup_size = [32 : index, 8 : index, 1 : index]
+// CHECK: func.func @multi_reduction_transposed_b_matmul()
+// CHECK: linalg.generic
+// CHECK-SAME: lowering_config = #[[CONFIG]]