Ukernels: mmt4d paths for arm64 fp16 extensions (#14490)

On the arm64 architecture, widely-available `fp16` and `fp16fml`
extensions allow greatly speeding up the `f16f16f16` and `f16f16f32`
cases respectively.

Along the way, this renames `fullfp16` -> `fp16` --- we follow LLVM's
nomenclature for ISA extensions and there is some confusion there as
both strings exist in LLVM, but from actually exercising this, it
appears that `fp16` is what we need here, and that conveniently allows
dropping those 4 characters, `full`.

Also along the way, reported and worked around
https://github.com/llvm/llvm-project/issues/64104

Results from `mmt4d_benchmark`, Gflop/s, single thread. Results with
these new code paths are **in bold** and are to be compared to the
results immediately above them (same element types, not using the new
code path).

benchmark | Arm Cortex-X2 | Arm Cortex-A510
--- | --- | ---
BM_mmt4d_f16f16f32_tile_8x8x1 | 47.1 | 9.5
BM_mmt4d_f16f16f32_tile_8x8x1_fp16fml | **92.0** | **12.1**
BM_mmt4d_f16f16f16_tile_8x8x1 | 46.8 | 9.49
BM_mmt4d_f16f16f16_tile_8x8x1_fp16 | **178.5** | **21.3**
diff --git a/runtime/src/iree/base/internal/cpu.c b/runtime/src/iree/base/internal/cpu.c
index 44d97716..08d99f2 100644
--- a/runtime/src/iree/base/internal/cpu.c
+++ b/runtime/src/iree/base/internal/cpu.c
@@ -61,7 +61,7 @@
   IREE_COPY_BITS(out_fields[0], IREE_CPU_DATA0_ARM_64_LSE, hwcap,
                  IREE_HWCAP_ATOMICS);
   // LSE2/lse128 does not seem to be exposed in hwcaps.
-  IREE_COPY_BITS(out_fields[0], IREE_CPU_DATA0_ARM_64_FULLFP16, hwcap,
+  IREE_COPY_BITS(out_fields[0], IREE_CPU_DATA0_ARM_64_FP16, hwcap,
                  IREE_HWCAP_ASIMDHP);
   IREE_COPY_BITS(out_fields[0], IREE_CPU_DATA0_ARM_64_FP16FML, hwcap,
                  IREE_HWCAP_ASIMDFHM);
@@ -128,7 +128,7 @@
       {
           .sysctl_key = "hw.optional.arm.FEAT_FP16",
           .out_field_index = 0,
-          .out_field_bits = IREE_CPU_DATA0_ARM_64_FULLFP16,
+          .out_field_bits = IREE_CPU_DATA0_ARM_64_FP16,
       },
       {
           .sysctl_key = "hw.optional.arm.FEAT_FHM",
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD.bazel b/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD.bazel
index 7ea1f99..3f0bffa 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD.bazel
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/BUILD.bazel
@@ -67,6 +67,22 @@
 )
 
 iree_bitcode_library(
+    name = "ukernel_bitcode_arm_64_fp16",
+    srcs = ["mmt4d_arm_64_fp16.c"],
+    arch = "arm_64",
+    copts = ["-march=armv8.2-a+fp16"],
+    internal_hdrs = UKERNEL_ARM_64_INTERNAL_HEADERS,
+)
+
+iree_bitcode_library(
+    name = "ukernel_bitcode_arm_64_fp16fml",
+    srcs = ["mmt4d_arm_64_fp16fml.c"],
+    arch = "arm_64",
+    copts = ["-march=armv8.2-a+fp16fml"],
+    internal_hdrs = UKERNEL_ARM_64_INTERNAL_HEADERS,
+)
+
+iree_bitcode_library(
     name = "ukernel_bitcode_arm_64_dotprod",
     srcs = ["mmt4d_arm_64_dotprod.c"],
     arch = "arm_64",
@@ -86,6 +102,8 @@
     name = "ukernel_bitcode_arm_64",
     bitcode_files = [
         "ukernel_bitcode_arm_64_base.bc",
+        "ukernel_bitcode_arm_64_fp16.bc",
+        "ukernel_bitcode_arm_64_fp16fml.bc",
         "ukernel_bitcode_arm_64_dotprod.bc",
         "ukernel_bitcode_arm_64_i8mm.bc",
     ],
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt
index 50b6131..c245d6b 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/CMakeLists.txt
@@ -38,6 +38,28 @@
 
 iree_bitcode_library(
   NAME
+    ukernel_bitcode_arm_64_fp16
+  ARCH
+    arm_64
+  SRCS
+    "mmt4d_arm_64_fp16.c"
+  COPTS
+    "-march=armv8.2-a+fp16"
+)
+
+iree_bitcode_library(
+  NAME
+    ukernel_bitcode_arm_64_fp16fml
+  ARCH
+    arm_64
+  SRCS
+    "mmt4d_arm_64_fp16fml.c"
+  COPTS
+    "-march=armv8.2-a+fp16fml"
+)
+
+iree_bitcode_library(
+  NAME
     ukernel_bitcode_arm_64_dotprod
   ARCH
     arm_64
@@ -64,6 +86,8 @@
   SRCS
     "ukernel_bitcode_arm_64_base.bc"
     "ukernel_bitcode_arm_64_dotprod.bc"
+    "ukernel_bitcode_arm_64_fp16.bc"
+    "ukernel_bitcode_arm_64_fp16fml.bc"
     "ukernel_bitcode_arm_64_i8mm.bc"
 
 )
@@ -79,6 +103,16 @@
   return()
 endif()
 
+iree_select_compiler_opts(IREE_UK_COPTS_ARM_64_FP16
+  CLANG_OR_GCC
+    "-march=armv8.2-a+fp16"
+)
+
+iree_select_compiler_opts(IREE_UK_COPTS_ARM_64_FP16FML
+  CLANG_OR_GCC
+    "-march=armv8.2-a+fp16fml"
+)
+
 iree_select_compiler_opts(IREE_UK_COPTS_ARM_64_DOTPROD
   CLANG_OR_GCC
     "-march=armv8.2-a+dotprod"
@@ -89,6 +123,8 @@
     "-march=armv8.2-a+i8mm"
 )
 
+check_cxx_compiler_flag("${IREE_UK_COPTS_ARM_64_FP16}" IREE_UK_BUILD_ARM_64_FP16)
+check_cxx_compiler_flag("${IREE_UK_COPTS_ARM_64_FP16FML}" IREE_UK_BUILD_ARM_64_FP16FML)
 check_cxx_compiler_flag("${IREE_UK_COPTS_ARM_64_DOTPROD}" IREE_UK_BUILD_ARM_64_DOTPROD)
 check_cxx_compiler_flag("${IREE_UK_COPTS_ARM_64_I8MM}" IREE_UK_BUILD_ARM_64_I8MM)
 configure_file("config_arm_64.h.in" "config_arm_64.h")
@@ -105,6 +141,34 @@
 
 set(IREE_UK_ARM_64_DEPS "")
 
+if(IREE_UK_BUILD_ARM_64_FP16)
+iree_cc_library(
+  NAME
+    arm_64_fp16
+  SRCS
+    "mmt4d_arm_64_fp16.c"
+  COPTS
+    "${IREE_UK_COPTS_ARM_64_FP16}"
+  DEPS
+    iree::builtins::ukernel::internal_headers
+)
+list(APPEND IREE_UK_ARM_64_DEPS "::arm_64_fp16")
+endif()  # IREE_UK_BUILD_ARM_64_FP16
+
+if(IREE_UK_BUILD_ARM_64_FP16FML)
+iree_cc_library(
+  NAME
+    arm_64_fp16fml
+  SRCS
+    "mmt4d_arm_64_fp16fml.c"
+  COPTS
+    "${IREE_UK_COPTS_ARM_64_FP16FML}"
+  DEPS
+    iree::builtins::ukernel::internal_headers
+)
+list(APPEND IREE_UK_ARM_64_DEPS "::arm_64_fp16fml")
+endif()  # IREE_UK_BUILD_ARM_64_FP16FML
+
 if(IREE_UK_BUILD_ARM_64_DOTPROD)
 iree_cc_library(
   NAME
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64_entry_point.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64_entry_point.h
index accec89..d3f4ad0 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64_entry_point.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64_entry_point.h
@@ -12,6 +12,8 @@
 
 #if defined(IREE_DEVICE_STANDALONE)
 // Standalone builds (e.g. bitcode) use our own Clang, supporting everything.
+#define IREE_UK_BUILD_ARM_64_FP16
+#define IREE_UK_BUILD_ARM_64_FP16FML
 #define IREE_UK_BUILD_ARM_64_DOTPROD
 #define IREE_UK_BUILD_ARM_64_I8MM
 #else
@@ -19,6 +21,19 @@
 #include "iree/builtins/ukernel/arch/arm_64/config_arm_64.h"
 #endif
 
+#if defined(IREE_UK_BUILD_ARM_64_FP16)
+static inline bool iree_uk_cpu_supports_fp16(const iree_uk_uint64_t* cpu_data) {
+  return iree_uk_all_bits_set(cpu_data[0], IREE_CPU_DATA0_ARM_64_FP16);
+}
+#endif  // IREE_UK_BUILD_ARM_64_FP16
+
+#if defined(IREE_UK_BUILD_ARM_64_FP16FML)
+static inline bool iree_uk_cpu_supports_fp16fml(
+    const iree_uk_uint64_t* cpu_data) {
+  return iree_uk_all_bits_set(cpu_data[0], IREE_CPU_DATA0_ARM_64_FP16FML);
+}
+#endif  // IREE_UK_BUILD_ARM_64_FP16FML
+
 #if defined(IREE_UK_BUILD_ARM_64_DOTPROD)
 static inline bool iree_uk_cpu_supports_dotprod(
     const iree_uk_uint64_t* cpu_data) {
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/config_arm_64.h.in b/runtime/src/iree/builtins/ukernel/arch/arm_64/config_arm_64.h.in
index cc1143b..d86c776 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/config_arm_64.h.in
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/config_arm_64.h.in
@@ -11,6 +11,8 @@
 #ifndef IREE_BUILTINS_UKERNEL_ARCH_ARM_64_CONFIG_ARM_64_H_
 #define IREE_BUILTINS_UKERNEL_ARCH_ARM_64_CONFIG_ARM_64_H_
 
+#cmakedefine IREE_UK_BUILD_ARM_64_FP16
+#cmakedefine IREE_UK_BUILD_ARM_64_FP16FML
 #cmakedefine IREE_UK_BUILD_ARM_64_DOTPROD
 #cmakedefine IREE_UK_BUILD_ARM_64_I8MM
 
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
index 30fff42..599313d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
@@ -49,6 +49,11 @@
 iree_uk_mmt4d_select_tile_func_arm_64_f16f16f32(
     const iree_uk_mmt4d_params_t* params) {
   if (params->M0 == 8 && params->N0 == 8 && params->K0 == 1) {
+#ifdef IREE_UK_BUILD_ARM_64_FP16FML
+    if (iree_uk_cpu_supports_fp16fml(params->cpu_data)) {
+      return iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64_fp16fml;
+    }
+#endif
     return iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64;
   }
   return 0;
@@ -57,9 +62,15 @@
 static iree_uk_mmt4d_tile_func_t
 iree_uk_mmt4d_select_tile_func_arm_64_f16f16f16(
     const iree_uk_mmt4d_params_t* params) {
-  if ((params->flags & IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS) &&
-      params->M0 == 8 && params->N0 == 8 && params->K0 == 1) {
-    return iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64;
+  if (params->M0 == 8 && params->N0 == 8 && params->K0 == 1) {
+#ifdef IREE_UK_BUILD_ARM_64_FP16
+    if (iree_uk_cpu_supports_fp16(params->cpu_data)) {
+      return iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fp16;
+    }
+#endif
+    if (params->flags & IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS) {
+      return iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64;
+    }
   }
   return 0;
 }
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16.c
new file mode 100644
index 0000000..cfeca5d
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16.c
@@ -0,0 +1,61 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/builtins/ukernel/arch/arm_64/common_arm_64.h"
+#include "iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h"
+
+void iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fp16(
+    void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
+    const void* IREE_UK_RESTRICT rhs_panel, iree_uk_int32_t K,
+    iree_uk_uint32_t flags, const iree_uk_mmt4d_params_t* params) {
+  (void)params;
+  float16_t* IREE_UK_RESTRICT out_ptr = out_tile;
+  const float16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
+  const float16_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
+  float16x8_t acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7;
+  if (flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+    acc0 = vld1q_f16(out_ptr + 8 * 0);
+    acc1 = vld1q_f16(out_ptr + 8 * 1);
+    acc2 = vld1q_f16(out_ptr + 8 * 2);
+    acc3 = vld1q_f16(out_ptr + 8 * 3);
+    acc4 = vld1q_f16(out_ptr + 8 * 4);
+    acc5 = vld1q_f16(out_ptr + 8 * 5);
+    acc6 = vld1q_f16(out_ptr + 8 * 6);
+    acc7 = vld1q_f16(out_ptr + 8 * 7);
+  } else {
+    acc0 = vdupq_n_f16(0);
+    acc1 = vdupq_n_f16(0);
+    acc2 = vdupq_n_f16(0);
+    acc3 = vdupq_n_f16(0);
+    acc4 = vdupq_n_f16(0);
+    acc5 = vdupq_n_f16(0);
+    acc6 = vdupq_n_f16(0);
+    acc7 = vdupq_n_f16(0);
+  }
+  IREE_UK_ASSUME(K >= 1);
+  for (int k = 0; k < K; ++k) {
+    float16x8_t lhs = vld1q_f16(lhs_ptr);
+    lhs_ptr += 8;
+    float16x8_t rhs = vld1q_f16(rhs_ptr);
+    rhs_ptr += 8;
+    acc0 = vfmaq_lane_f16(acc0, rhs, vget_low_f16(lhs), 0);
+    acc1 = vfmaq_lane_f16(acc1, rhs, vget_low_f16(lhs), 1);
+    acc2 = vfmaq_lane_f16(acc2, rhs, vget_low_f16(lhs), 2);
+    acc3 = vfmaq_lane_f16(acc3, rhs, vget_low_f16(lhs), 3);
+    acc4 = vfmaq_lane_f16(acc4, rhs, vget_high_f16(lhs), 0);
+    acc5 = vfmaq_lane_f16(acc5, rhs, vget_high_f16(lhs), 1);
+    acc6 = vfmaq_lane_f16(acc6, rhs, vget_high_f16(lhs), 2);
+    acc7 = vfmaq_lane_f16(acc7, rhs, vget_high_f16(lhs), 3);
+  }
+  vst1q_f16(out_ptr + 8 * 0, acc0);
+  vst1q_f16(out_ptr + 8 * 1, acc1);
+  vst1q_f16(out_ptr + 8 * 2, acc2);
+  vst1q_f16(out_ptr + 8 * 3, acc3);
+  vst1q_f16(out_ptr + 8 * 4, acc4);
+  vst1q_f16(out_ptr + 8 * 5, acc5);
+  vst1q_f16(out_ptr + 8 * 6, acc6);
+  vst1q_f16(out_ptr + 8 * 7, acc7);
+}
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16fml.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16fml.c
new file mode 100644
index 0000000..2d1e890
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16fml.c
@@ -0,0 +1,122 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/builtins/ukernel/arch/arm_64/common_arm_64.h"
+#include "iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h"
+
+#ifdef __clang__
+// Work around https://github.com/llvm/llvm-project/issues/64104
+// Notes:
+// 1. The key to this work-around is the "x" constraint on the C operand,
+//    which restricts it to registers v0 .. v15, as opposed to the "w"
+//    constraint used here for the A and B operands, allowing v0 .. v31. See:
+//      https://llvm.org/docs/LangRef.html#supported-constraint-code-list
+// 2. The ({...}) syntax is GCC-compatible "statement expressions". See:
+//      https://gcc.gnu.org/onlinedocs/gcc/Statement-Exprs.html
+#define iree_workaround_vfmlalq_laneq_x_f16(INSTR, A, B, C, L) \
+  ({                                                           \
+    asm(INSTR " %[a].4s, %[b].4h, %[c].h[%[l]]"                \
+        : [a] "+w"(A)                                          \
+        : [b] "w"(B), [c] "x"(C), [l] "i"(L)                   \
+        :);                                                    \
+    A;                                                         \
+  })
+#define iree_workaround_vfmlalq_laneq_low_f16(A, B, C, L) \
+  iree_workaround_vfmlalq_laneq_x_f16("fmlal", A, B, C, L)
+#define iree_workaround_vfmlalq_laneq_high_f16(A, B, C, L) \
+  iree_workaround_vfmlalq_laneq_x_f16("fmlal2", A, B, C, L)
+#else
+#define iree_workaround_vfmlalq_laneq_low_f16(A, X, Y, L) \
+  vfmlalq_laneq_low_f16(A, X, Y, L)
+#define iree_workaround_vfmlalq_laneq_high_f16(A, X, Y, L) \
+  vfmlalq_laneq_high_f16(A, X, Y, L)
+#endif
+
+void iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64_fp16fml(
+    void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
+    const void* IREE_UK_RESTRICT rhs_panel, iree_uk_int32_t K,
+    iree_uk_uint32_t flags, const iree_uk_mmt4d_params_t* params) {
+  (void)params;
+  float* IREE_UK_RESTRICT out_ptr = out_tile;
+  const float16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
+  const float16_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
+  float32x4_t acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7, acc8, acc9, acc10,
+      acc11, acc12, acc13, acc14, acc15;
+  if (flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+    acc0 = vld1q_f32(out_ptr + 4 * 0);
+    acc1 = vld1q_f32(out_ptr + 4 * 1);
+    acc2 = vld1q_f32(out_ptr + 4 * 2);
+    acc3 = vld1q_f32(out_ptr + 4 * 3);
+    acc4 = vld1q_f32(out_ptr + 4 * 4);
+    acc5 = vld1q_f32(out_ptr + 4 * 5);
+    acc6 = vld1q_f32(out_ptr + 4 * 6);
+    acc7 = vld1q_f32(out_ptr + 4 * 7);
+    acc8 = vld1q_f32(out_ptr + 4 * 8);
+    acc9 = vld1q_f32(out_ptr + 4 * 9);
+    acc10 = vld1q_f32(out_ptr + 4 * 10);
+    acc11 = vld1q_f32(out_ptr + 4 * 11);
+    acc12 = vld1q_f32(out_ptr + 4 * 12);
+    acc13 = vld1q_f32(out_ptr + 4 * 13);
+    acc14 = vld1q_f32(out_ptr + 4 * 14);
+    acc15 = vld1q_f32(out_ptr + 4 * 15);
+  } else {
+    acc0 = vdupq_n_f32(0);
+    acc1 = vdupq_n_f32(0);
+    acc2 = vdupq_n_f32(0);
+    acc3 = vdupq_n_f32(0);
+    acc4 = vdupq_n_f32(0);
+    acc5 = vdupq_n_f32(0);
+    acc6 = vdupq_n_f32(0);
+    acc7 = vdupq_n_f32(0);
+    acc8 = vdupq_n_f32(0);
+    acc9 = vdupq_n_f32(0);
+    acc10 = vdupq_n_f32(0);
+    acc11 = vdupq_n_f32(0);
+    acc12 = vdupq_n_f32(0);
+    acc13 = vdupq_n_f32(0);
+    acc14 = vdupq_n_f32(0);
+    acc15 = vdupq_n_f32(0);
+  }
+  IREE_UK_ASSUME(K >= 1);
+  for (int k = 0; k < K; ++k) {
+    float16x8_t lhs = vld1q_f16(lhs_ptr);
+    lhs_ptr += 8;
+    float16x8_t rhs = vld1q_f16(rhs_ptr);
+    rhs_ptr += 8;
+    acc0 = iree_workaround_vfmlalq_laneq_low_f16(acc0, rhs, lhs, 0);
+    acc1 = iree_workaround_vfmlalq_laneq_high_f16(acc1, rhs, lhs, 0);
+    acc2 = iree_workaround_vfmlalq_laneq_low_f16(acc2, rhs, lhs, 1);
+    acc3 = iree_workaround_vfmlalq_laneq_high_f16(acc3, rhs, lhs, 1);
+    acc4 = iree_workaround_vfmlalq_laneq_low_f16(acc4, rhs, lhs, 2);
+    acc5 = iree_workaround_vfmlalq_laneq_high_f16(acc5, rhs, lhs, 2);
+    acc6 = iree_workaround_vfmlalq_laneq_low_f16(acc6, rhs, lhs, 3);
+    acc7 = iree_workaround_vfmlalq_laneq_high_f16(acc7, rhs, lhs, 3);
+    acc8 = iree_workaround_vfmlalq_laneq_low_f16(acc8, rhs, lhs, 4);
+    acc9 = iree_workaround_vfmlalq_laneq_high_f16(acc9, rhs, lhs, 4);
+    acc10 = iree_workaround_vfmlalq_laneq_low_f16(acc10, rhs, lhs, 5);
+    acc11 = iree_workaround_vfmlalq_laneq_high_f16(acc11, rhs, lhs, 5);
+    acc12 = iree_workaround_vfmlalq_laneq_low_f16(acc12, rhs, lhs, 6);
+    acc13 = iree_workaround_vfmlalq_laneq_high_f16(acc13, rhs, lhs, 6);
+    acc14 = iree_workaround_vfmlalq_laneq_low_f16(acc14, rhs, lhs, 7);
+    acc15 = iree_workaround_vfmlalq_laneq_high_f16(acc15, rhs, lhs, 7);
+  }
+  vst1q_f32(out_ptr + 4 * 0, acc0);
+  vst1q_f32(out_ptr + 4 * 1, acc1);
+  vst1q_f32(out_ptr + 4 * 2, acc2);
+  vst1q_f32(out_ptr + 4 * 3, acc3);
+  vst1q_f32(out_ptr + 4 * 4, acc4);
+  vst1q_f32(out_ptr + 4 * 5, acc5);
+  vst1q_f32(out_ptr + 4 * 6, acc6);
+  vst1q_f32(out_ptr + 4 * 7, acc7);
+  vst1q_f32(out_ptr + 4 * 8, acc8);
+  vst1q_f32(out_ptr + 4 * 9, acc9);
+  vst1q_f32(out_ptr + 4 * 10, acc10);
+  vst1q_f32(out_ptr + 4 * 11, acc11);
+  vst1q_f32(out_ptr + 4 * 12, acc12);
+  vst1q_f32(out_ptr + 4 * 13, acc13);
+  vst1q_f32(out_ptr + 4 * 14, acc14);
+  vst1q_f32(out_ptr + 4 * 15, acc15);
+}
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
index 7a8c27e..91066f4 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
@@ -11,7 +11,9 @@
 
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f32f32f32_8x8x1_arm_64)
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64_fp16fml)
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fp16)
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_8x8x1_arm_64)
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_8x8x4_arm_64_dotprod)
 IREE_UK_MMT4D_TILE_FUNC_DECL(
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index ced9482..04e6306 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -133,8 +133,12 @@
                                    "");
   iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1,
                                    "");
+  iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1,
+                                   "fp16fml");
   iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1,
                                    "");
