Add F8_16x16x32_F32 support for MFMA (#17792)

Added F8_16x16x32xF32 MFMA layout support and their e2e tests. 
Needed to adjust/branch in the e2e matmul test's cmake because only gfx94x GPUs have FP8 MFMA layouts and it has different I8 intrinsic shape/layout as opposed to what is present in gfx90x.

 
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index c801f9b..7859c0b 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -6,13 +6,13 @@
 // GFX942: target = #iree_gpu.target<arch = "gfx942",
 // GFX942-SAME: wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8,
 // GFX942-SAME:         subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32,
-// GFX942-SAME:         mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
+// GFX942-SAME:         mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
 // GFX942-SAME:         subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
 // GFX942-SAME:         max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
 // GFX942-SAME: chip = <wgp_count = 304>>
 
 // GFX940: target = #iree_gpu.target<arch = "gfx940",
-// GFX940-SAME:         mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
+// GFX940-SAME:         mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
 
 // GFX1100: target = #iree_gpu.target<arch = "gfx1100",
 // GFX1100-SAME:        mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index cbca100..2e4d771 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -203,6 +203,7 @@
 
 static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
                                            MMAIntrinsic type) {
+  Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
   Type f16 = Float16Type::get(context);
   Type f32 = Float32Type::get(context);
 
@@ -216,6 +217,9 @@
   case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
     return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
   }
+  case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32: {
+    return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
+  }
   case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
     return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
   }
@@ -290,6 +294,7 @@
     return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
                              bNLayout,     cMLayout, cNLayout};
   }
+  case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
   case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
     // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
     // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 8]>
@@ -416,6 +421,7 @@
     auto cType = VectorType::get({16}, getCType());
     return std::make_tuple(aType, bType, cType);
   }
+  case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
   case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
     auto aType = VectorType::get({8}, getAType());
     auto bType = VectorType::get({8}, getBType());
@@ -452,6 +458,7 @@
   switch (getIntrinsic().getValue()) {
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
   case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+  case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
   case MMAIntrinsic::MFMA_I8_16x16x32_I32:
   case MMAIntrinsic::MFMA_I8_32x32x16_I32:
   case MMAIntrinsic::WMMA_F16_16x16x16_F16:
@@ -467,6 +474,7 @@
   switch (getIntrinsic().getValue()) {
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
   case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+  case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
   case MMAIntrinsic::MFMA_I8_16x16x32_I32:
   case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
     return 64;
@@ -490,6 +498,7 @@
     return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
             /*element=*/{1, 4}};
   }
+  case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
   case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
     return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
             /*element=*/{1, 8}};
@@ -517,6 +526,7 @@
     return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
             /*element=*/{4, 1}};
   }
+  case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
   case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
     return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
             /*element=*/{8, 1}};
@@ -537,6 +547,7 @@
 MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
   switch (getIntrinsic().getValue()) {
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
+  case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
   case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
     return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
             /*element=*/{4, 1}};
