[ROCM] fix layout for WMMA_F16_16x16x16_F16 intrinsic (#18206)
The existing layout for the intrinsic was for subgroup=64 but we are
using subgroup=32 so it lead to this error
https://github.com/iree-org/iree/issues/18060
This PR fixes this to use the correct layout for subgroup=32 hence fixes
https://github.com/iree-org/iree/issues/18060 and
https://github.com/iree-org/iree/issues/17807
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 815f286..831bdfe 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -353,8 +353,7 @@
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
- case MMAIntrinsic::WMMA_F32_16x16x16_F16:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
@@ -372,6 +371,24 @@
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
+ case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
+ // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
+ // #layout_a = #iree_vector_ext.layout<#outer, #inner>
+ // #layout_b = #iree_vector_ext.layout<#inner, #outer>
+
+ auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
+ auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {1, 16});
+ auto aMLayout = outer;
+ auto aKLayout = inner;
+ auto bKLayout = inner;
+ auto bNLayout = outer;
+ auto cMLayout =
+ PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {16, 1, 1});
+ auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16});
+ return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+ bNLayout, cMLayout, cNLayout};
+ }
default: {
break;
}
@@ -463,13 +480,18 @@
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
- case MMAIntrinsic::WMMA_F32_16x16x16_F16:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
auto aType = VectorType::get({16}, getAType());
auto bType = VectorType::get({16}, getBType());
auto cType = VectorType::get({8}, getCType());
return std::make_tuple(aType, bType, cType);
}
+ case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ auto aType = VectorType::get({16}, getAType());
+ auto bType = VectorType::get({16}, getBType());
+ auto cType = VectorType::get({16}, getCType());
+ return std::make_tuple(aType, bType, cType);
+ }
}
// This should not happen but just to make GCC happy.
return std::make_tuple(VectorType{}, VectorType{}, VectorType{});
@@ -597,11 +619,14 @@
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
- case MMAIntrinsic::WMMA_F32_16x16x16_F16:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
return {/*outer=*/{8, 1}, /*thread=*/{2, 16}, /*strides=*/{16, 1},
/*element=*/{1, 1}};
}
+ case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ return {/*outer=*/{16, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
+ /*element=*/{1, 1}};
+ }
}
return {};
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir
index facbb84..1032fce 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir
@@ -238,8 +238,8 @@
// CHECK-INPUTS: %[[MMA:.+]] = iree_gpu.multi_mma
// CHECK-INPUTS: return %[[MMA]]
-// CHECK-RESULT: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0, 1], [2]] output_shape [8, 2, 16]
+// CHECK-RESULT: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0, 1], [2]] output_shape [16, 1, 16]
// CHECK-RESULT: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]]
-// CHECK-RESULT-SAME: : tensor<16x16xf16>, tensor<16x16xf16> into tensor<8x2x16xf16>
+// CHECK-RESULT-SAME: : tensor<16x16xf16>, tensor<16x16xf16> into tensor<16x1x16xf16>
// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0, 1], [2]]
// CHECK-RESULT: return %[[COLLAPSED]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
index 5569b3b..60efae3 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
@@ -239,7 +239,7 @@
}
// CHECK-DAG: #[[$XMAP:.+]] = affine_map<(d0) -> (d0 mod 16)>
-// CHECK-DAG: #[[$YMAP:.+]] = affine_map<(d0) -> ((d0 floordiv 16) mod 2)>
+// CHECK-DAG: #[[$YMAP:.+]] = affine_map<() -> ()>
// CHECK-LABEL: func @distribute_WMMA_F16_16x16x16_F16
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<16x16xf16>
@@ -248,10 +248,9 @@
// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[$XMAP]](%[[LANEID]])
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[IDX]], 0] [1, 16]
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, %[[IDX]]] [16, 1]
-// CHECK-DAG: %[[IDY:.+]] = affine.apply #[[$YMAP]](%[[LANEID]])
-// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, %[[IDY]], %[[IDX]]] [8, 1, 1]
+// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDX]]] [16, 1, 1]
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
// CHECK-SAME: kind = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F16>
-// CHECK-SAME: : tensor<1x16xf16>, tensor<16x1xf16> into tensor<8x1x1xf16>
-// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, %[[IDY]], %[[IDX]]] [8, 1, 1]
+// CHECK-SAME: : tensor<1x16xf16>, tensor<16x1xf16> into tensor<16x1x1xf16>
+// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDX]]] [16, 1, 1]
// CHECK: mapping = [#iree_gpu.lane_id<0>]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index 4ae197c..958baf1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -495,6 +495,6 @@
// CHECK-LABEL: func @matmul_transpose_b_wmma_f16_16x16x16_f16
// CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
// CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
-// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x8x1x1xf16>)
-// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf16>
+// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x16x1x1xf16>)
+// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16>
// CHECK: scf.yield
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index 9074458..d8f9673 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -505,6 +505,58 @@
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
+hal.executable @matmul_256x256x256_f16_f16 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.export @matmul_256x256x256_f16_f16 layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @matmul_256x256x256_f16_f16() {
+ %cst = arith.constant 0.000000e+00 : f16
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf16>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
+ %5 = tensor.empty() : tensor<256x256xf16>
+ %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16>
+ %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf16>) -> tensor<256x256xf16>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf16> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf16>>
+ return
+ }
+ }
+}
+}
+
+// RDNA3: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 32
+// RDNA3-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F16>,
+// RDNA3-SAME: subgroup_m_count = 2, subgroup_n_count = 2>
+// RDNA3-SAME: prefetch_shared_memory
+
+// RDNA3-LABEL: func.func @matmul_256x256x256_f16_f16
+// RDNA3-SAME: translation_info = #[[$TRANSLATION]]
+// RDNA3: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x16x1x1x1xf16>)
+// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
+// along the K dimension. So in total 32 wmma ops.
+// RDNA3-COUNT-32: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16>
+// RDNA3: scf.yield %{{.+}} : vector<2x2x16x1x1x1xf16>
+// Since each subgroup handles 2 * 2 tiles, and for each tile, each lane holds 4 values.
+// we will have 32 writes. We cannot do contiguous writes since the outputs columns has interleaved
+// thread ids.
+// RDNA3-COUNT-32: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<1x1xf16>, memref<256x256xf16, #hal.descriptor_type<storage_buffer>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
hal.executable @unaligned_mk_batch_matmul_64x978x1281x1281_f16_f16 {
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export @unaligned_mk_batch_matmul_64x978x1281x1281_f16_f16 layout(#pipeline_layout) {