Add CDNA3 MFMA BF16 intrinsics. (#18892)

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 7681816..578cd59 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -15,7 +15,7 @@
 // GFX942: target = #iree_gpu.target<arch = "gfx942",
 // GFX942-SAME: wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8,
 // GFX942-SAME:         subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32,
-// GFX942-SAME:         mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
+// GFX942-SAME:         mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
 // GFX942-SAME:         subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
 // GFX942-SAME:         max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
 // GFX942-SAME:         max_workgroup_counts = [2147483647, 2147483647, 2147483647],
@@ -26,7 +26,7 @@
 // GFX941-SAME:         features = "+sramecc,-xnack"
 
 // GFX940: target = #iree_gpu.target<arch = "gfx940",
-// GFX940-SAME:         mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
+// GFX940-SAME:         mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
 
 // GFX1100: target = #iree_gpu.target<arch = "gfx1100",
 // GFX1100-SAME:        mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>]
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
index e20fc3b..90becb2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
@@ -1130,3 +1130,63 @@
 // CHECK-SAME:    iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
 // CHECK-SAME:    kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x32_F8E4M3FNUZ, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
 // CHECK:       flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2]>
+#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2]>
+#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2]>
+#pipeline_layout_4 = #hal.pipeline.layout<constants = 4, bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16() {
+  %c0 = arith.constant 0 : index
+  %B = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(0) : index
+  %M = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(1) : index
+  %N = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(2) : index
+  %K = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(3) : index
+  %0 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(0) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_lhs>>{%B, %M, %K}
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(1) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_rhs>>{%B, %K, %N}
+  %2 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(2) alignment(64) offset(%c0)
+      : !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
+  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [%B, %M, %K], strides = [1, 1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_lhs>>{%B, %M, %K}
+      -> tensor<?x?x?xbf16, #encoding_lhs>
+  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [%B, %K, %N], strides = [1, 1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_rhs>>{%B, %K, %N}
+      -> tensor<?x?x?xbf16, #encoding_rhs>
+  %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1]
+      : !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
+      -> tensor<?x?x?xf32, #encoding_result>
+  %6 = linalg.batch_matmul
+      ins(%3, %4 : tensor<?x?x?xbf16, #encoding_lhs>,
+                   tensor<?x?x?xbf16, #encoding_rhs>)
+      outs(%5 : tensor<?x?x?xf32, #encoding_result>)
+      -> tensor<?x?x?xf32, #encoding_result>
+  flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1]
+      : tensor<?x?x?xf32, #encoding_result>
+      -> !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
+  return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK:     func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16
+// CHECK-DAG:   %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
+// CHECK-DAG:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
+// CHECK-DAG:   %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
+// CHECK-DAG:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x2x4xbf16>
+// CHECK-DAG:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x2x4xbf16>
+// CHECK-DAG:   %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x8x4x2x4x16x4xf32>
+// CHECK:       %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:    iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
+// CHECK-SAME:    kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_BF16, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
+// CHECK:       flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index d9c26ae..41c099f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -214,6 +214,7 @@
   Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
   Type f8E5M2FNUZ = Float8E5M2FNUZType::get(context);
   Type f16 = Float16Type::get(context);
+  Type bf16 = BFloat16Type::get(context);
   Type f32 = Float32Type::get(context);
 
   Type i8 = IntegerType::get(context, 8);
@@ -229,6 +230,12 @@
   case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
     return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
   }
+  case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
+    return OpaqueMmaLayout{16, 16, 16, bf16, bf16, f32};
+  }
+  case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
+    return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32};
+  }
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
     return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
   }
@@ -336,6 +343,45 @@
     return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
                              bNLayout,     cMLayout, cNLayout};
   }
+  case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
+    // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
+    // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
+    // #layout_a = #iree_vector_ext.layout<#outer, #inner>
+    // #layout_b = #iree_vector_ext.layout<#inner, #outer>
+    // #layout_c = #iree_vector_ext.layout<#inner, #outer>
+
+    auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
+    auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
+    auto aMLayout = outer;
+    auto aKLayout = inner;
+    auto bKLayout = inner;
+    auto bNLayout = outer;
+    auto cMLayout = inner;
+    auto cNLayout = outer;
+    return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+                             bNLayout,     cMLayout, cNLayout};
+  }
+  case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
+    // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]>
+    // #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]>
+    // #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX],
+    //                                           [4, 2, 4]>
+    // #layout_a = #iree_vector_ext.layout<#outer, #inner1>
+    // #layout_b = #iree_vector_ext.layout<#inner1, #outer>
+    // #layout_c = #iree_vector_ext.layout<#inner2, #outer>
+
+    auto outer = PerDimLayoutAttr::get(context, {laneX}, {32});
+    auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 4});
+    auto aMLayout = outer;
+    auto aKLayout = inner;
+    auto bKLayout = inner;
+    auto bNLayout = outer;
+    auto cMLayout =
+        PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4});
+    auto cNLayout = outer;
+    return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+                             bNLayout,     cMLayout, cNLayout};
+  }
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
   case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
     // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