@@ -573,6 +584,7 @@
   switch (getIntrinsic().getValue()) {
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
   case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+  case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
   case MMAIntrinsic::MFMA_I8_16x16x32_I32:
   case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
     auto [m, n, k] = getMNKShape();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index 0423c2f..108deb9 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -101,16 +101,18 @@
 // Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
 def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
 def MFMA_F16_32x32x8_F32  : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
-def MFMA_I8_16x16x32_I32  : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 2>;
-def MFMA_I8_32x32x16_I32  : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 3>;
+def MFMA_F8E4M3FNUZ_16x16x32_F32  : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 2>;
+def MFMA_I8_16x16x32_I32  : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 4>;
+def MFMA_I8_32x32x16_I32  : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 5>;
 // TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
-def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 4>;
-def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 5>;
+def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 6>;
+def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 7>;
 
 def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
     "Descriptor for different MMA intrinsics", [
       MFMA_F16_16x16x16_F32,
       MFMA_F16_32x32x8_F32,
+      MFMA_F8E4M3FNUZ_16x16x32_F32,
       MFMA_I8_16x16x32_I32,
       MFMA_I8_32x32x16_I32,
       WMMA_F16_16x16x16_F32,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index f4584fe..c9e200f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -124,6 +124,7 @@
   static const MMAIntrinsic cdna3MMAOps[] = {
       MMAIntrinsic::MFMA_F16_16x16x16_F32,
       MMAIntrinsic::MFMA_F16_32x32x8_F32,
+      MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32,
       MMAIntrinsic::MFMA_I8_16x16x32_I32,
       MMAIntrinsic::MFMA_I8_32x32x16_I32,
   };
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index 25fcc8e..24f0353 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -179,6 +179,56 @@
 
 // -----
 
+// Basic f8, f8 -> f32 matmul.
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable @matmul_256x256x256_f8_f32 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+  hal.executable.export @matmul_256x256x256_f8_f32 layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+      hal.return %x, %y, %z : index, index, index
+    }
+  builtin.module {
+    func.func @matmul_256x256x256_f8_f32() {
+      %cst = arith.constant 0.000000e+00 : f32
+      %c0 = arith.constant 0 : index
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
+      %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+      %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
+      %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
+      %5 = tensor.empty() : tensor<256x256xf32>
+      %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
+      %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
+      flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+      return
+    }
+  }
+}
+}
+
+// Make sure it generates the mfma instructions we expect for f8 inputs.
+
+//       CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
+//  CHECK-SAME:   mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F8E4M3FNUZ_16x16x32_F32>,
+//  CHECK-SAME:     subgroup_m_count = 2, subgroup_n_count = 2>
+//  CHECK-SAME:   prefetch_shared_memory
+
+//    CHECK-LABEL: func.func @matmul_256x256x256_f8_f32()
+//     CHECK-SAME:    translation_info = #[[$TRANSLATION]]
+// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
+// along the K dimension. So in total 32 mfma ops.
+// CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
+//  CHECK-COUNT-4:   vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type<storage_buffer>>
+
+// -----
+
 // Basic i8, i8 -> i32 matmul.
 
 #pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 88ef71a..cc4de30 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -845,6 +845,7 @@
     return makeElementTypeValue(numericalType, intType.getWidth());
   } else if (auto floatType = llvm::dyn_cast_if_present<FloatType>(type)) {
     switch (APFloat::SemanticsToEnum(floatType.getFloatSemantics())) {
+    case APFloat::S_Float8E4M3FNUZ:
     case APFloat::S_IEEEhalf:
     case APFloat::S_IEEEsingle:
     case APFloat::S_IEEEdouble:
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 1c6dc9f..c908a8a 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -2318,6 +2318,39 @@
     "requires-gpu-cdna3"
 )
 
+if(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx94")
+
+# I8 Intrinsics has different layout on CDNA3/gfx94x,
+# and only CDNA3/gfx94x has F8 intrinsics.
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_rocm_f8_large_cdna3_mfma
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f8E4M3FNUZ"
+    "--acc_type=f32"
+    "--shapes=gpu_large_aligned"
+    "--compilation_info=LLVMGPUVectorDistributeMFMA"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    ${IREE_HIP_TEST_COMPILER_FLAGS}
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-cdna3"
+)
+
 iree_generated_e2e_runner_test(
   NAME
     e2e_matmul_rocm_i8_large_cdna3_mfma_tb
@@ -2375,6 +2408,7 @@
     "noubsan"
     "requires-gpu-cdna3"
 )
+endif()
 
 elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")
 
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 003f3de..8b1fbf6 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -27,6 +27,7 @@
     I32 = "i32"
     F32 = "f32"
     F16 = "f16"
+    F8E4M3FNUZ = "f8E4M3FNUZ"
     BF16 = "bf16"
 
 
@@ -271,6 +272,10 @@
             MMASchedule("MFMA_F16_32x32x8_F32", 2, 2, 1, 1, 1),
             MMASchedule("MFMA_F16_32x32x8_F32", 1, 4, 2, 1, 2),
             MMASchedule("MFMA_F16_32x32x8_F32", 4, 2, 1, 2, 4),
+            MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 1, 1, 1, 1, 1),
+            MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 2, 2, 1, 1, 2),
+            MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 4, 1, 4, 1, 1),
+            MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 4, 2, 4, 2, 1),
             MMASchedule("MFMA_I8_16x16x32_I32", 1, 1, 1, 1, 1),
             MMASchedule("MFMA_I8_16x16x32_I32", 2, 2, 1, 1, 2),
             MMASchedule("MFMA_I8_16x16x32_I32", 4, 1, 4, 1, 1),
@@ -310,7 +315,10 @@
             wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
             wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
             wg_tile_k = schedule.k_tile_count * 8
-        elif schedule.intrinsic == "MFMA_I8_16x16x32_I32":
+        elif (
+            schedule.intrinsic == "MFMA_I8_16x16x32_I32"
+            or schedule.intrinsic == "MFMA_F8E4M3FNUZ_16x16x32_F32"
+        ):
             wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
             wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
             wg_tile_k = schedule.k_tile_count * 32