+  iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1,
+                                   "fp16");
   iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 1,
                                    "");
   iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 4,
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index 680a873..3072812 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -388,8 +388,11 @@
   // we use iree_uk_test_mmt4d_default_and_intrinsics to test both.
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1, "");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1, "");
+  iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1, "fp16fml");
   iree_uk_test_mmt4d_default_and_skip_intermediate_roundings(
       IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1, "");
+  iree_uk_test_mmt4d_default_and_skip_intermediate_roundings(
+      IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1, "fp16");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 1, "");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 4, "dotprod");
   iree_uk_test_mmt4d_default_and_intrinsics(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8,
diff --git a/runtime/src/iree/schemas/cpu_feature_bits.inl b/runtime/src/iree/schemas/cpu_feature_bits.inl
index a8380c9..b6bf4e2 100644
--- a/runtime/src/iree/schemas/cpu_feature_bits.inl
+++ b/runtime/src/iree/schemas/cpu_feature_bits.inl
@@ -54,7 +54,7 @@
 IREE_CPU_FEATURE_BIT(ARM_64, 0, 2, LSE128, "lse128")  // Armv8.1 atomics
 
 // SIMD features, not SVE-specific.
-IREE_CPU_FEATURE_BIT(ARM_64, 0, 10, FULLFP16, "fullfp16")
+IREE_CPU_FEATURE_BIT(ARM_64, 0, 10, FP16, "fp16")
 IREE_CPU_FEATURE_BIT(ARM_64, 0, 11, FP16FML, "fp16fml")
 IREE_CPU_FEATURE_BIT(ARM_64, 0, 12, DOTPROD, "dotprod")
 IREE_CPU_FEATURE_BIT(ARM_64, 0, 13, I8MM, "i8mm")