Add CDNA3 MFMA BF16 intrinsics. (#18892)
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 7681816..578cd59 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -15,7 +15,7 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
-// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
+// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
@@ -26,7 +26,7 @@
// GFX941-SAME: features = "+sramecc,-xnack"
// GFX940: target = #iree_gpu.target<arch = "gfx940",
-// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
+// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>]
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
index e20fc3b..90becb2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
@@ -1130,3 +1130,63 @@
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x32_F8E4M3FNUZ, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2]>
+#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2]>
+#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2]>
+#pipeline_layout_4 = #hal.pipeline.layout<constants = 4, bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16() {
+ %c0 = arith.constant 0 : index
+ %B = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(0) : index
+ %M = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(1) : index
+ %N = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(2) : index
+ %K = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(3) : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(0) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_lhs>>{%B, %M, %K}
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(1) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_rhs>>{%B, %K, %N}
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(2) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [%B, %M, %K], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_lhs>>{%B, %M, %K}
+ -> tensor<?x?x?xbf16, #encoding_lhs>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [%B, %K, %N], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_rhs>>{%B, %K, %N}
+ -> tensor<?x?x?xbf16, #encoding_rhs>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
+ -> tensor<?x?x?xf32, #encoding_result>
+ %6 = linalg.batch_matmul
+ ins(%3, %4 : tensor<?x?x?xbf16, #encoding_lhs>,
+ tensor<?x?x?xbf16, #encoding_rhs>)
+ outs(%5 : tensor<?x?x?xf32, #encoding_result>)
+ -> tensor<?x?x?xf32, #encoding_result>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1]
+ : tensor<?x?x?xf32, #encoding_result>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16
+// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
+// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
+// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x2x4xbf16>
+// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x2x4xbf16>
+// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x8x4x2x4x16x4xf32>
+// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
+// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_BF16, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
+// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index d9c26ae..41c099f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -214,6 +214,7 @@
Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
Type f8E5M2FNUZ = Float8E5M2FNUZType::get(context);
Type f16 = Float16Type::get(context);
+ Type bf16 = BFloat16Type::get(context);
Type f32 = Float32Type::get(context);
Type i8 = IntegerType::get(context, 8);
@@ -229,6 +230,12 @@
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
}
+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
+ return OpaqueMmaLayout{16, 16, 16, bf16, bf16, f32};
+ }
+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
+ return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32};
+ }
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
@@ -336,6 +343,45 @@
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
+ // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
+ // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
+ // #layout_a = #iree_vector_ext.layout<#outer, #inner>
+ // #layout_b = #iree_vector_ext.layout<#inner, #outer>
+ // #layout_c = #iree_vector_ext.layout<#inner, #outer>
+
+ auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
+ auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
+ auto aMLayout = outer;
+ auto aKLayout = inner;
+ auto bKLayout = inner;
+ auto bNLayout = outer;
+ auto cMLayout = inner;
+ auto cNLayout = outer;
+ return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+ bNLayout, cMLayout, cNLayout};
+ }
+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
+ // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]>
+ // #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]>
+ // #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX],
+ // [4, 2, 4]>
+ // #layout_a = #iree_vector_ext.layout<#outer, #inner1>
+ // #layout_b = #iree_vector_ext.layout<#inner1, #outer>
+ // #layout_c = #iree_vector_ext.layout<#inner2, #outer>
+
+ auto outer = PerDimLayoutAttr::get(context, {laneX}, {32});
+ auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 4});
+ auto aMLayout = outer;
+ auto aKLayout = inner;
+ auto bKLayout = inner;
+ auto bNLayout = outer;
+ auto cMLayout =
+ PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4});
+ auto cNLayout = outer;
+ return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+ bNLayout, cMLayout, cNLayout};
+ }
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
@@ -462,14 +508,16 @@
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
- case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
+ case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
- case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
+ case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({16}, getCType());
@@ -519,8 +567,10 @@
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
@@ -540,8 +590,10 @@
switch (intrinsic) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
@@ -584,6 +636,7 @@
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16},
@@ -597,6 +650,7 @@
}
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32},
@@ -704,8 +758,10 @@
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index d1c84a8..9d4ac2e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -121,6 +121,8 @@
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>;
def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>;
def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>;
+def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x0920>;
+def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>;
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>;
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x0940>;
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x0980>;
@@ -143,6 +145,8 @@
MFMA_F32_16x16x4_F32,
MFMA_F32_16x16x16_F16,
MFMA_F32_32x32x8_F16,
+ MFMA_F32_16x16x16_BF16,
+ MFMA_F32_32x32x8_BF16,
MFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_F32_16x16x32_F8E5M2FNUZ,
MFMA_I32_16x16x32_I8,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index c187f44..5e8f031 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -136,6 +136,8 @@
MMAIntrinsic::MFMA_F32_16x16x4_F32,
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
+ MMAIntrinsic::MFMA_F32_16x16x16_BF16,
+ MMAIntrinsic::MFMA_F32_32x32x8_BF16,
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ,
MMAIntrinsic::MFMA_I32_16x16x32_I8,
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 98a4ff1..f229434 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1572,6 +1572,35 @@
iree_generated_e2e_runner_test(
NAME
+ e2e_matmul_rocm_bf16_cdna3_mfma_data_tiled
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=bf16"
+ "--acc_type=f32"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ ${IREE_HIP_TEST_COMPILER_FLAGS}
+ "--iree-opt-data-tiling"
+ "--iree-global-opt-experimental-rocm-data-tiling"
+ "--iree-global-opt-enable-early-materialization=true"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_runner_test(
+ NAME
e2e_matmul_rocm_i8_cdna3_mfma_data_tiled
TEST_TYPE
matmul