@@ -454,6 +462,20 @@
     return s.value or "DYN"
 
 
+# Gets friendlier form/type that we can use as arg types which we can cast into the target_type.
+def cast_argtype_if_required(target_type: MatrixElemTypeId):
+    if target_type == MatrixElemTypeId.F8E4M3FNUZ:
+        return MatrixElemTypeId.F32
+    return target_type
+
+
+# Gets the op needed to cast/convert from the friendly form/type into the target_type.
+def get_castback_from_arg_op(target_type: MatrixElemTypeId):
+    if target_type == MatrixElemTypeId.F8E4M3FNUZ:
+        return "arith.truncf"
+    return ValueError(f"Unhandled castback type of {t}")
+
+
 # Describes the fully resolved shape dimensions of all 3 input matrices,
 # LHS, RHS, and Accumulator, in a testcase.
 # Each value is a string, which may either represent a positive integer such as "123",
@@ -559,8 +581,10 @@
     rhs_c = int_or_question_mark(shapes.rhs_cols)
     acc_r = int_or_question_mark(shapes.acc_rows)
     acc_c = int_or_question_mark(shapes.acc_cols)
-    lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
-    rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
+
+    casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
+    lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{casted_lhs_rhs_type.value}>"
+    rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{casted_lhs_rhs_type.value}>"
     acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"
 
     if transpose_rhs:
@@ -603,13 +627,22 @@
         )
         func_definition = func_definition + compilation_info_string
         generate_function.compilation_index += 1
-
+    compute = f"  %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+    if casted_lhs_rhs_type != lhs_rhs_type:
+        castback_op = get_castback_from_arg_op(lhs_rhs_type)
+        compute_lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
+        compute_rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
+        compute = (
+            f"  %lhs_casted = {castback_op} %lhs: {lhs_tensor_type} to {compute_lhs_tensor_type}\n"
+            f"  %rhs_casted = {castback_op} %rhs: {rhs_tensor_type} to {compute_rhs_tensor_type}\n"
+            f"  %result = {op_name} {compilation_info_attr}ins(%lhs_casted, %rhs_casted: {compute_lhs_tensor_type}, {compute_rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}"
+        )
     if shape.accumulate:
         signature = f"({lhs_tensor_type}, {rhs_tensor_type}, {acc_tensor_type}) -> {acc_tensor_type}"
         import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view"
         func_definition = func_definition + (
             f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
-            f"  %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+            f"{compute}\n"
             f"  return %result: {acc_tensor_type}\n"
             f"}}\n"
         )
@@ -627,7 +660,7 @@
                 f"  %init_acc = tensor.empty(%acc_dim0, %acc_dim1) : {acc_tensor_type}\n"
                 f"  %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
                 f"  %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
-                f"  %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+                f"{compute}"
                 f"  return %result: {acc_tensor_type}\n"
                 f"}}\n"
             )
@@ -639,7 +672,7 @@
                 f"  %init_acc = tensor.empty() : {acc_tensor_type}\n"
                 f"  %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
                 f"  %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
-                f"  %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+                f"{compute}"
                 f"  return %result: {acc_tensor_type}\n"
                 f"}}\n"
             )
@@ -733,8 +766,9 @@
         rhs_shape = [shape.k, shape.n]
         transpose_rhs = 0
 
-    op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type)
-    op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type)
+    casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
+    op = op + generate_random_matrix("lhs", lhs_shape, casted_lhs_rhs_type)
+    op = op + generate_random_matrix("rhs", rhs_shape, casted_lhs_rhs_type)
     if shape.accumulate:
         op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type)
         # TODO(#16168): there's a bug with in-place input->output aliasing and
@@ -822,7 +856,7 @@
     parser.add_argument(
         "--lhs_rhs_type",
         type=str,
-        choices=["i32", "i8", "f32", "f16", "bf16"],
+        choices=["i32", "i8", "f32", "f16", "f8E4M3FNUZ", "bf16"],
         help="Numeric type of input matrices",
         required=True,
     )
@@ -915,6 +949,8 @@
 def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
     if acc_type != MatrixElemTypeId.NONE:
         return acc_type
+    if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ:
+        return MatrixElemTypeId.F32
     if lhs_rhs_type == MatrixElemTypeId.I8:
         return MatrixElemTypeId.I32
     return lhs_rhs_type