Improvements to e2e matmul tests (#15243)
* Test f16 on CPU, not just on GPU.
* Test bf16
* Test many SIMD variants that we now have dedicated ukernel code paths
for. These code paths were so far tested only in isolated unit tests,
not in integration.
* Rationalize/reorganize the collection of e2e matmul tests:
- Drop no longer relevant cases.
- Rename: `mmt4d` -> `dt` (for data-tiling), `direct` -> `nondt`, etc.
The `nondt` tests explicitly disable data-tiling, so they will continue
testing that through the merging of #15215.
- Group by target backend.
diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake
index b7d6a24..73fc57a 100644
--- a/build_tools/cmake/iree_check_test.cmake
+++ b/build_tools/cmake/iree_check_test.cmake
@@ -8,7 +8,8 @@
function(iree_is_bytecode_module_test_excluded_by_labels _DST_IS_EXCLUDED_VAR _SRC_LABELS)
string(TOLOWER "${CMAKE_BUILD_TYPE}" _LOWERCASE_BUILD_TYPE)
- if(((IREE_ARCH MARCHES "^riscv_") AND ("noriscv" IN_LIST _SRC_LABELS)) OR
+ if(((IREE_ARCH MATCHES "^riscv_") AND ("noriscv" IN_LIST _SRC_LABELS)) OR
+ (EMSCRIPTEN AND ("nowasm" IN_LIST _SRC_LABELS)) OR
(IREE_ENABLE_ASAN AND ("noasan" IN_LIST _SRC_LABELS)) OR
(IREE_ENABLE_TSAN AND ("notsan" IN_LIST _SRC_LABELS)) OR
(CMAKE_CROSSCOMPILING AND "hostonly" IN_LIST _RULE_LABELS) OR
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c
index c4b9f31..6f2c849 100644
--- a/runtime/src/iree/tooling/trace_replay.c
+++ b/runtime/src/iree/tooling/trace_replay.c
@@ -349,6 +349,9 @@
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
*(uint16_t*)dst = iree_math_f32_to_f16((float)value);
break;
+ case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
+ *(uint16_t*)dst = iree_math_f32_to_bf16((float)value);
+ break;
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_32, float)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_64, double)
// clang-format on
@@ -402,6 +405,7 @@
case IREE_HAL_ELEMENT_TYPE_INT_16:
case IREE_HAL_ELEMENT_TYPE_SINT_16:
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
+ case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
*min = -4;
*max = +4;
break;
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index 6c76c35..d3ca5b1 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -18,156 +18,50 @@
srcs = ["generate_e2e_matmul_tests.py"],
)
-[iree_generated_trace_runner_test(
- name = "e2e_matmul_direct_%s_small" % lhs_rhs_type,
- generator = ":generate_e2e_matmul_tests",
- generator_args = [
- "--lhs_rhs_type=%s" % lhs_rhs_type,
- "--shapes=small",
- ],
- target_backends_and_drivers = [
- ("llvm-cpu", "local-task"),
- ],
- trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
- "i8",
- "f32",
-]]
+###########################################################################
+##
+## LLVMCPU backend
+##
+###########################################################################
-# Test asm
+# LLVMCPU, non-data-tiling, no microkernels
[iree_generated_trace_runner_test(
- name = "e2e_matmul_mmt4d_%s_small" % lhs_rhs_type,
+ name = "e2e_matmul_nondt_%s_%s_small" % (lhs_rhs_type, acc_type),
compiler_flags = [
- "--iree-opt-data-tiling",
+ "--iree-opt-data-tiling=false",
],
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
+ "--acc_type=%s" % acc_type,
"--shapes=small",
],
- target_backends_and_drivers = [
- ("llvm-cpu", "local-task"),
- ],
- target_cpu_features_variants = ["default"] +
- ([
- "arm_64:dotprod:+dotprod",
- "arm_64:i8mm:+i8mm",
- ] if lhs_rhs_type == "i8" else []),
- trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
- "i8",
- "f32",
-]]
-
-[iree_generated_trace_runner_test(
- name = "e2e_matmul_mmt4d_%s_large" % lhs_rhs_type,
- compiler_flags = [
- "--iree-opt-data-tiling",
- ],
- generator = ":generate_e2e_matmul_tests",
- generator_args = [
- "--lhs_rhs_type=%s" % lhs_rhs_type,
- "--shapes=large",
- ],
tags = [
- # "--shapes=large" can cause timeouts on riscv emulator and sanitizers.
+ # f16/bf16 trigger internal LLVM assertion errors on riscv and wasm.
"noriscv",
- "noasan",
- "notsan",
- ],
+ "nowasm",
+ ] if (lhs_rhs_type == "f16" or lhs_rhs_type == "bf16") else [],
target_backends_and_drivers = [
("llvm-cpu", "local-task"),
],
- target_cpu_features_variants = ["default"] +
- ([
- "arm_64:dotprod:+dotprod",
- "arm_64:i8mm:+i8mm",
- ] if lhs_rhs_type == "i8" else []),
trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
- "i8",
- "f32",
+) for (lhs_rhs_type, acc_type) in [
+ ("i8", "i32"),
+ ("f32", "f32"),
+ ("f16", "f16"),
+ ("f16", "f32"),
+ ("bf16", "bf16"),
+ ("bf16", "f32"),
]]
-# Test intrinsics. No need to run vmvx again, since it isn't affected by this
-# codegen flag. No need to run "large" sizes, since this only differs from other
-# tests in ways that are orthogonal to problem sizes.
-[iree_generated_trace_runner_test(
- name = "e2e_matmul_mmt4d_%s_intrinsics_%s" % (lhs_rhs_type, size),
- compiler_flags = [
- "--iree-codegen-mmt4d-use-intrinsics",
- "--iree-opt-data-tiling",
- ],
- generator = ":generate_e2e_matmul_tests",
- generator_args = [
- "--lhs_rhs_type=%s" % lhs_rhs_type,
- "--shapes=%s" % size,
- ],
- target_backends_and_drivers = [
- ("llvm-cpu", "local-task"),
- ],
- target_cpu_features_variants = ["default"] +
- ([
- "arm_64:dotprod:+dotprod",
- "arm_64:i8mm:+i8mm",
- ] if lhs_rhs_type == "i8" else []),
- trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
- "i8",
- "f32",
-] for size in [
- "small",
-]]
-
-# Test VMVX+ukernel, direct (not mmt4d)
-[iree_generated_trace_runner_test(
- name = "e2e_matmul_direct_%s_small_ukernel" % lhs_rhs_type,
- compiler_flags = [
- "--iree-vmvx-enable-microkernels",
- ],
- generator = ":generate_e2e_matmul_tests",
- generator_args = [
- "--lhs_rhs_type=%s" % lhs_rhs_type,
- "--shapes=small",
- ],
- target_backends_and_drivers = [
- ("vmvx", "local-task"),
- ],
- trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
- "i8",
- "f32",
-]]
-
-# Test VMVX+ukernel, mmt4d, with target CPU features variants relevant to each
-# lhs_rhs_type.
-[iree_generated_trace_runner_test(
- name = "e2e_matmul_mmt4d_%s_small_vmvx_ukernel" % lhs_rhs_type,
- compiler_flags = [
- "--iree-vmvx-enable-microkernels",
- "--iree-opt-data-tiling",
- ],
- generator = ":generate_e2e_matmul_tests",
- generator_args = [
- "--lhs_rhs_type=%s" % lhs_rhs_type,
- "--shapes=small",
- ],
- target_backends_and_drivers = [
- ("vmvx", "local-task"),
- ],
- trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
- "i8",
- "f32",
-]]
-
-X86_64_AVX2_FMA = [
+X86_64_AVX2 = [
"+avx",
"+avx2",
"+fma",
+ "+f16c",
]
-X86_64_AVX512_BASE = X86_64_AVX2_FMA + [
+X86_64_AVX512 = X86_64_AVX2 + [
"+avx512f",
"+avx512vl",
"+avx512cd",
@@ -175,49 +69,132 @@
"+avx512dq",
]
-X86_64_AVX512_VNNI = X86_64_AVX512_BASE + [
+X86_64_AVX512_VNNI = X86_64_AVX512 + [
"+avx512vnni",
]
-# Test mmt4d with --iree-llvmcpu-enable-microkernels.
+X86_64_AVX512_BF16 = X86_64_AVX512 + [
+ "+avx512bf16",
+]
+
+# LLVMCPU, data-tiling + microkernels.
+# TODO(#15241, #15215): also test data-tiling alone without microkernels. This currently
+# fails (#15241), which needs to be resolved to unblock data-tiling-by-default (#15215).
[iree_generated_trace_runner_test(
- name = "e2e_matmul_mmt4d_%s_%s_ukernel" % (lhs_rhs_type, size),
+ name = "e2e_matmul_dt_uk_%s_%s_%s" % (lhs_rhs_type, acc_type, size),
compiler_flags = [
+ "--iree-opt-data-tiling",
"--iree-llvmcpu-enable-microkernels",
+ ],
+ generator = ":generate_e2e_matmul_tests",
+ generator_args = [
+ "--lhs_rhs_type=%s" % lhs_rhs_type,
+ "--acc_type=%s" % acc_type,
+ "--shapes=%s" % size,
+ ],
+ tags = ([
+ # "--shapes=large" can cause timeouts on sanitizers.
+ "noasan",
+ "notsan",
+ ] if size == "large" else []) + ([
+ # "--shapes=large" can cause timeouts on RISC-V emulator.
+ # f16/bf16 trigger internal LLVM assertion errors on riscv and wasm.
+ "noriscv",
+ "nowasm",
+ ] if (lhs_rhs_type == "f16" or lhs_rhs_type == "bf16") else []),
+ target_backends_and_drivers = [
+ ("llvm-cpu", "local-task"),
+ ],
+ target_cpu_features_variants = ["default"] +
+ ([
+ "arm_64:dotprod:+dotprod",
+ "arm_64:i8mm:+i8mm",
+ "x86_64:avx512vnni:" + ",".join(X86_64_AVX512_VNNI),
+ ] if lhs_rhs_type == "i8" and acc_type == "i32" else [
+ "x86_64:avx2:" + ",".join(X86_64_AVX2),
+ "x86_64:avx512:" + ",".join(X86_64_AVX512),
+ ] if lhs_rhs_type == "f32" and acc_type == "f32" else [
+ "x86_64:avx2:" + ",".join(X86_64_AVX2),
+ "x86_64:avx512:" + ",".join(X86_64_AVX512),
+ "arm_64:fp16:+fp16",
+ ] if lhs_rhs_type == "f16" and acc_type == "f16" else [
+ "x86_64:avx2:" + ",".join(X86_64_AVX2),
+ "x86_64:avx512:" + ",".join(X86_64_AVX512),
+ "arm_64:fp16:+fp16fml",
+ ] if lhs_rhs_type == "f16" and acc_type == "f32" else [
+ "x86_64:avx2:" + ",".join(X86_64_AVX2),
+ "x86_64:avx512:" + ",".join(X86_64_AVX512),
+ "x86_64:avx512bf16:" + ",".join(X86_64_AVX512_BF16),
+ "arm_64:fp16:+bf16",
+ ] if lhs_rhs_type == "bf16" and acc_type == "bf16" else [
+ "x86_64:avx2:" + ",".join(X86_64_AVX2),
+ "x86_64:avx512:" + ",".join(X86_64_AVX512),
+ "x86_64:avx512bf16:" + ",".join(X86_64_AVX512_BF16),
+ "arm_64:fp16:+bf16",
+ ] if lhs_rhs_type == "bf16" and acc_type == "f32" else []),
+ trace_runner = "//tools:iree-e2e-matmul-test",
+) for (lhs_rhs_type, acc_type) in [
+ ("i8", "i32"),
+ ("f32", "f32"),
+ ("f16", "f16"),
+ ("f16", "f32"),
+ ("bf16", "bf16"),
+ ("bf16", "f32"),
+] for size in [
+ "small",
+ "large",
+]]
+
+# Some e2e testing for --iree-codegen-enable-vector-padding=false.
+iree_generated_trace_runner_test(
+ name = "e2e_matmul_nondt_f32_small_no_padding",
+ compiler_flags = [
+ "--iree-codegen-enable-vector-padding=false",
+ ],
+ generator = ":generate_e2e_matmul_tests",
+ generator_args = [
+ "--lhs_rhs_type=f32",
+ "--shapes=small",
+ ],
+ target_backends_and_drivers = [
+ ("llvm-cpu", "local-task"),
+ ],
+ trace_runner = "//tools:iree-e2e-matmul-test",
+)
+
+###########################################################################
+##
+## VMVX backend
+##
+###########################################################################
+
+# VMVX, data-tiling + microkernels.
+[iree_generated_trace_runner_test(
+ name = "e2e_matmul_dt_uk_%s_small" % lhs_rhs_type,
+ compiler_flags = [
+ "--iree-vmvx-enable-microkernels",
"--iree-opt-data-tiling",
],
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
- "--shapes=%s" % size,
+ "--shapes=small",
],
- tags = [
- # "--shapes=large" can cause timeouts on riscv emulator and sanitizers.
- "noriscv",
- "noasan",
- "notsan",
- ] if size == "large" else [],
target_backends_and_drivers = [
- ("llvm-cpu", "local-task"),
+ ("vmvx", "local-task"),
],
- target_cpu_features_variants = [
- "default",
- "x86_64:avx2_fma:" + ",".join(X86_64_AVX2_FMA),
- "x86_64:avx512_base:" + ",".join(X86_64_AVX512_BASE),
- ] + ([
- "x86_64:avx512_vnni:" + ",".join(X86_64_AVX512_VNNI),
- "arm_64:dotprod:+dotprod",
- "arm_64:i8mm:+i8mm",
- ] if lhs_rhs_type == "i8" else []),
trace_runner = "//tools:iree-e2e-matmul-test",
) for lhs_rhs_type in [
"i8",
"f32",
-] for size in [
- "small",
- "large",
]]
+###########################################################################
+##
+## CUDA backend
+##
+###########################################################################
+
iree_generated_trace_runner_test(
name = "e2e_matmul_direct_f32_gpu_large_LLVMGPUMatmulSimt",
generator = ":generate_e2e_matmul_tests",
@@ -415,13 +392,18 @@
],
target_backends_and_drivers = [
("cuda", "cuda"),
- ("llvm-cpu", "local-task"),
],
trace_runner = "//tools:iree-e2e-matmul-test",
) for lhs_rhs_type in [
"f32",
]]
+###########################################################################
+##
+## Vulkan backend
+##
+###########################################################################
+
[iree_generated_trace_runner_test(
name = "e2e_matmul_direct_{0}_gpu_large_valhall".format(lhs_rhs_type),
compiler_flags = [
@@ -474,22 +456,6 @@
]]
iree_generated_trace_runner_test(
- name = "e2e_matmul_direct_f32_small_no_padding",
- compiler_flags = [
- "--iree-codegen-enable-vector-padding=false",
- ],
- generator = ":generate_e2e_matmul_tests",
- generator_args = [
- "--lhs_rhs_type=f32",
- "--shapes=small",
- ],
- target_backends_and_drivers = [
- ("llvm-cpu", "local-task"),
- ],
- trace_runner = "//tools:iree-e2e-matmul-test",
-)
-
-iree_generated_trace_runner_test(
name = "e2e_matmul_direct_f16_gpu_large_rdna3",
compiler_flags = [
"--iree-vulkan-target-triple=rdna3-unknown-linux",
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 0fd02ec..913aa3d 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -12,11 +12,12 @@
iree_generated_trace_runner_test(
NAME
- e2e_matmul_direct_i8_small
+ e2e_matmul_nondt_i8_i32_small
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
+ "--acc_type=i32"
"--shapes=small"
TRACE_RUNNER
iree-e2e-matmul-test
@@ -24,15 +25,20 @@
"llvm-cpu"
DRIVERS
"local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling=false"
+ LABELS
+
)
iree_generated_trace_runner_test(
NAME
- e2e_matmul_direct_f32_small
+ e2e_matmul_nondt_f32_f32_small
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
+ "--acc_type=f32"
"--shapes=small"
TRACE_RUNNER
iree-e2e-matmul-test
@@ -40,15 +46,108 @@
"llvm-cpu"
DRIVERS
"local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling=false"
+ LABELS
+
)
iree_generated_trace_runner_test(
NAME
- e2e_matmul_mmt4d_i8_small
+ e2e_matmul_nondt_f16_f16_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f16"
+ "--acc_type=f16"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling=false"
+ LABELS
+ "noriscv"
+ "nowasm"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_nondt_f16_f32_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f16"
+ "--acc_type=f32"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling=false"
+ LABELS
+ "noriscv"
+ "nowasm"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_nondt_bf16_bf16_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=bf16"
+ "--acc_type=bf16"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling=false"
+ LABELS
+ "noriscv"
+ "nowasm"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_nondt_bf16_f32_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=bf16"
+ "--acc_type=f32"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling=false"
+ LABELS
+ "noriscv"
+ "nowasm"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_dt_uk_i8_i32_small
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
+ "--acc_type=i32"
"--shapes=small"
TRACE_RUNNER
iree-e2e-matmul-test
@@ -58,39 +157,24 @@
"local-task"
COMPILER_FLAGS
"--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
+ LABELS
+
TARGET_CPU_FEATURES_VARIANTS
"default"
"arm_64:dotprod:+dotprod"
"arm_64:i8mm:+i8mm"
+ "x86_64:avx512vnni:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
)
iree_generated_trace_runner_test(
NAME
- e2e_matmul_mmt4d_f32_small
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=f32"
- "--shapes=small"
- TRACE_RUNNER
- iree-e2e-matmul-test
- TARGET_BACKENDS
- "llvm-cpu"
- DRIVERS
- "local-task"
- COMPILER_FLAGS
- "--iree-opt-data-tiling"
- TARGET_CPU_FEATURES_VARIANTS
- "default"
-)
-
-iree_generated_trace_runner_test(
- NAME
- e2e_matmul_mmt4d_i8_large
+ e2e_matmul_dt_uk_i8_i32_large
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
+ "--acc_type=i32"
"--shapes=large"
TRACE_RUNNER
iree-e2e-matmul-test
@@ -100,23 +184,51 @@
"local-task"
COMPILER_FLAGS
"--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
LABELS
- "noriscv"
"noasan"
"notsan"
TARGET_CPU_FEATURES_VARIANTS
"default"
"arm_64:dotprod:+dotprod"
"arm_64:i8mm:+i8mm"
+ "x86_64:avx512vnni:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
)
iree_generated_trace_runner_test(
NAME
- e2e_matmul_mmt4d_f32_large
+ e2e_matmul_dt_uk_f32_f32_small
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
+ "--acc_type=f32"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
+ LABELS
+
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+ "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_dt_uk_f32_f32_large
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f32"
+ "--acc_type=f32"
"--shapes=large"
TRACE_RUNNER
iree-e2e-matmul-test
@@ -126,21 +238,24 @@
"local-task"
COMPILER_FLAGS
"--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
LABELS
- "noriscv"
"noasan"
"notsan"
TARGET_CPU_FEATURES_VARIANTS
"default"
+ "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+ "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
)
iree_generated_trace_runner_test(
NAME
- e2e_matmul_mmt4d_i8_intrinsics_small
+ e2e_matmul_dt_uk_f16_f16_small
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
- "--lhs_rhs_type=i8"
+ "--lhs_rhs_type=f16"
+ "--acc_type=f16"
"--shapes=small"
TRACE_RUNNER
iree-e2e-matmul-test
@@ -149,17 +264,229 @@
DRIVERS
"local-task"
COMPILER_FLAGS
- "--iree-codegen-mmt4d-use-intrinsics"
"--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
+ LABELS
+ "noriscv"
+ "nowasm"
TARGET_CPU_FEATURES_VARIANTS
"default"
- "arm_64:dotprod:+dotprod"
- "arm_64:i8mm:+i8mm"
+ "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+ "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+ "arm_64:fp16:+fp16"
)
iree_generated_trace_runner_test(
NAME
- e2e_matmul_mmt4d_f32_intrinsics_small
+ e2e_matmul_dt_uk_f16_f16_large
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f16"
+ "--acc_type=f16"
+ "--shapes=large"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
+ LABELS
+ "noasan"
+ "notsan"
+ "noriscv"
+ "nowasm"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+ "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+ "arm_64:fp16:+fp16"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_dt_uk_f16_f32_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f16"
+ "--acc_type=f32"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
+ LABELS
+ "noriscv"
+ "nowasm"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+ "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+ "arm_64:fp16:+fp16fml"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_dt_uk_f16_f32_large
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f16"
+ "--acc_type=f32"
+ "--shapes=large"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
+ LABELS
+ "noasan"
+ "notsan"
+ "noriscv"
+ "nowasm"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+ "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+ "arm_64:fp16:+fp16fml"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_dt_uk_bf16_bf16_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=bf16"
+ "--acc_type=bf16"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
+ LABELS
+ "noriscv"
+ "nowasm"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+ "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+ "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+ "arm_64:fp16:+bf16"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_dt_uk_bf16_bf16_large
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=bf16"
+ "--acc_type=bf16"
+ "--shapes=large"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
+ LABELS
+ "noasan"
+ "notsan"
+ "noriscv"
+ "nowasm"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+ "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+ "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+ "arm_64:fp16:+bf16"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_dt_uk_bf16_f32_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=bf16"
+ "--acc_type=f32"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
+ LABELS
+ "noriscv"
+ "nowasm"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+ "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+ "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+ "arm_64:fp16:+bf16"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_dt_uk_bf16_f32_large
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=bf16"
+ "--acc_type=f32"
+ "--shapes=large"
+ TRACE_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ COMPILER_FLAGS
+ "--iree-opt-data-tiling"
+ "--iree-llvmcpu-enable-microkernels"
+ LABELS
+ "noasan"
+ "notsan"
+ "noriscv"
+ "nowasm"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+ "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+ "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+ "arm_64:fp16:+bf16"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_nondt_f32_small_no_padding
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
@@ -172,51 +499,12 @@
DRIVERS
"local-task"
COMPILER_FLAGS
- "--iree-codegen-mmt4d-use-intrinsics"
- "--iree-opt-data-tiling"
- TARGET_CPU_FEATURES_VARIANTS
- "default"
+ "--iree-codegen-enable-vector-padding=false"
)
iree_generated_trace_runner_test(
NAME
- e2e_matmul_direct_i8_small_ukernel
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=i8"
- "--shapes=small"
- TRACE_RUNNER
- iree-e2e-matmul-test
- TARGET_BACKENDS
- "vmvx"
- DRIVERS
- "local-task"
- COMPILER_FLAGS
- "--iree-vmvx-enable-microkernels"
-)
-
-iree_generated_trace_runner_test(
- NAME
- e2e_matmul_direct_f32_small_ukernel
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=f32"
- "--shapes=small"
- TRACE_RUNNER
- iree-e2e-matmul-test
- TARGET_BACKENDS
- "vmvx"
- DRIVERS
- "local-task"
- COMPILER_FLAGS
- "--iree-vmvx-enable-microkernels"
-)
-
-iree_generated_trace_runner_test(
- NAME
- e2e_matmul_mmt4d_i8_small_vmvx_ukernel
+ e2e_matmul_dt_uk_i8_small
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
@@ -235,7 +523,7 @@
iree_generated_trace_runner_test(
NAME
- e2e_matmul_mmt4d_f32_small_vmvx_ukernel
+ e2e_matmul_dt_uk_f32_small
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
@@ -254,116 +542,6 @@
iree_generated_trace_runner_test(
NAME
- e2e_matmul_mmt4d_i8_small_ukernel
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=i8"
- "--shapes=small"
- TRACE_RUNNER
- iree-e2e-matmul-test
- TARGET_BACKENDS
- "llvm-cpu"
- DRIVERS
- "local-task"
- COMPILER_FLAGS
- "--iree-llvmcpu-enable-microkernels"
- "--iree-opt-data-tiling"
- LABELS
-
- TARGET_CPU_FEATURES_VARIANTS
- "default"
- "x86_64:avx2_fma:+avx,+avx2,+fma"
- "x86_64:avx512_base:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
- "x86_64:avx512_vnni:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
- "arm_64:dotprod:+dotprod"
- "arm_64:i8mm:+i8mm"
-)
-
-iree_generated_trace_runner_test(
- NAME
- e2e_matmul_mmt4d_i8_large_ukernel
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=i8"
- "--shapes=large"
- TRACE_RUNNER
- iree-e2e-matmul-test
- TARGET_BACKENDS
- "llvm-cpu"
- DRIVERS
- "local-task"
- COMPILER_FLAGS
- "--iree-llvmcpu-enable-microkernels"
- "--iree-opt-data-tiling"
- LABELS
- "noriscv"
- "noasan"
- "notsan"
- TARGET_CPU_FEATURES_VARIANTS
- "default"
- "x86_64:avx2_fma:+avx,+avx2,+fma"
- "x86_64:avx512_base:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
- "x86_64:avx512_vnni:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
- "arm_64:dotprod:+dotprod"
- "arm_64:i8mm:+i8mm"
-)
-
-iree_generated_trace_runner_test(
- NAME
- e2e_matmul_mmt4d_f32_small_ukernel
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=f32"
- "--shapes=small"
- TRACE_RUNNER
- iree-e2e-matmul-test
- TARGET_BACKENDS
- "llvm-cpu"
- DRIVERS
- "local-task"
- COMPILER_FLAGS
- "--iree-llvmcpu-enable-microkernels"
- "--iree-opt-data-tiling"
- LABELS
-
- TARGET_CPU_FEATURES_VARIANTS
- "default"
- "x86_64:avx2_fma:+avx,+avx2,+fma"
- "x86_64:avx512_base:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
-)
-
-iree_generated_trace_runner_test(
- NAME
- e2e_matmul_mmt4d_f32_large_ukernel
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=f32"
- "--shapes=large"
- TRACE_RUNNER
- iree-e2e-matmul-test
- TARGET_BACKENDS
- "llvm-cpu"
- DRIVERS
- "local-task"
- COMPILER_FLAGS
- "--iree-llvmcpu-enable-microkernels"
- "--iree-opt-data-tiling"
- LABELS
- "noriscv"
- "noasan"
- "notsan"
- TARGET_CPU_FEATURES_VARIANTS
- "default"
- "x86_64:avx2_fma:+avx,+avx2,+fma"
- "x86_64:avx512_base:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
-)
-
-iree_generated_trace_runner_test(
- NAME
e2e_matmul_direct_f32_gpu_large_LLVMGPUMatmulSimt
GENERATOR
"generate_e2e_matmul_tests.py"
@@ -545,10 +723,8 @@
iree-e2e-matmul-test
TARGET_BACKENDS
"cuda"
- "llvm-cpu"
DRIVERS
"cuda"
- "local-task"
COMPILER_FLAGS
"--iree-flow-split-matmul-reduction=4"
LABELS
@@ -694,24 +870,6 @@
iree_generated_trace_runner_test(
NAME
- e2e_matmul_direct_f32_small_no_padding
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=f32"
- "--shapes=small"
- TRACE_RUNNER
- iree-e2e-matmul-test
- TARGET_BACKENDS
- "llvm-cpu"
- DRIVERS
- "local-task"
- COMPILER_FLAGS
- "--iree-codegen-enable-vector-padding=false"
-)
-
-iree_generated_trace_runner_test(
- NAME
e2e_matmul_direct_f16_gpu_large_rdna3
GENERATOR
"generate_e2e_matmul_tests.py"
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index bf0df7f..631fd6a 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -22,10 +22,12 @@
# as this also includes accumulator-specific types like i32.
@enum.unique
class MatrixElemTypeId(enum.Enum):
+ NONE = ""
I8 = "i8"
I32 = "i32"
F32 = "f32"
F16 = "f16"
+ BF16 = "bf16"
# Enumerates of the collections of shapes that we can generate tests for.
@@ -613,11 +615,19 @@
parser.add_argument(
"--lhs_rhs_type",
type=str,
- choices=["i8", "f32", "f16"],
+ choices=["i8", "f32", "f16", "bf16"],
help="Numeric type of input matrices",
required=True,
)
parser.add_argument(
+ "--acc_type",
+ type=str,
+ choices=["i32", "f32", "f16", "bf16"],
+ help="Numeric type of input matrices",
+ default="",
+ required=False,
+ )
+ parser.add_argument(
"--shapes",
type=str,
choices=[s.value for s in ShapesId],
@@ -632,7 +642,6 @@
default="",
required=False,
)
-
parser.add_argument(
"--module_path",
type=str,
@@ -704,16 +713,18 @@
# type, so we do that. That is temporary: eventually there will be cases
# where the same input types are used with different accumulator types, e.g.
# f16 inputs with both f16 and f32 accumulator.
-def infer_acc_type(lhs_rhs_type: MatrixElemTypeId):
+def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
+ if acc_type != MatrixElemTypeId.NONE:
+ return acc_type
if lhs_rhs_type == MatrixElemTypeId.I8:
return MatrixElemTypeId.I32
- else:
- return lhs_rhs_type
+ return lhs_rhs_type
def main(args):
lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type)
- acc_type = infer_acc_type(lhs_rhs_type)
+ acc_type = MatrixElemTypeId(args.acc_type)
+ acc_type = infer_acc_type(lhs_rhs_type, acc_type)
shapes_id = ShapesId(args.shapes)
compilation_info_id = CompilationInfoId(args.compilation_info)
(function_definitions, traces) = generate(
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index 758ae35..1c0171b 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -59,6 +59,8 @@
IREE_E2E_TEST_VALUE_TYPE_F32 = 6,
// double.
IREE_E2E_TEST_VALUE_TYPE_F64 = 7,
+ // bfloat16
+ IREE_E2E_TEST_VALUE_TYPE_BF16 = 8,
} iree_e2e_test_value_type_t;
// Maximum size, in bytes, of any value type we can represent.
@@ -74,6 +76,7 @@
int64_t i64;
float f32;
uint16_t f16_u16;
+ uint16_t bf16_u16;
double f64;
uint8_t value_storage[IREE_E2E_TEST_VALUE_STORAGE_SIZE]; // max size of all
// value types
@@ -137,6 +140,14 @@
return result;
}
+static inline iree_e2e_test_value_t iree_e2e_test_value_make_bf16(
+ uint16_t value) {
+ iree_e2e_test_value_t result;
+ result.type = IREE_E2E_TEST_VALUE_TYPE_BF16;
+ result.bf16_u16 = value;
+ return result;
+}
+
static inline iree_e2e_test_value_t iree_e2e_test_value_make_f32(float value) {
iree_e2e_test_value_t result;
result.type = IREE_E2E_TEST_VALUE_TYPE_F32;
@@ -155,6 +166,12 @@
return value->f16_u16;
}
+// TODO(#5542): check the value type before accessing the union.
+static inline uint16_t iree_e2e_test_value_get_bf16(
+ iree_e2e_test_value_t* value) {
+ return value->bf16_u16;
+}
+
static inline iree_e2e_test_value_t iree_e2e_test_value_make_f64(double value) {
iree_e2e_test_value_t result;
result.type = IREE_E2E_TEST_VALUE_TYPE_F64;
@@ -343,10 +360,13 @@
switch (element_type) {
WRITE_INT_ELEMENT_CASE(INT_8, int8_t)
WRITE_INT_ELEMENT_CASE(INT_32, int32_t)
+ WRITE_INT_ELEMENT_CASE(FLOAT_32, float)
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
*(uint16_t*)dst = iree_math_f32_to_f16((float)value);
break;
- WRITE_INT_ELEMENT_CASE(FLOAT_32, float)
+ case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
+ *(uint16_t*)dst = iree_math_f32_to_bf16((float)value);
+ break;
default:
IREE_ASSERT(false, "unhandled element type");
break;
@@ -382,6 +402,9 @@
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
((uint16_t*)data)[index] = iree_math_f32_to_f16((float)value);
return;
+ } else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
+ ((uint16_t*)data)[index] = iree_math_f32_to_bf16((float)value);
+ return;
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
((float*)data)[index] = value;
return;
@@ -405,6 +428,8 @@
return iree_e2e_test_value_make_i32(((int32_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
return iree_e2e_test_value_make_f16(((uint16_t*)data)[index]);
+ } else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
+ return iree_e2e_test_value_make_bf16(((uint16_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
return iree_e2e_test_value_make_f32(((float*)data)[index]);
}
@@ -483,9 +508,9 @@
static void reference_matmul_##LHSTYPE##_##RHSTYPE##_##RESTYPE##_##ACCTYPE( \
iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, \
iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, \
- iree_hal_element_type_t acc_type, LHSTYPE* lhs_data, RHSTYPE* rhs_data, \
- ACCTYPE* acc_data, RESTYPE* result_data, iree_hal_dim_t m, \
- iree_hal_dim_t n) { \
+ iree_hal_element_type_t acc_type, const LHSTYPE* lhs_data, \
+ const RHSTYPE* rhs_data, const ACCTYPE* acc_data, RESTYPE* result_data, \
+ iree_hal_dim_t m, iree_hal_dim_t n) { \
ACCTYPE acc = acc_data ? acc_data[n + m * n_size] : 0; \
for (iree_hal_dim_t k = 0; k < k_size; ++k) { \
LHSTYPE lhs_value = lhs_data[k + m * k_size]; \
@@ -505,15 +530,15 @@
// [i32 <= i8 * i8 + i32]
IREE_TRACE_REPLAY_REFERENCE_MATMUL(int8_t, int8_t, int32_t, int32_t)
-// Reference mamtul for the half_t input, half_t accumlation, and half_t result.
+// Reference mamtul for the f16 input, f16 accumlation, and f16 result.
// [f16 <= f16 * f16 + f16]
static void reference_matmul_f16_f16_f16_f16(
iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
- iree_hal_element_type_t acc_type, uint16_t* lhs_data, uint16_t* rhs_data,
- uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t m,
- iree_hal_dim_t n) {
- float acc = acc_data ? iree_math_f16_to_f32(acc_data[n + m * n_size]) : 0;
+ iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
+ const uint16_t* rhs_data, const uint16_t* acc_data, uint16_t* result_data,
+ iree_hal_dim_t m, iree_hal_dim_t n) {
+ float acc = acc_data ? iree_math_f16_to_f32(acc_data[n + m * n_size]) : 0.f;
for (iree_hal_dim_t k = 0; k < k_size; ++k) {
acc = iree_math_round_to_nearest_f16(
iree_math_round_to_nearest_f16(
@@ -524,6 +549,57 @@
result_data[n + m * n_size] = iree_math_f32_to_f16(acc);
}
+// Reference mamtul for the f16 input, f32 accumlation, and f32 result.
+// [f32 <= f16 * f16 + f32]
+static void reference_matmul_f16_f16_f32_f32(
+ iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
+ iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
+ iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
+ const uint16_t* rhs_data, const float* acc_data, float* result_data,
+ iree_hal_dim_t m, iree_hal_dim_t n) {
+ float acc = acc_data ? acc_data[n + m * n_size] : 0.f;
+ for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+ acc += iree_math_f16_to_f32(lhs_data[k + m * k_size]) *
+ iree_math_f16_to_f32(rhs_data[n + k * n_size]);
+ }
+ result_data[n + m * n_size] = acc;
+}
+
+// Reference mamtul for the bf16 input, bf16 accumlation, and bf16 result.
+// [bf16 <= bf16 * bf16 + bf16]
+static void reference_matmul_bf16_bf16_bf16_bf16(
+ iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
+ iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
+ iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
+ const uint16_t* rhs_data, const uint16_t* acc_data, uint16_t* result_data,
+ iree_hal_dim_t m, iree_hal_dim_t n) {
+ float acc = acc_data ? iree_math_bf16_to_f32(acc_data[n + m * n_size]) : 0.f;
+ for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+ acc = iree_math_round_to_nearest_bf16(
+ iree_math_round_to_nearest_bf16(
+ (iree_math_bf16_to_f32(lhs_data[k + m * k_size]) *
+ iree_math_bf16_to_f32(rhs_data[n + k * n_size]))) +
+ acc);
+ }
+ result_data[n + m * n_size] = iree_math_f32_to_bf16(acc);
+}
+
+// Reference mamtul for the bf16 input, f32 accumlation, and f32 result.
+// [f32 <= bf16 * bf16 + f32]
+static void reference_matmul_bf16_bf16_f32_f32(
+ iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
+ iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
+ iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
+ const uint16_t* rhs_data, const float* acc_data, float* result_data,
+ iree_hal_dim_t m, iree_hal_dim_t n) {
+ float acc = acc_data ? acc_data[n + m * n_size] : 0.f;
+ for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+ acc += iree_math_bf16_to_f32(lhs_data[k + m * k_size]) *
+ iree_math_bf16_to_f32(rhs_data[n + k * n_size]);
+ }
+ result_data[n + m * n_size] = acc;
+}
+
// Helper for reference_matmul.
// Computes one element in the result matrix.
static void reference_matmul_element(
@@ -535,21 +611,44 @@
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
reference_matmul_float_float_float_float(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type, (float*)lhs_data,
- (float*)rhs_data, (float*)acc_data, (float*)result_data, m, n);
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ (const float*)lhs_data, (const float*)rhs_data, (const float*)acc_data,
+ (float*)result_data, m, n);
} else if (iree_hal_element_type_is_integer(lhs_type, 8) &&
iree_hal_element_type_is_integer(rhs_type, 8) &&
iree_hal_element_type_is_integer(acc_type, 32)) {
reference_matmul_int8_t_int8_t_int32_t_int32_t(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type, (int8_t*)lhs_data,
- (int8_t*)rhs_data, (int32_t*)acc_data, (int32_t*)result_data, m, n);
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ (const int8_t*)lhs_data, (const int8_t*)rhs_data,
+ (const int32_t*)acc_data, (int32_t*)result_data, m, n);
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
- reference_matmul_f16_f16_f16_f16(m_size, k_size, n_size, lhs_type, rhs_type,
- acc_type, (uint16_t*)lhs_data,
- (uint16_t*)rhs_data, (uint16_t*)acc_data,
- (uint16_t*)result_data, m, n);
+ reference_matmul_f16_f16_f16_f16(
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
+ (const uint16_t*)acc_data, (uint16_t*)result_data, m, n);
+ } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
+ rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
+ acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+ reference_matmul_f16_f16_f32_f32(
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
+ (const float*)acc_data, (float*)result_data, m, n);
+ } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
+ rhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
+ acc_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
+ reference_matmul_bf16_bf16_bf16_bf16(
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
+ (const uint16_t*)acc_data, (uint16_t*)result_data, m, n);
+ } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
+ rhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
+ acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+ reference_matmul_bf16_bf16_f32_f32(
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
+ (const float*)acc_data, (float*)result_data, m, n);
} else {
iree_status_abort(
iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
@@ -682,6 +781,12 @@
return fabsf(iree_math_f16_to_f32(actual.f16_u16) -
iree_math_f16_to_f32(expected.f16_u16)) <
FLAG_acceptable_fp_delta;
+ case IREE_E2E_TEST_VALUE_TYPE_BF16:
+ if (actual.bf16_u16 == expected.bf16_u16) return true;
+ if (FLAG_require_exact_results) return false;
+ return fabsf(iree_math_bf16_to_f32(actual.f16_u16) -
+ iree_math_bf16_to_f32(expected.f16_u16)) <
+ FLAG_acceptable_fp_delta;
case IREE_E2E_TEST_VALUE_TYPE_F32:
if (actual.f32 == expected.f32) return true;
if (FLAG_require_exact_results) return false;
@@ -843,12 +948,7 @@
// We have a lot more freedom to pick k_start, k_end, since these parameters
// only affect which regions of the input lhs and rhs matrices are printed.
// If we were only testing random lhs and rhs, we would just pick
- // k_start = 0 and any reasonable k_end value. Since we are often using
- // identity matrices for lhs and rhs, and we expect the majority of
- // test failures to occur with such identity matrices, we try to pick
- // k_start and k_end so that nontrivial regions of identity matrices will be
- // printed. That means that we try to have [k_start, k_end) intervals
- // overlap [m_start, m_end) and [n_start, n_end).
+ // k_start = 0 and any reasonable k_end value.
int k_start = iree_max(0, iree_min(m_start, n_start));
int k_end = iree_min(k_size, iree_max(m_end, n_end));
// [k_start, k_end) could be arbitrarily long at this point. Constrain it a
@@ -978,29 +1078,6 @@
*
*****************************************************************************/
-static iree_status_t make_identity_matrix_callback(
- iree_hal_buffer_mapping_t* mapping, void* user_data) {
- iree_hal_buffer_view_t* src = (iree_hal_buffer_view_t*)user_data;
- iree_hal_element_type_t elem_type = iree_hal_buffer_view_element_type(src);
- iree_host_size_t elem_byte_count =
- iree_hal_element_dense_byte_count(elem_type);
- iree_hal_dim_t dims[2] = {0};
- IREE_RETURN_IF_ERROR(get_matrix_shape(src, dims));
- int rows = dims[0];
- int cols = dims[1];
- // Write 1 to matrix elements on the main diagonal.
- int diagonal_size = iree_min(rows, cols);
- memset(mapping->contents.data, 0, mapping->contents.data_length);
- intptr_t diagonal_elem_addr = (intptr_t)mapping->contents.data;
- for (int i = 0; i < diagonal_size; ++i) {
- write_int_element(elem_type, 1, (void*)diagonal_elem_addr);
- // Due to the row-major storage, the diagonal entries are every
- // (cols + 1)-th buffer elements.
- diagonal_elem_addr += elem_byte_count * (cols + 1);
- }
- return iree_ok_status();
-}
-
// Deep-copies device-local list of buffer_views |src| into |dst|.
static iree_status_t copy_device_buffer_views_to_device(
iree_hal_device_t* device, iree_hal_allocator_t* hal_allocator,