Generalize matmul tensorcore strategy to work with arbitrary unaligned f32 tens… (#13192)
…or sizes
After this change, we still bail on aligned cases for which some of the transforms do not yet compose.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
index fa5aff19..5e1e1b2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
@@ -26,6 +26,8 @@
}
}
+// CHECK-LABEL: func @matmul
+
// One of this matmul's dimensions is divisible by 64/64/16, we currently bail on such cases.
// CHECK-NOT: transform.sequence
@@ -42,37 +44,6 @@
func.func @matmul() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
- %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2049x2556xf32>>
- %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2556x2556xf32>>
- %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2049x2556xf32>>
- %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2049, 2556], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2049x2556xf32>> -> tensor<2049x2556xf32>
- %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2556, 2556], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2556x2556xf32>> -> tensor<2556x2556xf32>
- %5 = tensor.empty() : tensor<2049x2556xf32>
- %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2049x2556xf32>) -> tensor<2049x2556xf32>
- %7 = linalg.matmul ins(%3, %4 : tensor<2049x2556xf32>, tensor<2556x2556xf32>) outs(%6 : tensor<2049x2556xf32>) -> tensor<2049x2556xf32>
- flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2049, 2556], strides = [1, 1] : tensor<2049x2556xf32> -> !flow.dispatch.tensor<writeonly:tensor<2049x2556xf32>>
- return
- }
- }
-}
-}
-
-// One of this matmul's dimensions is not divisible by 4, we currently bail on such cases.
-// CHECK-NOT: transform.sequence
-
-// -----
-
-hal.executable @matmul {
-hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> {
- hal.executable.export public @matmul ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
- ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
- %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- func.func @matmul() {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2052x2556xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2556x2052xf32>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2052x2052xf32>>
@@ -88,6 +59,8 @@
}
}
+// CHECK-LABEL: func @matmul
+
// CHECK: transform.sequence failures(propagate) {
// CHECK: transform.iree.match_callback failures(propagate) "matmul"
// CHECK: transform.iree.tile_to_forall_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [128, 128](mapping = [#gpu.block<y>, #gpu.block<x>])
@@ -128,3 +101,53 @@
// CHECK: transform.vector.lower_masks %{{.*}} : (!pdl.operation) -> !pdl.operation
// CHECK: transform.vector.materialize_masks %{{.*}} : (!pdl.operation) -> !pdl.operation
// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, fold_memref_aliases, licm, tiling_canonicalization} : (!pdl.operation) -> ()
+
+
+// -----
+
+hal.executable @matmul {
+hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> {
+ hal.executable.export public @matmul ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @matmul() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2051x2555xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2555x2050xf32>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2051x2050xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2051, 2555], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2051x2555xf32>> -> tensor<2051x2555xf32>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2555, 2051], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2555x2050xf32>> -> tensor<2555x2050xf32>
+ %5 = tensor.empty() : tensor<2051x2050xf32>
+ %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2051x2050xf32>) -> tensor<2051x2050xf32>
+ %7 = linalg.matmul ins(%3, %4 : tensor<2051x2555xf32>, tensor<2555x2050xf32>) outs(%6 : tensor<2051x2050xf32>) -> tensor<2051x2050xf32>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2051, 2050], strides = [1, 1] : tensor<2051x2050xf32> -> !flow.dispatch.tensor<writeonly:tensor<2051x2050xf32>>
+ return
+ }
+ }
+}
+}
+
+// CHECK-LABEL: func @matmul
+
+// CHECK: transform.sequence failures(propagate) {
+// CHECK: transform.iree.match_callback failures(propagate) "matmul"
+// CHECK: transform.iree.tile_to_forall_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [128, 128](mapping = [#gpu.block<y>, #gpu.block<x>])
+// CHECK: transform.structured.tile %{{.*}}[0, 0, 16]
+// align1
+// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [8, 16] tile_sizes [](mapping = [#gpu.linear<x>, #gpu.linear<y>])
+// align2
+// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 64] tile_sizes [](mapping = [#gpu.linear<y>, #gpu.linear<x>])
+// align2
+// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 64] tile_sizes [](mapping = [#gpu.linear<y>, #gpu.linear<x>])
+// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<y>, #gpu.warp<x>])
+// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<y>, #gpu.warp<x>])
+// align1
+// CHECK: transform.structured.masked_vectorize %{{.*}} vector_sizes [16, 1]
+// align2
+// CHECK: transform.structured.masked_vectorize %{{.*}} vector_sizes [8, 2]
+// align2
+// CHECK: transform.structured.masked_vectorize %{{.*}} vector_sizes [64, 2]
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
index 1319685..eebefa8 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
@@ -504,9 +504,7 @@
bool alignedAny64x64x16 = matmulSize[0] % 64 == 0 ||
matmulSize[1] % 64 == 0 || matmulSize[2] % 16 == 0;
- bool alignedAll4x4x4 = matmulSize[0] % 4 == 0 && matmulSize[1] % 4 == 0 &&
- matmulSize[2] % 4 == 0;
- if (alignedAny64x64x16 || !alignedAll4x4x4) {
+ if (alignedAny64x64x16) {
LLVM_DEBUG(DBGS() << "alignment check failed\n");
return failure();
}