[Codegen][ROCDL] Extend mfma pipeline to support a few more matmul variants (#16582)
Similar to convolution, we just tile all MNK dimensions to 1 except the
inner most. This is in preparation for some fusion experiments.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 48fa797..066fa32 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -402,8 +402,7 @@
workgroupTileSizes[batch] = 1;
}
// Tile all output image dimensions with unit size except the last one.
- for (int64_t oi :
- ArrayRef<unsigned int>(convolutionDims->outputImage).drop_back()) {
+ for (int64_t oi : llvm::drop_end(convolutionDims->outputImage)) {
workgroupTileSizes[oi] = 1;
}
// Compute the M/N dimension tile size by multiply subgroup information.
@@ -463,15 +462,23 @@
mlir::linalg::inferContractionDims(op);
assert(succeeded(contractionDims) && "Could not infer contraction dims");
- // TODO: Relax this condition to strictly alignment requirements.
- if (contractionDims->k.size() != 1 || contractionDims->m.size() != 1 ||
- contractionDims->n.size() != 1) {
+ if (contractionDims->k.size() < 1 || contractionDims->m.size() < 1 ||
+ contractionDims->n.size() < 1) {
return failure();
}
- int64_t mDim = contractionDims->m[0];
- int64_t nDim = contractionDims->n[0];
- int64_t kDim = contractionDims->k[0];
+ // For now we are not being smart and trying to reshape dimensions to allow
+ // for better usage of intrinsics, and instead are tiling all dimensions
+ // except the inner most m, n, and k dimensions to 1.
+ int64_t mDim = contractionDims->m.back();
+ int64_t nDim = contractionDims->n.back();
+
+ // Bail out on matvec-like cases.
+ if (bounds[mDim] == 1 || bounds[nDim] == 1) {
+ return failure();
+ }
+
+ int64_t kDim = contractionDims->k.back();
Value lhs = op.getDpsInputOperand(0)->get();
Value rhs = op.getDpsInputOperand(1)->get();
@@ -520,6 +527,19 @@
for (int64_t batch : contractionDims->batch) {
workgroupTileSizes[batch] = 1;
}
+
+ // Tile all m, n, and k dimensions to 1 except the innermost. Unit dims
+ // from this tiling are folded before vectorization.
+ for (int64_t m : llvm::drop_end(contractionDims->m)) {
+ workgroupTileSizes[m] = 1;
+ }
+ for (int64_t n : llvm::drop_end(contractionDims->n)) {
+ workgroupTileSizes[n] = 1;
+ }
+ for (int64_t k : llvm::drop_end(contractionDims->k)) {
+ workgroupTileSizes[k] = 1;
+ }
+
// Compute the M/N dimension tile size by multiply subgroup information.
workgroupTileSizes[mDim] =
schedule->mWarpCount * schedule->mTileCount * schedule->mSize;
@@ -555,7 +575,7 @@
setVectorDistributionConfig(mlir::FunctionOpInterface entryPoint,
linalg::LinalgOp linalgOp,
const TargetInfo &targetInfo) {
- if (isMatmulOrBatchMatmul(linalgOp)) {
+ if (linalg::isaContractionOpInterface(linalgOp)) {
return setMatmulVectorDistributionConfig(entryPoint, linalgOp, targetInfo);
}
if (isa<linalg::ConvolutionOpInterface>(*linalgOp)) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
index 169e7c1..ae5d3bb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
@@ -5,6 +5,74 @@
// to be migrated to the rocdl heuristics, but for now is just physically
// located here.
+// CHECK: #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 64, 64, 32]{{\]}}
+// CHECK: #iree_codegen.translation_info<LLVMGPUVectorDistribute
+// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
+// CHECK-SAME: intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>
+// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 4,
+// CHECK-SAME: subgroup_m_tile_count = 4, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+hal.executable @expanded_matmul_transpose_b_executable {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
+ target_arch = "gfx940",
+ mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>,
+ #iree_gpu.mfma_layout<F16_32x32x8_F32>]
+ }>) {
+ hal.executable.export @expanded_matmul_transpose_b layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @expanded_matmul_transpose_b() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<2x64x2048xf16>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<10x64x2048xf16>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<writeonly:tensor<2x10x64x64xf16>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 64, 2048], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<2x64x2048xf16>> -> tensor<2x64x2048xf16>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [10, 64, 2048], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<10x64x2048xf16>> -> tensor<10x64x2048xf16>
+
+ %5 = tensor.empty() : tensor<2x10x64x64xf16>
+ %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
+ %7 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
+ } ins(%3, %4 : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) {
+ ^bb0(%lhs: f16, %rhs: f16, %out: f16):
+ %mul = arith.mulf %lhs, %rhs : f16
+ %add = arith.addf %mul, %out : f16
+ linalg.yield %add : f16
+ } -> tensor<2x10x64x64xf16>
+
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 10, 64, 64], strides = [1, 1, 1, 1]
+ : tensor<2x10x64x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x10x64x64xf16>>
+ return
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: hal.executable public @expanded_matmul_transpose_b
+// CHECK: linalg.generic {{.*}}lowering_config = #[[$TILE_SIZES]]
+
+// -----
+
// CHECK: #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 64, 128, 1, 1, 32]{{\]}}
// CHECK: #iree_codegen.translation_info<LLVMGPUVectorDistribute
// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index 31bdc44..ddcc660 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -111,7 +111,7 @@
// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x1x4xf16>)
// CHECK: arith.extf %[[ARG]] : vector<2x4x1x1x1x4xf16> to vector<2x4x1x1x1x4xf32>
// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-// CHECK: %[[TRUNC:.+]] = arith.truncf %157 : vector<2x4x1x1x1x4xf32> to vector<2x4x1x1x1x4xf16>
+// CHECK: %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<2x4x1x1x1x4xf32> to vector<2x4x1x1x1x4xf16>
// CHECK: scf.yield %[[TRUNC]] : vector<2x4x1x1x1x4xf16>
// CHECK-COUNT-8: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<256x256xf16, #hal.descriptor_type<storage_buffer>>
@@ -123,6 +123,75 @@
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
+hal.executable @expanded_matmul_transpose_b_executable {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
+ target_arch = "gfx940",
+ mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>,
+ #iree_gpu.mfma_layout<F16_32x32x8_F32>]
+ }>) {
+ hal.executable.export @expanded_matmul_transpose_b layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @expanded_matmul_transpose_b() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<2x64x2048xf16>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<10x64x2048xf16>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<writeonly:tensor<2x10x64x64xf16>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 64, 2048], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<2x64x2048xf16>> -> tensor<2x64x2048xf16>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [10, 64, 2048], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<10x64x2048xf16>> -> tensor<10x64x2048xf16>
+
+ %5 = tensor.empty() : tensor<2x10x64x64xf16>
+ %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
+ %7 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
+ } ins(%3, %4 : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) {
+ ^bb0(%lhs: f16, %rhs: f16, %out: f16):
+ %mul = arith.mulf %lhs, %rhs : f16
+ %add = arith.addf %mul, %out : f16
+ linalg.yield %add : f16
+ } -> tensor<2x10x64x64xf16>
+
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 10, 64, 64], strides = [1, 1, 1, 1]
+ : tensor<2x10x64x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x10x64x64xf16>>
+ return
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: hal.executable.export public @expanded_matmul_transpose_b
+// CHECK-SAME: workgroup_size = [256 : index, 1 : index, 1 : index]
+
+// CHECK-LABEL: func @expanded_matmul_transpose_b
+// CHECK: scf.for {{.*}} = %c0 to %c2048 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x1x4xf16>)
+// CHECK: arith.extf %[[ARG]] : vector<4x1x1x1x1x4xf16> to vector<4x1x1x1x1x4xf32>
+// CHECK-COUNT-8: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+// CHECK: %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<4x1x1x1x1x4xf32> to vector<4x1x1x1x1x4xf16>
+// CHECK: scf.yield %[[TRUNC]] : vector<4x1x1x1x1x4xf16>
+// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<2x10x64x64xf16, #hal.descriptor_type<storage_buffer>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
hal.executable @conv_nhwc_dispatch_0 {
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
target_arch = "gfx940",