@@ -462,14 +508,16 @@
     return std::make_tuple(aType, bType, cType);
   }
   case MMAIntrinsic::MFMA_I32_16x16x16_I8:
-  case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
+  case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+  case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
     auto aType = VectorType::get({4}, getAType());
     auto bType = VectorType::get({4}, getBType());
     auto cType = VectorType::get({4}, getCType());
     return std::make_tuple(aType, bType, cType);
   }
   case MMAIntrinsic::MFMA_I32_32x32x8_I8:
-  case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
+  case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+  case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
     auto aType = VectorType::get({4}, getAType());
     auto bType = VectorType::get({4}, getBType());
     auto cType = VectorType::get({16}, getCType());
@@ -519,8 +567,10 @@
   switch (getIntrinsic().getValue()) {
   case MMAIntrinsic::MFMA_F32_16x16x4_F32:
   case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+  case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
   case MMAIntrinsic::MFMA_I32_16x16x16_I8:
   case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+  case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
   case MMAIntrinsic::MFMA_I32_32x32x8_I8:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
@@ -540,8 +590,10 @@
   switch (intrinsic) {
   case MMAIntrinsic::MFMA_F32_16x16x4_F32:
   case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+  case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
   case MMAIntrinsic::MFMA_I32_16x16x16_I8:
   case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+  case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
   case MMAIntrinsic::MFMA_I32_32x32x8_I8:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
@@ -584,6 +636,7 @@
     }
   case MMAIntrinsic::MFMA_I32_16x16x16_I8:
   case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+  case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
     switch (fragment) {
     case MMAFragment::Lhs:
       return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16},
@@ -597,6 +650,7 @@
     }
   case MMAIntrinsic::MFMA_I32_32x32x8_I8:
   case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+  case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
     switch (fragment) {
     case MMAFragment::Lhs:
       return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32},
@@ -704,8 +758,10 @@
   }
   case MMAIntrinsic::MFMA_I32_16x16x16_I8:
   case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+  case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
   case MMAIntrinsic::MFMA_I32_32x32x8_I8:
   case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+  case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
   case MMAIntrinsic::MFMA_I32_16x16x32_I8:
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index d1c84a8..9d4ac2e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -121,6 +121,8 @@
 def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>;
 def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>;
 def MFMA_F32_32x32x8_F16  : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>;
+def MFMA_F32_16x16x16_BF16  : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x0920>;
+def MFMA_F32_32x32x8_BF16  : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>;
 def MFMA_F32_16x16x32_F8E5M2FNUZ  : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>;
 def MFMA_F32_16x16x32_F8E4M3FNUZ  : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x0940>;
 def MFMA_I32_16x16x32_I8  : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x0980>;
@@ -143,6 +145,8 @@
       MFMA_F32_16x16x4_F32,
       MFMA_F32_16x16x16_F16,
       MFMA_F32_32x32x8_F16,
+      MFMA_F32_16x16x16_BF16,
+      MFMA_F32_32x32x8_BF16,
       MFMA_F32_16x16x32_F8E4M3FNUZ,
       MFMA_F32_16x16x32_F8E5M2FNUZ,
       MFMA_I32_16x16x32_I8,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index c187f44..5e8f031 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -136,6 +136,8 @@
       MMAIntrinsic::MFMA_F32_16x16x4_F32,
       MMAIntrinsic::MFMA_F32_16x16x16_F16,
       MMAIntrinsic::MFMA_F32_32x32x8_F16,
+      MMAIntrinsic::MFMA_F32_16x16x16_BF16,
+      MMAIntrinsic::MFMA_F32_32x32x8_BF16,
       MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ,
       MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ,
       MMAIntrinsic::MFMA_I32_16x16x32_I8,
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 98a4ff1..f229434 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1572,6 +1572,35 @@
 
 iree_generated_e2e_runner_test(
   NAME
+    e2e_matmul_rocm_bf16_cdna3_mfma_data_tiled
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=f32"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    ${IREE_HIP_TEST_COMPILER_FLAGS}
+    "--iree-opt-data-tiling"
+    "--iree-global-opt-experimental-rocm-data-tiling"
+    "--iree-global-opt-enable-early-materialization=true"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
     e2e_matmul_rocm_i8_cdna3_mfma_data_tiled
   TEST_TYPE
     matmul