Ukernels: mmt4d paths for arm64 fp16 extensions (#14490)
On the arm64 architecture, widely-available `fp16` and `fp16fml`
extensions allow greatly speeding up the `f16f16f16` and `f16f16f32`
cases respectively.
Along the way, this renames `fullfp16` -> `fp16` --- we follow LLVM's
nomenclature for ISA extensions and there is some confusion there as
both strings exist in LLVM, but from actually exercising this, it
appears that `fp16` is what we need here, and that conveniently allows
dropping those 4 characters, `full`.
Also along the way, reported and worked around
https://github.com/llvm/llvm-project/issues/64104
Results from `mmt4d_benchmark`, Gflop/s, single thread. Results with
these new code paths are **in bold** and are to be compared to the
results immediately above them (same element types, not using the new
code path).
benchmark | Arm Cortex-X2 | Arm Cortex-A510
--- | --- | ---
BM_mmt4d_f16f16f32_tile_8x8x1 | 47.1 | 9.5
BM_mmt4d_f16f16f32_tile_8x8x1_fp16fml | **92.0** | **12.1**
BM_mmt4d_f16f16f16_tile_8x8x1 | 46.8 | 9.49
BM_mmt4d_f16f16f16_tile_8x8x1_fp16 | **178.5** | **21.3**
diff --git a/runtime/src/iree/base/internal/cpu.c b/runtime/src/iree/base/internal/cpu.c
index 44d97716..08d99f2 100644
--- a/runtime/src/iree/base/internal/cpu.c
+++ b/runtime/src/iree/base/internal/cpu.c
@@ -61,7 +61,7 @@
IREE_COPY_BITS(out_fields[0], IREE_CPU_DATA0_ARM_64_LSE, hwcap,
IREE_HWCAP_ATOMICS);
// LSE2/lse128 does not seem to be exposed in hwcaps.
- IREE_COPY_BITS(out_fields[0], IREE_CPU_DATA0_ARM_64_FULLFP16, hwcap,
+ IREE_COPY_BITS(out_fields[0], IREE_CPU_DATA0_ARM_64_FP16, hwcap,
IREE_HWCAP_ASIMDHP);
IREE_COPY_BITS(out_fields[0], IREE_CPU_DATA0_ARM_64_FP16FML, hwcap,
IREE_HWCAP_ASIMDFHM);
@@ -128,7 +128,7 @@
{
.sysctl_key = "hw.optional.arm.FEAT_FP16",
.out_field_index = 0,
- .out_field_bits = IREE_CPU_DATA0_ARM_64_FULLFP16,
+ .out_field_bits = IREE_CPU_DATA0_ARM_64_FP16,
},
{
.sysctl_key = "hw.optional.arm.FEAT_FHM",
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD.bazel b/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD.bazel
index 7ea1f99..3f0bffa 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD.bazel
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD.bazel
@@ -67,6 +67,22 @@
)
iree_bitcode_library(
+ name = "ukernel_bitcode_arm_64_fp16",
+ srcs = ["mmt4d_arm_64_fp16.c"],
+ arch = "arm_64",
+ copts = ["-march=armv8.2-a+fp16"],
+ internal_hdrs = UKERNEL_ARM_64_INTERNAL_HEADERS,
+)
+
+iree_bitcode_library(
+ name = "ukernel_bitcode_arm_64_fp16fml",
+ srcs = ["mmt4d_arm_64_fp16fml.c"],
+ arch = "arm_64",
+ copts = ["-march=armv8.2-a+fp16fml"],
+ internal_hdrs = UKERNEL_ARM_64_INTERNAL_HEADERS,
+)
+
+iree_bitcode_library(
name = "ukernel_bitcode_arm_64_dotprod",
srcs = ["mmt4d_arm_64_dotprod.c"],
arch = "arm_64",
@@ -86,6 +102,8 @@
name = "ukernel_bitcode_arm_64",
bitcode_files = [
"ukernel_bitcode_arm_64_base.bc",
+ "ukernel_bitcode_arm_64_fp16.bc",
+ "ukernel_bitcode_arm_64_fp16fml.bc",
"ukernel_bitcode_arm_64_dotprod.bc",
"ukernel_bitcode_arm_64_i8mm.bc",
],
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt
index 50b6131..c245d6b 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt
@@ -38,6 +38,28 @@
iree_bitcode_library(
NAME
+ ukernel_bitcode_arm_64_fp16
+ ARCH
+ arm_64
+ SRCS
+ "mmt4d_arm_64_fp16.c"
+ COPTS
+ "-march=armv8.2-a+fp16"
+)
+
+iree_bitcode_library(
+ NAME
+ ukernel_bitcode_arm_64_fp16fml
+ ARCH
+ arm_64
+ SRCS
+ "mmt4d_arm_64_fp16fml.c"
+ COPTS
+ "-march=armv8.2-a+fp16fml"
+)
+
+iree_bitcode_library(
+ NAME
ukernel_bitcode_arm_64_dotprod
ARCH
arm_64
@@ -64,6 +86,8 @@
SRCS
"ukernel_bitcode_arm_64_base.bc"
"ukernel_bitcode_arm_64_dotprod.bc"
+ "ukernel_bitcode_arm_64_fp16.bc"
+ "ukernel_bitcode_arm_64_fp16fml.bc"
"ukernel_bitcode_arm_64_i8mm.bc"
)
@@ -79,6 +103,16 @@
return()
endif()
+iree_select_compiler_opts(IREE_UK_COPTS_ARM_64_FP16
+ CLANG_OR_GCC
+ "-march=armv8.2-a+fp16"
+)
+
+iree_select_compiler_opts(IREE_UK_COPTS_ARM_64_FP16FML
+ CLANG_OR_GCC
+ "-march=armv8.2-a+fp16fml"
+)
+
iree_select_compiler_opts(IREE_UK_COPTS_ARM_64_DOTPROD
CLANG_OR_GCC
"-march=armv8.2-a+dotprod"
@@ -89,6 +123,8 @@
"-march=armv8.2-a+i8mm"
)
+check_cxx_compiler_flag("${IREE_UK_COPTS_ARM_64_FP16}" IREE_UK_BUILD_ARM_64_FP16)
+check_cxx_compiler_flag("${IREE_UK_COPTS_ARM_64_FP16FML}" IREE_UK_BUILD_ARM_64_FP16FML)
check_cxx_compiler_flag("${IREE_UK_COPTS_ARM_64_DOTPROD}" IREE_UK_BUILD_ARM_64_DOTPROD)
check_cxx_compiler_flag("${IREE_UK_COPTS_ARM_64_I8MM}" IREE_UK_BUILD_ARM_64_I8MM)
configure_file("config_arm_64.h.in" "config_arm_64.h")
@@ -105,6 +141,34 @@
set(IREE_UK_ARM_64_DEPS "")
+if(IREE_UK_BUILD_ARM_64_FP16)
+iree_cc_library(
+ NAME
+ arm_64_fp16
+ SRCS
+ "mmt4d_arm_64_fp16.c"
+ COPTS
+ "${IREE_UK_COPTS_ARM_64_FP16}"
+ DEPS
+ iree::builtins::ukernel::internal_headers
+)
+list(APPEND IREE_UK_ARM_64_DEPS "::arm_64_fp16")
+endif() # IREE_UK_BUILD_ARM_64_FP16
+
+if(IREE_UK_BUILD_ARM_64_FP16FML)
+iree_cc_library(
+ NAME
+ arm_64_fp16fml
+ SRCS
+ "mmt4d_arm_64_fp16fml.c"
+ COPTS
+ "${IREE_UK_COPTS_ARM_64_FP16FML}"
+ DEPS
+ iree::builtins::ukernel::internal_headers
+)
+list(APPEND IREE_UK_ARM_64_DEPS "::arm_64_fp16fml")
+endif() # IREE_UK_BUILD_ARM_64_FP16FML
+
if(IREE_UK_BUILD_ARM_64_DOTPROD)
iree_cc_library(
NAME
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64_entry_point.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64_entry_point.h
index accec89..d3f4ad0 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64_entry_point.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64_entry_point.h
@@ -12,6 +12,8 @@
#if defined(IREE_DEVICE_STANDALONE)
// Standalone builds (e.g. bitcode) use our own Clang, supporting everything.
+#define IREE_UK_BUILD_ARM_64_FP16
+#define IREE_UK_BUILD_ARM_64_FP16FML
#define IREE_UK_BUILD_ARM_64_DOTPROD
#define IREE_UK_BUILD_ARM_64_I8MM
#else
@@ -19,6 +21,19 @@
#include "iree/builtins/ukernel/arch/arm_64/config_arm_64.h"
#endif
+#if defined(IREE_UK_BUILD_ARM_64_FP16)
+static inline bool iree_uk_cpu_supports_fp16(const iree_uk_uint64_t* cpu_data) {
+ return iree_uk_all_bits_set(cpu_data[0], IREE_CPU_DATA0_ARM_64_FP16);
+}
+#endif // IREE_UK_BUILD_ARM_64_FP16
+
+#if defined(IREE_UK_BUILD_ARM_64_FP16FML)
+static inline bool iree_uk_cpu_supports_fp16fml(
+ const iree_uk_uint64_t* cpu_data) {
+ return iree_uk_all_bits_set(cpu_data[0], IREE_CPU_DATA0_ARM_64_FP16FML);
+}
+#endif // IREE_UK_BUILD_ARM_64_FP16FML
+
#if defined(IREE_UK_BUILD_ARM_64_DOTPROD)
static inline bool iree_uk_cpu_supports_dotprod(
const iree_uk_uint64_t* cpu_data) {
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/config_arm_64.h.in b/runtime/src/iree/builtins/ukernel/arch/arm_64/config_arm_64.h.in
index cc1143b..d86c776 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/config_arm_64.h.in
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/config_arm_64.h.in
@@ -11,6 +11,8 @@
#ifndef IREE_BUILTINS_UKERNEL_ARCH_ARM_64_CONFIG_ARM_64_H_
#define IREE_BUILTINS_UKERNEL_ARCH_ARM_64_CONFIG_ARM_64_H_
+#cmakedefine IREE_UK_BUILD_ARM_64_FP16
+#cmakedefine IREE_UK_BUILD_ARM_64_FP16FML
#cmakedefine IREE_UK_BUILD_ARM_64_DOTPROD
#cmakedefine IREE_UK_BUILD_ARM_64_I8MM
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
index 30fff42..599313d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
@@ -49,6 +49,11 @@
iree_uk_mmt4d_select_tile_func_arm_64_f16f16f32(
const iree_uk_mmt4d_params_t* params) {
if (params->M0 == 8 && params->N0 == 8 && params->K0 == 1) {
+#ifdef IREE_UK_BUILD_ARM_64_FP16FML
+ if (iree_uk_cpu_supports_fp16fml(params->cpu_data)) {
+ return iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64_fp16fml;
+ }
+#endif
return iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64;
}
return 0;
@@ -57,9 +62,15 @@
static iree_uk_mmt4d_tile_func_t
iree_uk_mmt4d_select_tile_func_arm_64_f16f16f16(
const iree_uk_mmt4d_params_t* params) {
- if ((params->flags & IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS) &&
- params->M0 == 8 && params->N0 == 8 && params->K0 == 1) {
- return iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64;
+ if (params->M0 == 8 && params->N0 == 8 && params->K0 == 1) {
+#ifdef IREE_UK_BUILD_ARM_64_FP16
+ if (iree_uk_cpu_supports_fp16(params->cpu_data)) {
+ return iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fp16;
+ }
+#endif
+ if (params->flags & IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS) {
+ return iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64;
+ }
}
return 0;
}
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16.c
new file mode 100644
index 0000000..cfeca5d
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16.c
@@ -0,0 +1,61 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/builtins/ukernel/arch/arm_64/common_arm_64.h"
+#include "iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h"
+
+void iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fp16(
+ void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
+ const void* IREE_UK_RESTRICT rhs_panel, iree_uk_int32_t K,
+ iree_uk_uint32_t flags, const iree_uk_mmt4d_params_t* params) {
+ (void)params;
+ float16_t* IREE_UK_RESTRICT out_ptr = out_tile;
+ const float16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
+ const float16_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
+ float16x8_t acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7;
+ if (flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+ acc0 = vld1q_f16(out_ptr + 8 * 0);
+ acc1 = vld1q_f16(out_ptr + 8 * 1);
+ acc2 = vld1q_f16(out_ptr + 8 * 2);
+ acc3 = vld1q_f16(out_ptr + 8 * 3);
+ acc4 = vld1q_f16(out_ptr + 8 * 4);
+ acc5 = vld1q_f16(out_ptr + 8 * 5);
+ acc6 = vld1q_f16(out_ptr + 8 * 6);
+ acc7 = vld1q_f16(out_ptr + 8 * 7);
+ } else {
+ acc0 = vdupq_n_f16(0);
+ acc1 = vdupq_n_f16(0);
+ acc2 = vdupq_n_f16(0);
+ acc3 = vdupq_n_f16(0);
+ acc4 = vdupq_n_f16(0);
+ acc5 = vdupq_n_f16(0);
+ acc6 = vdupq_n_f16(0);
+ acc7 = vdupq_n_f16(0);
+ }
+ IREE_UK_ASSUME(K >= 1);
+ for (int k = 0; k < K; ++k) {
+ float16x8_t lhs = vld1q_f16(lhs_ptr);
+ lhs_ptr += 8;
+ float16x8_t rhs = vld1q_f16(rhs_ptr);
+ rhs_ptr += 8;
+ acc0 = vfmaq_lane_f16(acc0, rhs, vget_low_f16(lhs), 0);
+ acc1 = vfmaq_lane_f16(acc1, rhs, vget_low_f16(lhs), 1);
+ acc2 = vfmaq_lane_f16(acc2, rhs, vget_low_f16(lhs), 2);
+ acc3 = vfmaq_lane_f16(acc3, rhs, vget_low_f16(lhs), 3);
+ acc4 = vfmaq_lane_f16(acc4, rhs, vget_high_f16(lhs), 0);
+ acc5 = vfmaq_lane_f16(acc5, rhs, vget_high_f16(lhs), 1);
+ acc6 = vfmaq_lane_f16(acc6, rhs, vget_high_f16(lhs), 2);
+ acc7 = vfmaq_lane_f16(acc7, rhs, vget_high_f16(lhs), 3);
+ }
+ vst1q_f16(out_ptr + 8 * 0, acc0);
+ vst1q_f16(out_ptr + 8 * 1, acc1);
+ vst1q_f16(out_ptr + 8 * 2, acc2);
+ vst1q_f16(out_ptr + 8 * 3, acc3);
+ vst1q_f16(out_ptr + 8 * 4, acc4);
+ vst1q_f16(out_ptr + 8 * 5, acc5);
+ vst1q_f16(out_ptr + 8 * 6, acc6);
+ vst1q_f16(out_ptr + 8 * 7, acc7);
+}
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16fml.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16fml.c
new file mode 100644
index 0000000..2d1e890
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16fml.c
@@ -0,0 +1,122 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/builtins/ukernel/arch/arm_64/common_arm_64.h"
+#include "iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h"
+
+#ifdef __clang__
+// Work around https://github.com/llvm/llvm-project/issues/64104
+// Notes:
+// 1. The key to this work-around is the "x" constraint on the C operand,
+// which restricts it to registers v0 .. v15, as opposed to the "w"
+// constraint used here for the A and B operands, allowing v0 .. v31. See:
+// https://llvm.org/docs/LangRef.html#supported-constraint-code-list
+// 2. The ({...}) syntax is GCC-compatible "statement expressions". See:
+// https://gcc.gnu.org/onlinedocs/gcc/Statement-Exprs.html
+#define iree_workaround_vfmlalq_laneq_x_f16(INSTR, A, B, C, L) \
+ ({ \
+ asm(INSTR " %[a].4s, %[b].4h, %[c].h[%[l]]" \
+ : [a] "+w"(A) \
+ : [b] "w"(B), [c] "x"(C), [l] "i"(L) \
+ :); \
+ A; \
+ })
+#define iree_workaround_vfmlalq_laneq_low_f16(A, B, C, L) \
+ iree_workaround_vfmlalq_laneq_x_f16("fmlal", A, B, C, L)
+#define iree_workaround_vfmlalq_laneq_high_f16(A, B, C, L) \
+ iree_workaround_vfmlalq_laneq_x_f16("fmlal2", A, B, C, L)
+#else
+#define iree_workaround_vfmlalq_laneq_low_f16(A, X, Y, L) \
+ vfmlalq_laneq_low_f16(A, X, Y, L)
+#define iree_workaround_vfmlalq_laneq_high_f16(A, X, Y, L) \
+ vfmlalq_laneq_high_f16(A, X, Y, L)
+#endif
+
+void iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64_fp16fml(
+ void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
+ const void* IREE_UK_RESTRICT rhs_panel, iree_uk_int32_t K,
+ iree_uk_uint32_t flags, const iree_uk_mmt4d_params_t* params) {
+ (void)params;
+ float* IREE_UK_RESTRICT out_ptr = out_tile;
+ const float16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
+ const float16_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
+ float32x4_t acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7, acc8, acc9, acc10,
+ acc11, acc12, acc13, acc14, acc15;
+ if (flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+ acc0 = vld1q_f32(out_ptr + 4 * 0);
+ acc1 = vld1q_f32(out_ptr + 4 * 1);
+ acc2 = vld1q_f32(out_ptr + 4 * 2);
+ acc3 = vld1q_f32(out_ptr + 4 * 3);
+ acc4 = vld1q_f32(out_ptr + 4 * 4);
+ acc5 = vld1q_f32(out_ptr + 4 * 5);
+ acc6 = vld1q_f32(out_ptr + 4 * 6);
+ acc7 = vld1q_f32(out_ptr + 4 * 7);
+ acc8 = vld1q_f32(out_ptr + 4 * 8);
+ acc9 = vld1q_f32(out_ptr + 4 * 9);
+ acc10 = vld1q_f32(out_ptr + 4 * 10);
+ acc11 = vld1q_f32(out_ptr + 4 * 11);
+ acc12 = vld1q_f32(out_ptr + 4 * 12);
+ acc13 = vld1q_f32(out_ptr + 4 * 13);
+ acc14 = vld1q_f32(out_ptr + 4 * 14);
+ acc15 = vld1q_f32(out_ptr + 4 * 15);
+ } else {
+ acc0 = vdupq_n_f32(0);
+ acc1 = vdupq_n_f32(0);
+ acc2 = vdupq_n_f32(0);
+ acc3 = vdupq_n_f32(0);
+ acc4 = vdupq_n_f32(0);
+ acc5 = vdupq_n_f32(0);
+ acc6 = vdupq_n_f32(0);
+ acc7 = vdupq_n_f32(0);
+ acc8 = vdupq_n_f32(0);
+ acc9 = vdupq_n_f32(0);
+ acc10 = vdupq_n_f32(0);
+ acc11 = vdupq_n_f32(0);
+ acc12 = vdupq_n_f32(0);
+ acc13 = vdupq_n_f32(0);
+ acc14 = vdupq_n_f32(0);
+ acc15 = vdupq_n_f32(0);
+ }
+ IREE_UK_ASSUME(K >= 1);
+ for (int k = 0; k < K; ++k) {
+ float16x8_t lhs = vld1q_f16(lhs_ptr);
+ lhs_ptr += 8;
+ float16x8_t rhs = vld1q_f16(rhs_ptr);
+ rhs_ptr += 8;
+ acc0 = iree_workaround_vfmlalq_laneq_low_f16(acc0, rhs, lhs, 0);
+ acc1 = iree_workaround_vfmlalq_laneq_high_f16(acc1, rhs, lhs, 0);
+ acc2 = iree_workaround_vfmlalq_laneq_low_f16(acc2, rhs, lhs, 1);
+ acc3 = iree_workaround_vfmlalq_laneq_high_f16(acc3, rhs, lhs, 1);
+ acc4 = iree_workaround_vfmlalq_laneq_low_f16(acc4, rhs, lhs, 2);
+ acc5 = iree_workaround_vfmlalq_laneq_high_f16(acc5, rhs, lhs, 2);
+ acc6 = iree_workaround_vfmlalq_laneq_low_f16(acc6, rhs, lhs, 3);
+ acc7 = iree_workaround_vfmlalq_laneq_high_f16(acc7, rhs, lhs, 3);
+ acc8 = iree_workaround_vfmlalq_laneq_low_f16(acc8, rhs, lhs, 4);
+ acc9 = iree_workaround_vfmlalq_laneq_high_f16(acc9, rhs, lhs, 4);
+ acc10 = iree_workaround_vfmlalq_laneq_low_f16(acc10, rhs, lhs, 5);
+ acc11 = iree_workaround_vfmlalq_laneq_high_f16(acc11, rhs, lhs, 5);
+ acc12 = iree_workaround_vfmlalq_laneq_low_f16(acc12, rhs, lhs, 6);
+ acc13 = iree_workaround_vfmlalq_laneq_high_f16(acc13, rhs, lhs, 6);
+ acc14 = iree_workaround_vfmlalq_laneq_low_f16(acc14, rhs, lhs, 7);
+ acc15 = iree_workaround_vfmlalq_laneq_high_f16(acc15, rhs, lhs, 7);
+ }
+ vst1q_f32(out_ptr + 4 * 0, acc0);
+ vst1q_f32(out_ptr + 4 * 1, acc1);
+ vst1q_f32(out_ptr + 4 * 2, acc2);
+ vst1q_f32(out_ptr + 4 * 3, acc3);
+ vst1q_f32(out_ptr + 4 * 4, acc4);
+ vst1q_f32(out_ptr + 4 * 5, acc5);
+ vst1q_f32(out_ptr + 4 * 6, acc6);
+ vst1q_f32(out_ptr + 4 * 7, acc7);
+ vst1q_f32(out_ptr + 4 * 8, acc8);
+ vst1q_f32(out_ptr + 4 * 9, acc9);
+ vst1q_f32(out_ptr + 4 * 10, acc10);
+ vst1q_f32(out_ptr + 4 * 11, acc11);
+ vst1q_f32(out_ptr + 4 * 12, acc12);
+ vst1q_f32(out_ptr + 4 * 13, acc13);
+ vst1q_f32(out_ptr + 4 * 14, acc14);
+ vst1q_f32(out_ptr + 4 * 15, acc15);
+}
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
index 7a8c27e..91066f4 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
@@ -11,7 +11,9 @@
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f32f32f32_8x8x1_arm_64)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64_fp16fml)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fp16)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_8x8x1_arm_64)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_8x8x4_arm_64_dotprod)
IREE_UK_MMT4D_TILE_FUNC_DECL(
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index ced9482..04e6306 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -133,8 +133,12 @@
"");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1,
"");
+ iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1,
+ "fp16fml");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1,
"");
+ iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1,
+ "fp16");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 1,
"");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 4,
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index 680a873..3072812 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -388,8 +388,11 @@
// we use iree_uk_test_mmt4d_default_and_intrinsics to test both.
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1, "");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1, "");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1, "fp16fml");
iree_uk_test_mmt4d_default_and_skip_intermediate_roundings(
IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1, "");
+ iree_uk_test_mmt4d_default_and_skip_intermediate_roundings(
+ IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1, "fp16");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 1, "");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 4, "dotprod");
iree_uk_test_mmt4d_default_and_intrinsics(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8,
diff --git a/runtime/src/iree/schemas/cpu_feature_bits.inl b/runtime/src/iree/schemas/cpu_feature_bits.inl
index a8380c9..b6bf4e2 100644
--- a/runtime/src/iree/schemas/cpu_feature_bits.inl
+++ b/runtime/src/iree/schemas/cpu_feature_bits.inl
@@ -54,7 +54,7 @@
IREE_CPU_FEATURE_BIT(ARM_64, 0, 2, LSE128, "lse128") // Armv8.1 atomics
// SIMD features, not SVE-specific.
-IREE_CPU_FEATURE_BIT(ARM_64, 0, 10, FULLFP16, "fullfp16")
+IREE_CPU_FEATURE_BIT(ARM_64, 0, 10, FP16, "fp16")
IREE_CPU_FEATURE_BIT(ARM_64, 0, 11, FP16FML, "fp16fml")
IREE_CPU_FEATURE_BIT(ARM_64, 0, 12, DOTPROD, "dotprod")
IREE_CPU_FEATURE_BIT(ARM_64, 0, 13, I8MM, "i8mm")