Add F8_16x16x32_F32 support for MFMA (#17792)
Added F8_16x16x32xF32 MFMA layout support and their e2e tests.
Needed to adjust/branch in the e2e matmul test's cmake because only gfx94x GPUs have FP8 MFMA layouts and it has different I8 intrinsic shape/layout as opposed to what is present in gfx90x.
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index c801f9b..7859c0b 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -6,13 +6,13 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
-// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
+// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: chip = <wgp_count = 304>>
// GFX940: target = #iree_gpu.target<arch = "gfx940",
-// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
+// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index cbca100..2e4d771 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -203,6 +203,7 @@
static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
MMAIntrinsic type) {
+ Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
Type f16 = Float16Type::get(context);
Type f32 = Float32Type::get(context);
@@ -216,6 +217,9 @@
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
}
+ case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32: {
+ return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
+ }
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
}
@@ -290,6 +294,7 @@
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
+ case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 8]>
@@ -416,6 +421,7 @@
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
+ case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
@@ -452,6 +458,7 @@
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+ case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
@@ -467,6 +474,7 @@
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+ case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
return 64;
@@ -490,6 +498,7 @@
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 4}};
}
+ case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 8}};
@@ -517,6 +526,7 @@
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
+ case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{8, 1}};
@@ -537,6 +547,7 @@
MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
+ case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
@@ -573,6 +584,7 @@
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+ case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
auto [m, n, k] = getMNKShape();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index 0423c2f..108deb9 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -101,16 +101,18 @@
// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
-def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 2>;
-def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 3>;
+def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 2>;
+def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 4>;
+def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 5>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
-def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 4>;
-def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 5>;
+def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 6>;
+def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 7>;
def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
+ MFMA_F8E4M3FNUZ_16x16x32_F32,
MFMA_I8_16x16x32_I32,
MFMA_I8_32x32x16_I32,
WMMA_F16_16x16x16_F32,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index f4584fe..c9e200f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -124,6 +124,7 @@
static const MMAIntrinsic cdna3MMAOps[] = {
MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32,
+ MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32,
MMAIntrinsic::MFMA_I8_16x16x32_I32,
MMAIntrinsic::MFMA_I8_32x32x16_I32,
};
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index 25fcc8e..24f0353 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -179,6 +179,56 @@
// -----
+// Basic f8, f8 -> f32 matmul.
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+hal.executable @matmul_256x256x256_f8_f32 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.export @matmul_256x256x256_f8_f32 layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @matmul_256x256x256_f8_f32() {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
+ %5 = tensor.empty() : tensor<256x256xf32>
+ %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+ return
+ }
+ }
+}
+}
+
+// Make sure it generates the mfma instructions we expect for f8 inputs.
+
+// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
+// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F8E4M3FNUZ_16x16x32_F32>,
+// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2>
+// CHECK-SAME: prefetch_shared_memory
+
+// CHECK-LABEL: func.func @matmul_256x256x256_f8_f32()
+// CHECK-SAME: translation_info = #[[$TRANSLATION]]
+// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
+// along the K dimension. So in total 32 mfma ops.
+// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
+// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type<storage_buffer>>
+
+// -----
+
// Basic i8, i8 -> i32 matmul.
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 88ef71a..cc4de30 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -845,6 +845,7 @@
return makeElementTypeValue(numericalType, intType.getWidth());
} else if (auto floatType = llvm::dyn_cast_if_present<FloatType>(type)) {
switch (APFloat::SemanticsToEnum(floatType.getFloatSemantics())) {
+ case APFloat::S_Float8E4M3FNUZ:
case APFloat::S_IEEEhalf:
case APFloat::S_IEEEsingle:
case APFloat::S_IEEEdouble:
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 1c6dc9f..c908a8a 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -2318,6 +2318,39 @@
"requires-gpu-cdna3"
)
+if(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx94")
+
+# I8 Intrinsics has different layout on CDNA3/gfx94x,
+# and only CDNA3/gfx94x has F8 intrinsics.
+
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_matmul_rocm_f8_large_cdna3_mfma
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f8E4M3FNUZ"
+ "--acc_type=f32"
+ "--shapes=gpu_large_aligned"
+ "--compilation_info=LLVMGPUVectorDistributeMFMA"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ ${IREE_HIP_TEST_COMPILER_FLAGS}
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_i8_large_cdna3_mfma_tb
@@ -2375,6 +2408,7 @@
"noubsan"
"requires-gpu-cdna3"
)
+endif()
elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 003f3de..8b1fbf6 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -27,6 +27,7 @@
I32 = "i32"
F32 = "f32"
F16 = "f16"
+ F8E4M3FNUZ = "f8E4M3FNUZ"
BF16 = "bf16"
@@ -271,6 +272,10 @@
MMASchedule("MFMA_F16_32x32x8_F32", 2, 2, 1, 1, 1),
MMASchedule("MFMA_F16_32x32x8_F32", 1, 4, 2, 1, 2),
MMASchedule("MFMA_F16_32x32x8_F32", 4, 2, 1, 2, 4),
+ MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 1, 1, 1, 1, 1),
+ MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 2, 2, 1, 1, 2),
+ MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 4, 1, 4, 1, 1),
+ MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 4, 2, 4, 2, 1),
MMASchedule("MFMA_I8_16x16x32_I32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_I8_16x16x32_I32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_I8_16x16x32_I32", 4, 1, 4, 1, 1),
@@ -310,7 +315,10 @@
wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
wg_tile_k = schedule.k_tile_count * 8
- elif schedule.intrinsic == "MFMA_I8_16x16x32_I32":
+ elif (
+ schedule.intrinsic == "MFMA_I8_16x16x32_I32"
+ or schedule.intrinsic == "MFMA_F8E4M3FNUZ_16x16x32_F32"
+ ):
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 32
@@ -454,6 +462,20 @@
return s.value or "DYN"
+# Gets friendlier form/type that we can use as arg types which we can cast into the target_type.
+def cast_argtype_if_required(target_type: MatrixElemTypeId):
+ if target_type == MatrixElemTypeId.F8E4M3FNUZ:
+ return MatrixElemTypeId.F32
+ return target_type
+
+
+# Gets the op needed to cast/convert from the friendly form/type into the target_type.
+def get_castback_from_arg_op(target_type: MatrixElemTypeId):
+ if target_type == MatrixElemTypeId.F8E4M3FNUZ:
+ return "arith.truncf"
+ return ValueError(f"Unhandled castback type of {t}")
+
+
# Describes the fully resolved shape dimensions of all 3 input matrices,
# LHS, RHS, and Accumulator, in a testcase.
# Each value is a string, which may either represent a positive integer such as "123",
@@ -559,8 +581,10 @@
rhs_c = int_or_question_mark(shapes.rhs_cols)
acc_r = int_or_question_mark(shapes.acc_rows)
acc_c = int_or_question_mark(shapes.acc_cols)
- lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
- rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
+
+ casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
+ lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{casted_lhs_rhs_type.value}>"
+ rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{casted_lhs_rhs_type.value}>"
acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"
if transpose_rhs:
@@ -603,13 +627,22 @@
)
func_definition = func_definition + compilation_info_string
generate_function.compilation_index += 1
-
+ compute = f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ if casted_lhs_rhs_type != lhs_rhs_type:
+ castback_op = get_castback_from_arg_op(lhs_rhs_type)
+ compute_lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
+ compute_rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
+ compute = (
+ f" %lhs_casted = {castback_op} %lhs: {lhs_tensor_type} to {compute_lhs_tensor_type}\n"
+ f" %rhs_casted = {castback_op} %rhs: {rhs_tensor_type} to {compute_rhs_tensor_type}\n"
+ f" %result = {op_name} {compilation_info_attr}ins(%lhs_casted, %rhs_casted: {compute_lhs_tensor_type}, {compute_rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}"
+ )
if shape.accumulate:
signature = f"({lhs_tensor_type}, {rhs_tensor_type}, {acc_tensor_type}) -> {acc_tensor_type}"
import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view"
func_definition = func_definition + (
f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
- f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f"{compute}\n"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
@@ -627,7 +660,7 @@
f" %init_acc = tensor.empty(%acc_dim0, %acc_dim1) : {acc_tensor_type}\n"
f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
- f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f"{compute}"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
@@ -639,7 +672,7 @@
f" %init_acc = tensor.empty() : {acc_tensor_type}\n"
f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
- f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f"{compute}"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
@@ -733,8 +766,9 @@
rhs_shape = [shape.k, shape.n]
transpose_rhs = 0
- op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type)
- op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type)
+ casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
+ op = op + generate_random_matrix("lhs", lhs_shape, casted_lhs_rhs_type)
+ op = op + generate_random_matrix("rhs", rhs_shape, casted_lhs_rhs_type)
if shape.accumulate:
op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type)
# TODO(#16168): there's a bug with in-place input->output aliasing and
@@ -822,7 +856,7 @@
parser.add_argument(
"--lhs_rhs_type",
type=str,
- choices=["i32", "i8", "f32", "f16", "bf16"],
+ choices=["i32", "i8", "f32", "f16", "f8E4M3FNUZ", "bf16"],
help="Numeric type of input matrices",
required=True,
)
@@ -915,6 +949,8 @@
def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
if acc_type != MatrixElemTypeId.NONE:
return acc_type
+ if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ:
+ return MatrixElemTypeId.F32
if lhs_rhs_type == MatrixElemTypeId.I8:
return MatrixElemTypeId.I32
return lhs_rhs_type