bf16: select appropriate tile sizes on x86 and Arm, and enable in x86 bitcode build (#15244)
This PR fixes 2 issues uncovered by e2e testing of bf16 matmuls, #15243.
I noticed that the optimized ukernel code paths weren't exercised as
they should be. There were 2 separate issues, both fixed by this PR:
1. We weren't picking the right tile size in MaterializeEncoding.
2. On x86, we weren't defining `IREE_UK_BUILD_X86_64_AVX512_BF16` in the
bitcode build.
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
index d283a49..bdf7f80 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
@@ -57,6 +57,12 @@
Type out = elementTypes[2];
if (out.isF32() || out.isF16() || out.isBF16()) {
+ if (lhs.isBF16() && rhs.isBF16() && (out.isBF16() || out.isF32())) {
+ if (hasFeature(target, "+bf16")) {
+ // Aim to use BFMMLA.
+ return MatmulTileParams{8, 4, 8};
+ }
+ }
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
@@ -94,6 +100,11 @@
Type out = elementTypes[2];
if (out.isF32() || out.isF16() || out.isBF16()) {
+ if (lhs.isBF16() && rhs.isBF16() && (out.isBF16() || out.isF32())) {
+ if (hasFeature(target, "+avx512bf16")) {
+ return MatmulTileParams{16, 2, 16};
+ }
+ }
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h
index 9720e64..5f5c11c 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h
@@ -15,6 +15,7 @@
#define IREE_UK_BUILD_X86_64_AVX2_FMA
#define IREE_UK_BUILD_X86_64_AVX512_BASE
#define IREE_UK_BUILD_X86_64_AVX512_VNNI
+#define IREE_UK_BUILD_X86_64_AVX512_BF16
#else // IREE_DEVICE_STANDALONE
// Compiling with the system toolchain. Include the configured header.
#include "iree/builtins/ukernel/arch/x86_64/config_x86_64.h"