Dropping the VMLA runtime. (#5900)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4bf2179..b1ac598 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -124,7 +124,6 @@
 set(IREE_ALL_HAL_DRIVERS
   Cuda
   DyLib
-  VMLA
   VMVX
   Vulkan
 )
diff --git a/bindings/python/iree/runtime/system_api.py b/bindings/python/iree/runtime/system_api.py
index 664cb97..78f1be9 100644
--- a/bindings/python/iree/runtime/system_api.py
+++ b/bindings/python/iree/runtime/system_api.py
@@ -50,12 +50,12 @@
 PREFERRED_DRIVER_ENV_KEY = "IREE_DEFAULT_DRIVER"
 
 # Default value for IREE_DRIVER
-DEFAULT_IREE_DRIVER_VALUE = "dylib,vulkan,vmla"
+DEFAULT_IREE_DRIVER_VALUE = "dylib,vulkan,vmvx"
 
 # Mapping from IREE target backends to their corresponding drivers.
 TARGET_BACKEND_TO_DRIVER = {
-    "vmla": "vmla",
     "dylib-llvm-aot": "dylib",
+    "vmvx": "vmvx",
     "vulkan-*": "vulkan",
 }
 
diff --git a/bindings/tflite/java/build.gradle b/bindings/tflite/java/build.gradle
index f1577e9..f226f6c 100644
--- a/bindings/tflite/java/build.gradle
+++ b/bindings/tflite/java/build.gradle
@@ -32,7 +32,7 @@
             cmake {
                 arguments "-DIREE_BUILD_BINDINGS_TFLITE=ON",
                         "-DIREE_BUILD_BINDINGS_TFLITE_JAVA=ON",
-                        "-DIREE_HAL_DRIVERS_TO_BUILD=VMLA",
+                        "-DIREE_HAL_DRIVERS_TO_BUILD=VMVX",
 
                         // Disable all but the runtime components needed for the
                         // java bindings.
diff --git a/iree/hal/drivers/BUILD b/iree/hal/drivers/BUILD
index 238d1a5..c825ac7 100644
--- a/iree/hal/drivers/BUILD
+++ b/iree/hal/drivers/BUILD
@@ -31,7 +31,6 @@
         # TODO(*): select() and only pull in based on build configuration.
         "//iree/hal/dylib/registration",
         "//iree/hal/dylib/registration:sync",
-        "//iree/hal/vmla/registration",
         "//iree/hal/vmvx/registration",
         "//iree/hal/vulkan/registration",
     ] + IREE_CUDA_DEPS,
diff --git a/iree/hal/drivers/CMakeLists.txt b/iree/hal/drivers/CMakeLists.txt
index db5d49d..64ea93b 100644
--- a/iree/hal/drivers/CMakeLists.txt
+++ b/iree/hal/drivers/CMakeLists.txt
@@ -23,9 +23,6 @@
   # TODO(benvanik): add a IREE_HAL_DRIVER_DYLIB_SYNC or global flag.
   list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::dylib::registration::sync)
 endif()
-if(${IREE_HAL_DRIVER_VMLA})
-  list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vmla::registration)
-endif()
 if(${IREE_HAL_DRIVER_VMVX})
   list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vmvx::registration)
 endif()
diff --git a/iree/hal/drivers/init.c b/iree/hal/drivers/init.c
index 819dc29..fe50a0c 100644
--- a/iree/hal/drivers/init.c
+++ b/iree/hal/drivers/init.c
@@ -28,10 +28,6 @@
 #include "iree/hal/dylib/registration/driver_module_sync.h"
 #endif  // IREE_HAL_HAVE_DYLIB_SYNC_DRIVER_MODULE
 
-#if defined(IREE_HAL_HAVE_VMLA_DRIVER_MODULE)
-#include "iree/hal/vmla/registration/driver_module.h"
-#endif  // IREE_HAL_HAVE_VMLA_DRIVER_MODULE
-
 #if defined(IREE_HAL_HAVE_VMVX_DRIVER_MODULE)
 #include "iree/hal/vmvx/registration/driver_module.h"
 #endif  // IREE_HAL_HAVE_VMVX_DRIVER_MODULE
@@ -59,11 +55,6 @@
       z0, iree_hal_dylib_sync_driver_module_register(registry));
 #endif  // IREE_HAL_HAVE_DYLIB_SYNC_DRIVER_MODULE
 
-#if defined(IREE_HAL_HAVE_VMLA_DRIVER_MODULE)
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_vmla_driver_module_register(registry));
-#endif  // IREE_HAL_HAVE_VMLA_DRIVER_MODULE
-
 #if defined(IREE_HAL_HAVE_VMVX_DRIVER_MODULE)
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_hal_vmvx_driver_module_register(registry));
diff --git a/iree/hal/local/loaders/BUILD b/iree/hal/local/loaders/BUILD
index 1e936cc..22d8d24 100644
--- a/iree/hal/local/loaders/BUILD
+++ b/iree/hal/local/loaders/BUILD
@@ -93,40 +93,6 @@
 
 iree_cmake_extra_content(
     content = """
-if(${IREE_HAL_DRIVER_VMLA})
-""",
-    inline = True,
-)
-
-cc_library(
-    name = "vmla_module_loader",
-    srcs = ["vmla_module_loader.cc"],
-    hdrs = ["vmla_module_loader.h"],
-    defines = [
-        "IREE_HAL_HAVE_VMLA_MODULE_LOADER=1",
-    ],
-    deps = [
-        "//iree/base",
-        "//iree/base:tracing",
-        "//iree/base/internal:flatcc",
-        "//iree/hal",
-        "//iree/hal/local",
-        "//iree/modules/vmla:op_module",
-        "//iree/schemas:vmla_executable_def_c_fbs",
-        "//iree/vm",
-        "//iree/vm:bytecode_module",
-    ],
-)
-
-iree_cmake_extra_content(
-    content = """
-endif()
-""",
-    inline = True,
-)
-
-iree_cmake_extra_content(
-    content = """
 if(${IREE_HAL_DRIVER_VMVX})
 """,
     inline = True,
diff --git a/iree/hal/local/loaders/CMakeLists.txt b/iree/hal/local/loaders/CMakeLists.txt
index 6ac6bd1..37ecb04 100644
--- a/iree/hal/local/loaders/CMakeLists.txt
+++ b/iree/hal/local/loaders/CMakeLists.txt
@@ -86,32 +86,6 @@
   PUBLIC
 )
 
-if(${IREE_HAL_DRIVER_VMLA})
-
-iree_cc_library(
-  NAME
-    vmla_module_loader
-  HDRS
-    "vmla_module_loader.h"
-  SRCS
-    "vmla_module_loader.cc"
-  DEPS
-    iree::base
-    iree::base::internal::flatcc
-    iree::base::tracing
-    iree::hal
-    iree::hal::local
-    iree::modules::vmla::op_module
-    iree::schemas::vmla_executable_def_c_fbs
-    iree::vm
-    iree::vm::bytecode_module
-  DEFINES
-    "IREE_HAL_HAVE_VMLA_MODULE_LOADER=1"
-  PUBLIC
-)
-
-endif()
-
 if(${IREE_HAL_DRIVER_VMVX})
 
 iree_cc_library(
diff --git a/iree/hal/local/loaders/vmla_module_loader.cc b/iree/hal/local/loaders/vmla_module_loader.cc
deleted file mode 100644
index 930844f..0000000
--- a/iree/hal/local/loaders/vmla_module_loader.cc
+++ /dev/null
@@ -1,397 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/local/loaders/vmla_module_loader.h"
-
-#include "iree/base/tracing.h"
-#include "iree/hal/local/local_descriptor_set_layout.h"
-#include "iree/hal/local/local_executable.h"
-#include "iree/modules/vmla/op_module.h"
-#include "iree/vm/bytecode_module.h"
-
-// flatcc schemas:
-#include "iree/base/internal/flatcc.h"
-#include "iree/schemas/vmla_executable_def_reader.h"
-#include "iree/schemas/vmla_executable_def_verifier.h"
-
-//===----------------------------------------------------------------------===//
-// Verification and file utilities
-//===----------------------------------------------------------------------===//
-
-// Verifies the structure of the flatbuffer so that we can avoid doing so during
-// runtime. There are still some conditions we must be aware of (such as omitted
-// names on functions with internal linkage), however we shouldn't need to
-// bounds check anything within the flatbuffer after this succeeds.
-static iree_status_t iree_hal_vmla_executable_flatbuffer_verify(
-    iree_const_byte_span_t flatbuffer_data) {
-  // Special handling for valid but mismatching flatbuffers.
-  if (!flatbuffer_data.data || flatbuffer_data.data_length < 16 ||
-      !flatbuffers_has_identifier(flatbuffer_data.data,
-                                  iree_VMLAExecutableDef_file_identifier)) {
-    return iree_status_from_code(IREE_STATUS_CANCELLED);
-  }
-
-  // Run flatcc generated verification. This ensures all pointers are in-bounds
-  // and that we can safely walk the file, but not that the actual contents of
-  // the flatbuffer meet our expectations.
-  int verify_ret = iree_VMLAExecutableDef_verify_as_root(
-      flatbuffer_data.data, flatbuffer_data.data_length);
-  if (verify_ret != flatcc_verify_ok) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "flatbuffer verification failed: %s",
-                            flatcc_verify_error_string(verify_ret));
-  }
-
-  iree_VMLAExecutableDef_table_t executable_def =
-      iree_VMLAExecutableDef_as_root(flatbuffer_data.data);
-
-  if (flatbuffers_uint8_vec_len(
-          iree_VMLAExecutableDef_bytecode_module_get(executable_def)) < 0) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "executable bytecode_module is missing/empty");
-  }
-
-  // NOTE: we don't check the actual bytecode module contents here; it's opaque
-  // to us and passed on to the VM.
-  return iree_ok_status();
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hal_vmla_executable_t
-//===----------------------------------------------------------------------===//
-
-typedef struct {
-  iree_hal_local_executable_t base;
-
-  // Context containing both the VMLA module and the loaded executable.
-  iree_vm_context_t* context;
-
-  // Resolved entry functions from the module.
-  iree_host_size_t entry_fn_count;
-  iree_vm_function_t entry_fns[];
-} iree_hal_vmla_executable_t;
-
-extern const iree_hal_local_executable_vtable_t iree_hal_vmla_executable_vtable;
-
-static iree_status_t iree_hal_vmla_executable_create(
-    iree_vm_context_t* context, iree_vm_module_t* bytecode_module,
-    iree_host_size_t executable_layout_count,
-    iree_hal_executable_layout_t* const* executable_layouts,
-    iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) {
-  IREE_ASSERT_ARGUMENT(context);
-  IREE_ASSERT_ARGUMENT(bytecode_module);
-  IREE_ASSERT_ARGUMENT(!executable_layout_count || executable_layouts);
-  IREE_ASSERT_ARGUMENT(out_executable);
-  *out_executable = NULL;
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_host_size_t entry_count =
-      iree_vm_module_signature(bytecode_module).export_function_count;
-  if (entry_count != executable_layout_count) {
-    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
-                            "executable provides %zu entry points but caller "
-                            "provided %zu; must match",
-                            entry_count, executable_layout_count);
-  }
-
-  iree_hal_vmla_executable_t* executable = NULL;
-  iree_host_size_t total_size =
-      sizeof(*executable) + entry_count * sizeof(*executable->entry_fns) +
-      executable_layout_count * sizeof(iree_hal_local_executable_layout_t);
-  iree_status_t status =
-      iree_allocator_malloc(host_allocator, total_size, (void**)&executable);
-  if (iree_status_is_ok(status)) {
-    iree_hal_local_executable_layout_t** executable_layouts_ptr =
-        (iree_hal_local_executable_layout_t**)(((uint8_t*)executable) +
-                                               sizeof(*executable) +
-                                               entry_count *
-                                                   sizeof(
-                                                       *executable->entry_fns));
-    iree_hal_local_executable_initialize(
-        &iree_hal_vmla_executable_vtable, executable_layout_count,
-        executable_layouts, executable_layouts_ptr, host_allocator,
-        &executable->base);
-    executable->context = context;
-    iree_vm_context_retain(executable->context);
-
-    executable->entry_fn_count = entry_count;
-    for (iree_host_size_t i = 0; i < executable->entry_fn_count; ++i) {
-      status = iree_vm_module_lookup_function_by_ordinal(
-          bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i,
-          &executable->entry_fns[i], NULL);
-      if (!iree_status_is_ok(status)) break;
-    }
-  }
-
-  if (iree_status_is_ok(status)) {
-    *out_executable = (iree_hal_executable_t*)executable;
-  } else {
-    iree_hal_executable_release((iree_hal_executable_t*)executable);
-  }
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static void iree_hal_vmla_executable_destroy(
-    iree_hal_executable_t* base_executable) {
-  iree_hal_vmla_executable_t* executable =
-      (iree_hal_vmla_executable_t*)base_executable;
-  iree_allocator_t host_allocator = executable->base.host_allocator;
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_vm_context_release(executable->context);
-  iree_hal_local_executable_deinitialize(
-      (iree_hal_local_executable_t*)base_executable);
-  iree_allocator_free(host_allocator, executable);
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-static iree_status_t iree_hal_vmla_executable_issue_call(
-    iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal,
-    const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
-    const iree_hal_vec3_t* workgroup_id) {
-  iree_hal_vmla_executable_t* executable =
-      (iree_hal_vmla_executable_t*)base_executable;
-
-  if (IREE_UNLIKELY(ordinal >= executable->entry_fn_count)) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "entry point ordinal out of bounds");
-  }
-
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
-  iree_string_view_t entry_point_name =
-      iree_vm_function_name(&executable->entry_fns[ordinal]);
-  if (iree_string_view_is_empty(entry_point_name)) {
-    entry_point_name = iree_make_cstring_view("unknown_vmla_call");
-  }
-  IREE_TRACE_ZONE_BEGIN_NAMED_DYNAMIC(z0, entry_point_name.data,
-                                      entry_point_name.size);
-#endif  // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
-
-  // We churn memory here but I don't rightly care: this entire VMLA approach is
-  // deprecated and will be going away at some point. There's about 100
-  // low-hanging branches we can hack at in the compiler before this extra
-  // allocation matters :)
-  iree_allocator_t host_allocator = executable->base.host_allocator;
-  iree::hal::vmla::Interface interface;
-  iree_vm_ref_t interface_ref = Interface_retain_ref(&interface);
-  iree_host_size_t input_list_size = iree_vm_list_storage_size(
-      /*element_type=*/NULL, /*interface*/ 1 + /*workgroup_xyz[3]*/ 3);
-  void* input_list_storage = iree_alloca(input_list_size);
-  iree_vm_list_t* input_list = NULL;
-  IREE_CHECK_OK(iree_vm_list_initialize(
-      iree_make_byte_span(input_list_storage, input_list_size),
-      /*element_type=*/NULL,
-      /*interface*/ 1 + /*workgroup_xyz[3]*/ 3, &input_list));
-  iree_vm_list_push_ref_retain(input_list, &interface_ref);
-  iree_vm_value_t workgroup_id_x = iree_vm_value_make_i32(workgroup_id->x);
-  iree_vm_value_t workgroup_id_y = iree_vm_value_make_i32(workgroup_id->y);
-  iree_vm_value_t workgroup_id_z = iree_vm_value_make_i32(workgroup_id->z);
-  iree_vm_list_push_value(input_list, &workgroup_id_x);
-  iree_vm_list_push_value(input_list, &workgroup_id_y);
-  iree_vm_list_push_value(input_list, &workgroup_id_z);
-
-  iree_hal_local_executable_layout_t* local_layout =
-      executable->base.executable_layouts[ordinal];
-  IREE_CHECK_EQ(local_layout->push_constants,
-                dispatch_state->push_constant_count);
-  IREE_CHECK_OK(interface.SetConstants(dispatch_state->push_constants,
-                                       dispatch_state->push_constant_count));
-
-  for (iree_host_size_t set_ordinal = 0;
-       set_ordinal < local_layout->set_layout_count; ++set_ordinal) {
-    iree_hal_local_descriptor_set_layout_t* local_set_layout =
-        iree_hal_local_descriptor_set_layout_cast(
-            local_layout->set_layouts[set_ordinal]);
-    for (iree_host_size_t i = 0; i < local_set_layout->binding_count; ++i) {
-      auto buffer_or = iree::hal::vmla::Buffer::WrapMutable(
-          dispatch_state->binding_ptrs[i], dispatch_state->binding_lengths[i],
-          iree_allocator_null());
-      if (!buffer_or.ok()) {
-        IREE_CHECK_OK(std::move(buffer_or).status());
-      }
-      IREE_CHECK_OK(interface.SetBinding(set_ordinal,
-                                         local_set_layout->bindings[i].binding,
-                                         {std::move(buffer_or.value())}));
-    }
-  }
-
-  iree_status_t status =
-      iree_vm_invoke(executable->context, executable->entry_fns[ordinal],
-                     /*policy=*/NULL, input_list,
-                     /*outputs=*/NULL, host_allocator);
-
-  iree_vm_list_deinitialize(input_list);
-  iree_vm_ref_release(&interface_ref);
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-const iree_hal_local_executable_vtable_t iree_hal_vmla_executable_vtable = {
-    /*.base=*/
-    {
-        /*.destroy=*/iree_hal_vmla_executable_destroy,
-    },
-    /*.issue_call=*/iree_hal_vmla_executable_issue_call,
-};
-
-//===----------------------------------------------------------------------===//
-// iree_hal_vmla_module_loader_t
-//===----------------------------------------------------------------------===//
-
-typedef struct {
-  iree_hal_executable_loader_t base;
-  iree_allocator_t host_allocator;
-  iree_vm_instance_t* instance;
-  iree_vm_module_t* vmla_module;
-} iree_hal_vmla_module_loader_t;
-
-extern const iree_hal_executable_loader_vtable_t
-    iree_hal_vmla_module_loader_vtable;
-
-iree_status_t iree_hal_vmla_module_loader_create(
-    iree_vm_instance_t* instance, iree_allocator_t host_allocator,
-    iree_hal_executable_loader_t** out_executable_loader) {
-  IREE_ASSERT_ARGUMENT(instance);
-  IREE_ASSERT_ARGUMENT(out_executable_loader);
-  *out_executable_loader = NULL;
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // A single VMLA module is shared across all loaded executables.
-  IREE_RETURN_IF_ERROR(iree::hal::vmla::ModuleRegisterTypes());
-  iree_vm_module_t* vmla_module = NULL;
-  IREE_RETURN_IF_ERROR(
-      iree::hal::vmla::ModuleCreate(host_allocator, &vmla_module));
-
-  iree_hal_vmla_module_loader_t* executable_loader = NULL;
-  iree_status_t status = iree_allocator_malloc(
-      host_allocator, sizeof(*executable_loader), (void**)&executable_loader);
-  if (iree_status_is_ok(status)) {
-    iree_hal_executable_loader_initialize(&iree_hal_vmla_module_loader_vtable,
-                                          &executable_loader->base);
-    executable_loader->host_allocator = host_allocator;
-    executable_loader->instance = instance;
-    iree_vm_instance_retain(executable_loader->instance);
-    executable_loader->vmla_module = vmla_module;
-    iree_vm_module_retain(executable_loader->vmla_module);
-    *out_executable_loader = (iree_hal_executable_loader_t*)executable_loader;
-  }
-
-  iree_vm_module_release(vmla_module);
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-static void iree_hal_vmla_module_loader_destroy(
-    iree_hal_executable_loader_t* base_executable_loader) {
-  iree_hal_vmla_module_loader_t* executable_loader =
-      (iree_hal_vmla_module_loader_t*)base_executable_loader;
-  iree_allocator_t host_allocator = executable_loader->host_allocator;
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  iree_vm_module_release(executable_loader->vmla_module);
-  iree_vm_instance_release(executable_loader->instance);
-  iree_allocator_free(host_allocator, executable_loader);
-
-  IREE_TRACE_ZONE_END(z0);
-}
-
-static bool iree_hal_vmla_module_loader_query_support(
-    iree_hal_executable_loader_t* base_executable_loader,
-    iree_hal_executable_caching_mode_t caching_mode,
-    iree_string_view_t executable_format) {
-  return iree_string_view_equal(executable_format,
-                                iree_make_cstring_view("VMLA"));
-}
-
-static iree_status_t iree_hal_vmla_module_loader_try_load(
-    iree_hal_executable_loader_t* base_executable_loader,
-    const iree_hal_executable_spec_t* executable_spec,
-    iree_hal_executable_t** out_executable) {
-  iree_hal_vmla_module_loader_t* executable_loader =
-      (iree_hal_vmla_module_loader_t*)base_executable_loader;
-  IREE_TRACE_ZONE_BEGIN(z0);
-
-  // Verify that we have a valid flatbuffer that contains a VMLA executable.
-  IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
-                                    iree_hal_vmla_executable_flatbuffer_verify(
-                                        executable_spec->executable_data));
-  iree_VMLAExecutableDef_table_t executable_def =
-      iree_VMLAExecutableDef_as_root(executable_spec->executable_data.data);
-  flatbuffers_uint8_vec_t bytecode_module_vec =
-      iree_VMLAExecutableDef_bytecode_module_get(executable_def);
-  iree_const_byte_span_t bytecode_module_data = iree_make_const_byte_span(
-      bytecode_module_vec, flatbuffers_uint8_vec_len(bytecode_module_vec));
-
-  // If the caching mode allows for aliasing the existing flatbuffer data then
-  // we avoid allocations and just pass the pointer on through. The caller
-  // ensures that the data remains valid for the duration the executable is
-  // loaded. Otherwise, we clone it and let the bytecode module take ownership.
-  iree_allocator_t bytecode_module_allocator;
-  if (iree_all_bits_set(executable_spec->caching_mode,
-                        IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA)) {
-    // Zero-copy route.
-    bytecode_module_allocator = iree_allocator_null();
-  } else {
-    bytecode_module_allocator = executable_loader->host_allocator;
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0, iree_allocator_clone(executable_loader->host_allocator,
-                                 bytecode_module_data,
-                                 (void**)&bytecode_module_data.data));
-  }
-
-  // Load the user-provided bytecode module. We pass ownership of the data (if
-  // we have it) to the module to manage.
-  iree_vm_module_t* bytecode_module = NULL;
-  iree_status_t status = iree_vm_bytecode_module_create(
-      bytecode_module_data, bytecode_module_allocator,
-      executable_loader->host_allocator, &bytecode_module);
-
-  // Create the context tying together the shared VMLA module and the
-  // user-provided module that references it. If we wanted to allow custom
-  // modules here for user-provided functions we'd mix them in here.
-  iree_vm_context_t* context = NULL;
-  if (iree_status_is_ok(status)) {
-    iree_vm_module_t* modules[2] = {
-        executable_loader->vmla_module,
-        bytecode_module,
-    };
-    status = iree_vm_context_create_with_modules(
-        executable_loader->instance, modules, IREE_ARRAYSIZE(modules),
-        executable_loader->host_allocator, &context);
-  }
-
-  // Executable takes ownership of the entire context (including the bytecode
-  // module, which itself may own the underlying allocation).
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_vmla_executable_create(
-        context, bytecode_module, executable_spec->executable_layout_count,
-        executable_spec->executable_layouts, executable_loader->host_allocator,
-        out_executable);
-  }
-
-  iree_vm_context_release(context);
-  iree_vm_module_release(bytecode_module);
-
-  IREE_TRACE_ZONE_END(z0);
-  return status;
-}
-
-const iree_hal_executable_loader_vtable_t iree_hal_vmla_module_loader_vtable = {
-    /*.destroy=*/iree_hal_vmla_module_loader_destroy,
-    /*.query_support=*/iree_hal_vmla_module_loader_query_support,
-    /*.try_load=*/iree_hal_vmla_module_loader_try_load,
-};
diff --git a/iree/hal/local/loaders/vmla_module_loader.h b/iree/hal/local/loaders/vmla_module_loader.h
deleted file mode 100644
index 041b7d4..0000000
--- a/iree/hal/local/loaders/vmla_module_loader.h
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_LOCAL_LOADERS_VMLA_MODULE_LOADER_H_
-#define IREE_HAL_LOCAL_LOADERS_VMLA_MODULE_LOADER_H_
-
-#include <stdbool.h>
-#include <stdint.h>
-
-#include "iree/base/api.h"
-#include "iree/hal/local/executable_loader.h"
-#include "iree/vm/api.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
-// Creates an executable loader that can load compiled IREE VM bytecode modules
-// using the VMLA module. |instance| will be used for all loaded contexts.
-iree_status_t iree_hal_vmla_module_loader_create(
-    iree_vm_instance_t* instance, iree_allocator_t host_allocator,
-    iree_hal_executable_loader_t** out_executable_loader);
-
-#ifdef __cplusplus
-}  // extern "C"
-#endif  // __cplusplus
-
-#endif  // IREE_HAL_LOCAL_LOADERS_VMLA_MODULE_LOADER_H_
diff --git a/iree/hal/vmla/BUILD b/iree/hal/vmla/BUILD
deleted file mode 100644
index f3d5aa6..0000000
--- a/iree/hal/vmla/BUILD
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# A VMLA (VM-based Linear Algebra) runtime HAL backend.
-
-package(
-    default_visibility = ["//visibility:public"],
-    features = ["layering_check"],
-    licenses = ["notice"],  # Apache 2.0
-)
diff --git a/iree/hal/vmla/CMakeLists.txt b/iree/hal/vmla/CMakeLists.txt
deleted file mode 100644
index 130eea9..0000000
--- a/iree/hal/vmla/CMakeLists.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
-# iree/hal/vmla/BUILD                                                          #
-#                                                                              #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
-# CMake-only content.                                                          #
-#                                                                              #
-# To disable autogeneration for this file entirely, delete this header.        #
-################################################################################
-
-iree_add_all_subdirs()
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/hal/vmla/registration/BUILD b/iree/hal/vmla/registration/BUILD
deleted file mode 100644
index f3119c5..0000000
--- a/iree/hal/vmla/registration/BUILD
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
-
-package(
-    default_visibility = ["//visibility:public"],
-    features = ["layering_check"],
-    licenses = ["notice"],  # Apache 2.0
-)
-
-iree_cmake_extra_content(
-    content = """
-if(${IREE_HAL_DRIVER_VMLA})
-""",
-    inline = True,
-)
-
-cc_library(
-    name = "registration",
-    srcs = ["driver_module.c"],
-    hdrs = ["driver_module.h"],
-    defines = [
-        "IREE_HAL_HAVE_VMLA_DRIVER_MODULE=1",
-    ],
-    deps = [
-        "//iree/hal",
-        "//iree/hal/local:task_driver",
-        "//iree/hal/local/loaders:vmla_module_loader",
-    ],
-)
-
-iree_cmake_extra_content(
-    content = """
-endif()
-""",
-    inline = True,
-)
diff --git a/iree/hal/vmla/registration/CMakeLists.txt b/iree/hal/vmla/registration/CMakeLists.txt
deleted file mode 100644
index 6b26c03..0000000
--- a/iree/hal/vmla/registration/CMakeLists.txt
+++ /dev/null
@@ -1,33 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
-# iree/hal/vmla/registration/BUILD                                             #
-#                                                                              #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
-# CMake-only content.                                                          #
-#                                                                              #
-# To disable autogeneration for this file entirely, delete this header.        #
-################################################################################
-
-iree_add_all_subdirs()
-
-if(${IREE_HAL_DRIVER_VMLA})
-
-iree_cc_library(
-  NAME
-    registration
-  HDRS
-    "driver_module.h"
-  SRCS
-    "driver_module.c"
-  DEPS
-    iree::hal
-    iree::hal::local::loaders::vmla_module_loader
-    iree::hal::local::task_driver
-  DEFINES
-    "IREE_HAL_HAVE_VMLA_DRIVER_MODULE=1"
-  PUBLIC
-)
-
-endif()
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/hal/vmla/registration/driver_module.c b/iree/hal/vmla/registration/driver_module.c
deleted file mode 100644
index 73aef1a..0000000
--- a/iree/hal/vmla/registration/driver_module.c
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vmla/registration/driver_module.h"
-
-#include <inttypes.h>
-
-#include "iree/hal/local/loaders/vmla_module_loader.h"
-#include "iree/hal/local/task_driver.h"
-
-// TODO(#4298): remove this driver registration and wrapper.
-
-#define IREE_HAL_VMLA_DRIVER_ID 0x564D4C41u  // VMLA
-
-static iree_status_t iree_hal_vmla_driver_factory_enumerate(
-    void* self, const iree_hal_driver_info_t** out_driver_infos,
-    iree_host_size_t* out_driver_info_count) {
-  static const iree_hal_driver_info_t driver_infos[1] = {
-      {
-          .driver_id = IREE_HAL_VMLA_DRIVER_ID,
-          .driver_name = iree_string_view_literal("vmla"),
-          .full_name =
-              iree_string_view_literal("Reference backend (deprecated)"),
-      },
-  };
-  *out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
-  *out_driver_infos = driver_infos;
-  return iree_ok_status();
-}
-
-static iree_status_t iree_hal_vmla_driver_factory_try_create(
-    void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
-    iree_hal_driver_t** out_driver) {
-  if (driver_id != IREE_HAL_VMLA_DRIVER_ID) {
-    return iree_make_status(IREE_STATUS_UNAVAILABLE,
-                            "no driver with ID %016" PRIu64
-                            " is provided by this factory",
-                            driver_id);
-  }
-
-  iree_hal_task_device_params_t default_params;
-  iree_hal_task_device_params_initialize(&default_params);
-
-  // NOTE: VMLA doesn't tile so we don't really need many workers - having
-  // multiple does make it easier to test overlapping execution, though.
-  iree_task_topology_t topology;
-  iree_task_topology_initialize_from_group_count(4, &topology);
-
-  iree_vm_instance_t* instance = NULL;
-  iree_status_t status = iree_vm_instance_create(allocator, &instance);
-
-  iree_hal_executable_loader_t* vmla_loader = NULL;
-  if (iree_status_is_ok(status)) {
-    status =
-        iree_hal_vmla_module_loader_create(instance, allocator, &vmla_loader);
-  }
-  iree_hal_executable_loader_t* loaders[1] = {vmla_loader};
-
-  iree_task_executor_t* executor = NULL;
-  if (iree_status_is_ok(status)) {
-    status = iree_task_executor_create(IREE_TASK_SCHEDULING_MODE_RESERVED,
-                                       &topology, allocator, &executor);
-  }
-
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_task_driver_create(
-        iree_make_cstring_view("vmla"), &default_params, executor,
-        IREE_ARRAYSIZE(loaders), loaders, allocator, out_driver);
-  }
-
-  iree_task_executor_release(executor);
-  iree_task_topology_deinitialize(&topology);
-  iree_hal_executable_loader_release(vmla_loader);
-  iree_vm_instance_release(instance);
-  return status;
-}
-
-IREE_API_EXPORT iree_status_t
-iree_hal_vmla_driver_module_register(iree_hal_driver_registry_t* registry) {
-  static const iree_hal_driver_factory_t factory = {
-      /*self=*/NULL,
-      iree_hal_vmla_driver_factory_enumerate,
-      iree_hal_vmla_driver_factory_try_create,
-  };
-  return iree_hal_driver_registry_register_factory(registry, &factory);
-}
diff --git a/iree/hal/vmla/registration/driver_module.h b/iree/hal/vmla/registration/driver_module.h
deleted file mode 100644
index d7288cb..0000000
--- a/iree/hal/vmla/registration/driver_module.h
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VMLA_REGISTRATION_DRIVER_MODULE_H_
-#define IREE_HAL_VMLA_REGISTRATION_DRIVER_MODULE_H_
-
-#include "iree/hal/api.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
-// DEPRECATED: this entire driver will be removed soon.
-// TODO(#3580): remove this entire driver w/ iree_hal_executable_library_t.
-IREE_API_EXPORT iree_status_t
-iree_hal_vmla_driver_module_register(iree_hal_driver_registry_t* registry);
-
-#ifdef __cplusplus
-}  // extern "C"
-#endif  // __cplusplus
-
-#endif  // IREE_HAL_VMLA_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/modules/check/CMakeLists.txt b/iree/modules/check/CMakeLists.txt
index 85fbf53..732057c 100644
--- a/iree/modules/check/CMakeLists.txt
+++ b/iree/modules/check/CMakeLists.txt
@@ -14,8 +14,8 @@
 
 iree_add_all_subdirs()
 
-# Doesn't use bazel_to_cmake because IREE_HAL_DRIVER_VMLA filtering is custom logic
-if(${IREE_HAL_DRIVER_VMLA})
+# Doesn't use bazel_to_cmake because IREE_HAL_DRIVER_VMVX filtering is custom logic
+if(${IREE_HAL_DRIVER_VMVX})
   iree_cc_test(
     NAME
       check_test
diff --git a/iree/modules/vmla/BUILD b/iree/modules/vmla/BUILD
deleted file mode 100644
index d2cd9d1..0000000
--- a/iree/modules/vmla/BUILD
+++ /dev/null
@@ -1,76 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
-
-package(
-    default_visibility = ["//visibility:public"],
-    features = ["layering_check"],
-    licenses = ["notice"],  # Apache 2.0
-)
-
-iree_cmake_extra_content(
-    content = """
-if(NOT ${IREE_HAL_DRIVER_VMLA})
-  return()
-endif()
-""",
-)
-
-cc_library(
-    name = "op_kernels",
-    hdrs = ["op_kernels.h"],
-    textual_hdrs = [
-        # TODO(benvanik): SIMD variants.
-        "op_kernels_generic.h",
-        "op_kernels_ruy.h",
-        "op_kernels_fft.h",
-    ],
-    deps = [
-        "//iree/base:status",
-        "//iree/base:tracing",
-        "@com_google_absl//absl/algorithm",
-        "@com_google_absl//absl/types:span",
-        "@com_google_ruy//ruy",
-        "@com_google_ruy//ruy:context",
-        "@pffft",
-    ],
-)
-
-cc_test(
-    name = "op_kernels_test",
-    srcs = ["op_kernels_test.cc"],
-    deps = [
-        ":op_kernels",
-        "//iree/base:core_headers",
-        "//iree/testing:gtest",
-        "//iree/testing:gtest_main",
-    ],
-)
-
-cc_library(
-    name = "op_module",
-    srcs = ["op_module.cc"],
-    hdrs = ["op_module.h"],
-    deps = [
-        ":op_kernels",
-        "//iree/base",
-        "//iree/base:core_headers",
-        "//iree/base:status",
-        "//iree/base:tracing",
-        "//iree/vm",
-        "//iree/vm:cc",
-        "@com_google_absl//absl/types:span",
-    ],
-)
diff --git a/iree/modules/vmla/CMakeLists.txt b/iree/modules/vmla/CMakeLists.txt
deleted file mode 100644
index b7bac0a..0000000
--- a/iree/modules/vmla/CMakeLists.txt
+++ /dev/null
@@ -1,67 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
-# iree/modules/vmla/BUILD                                                      #
-#                                                                              #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
-# CMake-only content.                                                          #
-#                                                                              #
-# To disable autogeneration for this file entirely, delete this header.        #
-################################################################################
-
-if(NOT ${IREE_HAL_DRIVER_VMLA})
-  return()
-endif()
-
-iree_add_all_subdirs()
-
-iree_cc_library(
-  NAME
-    op_kernels
-  HDRS
-    "op_kernels.h"
-  TEXTUAL_HDRS
-    "op_kernels_fft.h"
-    "op_kernels_generic.h"
-    "op_kernels_ruy.h"
-  DEPS
-    absl::algorithm
-    absl::span
-    iree::base::status
-    iree::base::tracing
-    pffft
-    ruy
-  PUBLIC
-)
-
-iree_cc_test(
-  NAME
-    op_kernels_test
-  SRCS
-    "op_kernels_test.cc"
-  DEPS
-    ::op_kernels
-    iree::base::core_headers
-    iree::testing::gtest
-    iree::testing::gtest_main
-)
-
-iree_cc_library(
-  NAME
-    op_module
-  HDRS
-    "op_module.h"
-  SRCS
-    "op_module.cc"
-  DEPS
-    ::op_kernels
-    absl::span
-    iree::base
-    iree::base::core_headers
-    iree::base::status
-    iree::base::tracing
-    iree::vm
-    iree::vm::cc
-  PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/modules/vmla/op_kernels.h b/iree/modules/vmla/op_kernels.h
deleted file mode 100644
index facac4b..0000000
--- a/iree/modules/vmla/op_kernels.h
+++ /dev/null
@@ -1,509 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Defines kernel functions and provides their implementation via one (or more)
-// included files.
-//
-// Kernels should do the simplest possible operation. Buffer validation is
-// handled by the dispatch logic and need not be checked. Kernels may optionally
-// accept arguments beyond just the buffers, depending on the required state
-// and attributes.
-//
-// Kernels may optionally have runtime state. This is state that is allocated
-// once for the entire Runtime (and stored on RuntimeState) and shared across
-// all fibers. This enables kernels that may require thread pools or device
-// handles to be shared while kernels that require transient storage to be safe
-// to use from multiple fibers concurrently.
-//
-// All kernels are templated to enable specialization of particular types or
-// type combinations. By default the op_kernels_generic.h will provide C++
-// semantics as reference and platform-specific versions can be implemented
-// as needed.
-
-#ifndef IREE_MODULES_VMLA_OP_KERNELS_H_
-#define IREE_MODULES_VMLA_OP_KERNELS_H_
-
-#include <cstdint>
-
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-
-namespace iree {
-namespace hal {
-namespace vmla {
-namespace kernels {
-
-using ShapeSpan = absl::Span<const int32_t>;
-
-inline size_t GetElementCount(ShapeSpan shape) {
-  size_t count = 1;
-  for (size_t i = 0; i < shape.size(); ++i) {
-    count *= shape[i];
-  }
-  return count;
-}
-
-struct CompareEQ {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<uint8_t> dst_buffer);
-};
-struct CompareNE {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<uint8_t> dst_buffer);
-};
-struct CompareLT {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<uint8_t> dst_buffer);
-};
-struct CompareLE {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<uint8_t> dst_buffer);
-};
-struct CompareGT {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<uint8_t> dst_buffer);
-};
-struct CompareGE {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<uint8_t> dst_buffer);
-};
-
-struct Conv2D {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> input_buffer,
-                               ShapeSpan input_shape,
-                               absl::Span<const T> filter_buffer,
-                               ShapeSpan filter_shape, absl::Span<T> dst_buffer,
-                               ShapeSpan dst_shape, ShapeSpan strides,
-                               ShapeSpan pad_h, ShapeSpan pad_w,
-                               ShapeSpan lhs_dilation, ShapeSpan rhs_dilation,
-                               const int32_t groups);
-};
-
-struct Copy {
-  template <int element_size>
-  static iree_status_t Execute(absl::Span<const uint8_t> src_buffer,
-                               ShapeSpan src_shape,
-                               absl::Span<const int32_t> src_indices,
-                               absl::Span<uint8_t> dst_buffer,
-                               ShapeSpan dst_shape,
-                               absl::Span<const int32_t> dst_indices,
-                               absl::Span<const int32_t> lengths);
-};
-
-struct Select {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const uint8_t> cond_buffer,
-                               absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Finite {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<bool> dst_buffer);
-};
-
-struct Transpose {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               absl::Span<const int32_t> perm);
-};
-
-struct Pad {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const T> padding_value,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               ShapeSpan dst_shape,
-                               absl::Span<const int32_t> edge_padding_low,
-                               absl::Span<const int32_t> edge_padding_high,
-                               absl::Span<const int32_t> interior_padding);
-};
-
-struct Gather {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const int32_t> indices_buffer,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               ShapeSpan indices_shape, ShapeSpan dst_shape,
-                               const int32_t dim, const int32_t batch_dims);
-};
-
-struct Scatter {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const int32_t> indices_buffer,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               ShapeSpan indices_shape, ShapeSpan dst_shape);
-};
-
-struct Reverse {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               absl::Span<const int32_t> dimensions);
-};
-
-struct Sort {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<int32_t> dst_buffer,
-                               ShapeSpan src_shape);
-};
-
-struct Broadcast {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Iota {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<T> dst_buffer);
-};
-
-struct Tile {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               ShapeSpan dst_shape);
-};
-
-struct Not {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct And {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer, T rhs,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Or {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Xor {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer, T rhs,
-                               absl::Span<T> dst_buffer);
-};
-
-struct ShiftLeft {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct ShiftRight {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Add {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Sub {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Abs {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Neg {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Mul {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Div {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Rem {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Pow {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Exp {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Log {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Rsqrt {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Sqrt {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Cos {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Sin {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Tanh {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Atan2 {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Min {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Max {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> lhs_buffer,
-                               absl::Span<const T> rhs_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Clamp {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> min_buffer,
-                               absl::Span<const T> src_buffer,
-                               absl::Span<const T> max_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Floor {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Ceil {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Round {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer);
-};
-
-struct Convert {
-  template <typename SRC, typename DST>
-  static iree_status_t Execute(absl::Span<const SRC> src_buffer,
-                               absl::Span<DST> dst_buffer);
-};
-
-struct MatMul {
-  struct RuntimeState;
-
-  static std::unique_ptr<RuntimeState> CreateRuntimeState();
-
-  template <typename LhsEl, typename RhsEl, typename AccumEl, typename DstEl>
-  struct Buffers {
-    ShapeSpan lhs_shape;
-    absl::Span<const LhsEl> lhs_buffer;
-    ShapeSpan rhs_shape;
-    absl::Span<const RhsEl> rhs_buffer;
-    ShapeSpan dst_shape;
-    absl::Span<DstEl> dst_buffer;
-
-    // Optional bias buffer.
-    absl::Span<const AccumEl> bias_buffer;
-
-    // Fixed-point multiplier mantissa/exponent. May be a single value (for
-    // uniform quantization) or one element per row of the destination matrix
-    // for per-channel.
-    absl::Span<const AccumEl> multiplier_mantissa_buffer;
-    absl::Span<const int32_t> multiplier_exponent_buffer;
-  };
-
-  template <typename LhsEl, typename RhsEl, typename AccumEl, typename DstEl>
-  static iree_status_t Execute(
-      RuntimeState* runtime_state,
-      const Buffers<LhsEl, RhsEl, AccumEl, DstEl>& buffers);
-};
-
-struct RuntimeState {
-  std::unique_ptr<MatMul::RuntimeState> mat_mul_state =
-      MatMul::CreateRuntimeState();
-};
-
-struct ReduceSum {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const T> init_buffer,
-                               absl::Span<T> dst_buffer, int32_t dimension,
-                               ShapeSpan src_shape, ShapeSpan dst_shape);
-};
-
-struct ReduceMin {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const T> init_buffer,
-                               absl::Span<T> dst_buffer, int32_t dimension,
-                               ShapeSpan src_shape, ShapeSpan dst_shape);
-};
-
-struct ReduceMax {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const T> init_buffer,
-                               absl::Span<T> dst_buffer, int32_t dimension,
-                               ShapeSpan src_shape, ShapeSpan dst_shape);
-};
-
-struct ReduceAnd {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const T> init_buffer,
-                               absl::Span<T> dst_buffer, int32_t dimension,
-                               ShapeSpan src_shape, ShapeSpan dst_shape);
-};
-
-struct ReduceOr {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const T> init_buffer,
-                               absl::Span<T> dst_buffer, int32_t dimension,
-                               ShapeSpan src_shape, ShapeSpan dst_shape);
-};
-
-struct PoolingSum {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const T> init_buffer,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               ShapeSpan dst_shape, ShapeSpan window_dimensions,
-                               ShapeSpan strides, ShapeSpan pad_low);
-};
-
-struct PoolingMin {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const T> init_buffer,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               ShapeSpan dst_shape, ShapeSpan window_dimensions,
-                               ShapeSpan strides, ShapeSpan pad_low);
-};
-
-struct PoolingMax {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const T> init_buffer,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               ShapeSpan dst_shape, ShapeSpan window_dimensions,
-                               ShapeSpan strides, ShapeSpan pad_low);
-};
-
-}  // namespace kernels
-}  // namespace vmla
-}  // namespace hal
-}  // namespace iree
-
-// Inconsistent automated formatting here. Just disable clang-format (for now?).
-// clang-format off
-#include "iree/modules/vmla/op_kernels_generic.h"  // IWYU pragma: export
-#include "iree/modules/vmla/op_kernels_ruy.h"  // IWYU pragma: export
-#include "iree/modules/vmla/op_kernels_fft.h"  // IWYU pragma: export
-// clang-format on
-
-#endif  // IREE_HAL_VMLA_OP_KERNELS_H_
diff --git a/iree/modules/vmla/op_kernels_fft.h b/iree/modules/vmla/op_kernels_fft.h
deleted file mode 100644
index d925d97..0000000
--- a/iree/modules/vmla/op_kernels_fft.h
+++ /dev/null
@@ -1,182 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Defines kernel functions and provides their implementation via one (or more)
-// included files.
-//
-// Kernels should do the simplest possible operation. Buffer validation is
-// handled by the dispatch logic and need not be checked. Kernels may optionally
-// accept arguments beyond just the buffers, depending on the required state
-// and attributes.
-//
-// Kernels may optionally have runtime state. This is state that is allocated
-// once for the entire Runtime (and stored on RuntimeState) and shared across
-// all fibers. This enables kernels that may require thread pools or device
-// handles to be shared while kernels that require transient storage to be safe
-// to use from multiple fibers concurrently.
-//
-// All kernels are templated to enable specialization of particular types or
-// type combinations. By default the op_kernels_generic.h will provide C++
-// semantics as reference and platform-specific versions can be implemented
-// as needed.
-
-#ifndef IREE_MODULES_VMLA_OP_KERNELS_FFT_H_
-#define IREE_MODULES_VMLA_OP_KERNELS_FFT_H_
-
-#include "absl/types/span.h"
-#include "iree/base/logging.h"
-#include "iree/base/status.h"
-#include "pffft.h"
-
-namespace iree {
-namespace hal {
-namespace vmla {
-namespace kernels {
-
-using ShapeSpan = absl::Span<const int32_t>;
-
-struct Fft {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> real_src_buffer,
-                               absl::Span<const T> imag_src_buffer,
-                               absl::Span<T> real_dst_buffer,
-                               absl::Span<T> imag_dst_buffer,
-                               ShapeSpan real_src_shape,
-                               ShapeSpan imag_src_shape) {
-    PFFFT_Setup* fft_state =
-        pffft_new_setup(real_src_shape.back(), PFFFT_COMPLEX);
-    int element_count = real_src_buffer.size();
-    std::vector<T> complex_input;
-    complex_input.reserve(element_count * 2);
-
-    // pffft requires the input to be an array of interleaved complex numbers
-    for (int i = 0; i < element_count; i++) {
-      complex_input[i * 2] = real_src_buffer[i];
-      complex_input[i * 2 + 1] = imag_src_buffer[i];
-    }
-
-    std::vector<T> complex_output;
-    complex_output.reserve(element_count * 2);
-
-    pffft_transform_ordered(fft_state, &complex_input[0], &complex_output[0],
-                            NULL, PFFFT_FORWARD);
-
-    // Split the interleaved array back into a real and imag vectors.
-    for (int i = 0; i < element_count; i++) {
-      real_dst_buffer[i] = complex_output[i * 2];
-      imag_dst_buffer[i] = complex_output[i * 2 + 1];
-    }
-    pffft_destroy_setup(fft_state);
-    return iree_ok_status();
-  }
-};
-
-struct Ifft {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> real_src_buffer,
-                               absl::Span<const T> imag_src_buffer,
-                               absl::Span<T> real_dst_buffer,
-                               absl::Span<T> imag_dst_buffer,
-                               ShapeSpan real_src_shape,
-                               ShapeSpan imag_src_shape) {
-    PFFFT_Setup* fft_state =
-        pffft_new_setup(real_src_shape.back(), PFFFT_COMPLEX);
-    int element_count = real_src_buffer.size();
-    std::vector<T> complex_input;
-    complex_input.reserve(element_count * 2);
-
-    // pffft requires the input to be an array of interleaved complex numbers
-    for (int i = 0; i < element_count; i++) {
-      complex_input[i * 2] = real_src_buffer[i];
-      complex_input[i * 2 + 1] = imag_src_buffer[i];
-    }
-
-    std::vector<T> complex_output;
-    complex_output.reserve(element_count * 2);
-
-    pffft_transform_ordered(fft_state, &complex_input[0], &complex_output[0],
-                            NULL, PFFFT_BACKWARD);
-
-    // Split the interleaved array back into a real and imag vectors and scale
-    // them.
-    for (int i = 0; i < element_count; i++) {
-      real_dst_buffer[i] = complex_output[i * 2] / element_count;
-      imag_dst_buffer[i] = complex_output[i * 2 + 1] / element_count;
-    }
-    pffft_destroy_setup(fft_state);
-    return iree_ok_status();
-  }
-};
-
-struct Rfft {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> real_src_buffer,
-                               absl::Span<T> real_dst_buffer,
-                               absl::Span<T> imag_dst_buffer,
-                               ShapeSpan real_src_shape) {
-    PFFFT_Setup* fft_state = pffft_new_setup(real_src_shape.back(), PFFFT_REAL);
-    int element_count = real_src_buffer.size() / 2 + 1;
-
-    std::vector<T> complex_output;
-    complex_output.resize(element_count * 4);
-
-    pffft_transform_ordered(fft_state, &real_src_buffer[0], &complex_output[0],
-                            NULL, PFFFT_FORWARD);
-
-    // Split the interleaved array back into a real and imag vectors and scale
-    // them.
-    for (int i = 0; i < element_count; i++) {
-      real_dst_buffer[i] = complex_output[i * 2];
-      imag_dst_buffer[i] = complex_output[i * 2 + 1];
-    }
-    auto temp = real_dst_buffer[element_count - 1];
-    real_dst_buffer[element_count - 1] = imag_dst_buffer[0];
-    imag_dst_buffer[0] = temp;
-    pffft_destroy_setup(fft_state);
-    return iree_ok_status();
-  }
-};
-
-struct Irfft {
-  template <typename T>
-  static iree_status_t Execute(absl::Span<const T> real_src_buffer,
-                               absl::Span<const T> imag_src_buffer,
-                               absl::Span<T> real_dst_buffer,
-                               ShapeSpan real_src_shape,
-                               ShapeSpan imag_src_shape) {
-    PFFFT_Setup* fft_state = pffft_new_setup(real_src_shape.back(), PFFFT_REAL);
-    int element_count = real_src_buffer.size();
-    std::vector<T> complex_input;
-    complex_input.reserve(element_count * 2);
-
-    // pffft requires the input to be an array of interleaved complex numbers
-    for (int i = 0; i < element_count; i++) {
-      complex_input[i * 2] = real_src_buffer[i];
-      complex_input[i * 2 + 1] = imag_src_buffer[i];
-    }
-
-    pffft_transform_ordered(fft_state, &complex_input[0], &real_dst_buffer[0],
-                            NULL, PFFFT_BACKWARD);
-
-    pffft_destroy_setup(fft_state);
-    return iree_ok_status();
-  }
-};
-
-}  // namespace kernels
-}  // namespace vmla
-}  // namespace hal
-}  // namespace iree
-
-#endif  // IREE_MODULES_VMLA_OP_KERNELS_FFT_H_
diff --git a/iree/modules/vmla/op_kernels_generic.h b/iree/modules/vmla/op_kernels_generic.h
deleted file mode 100644
index 827ae4a..0000000
--- a/iree/modules/vmla/op_kernels_generic.h
+++ /dev/null
@@ -1,1203 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_MODULES_VMLA_OP_KERNELS_GENERIC_H_
-#define IREE_MODULES_VMLA_OP_KERNELS_GENERIC_H_
-
-#include <algorithm>
-#include <array>
-#include <cmath>
-#include <cstring>
-#include <iostream>
-#include <iterator>
-#include <numeric>
-#include <unordered_set>
-#include <vector>
-
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace hal {
-namespace vmla {
-namespace kernels {
-
-template <typename T>
-iree_status_t CompareEQ::Execute(absl::Span<const T> lhs_buffer,
-                                 absl::Span<const T> rhs_buffer,
-                                 absl::Span<uint8_t> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] == rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t CompareNE::Execute(absl::Span<const T> lhs_buffer,
-                                 absl::Span<const T> rhs_buffer,
-                                 absl::Span<uint8_t> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] != rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t CompareLT::Execute(absl::Span<const T> lhs_buffer,
-                                 absl::Span<const T> rhs_buffer,
-                                 absl::Span<uint8_t> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] < rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t CompareLE::Execute(absl::Span<const T> lhs_buffer,
-                                 absl::Span<const T> rhs_buffer,
-                                 absl::Span<uint8_t> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] <= rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t CompareGT::Execute(absl::Span<const T> lhs_buffer,
-                                 absl::Span<const T> rhs_buffer,
-                                 absl::Span<uint8_t> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] > rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t CompareGE::Execute(absl::Span<const T> lhs_buffer,
-                                 absl::Span<const T> rhs_buffer,
-                                 absl::Span<uint8_t> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] >= rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-namespace impl {
-inline std::vector<size_t> ComputeCopyStrides(ShapeSpan shape,
-                                              size_t element_size) {
-  std::vector<size_t> strides(shape.size());
-  strides.back() = element_size;
-  for (int i = static_cast<int>(shape.size()) - 2; i >= 0; --i) {
-    strides[i] = strides[i + 1] * shape[i + 1];
-  }
-  return strides;
-}
-
-inline void CopyRegion(absl::Span<const uint8_t> src_buffer,
-                       absl::Span<const size_t> src_strides,
-                       absl::Span<const int32_t> src_indices,
-                       absl::Span<uint8_t> dst_buffer,
-                       absl::Span<const size_t> dst_strides,
-                       absl::Span<const int32_t> dst_indices,
-                       absl::Span<const int32_t> lengths) {
-  if (lengths.size() > 1) {
-    for (int32_t i = 0; i < lengths[0]; ++i) {
-      size_t src_offset = src_strides[0] * (src_indices[0] + i);
-      size_t dst_offset = dst_strides[0] * (dst_indices[0] + i);
-      CopyRegion(src_buffer.subspan(src_offset), src_strides.subspan(1),
-                 src_indices.subspan(1), dst_buffer.subspan(dst_offset),
-                 dst_strides.subspan(1), dst_indices.subspan(1),
-                 lengths.subspan(1));
-    }
-  } else {
-    IREE_DCHECK_EQ(dst_strides.size(), 1);
-    IREE_DCHECK_EQ(src_strides.size(), 1);
-    IREE_DCHECK_EQ(src_indices.size(), 1);
-    IREE_DCHECK_EQ(dst_indices.size(), 1);
-    IREE_DCHECK_EQ(lengths.size(), 1);
-    auto src_offset = src_indices[0] * src_strides[0];
-    auto dst_offset = dst_indices[0] * dst_strides[0];
-    auto length = dst_strides[0] * lengths[0];
-    std::memcpy(dst_buffer.data() + dst_offset, src_buffer.data() + src_offset,
-                length);
-  }
-}
-}  // namespace impl
-
-// TODO(benvanik): replace with a real implementation once copy is defined.
-// TODO(gcmn): More consistent/principled handling for scalars.
-template <int element_size>
-iree_status_t Copy::Execute(absl::Span<const uint8_t> src_buffer,
-                            ShapeSpan src_shape,
-                            absl::Span<const int32_t> src_indices,
-                            absl::Span<uint8_t> dst_buffer, ShapeSpan dst_shape,
-                            absl::Span<const int32_t> dst_indices,
-                            absl::Span<const int32_t> lengths) {
-  IREE_DCHECK_EQ(src_indices.size(), lengths.size());
-  IREE_DCHECK_EQ(dst_indices.size(), lengths.size());
-  IREE_DCHECK_EQ(src_shape.size(), lengths.size());
-  IREE_DCHECK_EQ(dst_shape.size(), lengths.size());
-  if (lengths.empty()) {
-    std::memcpy(dst_buffer.data(), src_buffer.data(), element_size);
-    return iree_ok_status();
-  }
-
-  // TODO(gcmn) Maybe we can fast-path earlier if we detect contiguous memory
-  // across multiple rows.
-  auto src_strides = impl::ComputeCopyStrides(src_shape, element_size);
-  auto dst_strides = impl::ComputeCopyStrides(dst_shape, element_size);
-  IREE_DCHECK_EQ(src_strides.size(), lengths.size());
-  IREE_DCHECK_EQ(dst_strides.size(), lengths.size());
-  impl::CopyRegion(src_buffer, src_strides, src_indices, dst_buffer,
-                   dst_strides, dst_indices, lengths);
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Conv2D::Execute(absl::Span<const T> input_buffer,
-                              ShapeSpan input_shape,
-                              absl::Span<const T> filter_buffer,
-                              ShapeSpan filter_shape, absl::Span<T> dst_buffer,
-                              ShapeSpan dst_shape, ShapeSpan window_strides,
-                              ShapeSpan pad_h, ShapeSpan pad_w,
-                              ShapeSpan lhs_dilation, ShapeSpan rhs_dilation,
-                              const int32_t groups) {
-  const std::array<int32_t, 3> input_strides = {input_shape[1] * input_shape[2],
-                                                input_shape[2], 1};
-  const std::array<int32_t, 4> filter_strides = {
-      filter_shape[1] * filter_shape[2] * filter_shape[3],
-      filter_shape[2] * filter_shape[3], filter_shape[3], 1};
-  const std::array<int32_t, 3> dst_strides = {dst_shape[1] * dst_shape[2],
-                                              dst_shape[2], 1};
-  // Direct 2d (grouped) convolution slow implementation. ref:
-  // https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution)
-  // TODO(ataei): Implement tiled GEMM based implementation.
-  const int output_group_size = dst_shape[2] / groups;
-  const int input_group_size = input_shape[2] / groups;
-  for (int ho = 0; ho < dst_shape[0]; ho++) {
-    for (int wo = 0; wo < dst_shape[1]; wo++) {
-      for (int g = 0; g < groups; ++g) {
-        for (int kh = 0; kh < filter_shape[0]; kh++) {
-          int ih = ho * window_strides[0] + kh * rhs_dilation[0] - pad_h[0];
-          // left-right padding condition.
-          if (ih < 0 || ih % lhs_dilation[0]) continue;
-          ih = ih / lhs_dilation[0];
-          if (ih >= input_shape[0]) continue;
-          for (int kw = 0; kw < filter_shape[1]; kw++) {
-            // top-bottom padding condition.
-            int iw = wo * window_strides[1] + kw * rhs_dilation[1] - pad_w[0];
-            if (iw < 0 || iw % lhs_dilation[1]) continue;
-            iw = iw / lhs_dilation[1];
-            if (iw >= input_shape[1]) continue;
-            for (int co = 0; co < output_group_size; co++) {
-              const int cg_o = g * output_group_size + co;
-              const int y_i = ho * dst_strides[0] + wo * dst_strides[1] + cg_o;
-              T dst_value = T(0);
-              for (int ci = 0; ci < input_group_size; ci++) {
-                const int cg_i = g * input_group_size + ci;
-                const int w_i = kh * filter_strides[0] +
-                                kw * filter_strides[1] +
-                                cg_i * filter_strides[2] + co;
-                const int x_i =
-                    ih * input_strides[0] + iw * input_strides[1] + cg_i;
-                dst_value += input_buffer[x_i] * filter_buffer[w_i];
-              }
-              dst_buffer[y_i] += dst_value;
-            }
-          }
-        }
-      }
-    }
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Select::Execute(absl::Span<const uint8_t> cond_buffer,
-                              absl::Span<const T> lhs_buffer,
-                              absl::Span<const T> rhs_buffer,
-                              absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = cond_buffer[i] ? lhs_buffer[i] : rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Finite::Execute(absl::Span<const T> src_buffer,
-                              absl::Span<bool> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::isfinite(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-void TransposeRecurse(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer,
-                      ShapeSpan src_shape, ShapeSpan dst_shape,
-                      absl::Span<const int> src_strides,
-                      absl::Span<const int> dst_strides,
-                      absl::Span<const int32_t> perm, int rank, int dim_i,
-                      size_t src_base_offset, size_t dst_base_offset) {
-  // Two cases:
-  // -- dim_i < rank - 1: iterate on dim_i; set offsets and recurse on dim_i + 1
-  // -- dim_i = rank - 1: base case, fast copy with strides and offsets
-
-  int src_stride = src_strides[perm[dim_i]];
-  int dst_stride = dst_strides[dim_i];
-  if (dim_i < rank - 1) {
-    int recurse_dim_i = dim_i + 1;
-    for (size_t i = 0; i < dst_shape[dim_i]; ++i) {
-      size_t src_offset = src_base_offset + i * src_stride;
-      size_t dst_offset = dst_base_offset + i * dst_stride;
-      TransposeRecurse(src_buffer, dst_buffer, src_shape, dst_shape,
-                       src_strides, dst_strides, perm, rank, recurse_dim_i,
-                       src_offset, dst_offset);
-    }
-  } else {
-    for (size_t i = 0; i < dst_shape[dim_i]; ++i) {
-      size_t src_i = src_base_offset + i * src_stride;
-      // Stride for the last dim of dst is always 1.
-      size_t dst_i = dst_base_offset + i;
-      dst_buffer[dst_i] = src_buffer[src_i];
-    }
-  }
-}
-
-template <typename T>
-iree_status_t Transpose::Execute(absl::Span<const T> src_buffer,
-                                 absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                                 absl::Span<const int32_t> perm) {
-  int rank = src_shape.size();
-
-  std::vector<int> src_strides(rank);
-  std::vector<int> dst_strides(rank);
-  std::vector<int32_t> dst_shape(rank);
-  size_t src_stride = 1;
-  size_t dst_stride = 1;
-  for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
-    src_strides[dim_i] = src_stride;
-    dst_strides[dim_i] = dst_stride;
-    src_stride *= src_shape[dim_i];
-    dst_stride *= src_shape[perm[dim_i]];
-    dst_shape[dim_i] = src_shape[perm[dim_i]];
-  }
-
-  // Recurse starting from the first dimension with 0 offsets.
-  int dim_i = 0;
-  size_t src_base_offset = 0;
-  size_t dst_base_offset = 0;
-  TransposeRecurse(src_buffer, dst_buffer, src_shape, dst_shape, src_strides,
-                   dst_strides, perm, rank, dim_i, src_base_offset,
-                   dst_base_offset);
-  return iree_ok_status();
-}
-
-namespace impl {
-inline void IncrementShapeIndex(absl::Span<int32_t> indices, ShapeSpan shape) {
-  for (int i = indices.size() - 1; i >= 0; --i) {
-    if (++indices[i] < shape[i]) return;
-    indices[i] = 0;
-  }
-}
-
-inline bool IsPadding(absl::Span<const int32_t> indices, ShapeSpan shape,
-                      absl::Span<const int32_t> edge_padding_low,
-                      absl::Span<const int32_t> edge_padding_high,
-                      absl::Span<const int32_t> interior_padding) {
-  for (int i = 0; i < indices.size(); ++i) {
-    auto index = indices[i];
-    if (index < edge_padding_low[i] ||
-        index >= shape[i] - edge_padding_high[i] ||
-        (index - edge_padding_low[i]) % (interior_padding[i] + 1) != 0) {
-      return true;
-    }
-  }
-
-  return false;
-}
-}  // namespace impl
-
-template <typename T>
-iree_status_t Pad::Execute(absl::Span<const T> src_buffer,
-                           absl::Span<const T> padding_value_buffer,
-                           absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                           ShapeSpan dst_shape,
-                           absl::Span<const int32_t> edge_padding_low,
-                           absl::Span<const int32_t> edge_padding_high,
-                           absl::Span<const int32_t> interior_padding) {
-  // This implementation is not at all fast, as it iterates every index in the
-  // destination buffer individually. Potential improvements:
-  // 1. Fill the dst buffer with padded value initially. Only need to iterate
-  //    through source buffer and can exit early.
-  // 2. Use striding to advance through larger swaths of the buffer with a
-  //    memcpy from src and filling (or skipping) padded incides. Especially
-  //    useful when e.g. entire rows are padded.
-
-  // TODO(b/140836672) support negative padding
-
-  if (padding_value_buffer.size() != 1) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "adding value buffer is larger than one element");
-  }
-  auto padding_value = padding_value_buffer.front();
-
-  std::vector<int> dst_indices(src_shape.size(), 0);
-
-  const T* src_ptr = src_buffer.begin();
-  T* dst_ptr = dst_buffer.begin();
-  while (dst_ptr != dst_buffer.end()) {
-    if (impl::IsPadding(dst_indices, dst_shape, edge_padding_low,
-                        edge_padding_high, interior_padding)) {
-      *dst_ptr++ = padding_value;
-    } else {
-      IREE_DCHECK(src_ptr != src_buffer.end());
-      *dst_ptr++ = *src_ptr++;
-    }
-    impl::IncrementShapeIndex(absl::MakeSpan(dst_indices), dst_shape);
-  }
-
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Gather::Execute(absl::Span<const T> src_buffer,
-                              absl::Span<const int32_t> indices_buffer,
-                              absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                              ShapeSpan indices_shape, ShapeSpan dst_shape,
-                              const int32_t dim, const int32_t batch_dims) {
-  std::vector<int32_t> output_strides(dst_shape.size(), 1);
-  std::vector<int32_t> input_strides(src_shape.size(), 1);
-  std::vector<int32_t> indices_strides(indices_shape.size(), 1);
-  auto compute_strides = [](ShapeSpan shape, std::vector<int32_t>& strides) {
-    for (int i = shape.size() - 2; i >= 0; --i) {
-      strides[i] = strides[i + 1] * shape[i + 1];
-    }
-  };
-  compute_strides(dst_shape, output_strides);
-  compute_strides(src_shape, input_strides);
-  compute_strides(indices_shape, indices_strides);
-  size_t outer_size = 1, batching_size = 1;
-  for (size_t i = 0; i < batch_dims; ++i) {
-    batching_size *= src_shape[i];
-  }
-  for (size_t i = batch_dims; i < dim; ++i) {
-    outer_size *= src_shape[i];
-  }
-  // stride for batch outer dims.
-  size_t batch_stride = 1;
-  for (size_t i = batch_dims; i > 0; --i) {
-    batch_stride *= src_shape[i];
-  }
-  const size_t input_stride =
-      dim > 0 ? input_strides[dim - 1] : input_strides[0];
-  const size_t output_stride =
-      dim > 0 ? output_strides[dim - 1] : output_strides[0];
-  const size_t slize_size = input_strides[dim];
-  const int indices_size =
-      indices_shape.size() == 0
-          ? 1
-          : indices_shape[batch_dims] * indices_strides[batch_dims];
-  // This is equivalent to the linearized version of followng array expression:
-  // clang-format off
-  // dst[d_0,...,d_{dim-1},                     i_B,...,i_{M-1}, d_{dim+1},...,d_{N-1}] =
-  // src[d_0,...,d_{dim-1},indices[d_0,...,d_1, i_B,...,i_{M-1}, d_{dim+1},...,d_{N-1}]
-  // clang-format on
-  // see:https://www.tensorflow.org/api_docs/python/tf/gather
-  // TODO(ataei): Shrink inner loop by scanning indices_buffer for
-  // contiguous indices and collide the copy of these slices.
-  for (size_t b = 0; b < batching_size; ++b) {
-    for (size_t i = 0; i < outer_size; ++i) {
-      const int index = b * outer_size + i;
-      for (size_t j = 0; j < indices_size; ++j) {
-        const int indices_batching_stride =
-            batch_dims > 0 ? indices_strides[batch_dims - 1] : 1;
-        const int indices_index = b * indices_batching_stride + j;
-        const size_t dst_offset = index * output_stride + j * slize_size;
-        const size_t src_offset =
-            index * input_stride + indices_buffer[indices_index] * slize_size;
-        std::memcpy(dst_buffer.data() + dst_offset,
-                    src_buffer.data() + src_offset, sizeof(T) * slize_size);
-      }
-    }
-  }
-  return iree_ok_status();
-}
-
-namespace impl {
-template <typename T>
-iree_status_t ScatterCopy(absl::Span<const T> src_buffer,
-                          absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                          ShapeSpan dst_shape) {
-  if (src_shape.empty()) {
-    dst_buffer[0] = src_buffer[0];
-    return iree_ok_status();
-  }
-
-  // Scatter cannot subscatter, it must be legal across he entire shape.
-  // Therefore if the src and dst shape match we can copy the full bytes over.
-  if (src_shape == dst_shape) {
-    memcpy(dst_buffer.data(), src_buffer.data(), src_buffer.size() * sizeof(T));
-    return iree_ok_status();
-  }
-
-  auto src_stride = 1;
-  for (auto size : src_shape.subspan(1)) {
-    src_stride *= size;
-  }
-
-  auto dst_stride = 1;
-  for (auto size : dst_shape.subspan(1)) {
-    dst_stride *= size;
-  }
-
-  for (int i = 0; i < src_shape[0]; i++) {
-    IREE_RETURN_IF_ERROR(
-        ScatterCopy(src_buffer.subspan(i * src_stride, src_stride),
-                    dst_buffer.subspan(i * dst_stride, dst_stride),
-                    src_shape.subspan(1), dst_shape.subspan(1)));
-  }
-
-  return iree_ok_status();
-}
-
-// Scatter helper compute the offset into src buffer, removing the dependency
-// on the indices buffer.
-template <typename T>
-iree_status_t ScatterHelper(absl::Span<const T> src_buffer,
-                            absl::Span<const int32_t> indices_buffer,
-                            absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                            ShapeSpan dst_shape) {
-  size_t offset = 0;
-  for (int i = 0; i < indices_buffer.size(); i++) {
-    offset = offset * dst_shape[i] + indices_buffer[i];
-  }
-
-  for (int i = indices_buffer.size(); i < dst_shape.size(); i++) {
-    offset *= dst_shape[i];
-  }
-
-  if ((src_shape.size() + indices_buffer.size()) != dst_shape.size()) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "attempting to scatter to differing dimensions");
-  }
-
-  IREE_RETURN_IF_ERROR(ScatterCopy(src_buffer, dst_buffer.subspan(offset),
-                                   src_shape,
-                                   dst_shape.subspan(indices_buffer.size())));
-
-  return iree_ok_status();
-}
-}  // namespace impl
-
-template <typename T>
-iree_status_t Scatter::Execute(absl::Span<const T> src_buffer,
-                               absl::Span<const int32_t> indices_buffer,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               ShapeSpan indices_shape, ShapeSpan dst_shape) {
-  int indices_rank = indices_shape.size();
-
-  // First dimension of indices is the batch update.
-  int32_t batch_size = 1;
-  if (indices_rank > 0) {
-    batch_size = indices_shape[0];
-  }
-
-  // Second dimensions of indices is the indice offset to scatter along.
-  int32_t indices_size = 1;
-  if (indices_rank > 1) {
-    indices_size = indices_shape[1];
-  }
-
-  // Compute the source size per scatter.
-  int32_t src_size = 1;
-  for (auto val : src_shape.subspan(1)) {
-    src_size *= val;
-  }
-
-  for (int i = 0; i < batch_size; i++) {
-    IREE_RETURN_IF_ERROR(impl::ScatterHelper(
-        src_buffer.subspan(i * src_size, src_size),
-        indices_buffer.subspan(i * indices_size, indices_size), dst_buffer,
-        src_shape.subspan(1), dst_shape));
-  }
-
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Reverse::Execute(absl::Span<const T> src_buffer,
-                               absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                               absl::Span<const int32_t> dimensions) {
-  // This implementation is not fast either.
-  int rank = src_shape.size();
-  std::vector<int> strides(rank);
-  size_t stride = 1;
-  for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
-    strides[dim_i] = stride;
-    stride *= src_shape[dim_i];
-  }
-  std::unordered_set<int32_t> dims_set(dimensions.begin(), dimensions.end());
-  for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
-    size_t src_i = 0;
-    size_t t = dst_i;
-    for (int dim_i = 0; dim_i < rank; ++dim_i) {
-      size_t ratio = t / strides[dim_i];
-      t -= ratio * strides[dim_i];
-      bool do_reverse = dims_set.count(dim_i) > 0;
-      src_i += (do_reverse ? (src_shape[dim_i] - 1 - ratio) : ratio) *
-               strides[dim_i];
-    }
-    dst_buffer[dst_i] = src_buffer[src_i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Sort::Execute(absl::Span<const T> src_buffer,
-                            absl::Span<int32_t> dst_buffer,
-                            ShapeSpan src_shape) {
-  int elements = src_buffer.size();
-  const int sort_size = src_shape.back();
-
-  for (int i = 0; i < elements; i += sort_size) {
-    auto src_subspan = src_buffer.subspan(i, sort_size);
-    auto dst_subspan = dst_buffer.subspan(i, sort_size);
-    std::iota(dst_subspan.begin(), dst_subspan.end(), 0);
-    std::stable_sort(dst_subspan.begin(), dst_subspan.end(),
-                     [&src_subspan](int32_t i1, int32_t i2) {
-                       return src_subspan[i1] < src_subspan[i2];
-                     });
-  }
-
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Broadcast::Execute(absl::Span<const T> src_buffer,
-                                 absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = src_buffer[0];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Iota::Execute(absl::Span<T> dst_buffer) {
-  T value = 0;
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = value;
-    value += 1;
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Tile::Execute(absl::Span<const T> src_buffer,
-                            absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                            ShapeSpan dst_shape) {
-  // This implementation is .... not fast.
-  int rank = dst_shape.size();
-  std::vector<int> src_strides(rank);
-  std::vector<int> dst_strides(rank);
-  size_t src_stride = 1;
-  size_t dst_stride = 1;
-  for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
-    src_strides[dim_i] = src_stride;
-    dst_strides[dim_i] = dst_stride;
-    src_stride *= src_shape[dim_i];
-    dst_stride *= dst_shape[dim_i];
-  }
-  for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
-    size_t src_i = 0;
-    size_t t = dst_i;
-    for (int dim_i = 0; dim_i < rank; ++dim_i) {
-      src_i += t / dst_strides[dim_i] % src_shape[dim_i] * src_strides[dim_i];
-      t %= dst_strides[dim_i];
-    }
-    dst_buffer[dst_i] = src_buffer[src_i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Not::Execute(absl::Span<const T> src_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = ~src_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t And::Execute(absl::Span<const T> lhs_buffer,
-                           absl::Span<const T> rhs_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] & rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t And::Execute(absl::Span<const T> lhs_buffer, T rhs,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] & rhs;
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Or::Execute(absl::Span<const T> lhs_buffer,
-                          absl::Span<const T> rhs_buffer,
-                          absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] | rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Xor::Execute(absl::Span<const T> lhs_buffer,
-                           absl::Span<const T> rhs_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] ^ rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Xor::Execute(absl::Span<const T> lhs_buffer, T rhs,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] ^ rhs;
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t ShiftLeft::Execute(absl::Span<const T> lhs_buffer,
-                                 absl::Span<const T> rhs_buffer,
-                                 absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] << rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t ShiftRight::Execute(absl::Span<const T> lhs_buffer,
-                                  absl::Span<const T> rhs_buffer,
-                                  absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] >> rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Add::Execute(absl::Span<const T> lhs_buffer,
-                           absl::Span<const T> rhs_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] + rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Sub::Execute(absl::Span<const T> lhs_buffer,
-                           absl::Span<const T> rhs_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] - rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Abs::Execute(absl::Span<const T> src_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::abs(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Neg::Execute(absl::Span<const T> src_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = -src_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Mul::Execute(absl::Span<const T> lhs_buffer,
-                           absl::Span<const T> rhs_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] * rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Div::Execute(absl::Span<const T> lhs_buffer,
-                           absl::Span<const T> rhs_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = lhs_buffer[i] / rhs_buffer[i];
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Rem::Execute(absl::Span<const T> lhs_buffer,
-                           absl::Span<const T> rhs_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = remainder(lhs_buffer[i], rhs_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Pow::Execute(absl::Span<const T> lhs_buffer,
-                           absl::Span<const T> rhs_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::pow(lhs_buffer[i], rhs_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Exp::Execute(absl::Span<const T> src_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::exp(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Rsqrt::Execute(absl::Span<const T> src_buffer,
-                             absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = 1.0 / std::sqrt(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Sqrt::Execute(absl::Span<const T> src_buffer,
-                            absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::sqrt(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Log::Execute(absl::Span<const T> src_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::log(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Cos::Execute(absl::Span<const T> src_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::cos(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Sin::Execute(absl::Span<const T> src_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::sin(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Tanh::Execute(absl::Span<const T> src_buffer,
-                            absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::tanh(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Atan2::Execute(absl::Span<const T> lhs_buffer,
-                             absl::Span<const T> rhs_buffer,
-                             absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::atan2(lhs_buffer[i], rhs_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Min::Execute(absl::Span<const T> lhs_buffer,
-                           absl::Span<const T> rhs_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::min(lhs_buffer[i], rhs_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Max::Execute(absl::Span<const T> lhs_buffer,
-                           absl::Span<const T> rhs_buffer,
-                           absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::max(lhs_buffer[i], rhs_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Clamp::Execute(absl::Span<const T> min_buffer,
-                             absl::Span<const T> src_buffer,
-                             absl::Span<const T> max_buffer,
-                             absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    T src = src_buffer[i];
-    T min = min_buffer[i];
-    T max = max_buffer[i];
-    dst_buffer[i] = src <= min ? min : src >= max ? max : src;
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Floor::Execute(absl::Span<const T> src_buffer,
-                             absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::floor(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Ceil::Execute(absl::Span<const T> src_buffer,
-                            absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::ceil(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename T>
-iree_status_t Round::Execute(absl::Span<const T> src_buffer,
-                             absl::Span<T> dst_buffer) {
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = std::round(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-template <typename SRC, typename DST>
-iree_status_t Convert::Execute(absl::Span<const SRC> src_buffer,
-                               absl::Span<DST> dst_buffer) {
-  IREE_DCHECK_EQ(src_buffer.size(), dst_buffer.size());
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    dst_buffer[i] = static_cast<DST>(src_buffer[i]);
-  }
-  return iree_ok_status();
-}
-
-namespace impl {
-
-struct SumKernel {
-  template <typename T>
-  inline void operator()(T* value0, const T value1) {
-    *value0 += value1;
-  }
-};
-
-struct MinKernel {
-  template <typename T>
-  inline void operator()(T* value0, const T value1) {
-    *value0 = std::min(*value0, value1);
-  }
-};
-
-struct MaxKernel {
-  template <typename T>
-  inline void operator()(T* value0, const T value1) {
-    *value0 = std::max(*value0, value1);
-  }
-};
-
-struct AndKernel {
-  template <typename T>
-  inline void operator()(T* value0, const T value1) {
-    *value0 = *value0 && value1;
-  }
-};
-
-struct OrKernel {
-  template <typename T>
-  inline void operator()(T* value0, const T value1) {
-    *value0 = *value0 || value1;
-  }
-};
-
-template <typename T, typename KernelImpl>
-inline void ReduceDimension(absl::Span<const T> src_buffer,
-                            absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                            absl::Span<const int32_t> reduce_dims,
-                            absl::Span<const int> dst_strides, int dim,
-                            absl::Span<int> src_indices, size_t flat_src_i,
-                            size_t src_stride) {
-  if (dim < 0) {
-    // Base case of the recursion - figure out which elements should be acted
-    // upon and apply the reduction kernel to them.
-
-    // Derive destination indices from source indices.
-    // For example,
-    //     reduce_dims: [1, 2]
-    //     src_indices: [2, 1, 3, 0]
-    //                      ^  ^
-    //                      |  |
-    //                      |----- remove these dimensions
-    //     dst_indices: [2, 0]
-    //
-    // TODO(scotttodd): Clean this up somehow, share across recursion levels?
-    size_t dst_size = src_shape.size() - reduce_dims.size();
-    std::vector<int> dst_indices;
-    dst_indices.reserve(src_indices.size());
-    for (size_t i = 0; i < src_indices.size(); ++i) {
-      if (std::find(std::begin(reduce_dims), std::end(reduce_dims), i) ==
-          std::end(reduce_dims)) {
-        dst_indices.push_back(src_indices[i]);
-      }
-    }
-    // Compute the flattened index into dst_buffer at [dst_indices].
-    size_t dst_i = 0;
-    for (size_t i = 0; i < dst_indices.size(); ++i) {
-      dst_i += dst_indices[i] * dst_strides[dst_size - 1 - i];
-    }
-
-    // Flattened src and dst indices have been computed, invoke the kernel.
-    KernelImpl()(&dst_buffer[dst_i], src_buffer[flat_src_i]);
-    return;
-  }
-
-  // Iterate through the current dimension in the source shape, recursing
-  // down one dimension at a time.
-  //
-  // This touches each element in the source buffer once, tracking complete
-  // dimensions within the shaped source buffer and using them to compute
-  // the corresponding indices (shaped and flattened) within the destination
-  // buffer. Each element in the destination buffer will be touched multiple
-  // times.
-  //
-  // Note that cache coherency isn't considered here, and some computations
-  // are redundant, so this could be optimized substantially.
-  for (size_t dim_i = 0; dim_i < src_shape[dim]; ++dim_i) {
-    src_indices[dim] = dim_i;
-
-    // Recurse down to the next dimension (e.g. 2 -> 1 -> 0 -> base case)
-    //   * Add the current stride to flat_src_i
-    //   * Multiply src_stride by this dimension's shape
-    ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape,
-                                   reduce_dims, dst_strides, dim - 1,
-                                   src_indices, flat_src_i + dim_i * src_stride,
-                                   src_stride * src_shape[dim]);
-  }
-}
-
-template <typename T, typename KernelImpl>
-iree_status_t GenericReduce(absl::Span<const T> src_buffer,
-                            absl::Span<const T> init_buffer,
-                            absl::Span<T> dst_buffer, int32_t dimension,
-                            ShapeSpan src_shape, ShapeSpan dst_shape) {
-  // Initialize using init_buffer, which is expected to be a scalar.
-  std::fill_n(dst_buffer.data(), dst_buffer.size(), init_buffer[0]);
-
-  // Precompute destination strides.
-  int dst_rank = dst_shape.size();
-  std::vector<int> dst_strides;
-  size_t dst_stride = 1;
-  for (int dim_i = dst_rank - 1; dim_i >= 0; --dim_i) {
-    dst_strides.push_back(dst_stride);
-    dst_stride *= dst_shape[dim_i];
-  }
-
-  // Call the helper (recursive) function, starting with:
-  //   * source index [0, 0, ..., 0]
-  //   * the innermost dimension (last in the shape)
-  //   * flat_src_i of 0 (corresponds to [0, 0, ..., 0] above)
-  //   * source stride 1
-  std::vector<int> src_indices(src_shape.size(), 0);
-  ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape, {dimension},
-                                 absl::MakeSpan(dst_strides),
-                                 src_shape.size() - 1,
-                                 absl::MakeSpan(src_indices), 0, 1);
-
-  return iree_ok_status();
-}
-
-}  // namespace impl
-
-template <typename T>
-iree_status_t ReduceSum::Execute(absl::Span<const T> src_buffer,
-                                 absl::Span<const T> init_buffer,
-                                 absl::Span<T> dst_buffer, int32_t dimension,
-                                 ShapeSpan src_shape, ShapeSpan dst_shape) {
-  return impl::GenericReduce<T, impl::SumKernel>(
-      src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
-}
-
-template <typename T>
-iree_status_t ReduceMin::Execute(absl::Span<const T> src_buffer,
-                                 absl::Span<const T> init_buffer,
-                                 absl::Span<T> dst_buffer, int32_t dimension,
-                                 ShapeSpan src_shape, ShapeSpan dst_shape) {
-  return impl::GenericReduce<T, impl::MinKernel>(
-      src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
-}
-
-template <typename T>
-iree_status_t ReduceMax::Execute(absl::Span<const T> src_buffer,
-                                 absl::Span<const T> init_buffer,
-                                 absl::Span<T> dst_buffer, int32_t dimension,
-                                 ShapeSpan src_shape, ShapeSpan dst_shape) {
-  return impl::GenericReduce<T, impl::MaxKernel>(
-      src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
-}
-
-template <typename T>
-iree_status_t ReduceAnd::Execute(absl::Span<const T> src_buffer,
-                                 absl::Span<const T> init_buffer,
-                                 absl::Span<T> dst_buffer, int32_t dimension,
-                                 ShapeSpan src_shape, ShapeSpan dst_shape) {
-  return impl::GenericReduce<T, impl::AndKernel>(
-      src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
-}
-
-template <typename T>
-iree_status_t ReduceOr::Execute(absl::Span<const T> src_buffer,
-                                absl::Span<const T> init_buffer,
-                                absl::Span<T> dst_buffer, int32_t dimension,
-                                ShapeSpan src_shape, ShapeSpan dst_shape) {
-  return impl::GenericReduce<T, impl::OrKernel>(
-      src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
-}
-
-namespace impl {
-
-template <typename T, typename KernelImpl>
-void ComputePoolingWindow(absl::Span<const T> src_buffer,
-                          absl::Span<const int> src_indices,
-                          ShapeSpan src_shape, T init_value,
-                          ShapeSpan window_dimensions, T* dst_value) {
-  size_t rank = src_shape.size();
-  std::vector<int> window_indices(rank, 0);
-  auto getSrcValue = [&]() -> T {
-    size_t flat_idx = 0;
-    for (size_t i = 0; i < rank; ++i) {
-      size_t idx = src_indices[i] + window_indices[i];
-      if (idx < 0 || idx >= src_shape[i]) return init_value;
-      flat_idx = flat_idx * src_shape[i] + idx;
-    }
-    return src_buffer[flat_idx];
-  };
-
-  *dst_value = init_value;
-  for (size_t i = 0, e = GetElementCount(window_dimensions); i < e; ++i) {
-    KernelImpl()(dst_value, getSrcValue());
-    IncrementShapeIndex(absl::MakeSpan(window_indices), window_dimensions);
-  }
-}
-
-template <typename T, typename KernelImpl>
-iree_status_t GenericPooling(absl::Span<const T> src_buffer,
-                             absl::Span<const T> init_buffer,
-                             absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                             ShapeSpan dst_shape, ShapeSpan window_dimensions,
-                             ShapeSpan strides, ShapeSpan pad_low) {
-  size_t rank = src_shape.size();
-  std::vector<int> src_indices(rank, 0);
-  std::vector<int> dst_indices(rank, 0);
-  for (size_t i = 0, e = GetElementCount(dst_shape); i < e; ++i) {
-    for (size_t j = 0; j < rank; ++j) {
-      src_indices[j] = dst_indices[j] * strides[j] - pad_low[j];
-    }
-    ComputePoolingWindow<T, KernelImpl>(src_buffer, src_indices, src_shape,
-                                        init_buffer[0], window_dimensions,
-                                        &dst_buffer[i]);
-    IncrementShapeIndex(absl::MakeSpan(dst_indices), dst_shape);
-  }
-  return iree_ok_status();
-}
-
-}  // namespace impl
-
-template <typename T>
-iree_status_t PoolingSum::Execute(absl::Span<const T> src_buffer,
-                                  absl::Span<const T> init_buffer,
-                                  absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                                  ShapeSpan dst_shape,
-                                  ShapeSpan window_dimensions,
-                                  ShapeSpan strides, ShapeSpan pad_low) {
-  return impl::GenericPooling<T, impl::SumKernel>(
-      src_buffer, init_buffer, dst_buffer, src_shape, dst_shape,
-      window_dimensions, strides, pad_low);
-}
-
-template <typename T>
-iree_status_t PoolingMin::Execute(absl::Span<const T> src_buffer,
-                                  absl::Span<const T> init_buffer,
-                                  absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                                  ShapeSpan dst_shape,
-                                  ShapeSpan window_dimensions,
-                                  ShapeSpan strides, ShapeSpan pad_low) {
-  return impl::GenericPooling<T, impl::MinKernel>(
-      src_buffer, init_buffer, dst_buffer, src_shape, dst_shape,
-      window_dimensions, strides, pad_low);
-}
-
-template <typename T>
-iree_status_t PoolingMax::Execute(absl::Span<const T> src_buffer,
-                                  absl::Span<const T> init_buffer,
-                                  absl::Span<T> dst_buffer, ShapeSpan src_shape,
-                                  ShapeSpan dst_shape,
-                                  ShapeSpan window_dimensions,
-                                  ShapeSpan strides, ShapeSpan pad_low) {
-  return impl::GenericPooling<T, impl::MaxKernel>(
-      src_buffer, init_buffer, dst_buffer, src_shape, dst_shape,
-      window_dimensions, strides, pad_low);
-}
-
-}  // namespace kernels
-}  // namespace vmla
-}  // namespace hal
-}  // namespace iree
-
-#endif  // IREE_MODULES_VMLA_OP_KERNELS_GENERIC_H_
diff --git a/iree/modules/vmla/op_kernels_ruy.h b/iree/modules/vmla/op_kernels_ruy.h
deleted file mode 100644
index 4970a58..0000000
--- a/iree/modules/vmla/op_kernels_ruy.h
+++ /dev/null
@@ -1,131 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_MODULES_VMLA_OP_KERNELS_RUY_H_
-#define IREE_MODULES_VMLA_OP_KERNELS_RUY_H_
-
-#include <memory>
-#include <type_traits>
-
-#include "iree/base/status.h"
-#include "ruy/context.h"
-#include "ruy/mul_params.h"
-#include "ruy/ruy.h"
-
-namespace iree {
-namespace hal {
-namespace vmla {
-namespace kernels {
-
-// TODO(benvanik): something more clever for making this shareable.
-// Maybe a factory fn based on the impl selected?
-struct MatMul::RuntimeState {
-  // TODO(benvanik): share the thread pool but keep context per-fiber?
-  ruy::Context context;
-};
-
-inline std::unique_ptr<MatMul::RuntimeState> MatMul::CreateRuntimeState() {
-  return std::make_unique<RuntimeState>();
-}
-
-// Floating-point case.
-template <typename LhsEl, typename RhsEl, typename AccumEl, typename DstEl>
-struct MakeRuyMulParamsImpl {
-  static_assert(std::is_floating_point<LhsEl>::value, "");
-  static_assert(std::is_floating_point<RhsEl>::value, "");
-  static_assert(std::is_floating_point<AccumEl>::value, "");
-  static_assert(std::is_floating_point<DstEl>::value, "");
-  static void Run(const MatMul::Buffers<LhsEl, RhsEl, AccumEl, DstEl>& buffers,
-                  ruy::MulParams<AccumEl, DstEl>* mul_params) {
-    mul_params->set_bias(buffers.bias_buffer.data());
-  }
-};
-
-// Raw integer case with int32 destination. This case does not support any
-// output operation besides bias-addition.
-template <typename LhsEl, typename RhsEl>
-struct MakeRuyMulParamsImpl<LhsEl, RhsEl, std::int32_t, std::int32_t> {
-  static void Run(
-      const MatMul::Buffers<LhsEl, RhsEl, std::int32_t, std::int32_t>& buffers,
-      ruy::MulParams<std::int32_t, std::int32_t>* mul_params) {
-    mul_params->set_bias(buffers.bias_buffer.data());
-  }
-};
-
-// Integer quantized case with downquantization to a destination DstEl narrower
-// than int32.
-template <typename LhsEl, typename RhsEl, typename DstEl>
-struct MakeRuyMulParamsImpl<LhsEl, RhsEl, std::int32_t, DstEl> {
-  static_assert(std::is_integral<LhsEl>::value, "");
-  static_assert(std::is_integral<RhsEl>::value, "");
-  static_assert(std::is_integral<DstEl>::value, "");
-  static_assert(sizeof(DstEl) < sizeof(std::int32_t), "");
-  static void Run(
-      const MatMul::Buffers<LhsEl, RhsEl, std::int32_t, DstEl>& buffers,
-      ruy::MulParams<std::int32_t, DstEl>* mul_params) {
-    mul_params->set_bias(buffers.bias_buffer.data());
-    if (buffers.multiplier_mantissa_buffer.size() == 1) {
-      mul_params->set_multiplier_fixedpoint(
-          buffers.multiplier_mantissa_buffer[0]);
-      mul_params->set_multiplier_exponent(
-          buffers.multiplier_exponent_buffer[0]);
-    } else {
-      mul_params->set_multiplier_fixedpoint_perchannel(
-          buffers.multiplier_mantissa_buffer.data());
-      mul_params->set_multiplier_exponent_perchannel(
-          buffers.multiplier_exponent_buffer.data());
-    }
-  }
-};
-
-template <typename LhsEl, typename RhsEl, typename AccumEl, typename DstEl>
-void MakeRuyMulParams(
-    const MatMul::Buffers<LhsEl, RhsEl, AccumEl, DstEl>& buffers,
-    ruy::MulParams<AccumEl, DstEl>* mul_params) {
-  MakeRuyMulParamsImpl<LhsEl, RhsEl, AccumEl, DstEl>::Run(buffers, mul_params);
-}
-
-template <typename LhsEl, typename RhsEl, typename AccumEl, typename DstEl>
-iree_status_t MatMul::Execute(
-    RuntimeState* runtime_state,
-    const Buffers<LhsEl, RhsEl, AccumEl, DstEl>& buffers) {
-  ruy::Matrix<LhsEl> lhs;
-  lhs.set_data(buffers.lhs_buffer.data());
-  ruy::MakeSimpleLayout(buffers.lhs_shape[0], buffers.lhs_shape[1],
-                        ruy::Order::kRowMajor, lhs.mutable_layout());
-
-  ruy::Matrix<RhsEl> rhs;
-  rhs.set_data(buffers.rhs_buffer.data());
-  ruy::MakeSimpleLayout(buffers.rhs_shape[1], buffers.rhs_shape[0],
-                        ruy::Order::kColMajor, rhs.mutable_layout());
-
-  ruy::Matrix<DstEl> dst;
-  dst.set_data(buffers.dst_buffer.data());
-  ruy::MakeSimpleLayout(buffers.dst_shape[1], buffers.dst_shape[0],
-                        ruy::Order::kColMajor, dst.mutable_layout());
-
-  ruy::MulParams<AccumEl, DstEl> mul_params;
-  MakeRuyMulParams(buffers, &mul_params);
-
-  ruy::Mul(lhs, rhs, mul_params, &runtime_state->context, &dst);
-
-  return iree_ok_status();
-}
-
-}  // namespace kernels
-}  // namespace vmla
-}  // namespace hal
-}  // namespace iree
-
-#endif  // IREE_MODULES_VMLA_OP_KERNELS_RUY_H_
diff --git a/iree/modules/vmla/op_kernels_test.cc b/iree/modules/vmla/op_kernels_test.cc
deleted file mode 100644
index e96cfd0..0000000
--- a/iree/modules/vmla/op_kernels_test.cc
+++ /dev/null
@@ -1,598 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/modules/vmla/op_kernels.h"
-
-#include <vector>
-
-#include "iree/testing/gtest.h"
-#include "iree/testing/status_matchers.h"
-
-namespace iree {
-namespace hal {
-namespace vmla {
-namespace kernels {
-
-namespace {
-
-constexpr float kEpsilon = 0.0001f;
-
-using Shape = std::vector<int32_t>;
-
-// reinterpret_cast for Spans, preserving byte size.
-template <typename T, typename U>
-constexpr absl::Span<const T> ReinterpretSpan(absl::Span<const U> value) {
-  return absl::MakeSpan(reinterpret_cast<const T*>(value.data()),
-                        (value.size() * sizeof(U)) / sizeof(T));
-}
-template <typename T, typename U>
-constexpr absl::Span<T> ReinterpretSpan(absl::Span<U> value) {
-  return absl::MakeSpan(reinterpret_cast<T*>(value.data()),
-                        (value.size() * sizeof(U)) / sizeof(T));
-}
-
-size_t GetShapeElementCount(const Shape& shape) {
-  size_t count = 1;
-  for (size_t i = 0; i < shape.size(); ++i) {
-    count *= shape[i];
-  }
-  return count;
-}
-
-template <typename T>
-std::vector<T> MakeIota(size_t size) {
-  std::vector<T> v(size);
-  std::iota(v.begin(), v.end(), static_cast<T>(1));
-  return v;
-}
-
-TEST(Copy, WholeBuffer) {
-  Shape src_shape = {2, 2};
-  auto src_buffer = MakeIota<uint8_t>(4);
-  std::vector<int32_t> src_indices = {0, 0};
-  const Shape& dst_shape = src_shape;
-  std::vector<uint8_t> dst_buffer(GetShapeElementCount(dst_shape));
-  std::vector<int32_t> dst_indices = {0, 0};
-  std::vector<int32_t> lengths = {2, 2};
-  const auto& expected_dst = src_buffer;
-
-  IREE_EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
-                                  absl::MakeSpan(dst_buffer), dst_shape,
-                                  dst_indices, lengths));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, FirstRow) {
-  Shape src_shape = {3, 4};
-  auto src_buffer = MakeIota<uint8_t>(12);
-  std::vector<int32_t> src_indices = {0, 0};
-  Shape dst_shape = {1, 4};
-  std::vector<uint8_t> dst_buffer(GetShapeElementCount(dst_shape));
-  std::vector<int32_t> dst_indices = {0, 0};
-  std::vector<int32_t> lengths = {1, 4};
-  std::vector<uint8_t> expected_dst = {1, 2, 3, 4};
-
-  IREE_EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
-                                  absl::MakeSpan(dst_buffer), dst_shape,
-                                  dst_indices, lengths));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, RowPart) {
-  Shape src_shape = {3, 4};
-  auto src_buffer = MakeIota<uint8_t>(12);
-  std::vector<int32_t> src_indices = {1, 1};
-  Shape dst_shape = {1, 2};
-  std::vector<uint8_t> dst_buffer(GetShapeElementCount(dst_shape));
-  std::vector<int32_t> dst_indices = {0, 0};
-  std::vector<int32_t> lengths = {1, 2};
-  std::vector<uint8_t> expected_dst = {6, 7};
-
-  IREE_EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
-                                  absl::MakeSpan(dst_buffer), dst_shape,
-                                  dst_indices, lengths));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, MultiRow) {
-  Shape src_shape = {3, 4};
-  auto src_buffer = MakeIota<uint8_t>(12);
-  std::vector<int32_t> src_indices = {1, 0};
-  Shape dst_shape = {2, 4};
-  std::vector<uint8_t> dst_buffer(GetShapeElementCount(dst_shape));
-  std::vector<int32_t> dst_indices = {0, 0};
-  std::vector<int32_t> lengths = {2, 4};
-  std::vector<uint8_t> expected_dst = {5, 6, 7, 8, 9, 10, 11, 12};
-
-  IREE_EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
-                                  absl::MakeSpan(dst_buffer), dst_shape,
-                                  dst_indices, lengths));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, NonContiguous) {
-  Shape src_shape = {3, 4};
-  auto src_buffer = MakeIota<uint8_t>(12);
-  std::vector<int32_t> src_indices = {1, 1};
-  Shape dst_shape = {2, 2};
-  std::vector<uint8_t> dst_buffer(GetShapeElementCount(dst_shape));
-  std::vector<int32_t> dst_indices = {0, 0};
-  std::vector<int32_t> lengths = {2, 2};
-  std::vector<uint8_t> expected_dst = {6, 7, 10, 11};
-
-  IREE_EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
-                                  absl::MakeSpan(dst_buffer), dst_shape,
-                                  dst_indices, lengths));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, MultiByte) {
-  Shape src_shape = {3, 4};
-  auto src_vals = MakeIota<int32_t>(12);
-  auto src_buffer = ReinterpretSpan<uint8_t>(absl::MakeSpan(src_vals));
-  std::vector<int32_t> src_indices = {1, 1};
-  Shape dst_shape = {2, 2};
-  std::vector<uint8_t> dst_buffer(GetShapeElementCount(dst_shape) *
-                                  sizeof(int32_t));
-  std::vector<int32_t> dst_indices = {0, 0};
-  std::vector<int32_t> lengths = {2, 2};
-  std::vector<int32_t> expected_dst = {6, 7, 10, 11};
-
-  IREE_EXPECT_OK(Copy::Execute<4>(src_buffer, src_shape, src_indices,
-                                  absl::MakeSpan(dst_buffer), dst_shape,
-                                  dst_indices, lengths));
-
-  absl::Span<int32_t> dst_buffer_int32_t =
-      ReinterpretSpan<int32_t>(absl::MakeSpan(dst_buffer));
-
-  EXPECT_EQ(dst_buffer_int32_t, expected_dst);
-}
-
-TEST(Copy, NotFullDst) {
-  Shape src_shape = {3, 4};
-  auto src_buffer = MakeIota<uint8_t>(12);
-  std::vector<int32_t> src_indices = {0, 0};
-  Shape dst_shape = {4, 3};
-  std::vector<uint8_t> dst_buffer(12, 42);
-  std::vector<int32_t> dst_indices = {1, 1};
-  std::vector<int32_t> lengths = {2, 2};
-  // clang-format off
-  std::vector<uint8_t> expected_dst = {42, 42, 42,
-                                     42,  1,  2,
-                                     42,  5,  6,
-                                     42, 42, 42};
-  // clang-format on
-
-  IREE_EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
-                                  absl::MakeSpan(dst_buffer), dst_shape,
-                                  dst_indices, lengths));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, HighRank) {
-  Shape src_shape = {3, 3, 3, 3};
-  auto src_buffer = MakeIota<uint8_t>(81);
-  std::vector<int32_t> src_indices = {1, 1, 1, 1};
-  Shape dst_shape = {2, 2, 2, 2};
-  std::vector<uint8_t> dst_buffer(GetShapeElementCount(dst_shape));
-  std::vector<int32_t> dst_indices = {0, 0, 0, 0};
-  std::vector<int32_t> lengths = {2, 2, 2, 2};
-  std::vector<uint8_t> expected_dst = {41, 42, 44, 45, 50, 51, 53, 54,
-                                       68, 69, 71, 72, 77, 78, 80, 81};
-
-  IREE_EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
-                                  absl::MakeSpan(dst_buffer), dst_shape,
-                                  dst_indices, lengths));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, Scalar) {
-  Shape src_shape = {};
-  std::vector<uint8_t> src_buffer = {42};
-  std::vector<int32_t> src_indices = {};
-  Shape dst_shape = {};
-  std::vector<uint8_t> dst_buffer(GetShapeElementCount(dst_shape));
-  std::vector<int32_t> dst_indices = {};
-  std::vector<int32_t> lengths = {};
-  std::vector<uint8_t> expected_dst = {42};
-
-  IREE_EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
-                                  absl::MakeSpan(dst_buffer), dst_shape,
-                                  dst_indices, lengths));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, ScalarMultiByte) {
-  Shape src_shape = {};
-  std::vector<int32_t> src_vals = {INT32_MAX};
-  auto src_buffer = ReinterpretSpan<uint8_t>(absl::MakeSpan(src_vals));
-  std::vector<int32_t> src_indices = {};
-  Shape dst_shape = {};
-  std::vector<uint8_t> dst_buffer(sizeof(int32_t));
-  std::vector<int32_t> dst_indices = {};
-  std::vector<int32_t> lengths = {};
-  std::vector<int32_t> expected_dst = {INT32_MAX};
-
-  IREE_EXPECT_OK(Copy::Execute<4>(src_buffer, src_shape, src_indices,
-                                  absl::MakeSpan(dst_buffer), dst_shape,
-                                  dst_indices, lengths));
-
-  absl::Span<int32_t> dst_buffer_int32_t =
-      ReinterpretSpan<int32_t>(absl::MakeSpan(dst_buffer));
-
-  EXPECT_EQ(dst_buffer_int32_t, expected_dst);
-}
-
-TEST(Pad, NoPadding) {
-  Shape src_shape = {2, 3};
-  auto src_buffer = MakeIota<uint16_t>(GetShapeElementCount(src_shape));
-  std::vector<uint16_t> pad_value_buffer = {0};
-  std::vector<int32_t> edge_padding_low = {0, 0};
-  std::vector<int32_t> edge_padding_high = {0, 0};
-  std::vector<int32_t> interior_padding = {0, 0};
-  const Shape& dst_shape = src_shape;
-  std::vector<uint16_t> dst_buffer(GetShapeElementCount(dst_shape), UINT16_MAX);
-  const auto& expected_dst = src_buffer;
-
-  IREE_EXPECT_OK(Pad::Execute<uint16_t>(
-      src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
-      dst_shape, edge_padding_low, edge_padding_high, interior_padding));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, LowHighPadding) {
-  Shape src_shape = {2, 3};
-  auto src_buffer = MakeIota<uint16_t>(GetShapeElementCount(src_shape));
-  std::vector<uint16_t> pad_value_buffer = {0};
-  std::vector<int32_t> edge_padding_low = {0, 1};
-  std::vector<int32_t> edge_padding_high = {1, 2};
-  std::vector<int32_t> interior_padding = {0, 0};
-  Shape dst_shape = {3, 6};
-  std::vector<uint16_t> dst_buffer(GetShapeElementCount(dst_shape), UINT16_MAX);
-  // clang-format off
-  std::vector<uint16_t> expected_dst = {0, 1, 2, 3, 0, 0,
-                                      0, 4, 5, 6, 0, 0,
-                                      0, 0, 0, 0, 0, 0};
-  // clang-format on
-
-  IREE_EXPECT_OK(Pad::Execute<uint16_t>(
-      src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
-      dst_shape, edge_padding_low, edge_padding_high, interior_padding));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, OnlyHighPadding) {
-  Shape src_shape = {2, 3};
-  auto src_buffer = MakeIota<uint16_t>(GetShapeElementCount(src_shape));
-  std::vector<uint16_t> pad_value_buffer = {0};
-  std::vector<int32_t> edge_padding_low = {0, 0};
-  std::vector<int32_t> edge_padding_high = {1, 3};
-  std::vector<int32_t> interior_padding = {0, 0};
-  Shape dst_shape = {3, 6};
-  std::vector<uint16_t> dst_buffer(GetShapeElementCount(dst_shape), UINT16_MAX);
-  // clang-format off
-  std::vector<uint16_t> expected_dst = {1, 2, 3, 0, 0, 0,
-                                      4, 5, 6, 0, 0, 0,
-                                      0, 0, 0, 0, 0, 0};
-  // clang-format on
-
-  IREE_EXPECT_OK(Pad::Execute<uint16_t>(
-      src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
-      dst_shape, edge_padding_low, edge_padding_high, interior_padding));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, OnlyLowPadding) {
-  Shape src_shape = {2, 3};
-  auto src_buffer = MakeIota<uint16_t>(GetShapeElementCount(src_shape));
-  std::vector<uint16_t> pad_value_buffer = {0};
-  std::vector<int32_t> edge_padding_low = {1, 3};
-  std::vector<int32_t> edge_padding_high = {0, 0};
-  std::vector<int32_t> interior_padding = {0, 0};
-  Shape dst_shape = {3, 6};
-  std::vector<uint16_t> dst_buffer(GetShapeElementCount(dst_shape), UINT16_MAX);
-  // clang-format off
-  std::vector<uint16_t> expected_dst = {0, 0, 0, 0, 0, 0,
-                                      0, 0, 0, 1, 2, 3,
-                                      0, 0, 0, 4, 5, 6};
-  // clang-format on
-
-  IREE_EXPECT_OK(Pad::Execute<uint16_t>(
-      src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
-      dst_shape, edge_padding_low, edge_padding_high, interior_padding));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, OnlyInteriorPadding) {
-  Shape src_shape = {2, 3};
-  auto src_buffer = MakeIota<uint16_t>(GetShapeElementCount(src_shape));
-  std::vector<uint16_t> pad_value_buffer = {0};
-  std::vector<int32_t> edge_padding_low = {0, 0};
-  std::vector<int32_t> edge_padding_high = {0, 0};
-  std::vector<int32_t> interior_padding = {1, 1};
-  Shape dst_shape = {3, 5};
-  std::vector<uint16_t> dst_buffer(GetShapeElementCount(dst_shape), UINT16_MAX);
-  // clang-format off
-  std::vector<uint16_t> expected_dst = {1, 0, 2, 0, 3,
-                                      0, 0, 0, 0, 0,
-                                      4, 0, 5, 0, 6};
-  // clang-format on
-
-  IREE_EXPECT_OK(Pad::Execute<uint16_t>(
-      src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
-      dst_shape, edge_padding_low, edge_padding_high, interior_padding));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, AllPaddingTypes) {
-  Shape src_shape = {2, 3};
-  auto src_buffer = MakeIota<uint16_t>(GetShapeElementCount(src_shape));
-  std::vector<uint16_t> pad_value_buffer = {0};
-  std::vector<int32_t> edge_padding_low = {1, 1};
-  std::vector<int32_t> edge_padding_high = {1, 2};
-  std::vector<int32_t> interior_padding = {1, 1};
-  Shape dst_shape = {5, 8};
-  std::vector<uint16_t> dst_buffer(GetShapeElementCount(dst_shape), UINT16_MAX);
-  // clang-format off
-  std::vector<uint16_t> expected_dst = {0, 0, 0, 0, 0, 0, 0, 0,
-                                      0, 1, 0, 2, 0, 3, 0, 0,
-                                      0, 0, 0, 0, 0, 0, 0, 0,
-                                      0, 4, 0, 5, 0, 6, 0, 0,
-                                      0, 0, 0, 0, 0, 0, 0, 0};
-  // clang-format on
-
-  IREE_EXPECT_OK(Pad::Execute<uint16_t>(
-      src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
-      dst_shape, edge_padding_low, edge_padding_high, interior_padding));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, HighRank) {
-  Shape src_shape = {2, 2, 2, 2};
-  auto src_buffer = MakeIota<uint16_t>(GetShapeElementCount(src_shape));
-  std::vector<uint16_t> pad_value_buffer = {0};
-  std::vector<int32_t> edge_padding_low = {1, 0, 0, 0};
-  std::vector<int32_t> edge_padding_high = {0, 1, 0, 0};
-  std::vector<int32_t> interior_padding = {0, 0, 1, 0};
-  Shape dst_shape = {3, 3, 3, 2};
-  std::vector<uint16_t> dst_buffer(GetShapeElementCount(dst_shape), UINT16_MAX);
-  // clang-format off
-  std::vector<uint16_t> expected_dst = { 0,  0,   0, 0,   0,  0,
-                                       0,  0,   0, 0,   0,  0,
-                                       0,  0,   0, 0,   0,  0,
-
-                                       1,  2,   0, 0,   3,  4,
-                                       5,  6,   0, 0,   7,  8,
-                                       0,  0,   0, 0,   0,  0,
-
-                                       9, 10,   0, 0,  11, 12,
-                                      13, 14,   0, 0,  15, 16,
-                                       0,  0,   0, 0,   0,  0};
-  // clang-format on
-
-  ASSERT_EQ(dst_buffer.size(), expected_dst.size());
-
-  IREE_EXPECT_OK(Pad::Execute<uint16_t>(
-      src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
-      dst_shape, edge_padding_low, edge_padding_high, interior_padding));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(ReduceSum, Scalar) {
-  Shape src_shape = {5};
-  int32_t dimension = 0;
-  Shape dst_shape = {1};
-  std::vector<float> src_buffer = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
-  std::vector<float> init_buffer = {0.0f};
-  std::vector<float> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
-  std::vector<float> expected_dst = {5.0f};
-
-  IREE_EXPECT_OK(ReduceSum::Execute<float>(src_buffer, init_buffer,
-                                           absl::MakeSpan(dst_buffer),
-                                           dimension, src_shape, dst_shape));
-
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
-  }
-}
-
-TEST(ReduceMin, TwoDimensionsToOne) {
-  Shape src_shape = {3, 3};
-  int32_t dimension = 0;
-  Shape dst_shape = {3};
-  std::vector<float> src_buffer =
-      MakeIota<float>(GetShapeElementCount(src_shape));
-  std::vector<float> init_buffer = {std::numeric_limits<float>::max()};
-  std::vector<float> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
-  std::vector<float> expected_dst = {1.0f, 2.0f, 3.0f};
-
-  IREE_EXPECT_OK(ReduceMin::Execute<float>(src_buffer, init_buffer,
-                                           absl::MakeSpan(dst_buffer),
-                                           dimension, src_shape, dst_shape));
-
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
-  }
-}
-
-TEST(PoolingMax, NoOverlapping) {
-  Shape src_shape = {1, 4, 6, 1};
-  Shape dst_shape = {1, 2, 2, 1};
-  Shape window_sizes = {1, 2, 3, 1};
-  Shape strides = {1, 2, 3, 1};
-  Shape pad_low = {0, 0, 0, 0};
-  std::vector<int> src_buffer = MakeIota<int>(GetShapeElementCount(src_shape));
-  std::vector<int> init_buffer(1, 0);
-  std::vector<int> dst_buffer(GetShapeElementCount(dst_shape), 0);
-  std::vector<int> expected_dst = {9, 12, 21, 24};
-
-  IREE_EXPECT_OK(PoolingMax::Execute<int>(
-      src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape,
-      window_sizes, strides, pad_low));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(PoolingMin, Padding) {
-  // Padded input:
-  // 100 100 100 100
-  // 100   1   2   3
-  // 100   4   5   6
-  Shape src_shape = {2, 3};
-  Shape dst_shape = {2, 3};
-  Shape window_sizes = {2, 2};
-  Shape strides = {1, 1};
-  Shape pad_low = {1, 1};
-  std::vector<int> src_buffer = MakeIota<int>(GetShapeElementCount(src_shape));
-  std::vector<int> init_buffer(1, 100);
-  std::vector<int> dst_buffer(GetShapeElementCount(dst_shape), 0);
-  std::vector<int> expected_dst = {1, 1, 2, 1, 1, 2};
-
-  IREE_EXPECT_OK(PoolingMin::Execute<int>(
-      src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape,
-      window_sizes, strides, pad_low));
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(PoolingSum, Overlapping) {
-  Shape src_shape = {3, 4};
-  Shape dst_shape = {2, 2};
-  Shape window_sizes = {2, 3};
-  Shape strides = {1, 1};
-  Shape pad_low = {0, 0};
-  std::vector<float> src_buffer =
-      MakeIota<float>(GetShapeElementCount(src_shape));
-  std::vector<float> init_buffer(1, 0.0f);
-  std::vector<float> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
-  std::vector<float> expected_dst = {24, 30, 48, 54};
-
-  IREE_EXPECT_OK(PoolingSum::Execute<float>(
-      src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape,
-      window_sizes, strides, pad_low));
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
-  }
-}
-
-TEST(Conv2d, NoDilation) {
-  Shape input_shape = {4, 5, 2};
-  Shape filter_shape = {3, 2, 2, 1};
-  Shape dst_shape = {2, 4, 1};
-  Shape strides = {1, 1};
-  Shape pad_h = {0, 0};
-  Shape pad_w = {0, 0};
-  Shape lhs_dilation = {1, 1};
-  Shape rhs_dilation = {1, 1};
-  std::vector<float> input_buffer(GetShapeElementCount(input_shape));
-  std::vector<float> filter_buffer(GetShapeElementCount(filter_shape));
-  std::vector<float> expected_dst = {1310, 1466, 1622, 1778,
-                                     2090, 2246, 2402, 2558};
-  for (size_t i = 0; i < GetShapeElementCount(input_shape); ++i) {
-    input_buffer[i] = static_cast<float>(i + 1);
-    if (i < GetShapeElementCount(filter_shape)) {
-      filter_buffer[i] = static_cast<float>(i + 1);
-    }
-  }
-  std::vector<float> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
-
-  IREE_EXPECT_OK(Conv2D::Execute<float>(
-      input_buffer, input_shape, filter_buffer, filter_shape,
-      absl::MakeSpan(dst_buffer), dst_shape, strides, pad_h, pad_w,
-      lhs_dilation, rhs_dilation, 1));
-
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
-  }
-}
-
-TEST(Conv2d, DepthwiseConv) {
-  Shape input_shape = {4, 5, 2};
-  Shape filter_shape = {3, 2, 2, 2};
-  Shape dst_shape = {2, 4, 4};
-  Shape strides = {1, 1};
-  Shape pad_h = {0, 0};
-  Shape pad_w = {0, 0};
-  Shape lhs_dilation = {1, 1};
-  Shape rhs_dilation = {1, 1};
-  std::vector<float> input_buffer(GetShapeElementCount(input_shape));
-  std::vector<float> filter_buffer(GetShapeElementCount(filter_shape));
-  std::vector<float> expected_dst = {
-      1124, 1196, 1346, 1424, 1256, 1340, 1502, 1592, 1388, 1484, 1658,
-      1760, 1520, 1628, 1814, 1928, 1784, 1916, 2126, 2264, 1916, 2060,
-      2282, 2432, 2048, 2204, 2438, 2600, 2180, 2348, 2594, 2768};
-  for (size_t i = 0; i < GetShapeElementCount(input_shape); ++i) {
-    input_buffer[i] = i + 1;
-    if (i < GetShapeElementCount(filter_shape)) {
-      filter_buffer[i] = i + 1;
-    }
-  }
-  std::vector<float> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
-
-  IREE_EXPECT_OK(Conv2D::Execute<float>(
-      input_buffer, input_shape, filter_buffer, filter_shape,
-      absl::MakeSpan(dst_buffer), dst_shape, strides, pad_h, pad_w,
-      lhs_dilation, rhs_dilation, 2));
-
-  for (size_t i = 0; i < dst_buffer.size(); ++i) {
-    EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
-  }
-}
-
-TEST(Transpose, 2Dimen) {
-  Shape src_shape = {2, 3};
-  Shape dst_shape = {3, 2};
-  std::vector<int32_t> perm = {1, 0};
-  // clang-format off
-  std::vector<uint16_t> src_buffer = {1,  2,  3,
-                                      4,  5,  6};
-  std::vector<uint16_t> expected_dst = {1,  4,
-                                        2,  5,
-                                        3,  6};
-  // clang-format on
-  std::vector<uint16_t> dst_buffer(GetShapeElementCount(dst_shape), UINT16_MAX);
-
-  IREE_EXPECT_OK(Transpose::Execute<uint16_t>(
-      src_buffer, absl::Span<uint16_t>(dst_buffer), src_shape, perm));
-
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Transpose, 3Dimen) {
-  Shape src_shape = {2, 2, 3};
-  Shape dst_shape = {2, 3, 2};
-  std::vector<int32_t> perm = {0, 2, 1};
-  // clang-format off
-  std::vector<uint16_t> src_buffer = { 1,  2,  3,
-                                       4,  5,  6,
-                                       7,  8,  9,
-                                      10, 11, 12};
-  std::vector<uint16_t> expected_dst = {1,  4,
-                                        2,  5,
-                                        3,  6,
-                                        7, 10,
-                                        8, 11,
-                                        9, 12};
-  // clang-format on
-  std::vector<uint16_t> dst_buffer(GetShapeElementCount(dst_shape), UINT16_MAX);
-
-  IREE_EXPECT_OK(Transpose::Execute<uint16_t>(
-      src_buffer, absl::Span<uint16_t>(dst_buffer), src_shape, perm));
-
-  EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-}  // namespace
-}  // namespace kernels
-}  // namespace vmla
-}  // namespace hal
-}  // namespace iree
diff --git a/iree/modules/vmla/op_module.cc b/iree/modules/vmla/op_module.cc
deleted file mode 100644
index 02a9e4b..0000000
--- a/iree/modules/vmla/op_module.cc
+++ /dev/null
@@ -1,1142 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/modules/vmla/op_module.h"
-
-#include <cstdint>
-
-#include "absl/types/span.h"
-#include "iree/base/tracing.h"
-#include "iree/modules/vmla/op_kernels.h"
-#include "iree/vm/module_abi_packing.h"
-
-//===----------------------------------------------------------------------===//
-// Type registration
-//===----------------------------------------------------------------------===//
-
-static iree_vm_ref_type_descriptor_t Buffer_descriptor = {0};
-static iree_vm_ref_type_descriptor_t Interface_descriptor = {0};
-
-IREE_VM_DEFINE_TYPE_ADAPTERS(Buffer, iree::hal::vmla::Buffer);
-IREE_VM_DEFINE_TYPE_ADAPTERS(Interface, iree::hal::vmla::Interface);
-
-#define IREE_VMLA_REGISTER_CC_TYPE(type, name, descriptor)     \
-  descriptor.type_name = iree_make_cstring_view(name);         \
-  descriptor.offsetof_counter = type::offsetof_counter();      \
-  descriptor.destroy = type::DirectDestroy;                    \
-  IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor), \
-                       "failed to register type %s", name);
-
-namespace iree {
-namespace hal {
-namespace vmla {
-
-Status ModuleRegisterTypes() {
-  static bool has_registered = false;
-  if (has_registered) return OkStatus();
-
-  IREE_VMLA_REGISTER_CC_TYPE(Buffer, "vmla.buffer", Buffer_descriptor);
-  IREE_VMLA_REGISTER_CC_TYPE(Interface, "vmla.interface", Interface_descriptor);
-
-  has_registered = true;
-  return OkStatus();
-}
-
-//===----------------------------------------------------------------------===//
-// API type implementations
-//===----------------------------------------------------------------------===//
-
-// static
-StatusOr<vm::ref<Buffer>> Buffer::Allocate(size_t byte_length,
-                                           iree_allocator_t allocator) {
-  void* data = nullptr;
-  IREE_RETURN_IF_ERROR(iree_allocator_malloc(allocator, byte_length, &data),
-                       "failed to allocate buffer of size %zu", byte_length);
-
-  auto buffer = vm::assign_ref(new Buffer());
-  buffer->data_ = data;
-  buffer->data_length_ = byte_length;
-  buffer->allocator_ = allocator;
-  return std::move(buffer);
-}
-
-// static
-StatusOr<vm::ref<Buffer>> Buffer::Wrap(const void* data, size_t data_length,
-                                       iree_allocator_t allocator) {
-  auto buffer = vm::assign_ref(new Buffer());
-  buffer->data_ = const_cast<void*>(data);
-  buffer->data_length_ = data_length;
-  buffer->allocator_ = allocator;
-  return std::move(buffer);
-}
-
-// static
-StatusOr<vm::ref<Buffer>> Buffer::WrapMutable(void* data, size_t data_length,
-                                              iree_allocator_t allocator) {
-  auto buffer = vm::assign_ref(new Buffer());
-  buffer->data_ = data;
-  buffer->data_length_ = data_length;
-  buffer->allocator_ = allocator;
-  return std::move(buffer);
-}
-
-Buffer::~Buffer() {
-  if (!parent_) {
-    iree_allocator_free(allocator_, data_);
-    data_ = nullptr;
-  }
-  parent_.reset();
-}
-
-Status Buffer::MakeRange(iree_vmla_size_t byte_offset,
-                         iree_vmla_size_t byte_length,
-                         absl::Span<uint8_t>* out_range) const {
-  if (byte_length == kVMLAWholeBuffer) {
-    byte_length = size() - byte_offset;
-  }
-  if (byte_offset > size()) {
-    return iree_make_status(
-        IREE_STATUS_OUT_OF_RANGE,
-        "attempted to access an address off the end of the valid buffer range "
-        "(offset=%u, length=%u, byte_length=%zu)",
-        byte_offset, byte_length, size());
-  }
-  size_t end = byte_offset + byte_length - 1;
-  if (end >= size()) {
-    return iree_make_status(
-        IREE_STATUS_OUT_OF_RANGE,
-        "attempted to access an address off the end of the valid buffer range "
-        "(offset=%u, length=%u, end=%zu, byte_length=%zu)",
-        byte_offset, byte_length, end, size());
-  }
-  uint8_t* data = reinterpret_cast<uint8_t*>(data_) + byte_offset;
-  size_t data_length = byte_length;
-  *out_range = absl::MakeSpan(data, data_length);
-  return OkStatus();
-}
-
-constexpr int Interface::kMaxConstants;
-constexpr int Interface::kMaxSets;
-constexpr int Interface::kMaxBindings;
-
-void Interface::Reset() {
-  for (size_t i = 0; i < bindings_.size(); ++i) {
-    for (size_t j = 0; j < bindings_[i].size(); ++j) {
-      bindings_[i][j] = {};
-    }
-  }
-}
-
-StatusOr<uint32_t> Interface::GetConstant(uint32_t offset) const {
-  if (offset >= kMaxConstants) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "invalid constant offset=%u", offset);
-  }
-  return constants_[offset];
-}
-
-Status Interface::SetConstants(const uint32_t* values,
-                               iree_host_size_t value_count) {
-  if (value_count > kMaxConstants) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "constant value overflow; have %zu but max is %d",
-                            value_count, kMaxConstants);
-  }
-  for (size_t i = 0; i < value_count; ++i) {
-    constants_[i] = values[i];
-  }
-  return OkStatus();
-}
-
-StatusOr<const Interface::Binding> Interface::GetBinding(
-    uint32_t set, uint32_t binding) const {
-  if (set >= kMaxSets || binding >= kMaxBindings) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "invalid binding set=%u, binding=%u", set, binding);
-  }
-  return bindings_[set][binding];
-}
-
-Status Interface::SetBinding(uint32_t set, uint32_t binding, Binding value) {
-  if (set >= kMaxSets || binding >= kMaxBindings) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "invalid binding set=%u, binding=%u", set, binding);
-  }
-  bindings_[set][binding] = std::move(value);
-  return OkStatus();
-}
-
-//===----------------------------------------------------------------------===//
-// Module state and method implementation
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-// Per-executable VMLA module state.
-// This provides the exported kernel functions to the VM and is instantiated
-// one or more times per executable used within a device. Any state here can be
-// treated as workgroup-local memory.
-//
-// Thread-compatible.
-class VMLAModuleState final {
- public:
-  VMLAModuleState(iree_allocator_t allocator,
-                  kernels::RuntimeState* kernel_state)
-      : allocator_(allocator), kernel_state_(kernel_state) {}
-
-  ~VMLAModuleState() = default;
-
-  //===--------------------------------------------------------------------===//
-  // vmla.interface.*
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<uint32_t> InterfaceConst(vm::ref<Interface> interface,
-                                    uint32_t offset) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::InterfaceConst");
-    return interface->GetConstant(offset);
-  }
-
-  StatusOr<vm::ref<Buffer>> InterfaceBinding(vm::ref<Interface> interface,
-                                             uint32_t set, uint32_t binding) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::InterfaceBinding");
-    IREE_ASSIGN_OR_RETURN(const auto& value,
-                          interface->GetBinding(set, binding));
-    return vm::retain_ref(value.buffer);
-  }
-
-  //===--------------------------------------------------------------------===//
-  // vmla.buffer.*
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<Buffer>> BufferConst(vm::ref<iree_vm_buffer_t> value) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::BufferConst");
-    iree_allocator_t external_allocator = {0};
-    external_allocator.self = vm::retain_ref(value).release();
-    external_allocator.free = +[](void* self, void* ptr) {
-      vm::assign_ref(reinterpret_cast<iree_vm_buffer_t*>(self)).reset();
-    };
-    return Buffer::Wrap(value->data.data, value->data.data_length,
-                        external_allocator);
-  }
-
-  StatusOr<vm::ref<Buffer>> BufferAlloc(iree_vmla_size_t byte_length) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::BufferAlloc");
-    return Buffer::Allocate(byte_length, allocator_);
-  }
-
-  StatusOr<vm::ref<Buffer>> BufferClone(const vm::ref<Buffer>& src) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::BufferClone");
-    IREE_ASSIGN_OR_RETURN(auto dst, Buffer::Allocate(src->size(), allocator_));
-    std::memcpy(dst->data(), src->data(), dst->size());
-    return std::move(dst);
-  }
-
-  StatusOr<iree_vmla_size_t> BufferByteLength(const vm::ref<Buffer>& buffer) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::BufferByteLength");
-    return buffer->size();
-  }
-
-  StatusOr<vm::ref<Buffer>> BufferView(const vm::ref<Buffer>& src,
-                                       iree_vmla_size_t byte_offset,
-                                       iree_vmla_size_t byte_length) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::BufferView");
-
-    if (byte_length == kVMLAWholeBuffer) {
-      byte_length = src->size() - byte_offset;
-    }
-
-    if (byte_offset == 0 && byte_length == src->size()) {
-      // Asking for the same buffer.
-      return vm::retain_ref(src);
-    } else if (byte_offset > src->size()) {
-      return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
-                              "attempted to access an address off the end of "
-                              "the valid buffer range "
-                              "(offset=%u, length=%u, byte_length=%zu)",
-                              byte_offset, byte_length, src->size());
-    }
-    size_t end = byte_offset + byte_length - 1;
-    if (end >= src->size()) {
-      return iree_make_status(
-          IREE_STATUS_OUT_OF_RANGE,
-          "attempted to access an address off the end of the valid buffer "
-          "range (offset=%u, length=%u, end=%zu, byte_length=%zu)",
-          byte_offset, byte_length, end, src->size());
-    }
-    uint8_t* data = reinterpret_cast<uint8_t*>(src->data()) + byte_offset;
-    size_t data_length = byte_length;
-
-    iree_allocator_t external_allocator = {0};
-    external_allocator.self = vm::retain_ref(src).release();
-    external_allocator.free = +[](void* self, void* ptr) {
-      vm::assign_ref(reinterpret_cast<Buffer*>(self)).reset();
-    };
-    return Buffer::Wrap(data, data_length, external_allocator);
-  }
-
-  Status BufferCopy(const vm::ref<Buffer>& src,
-                    iree_vmla_size_t src_byte_offset,
-                    const vm::ref<Buffer>& dst,
-                    iree_vmla_size_t dst_byte_offset,
-                    iree_vmla_size_t byte_length) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::BufferCopy");
-    if (byte_length == kVMLAWholeBuffer) {
-      byte_length = src->size() - src_byte_offset;
-    }
-    IREE_ASSIGN_OR_RETURN(auto src_bytes, src->RangeAs<const uint8_t>(
-                                              src_byte_offset, byte_length));
-    IREE_ASSIGN_OR_RETURN(auto dst_bytes,
-                          dst->RangeAs<uint8_t>(dst_byte_offset, byte_length));
-    std::memcpy(dst_bytes.data(), src_bytes.data(), dst_bytes.size());
-    return OkStatus();
-  }
-
-  Status BufferFill(const vm::ref<Buffer>& value, const vm::ref<Buffer>& dst) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::BufferFill");
-    if (value->size() == 1) {
-      // Fast-path for single-byte memset values.
-      std::memset(dst->data(), value->As<uint8_t>()[0], dst->size());
-      return OkStatus();
-    } else if (dst->size() % value->size() != 0) {
-      return iree_make_status(
-          IREE_STATUS_INVALID_ARGUMENT,
-          "fill value length (%zu) must divide evenly into buffer length (%zu)",
-          value->size(), dst->size());
-    }
-    auto value_bytes = value->As<uint8_t>();
-    auto dst_bytes = dst->As<uint8_t>();
-    for (size_t i = 0; i < dst_bytes.size(); i += value_bytes.size()) {
-      std::memcpy(dst_bytes.data() + i, value_bytes.data(), value_bytes.size());
-    }
-    return OkStatus();
-  }
-
-  StatusOr<int32_t> BufferLoadI32(const vm::ref<Buffer>& src,
-                                  iree_vmla_size_t byte_offset) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::BufferLoadI32");
-    IREE_ASSIGN_OR_RETURN(auto data,
-                          src->RangeAs<int32_t>(byte_offset, sizeof(int32_t)));
-    return data[0];
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Common helpers for defining ops
-  //===--------------------------------------------------------------------===//
-
-#define IREE_VMLA_NONARY_OP(name, kernel, type)    \
-  Status name(const vm::ref<Buffer>& dst) {        \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);  \
-    return kernel::Execute<type>(dst->As<type>()); \
-  }
-
-#define IREE_VMLA_UNARY_OP(name, kernel, type)                          \
-  Status name(const vm::ref<Buffer>& src, const vm::ref<Buffer>& dst) { \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                       \
-    return kernel::Execute<type>(src->As<type>(), dst->As<type>());     \
-  }
-
-#define IREE_VMLA_BINARY_OP(name, kernel, type)                       \
-  Status name(const vm::ref<Buffer>& lhs, const vm::ref<Buffer>& rhs, \
-              const vm::ref<Buffer>& dst) {                           \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                     \
-    return kernel::Execute<type>(lhs->As<type>(), rhs->As<type>(),    \
-                                 dst->As<type>());                    \
-  }
-
-#define IREE_VMLA_BINARY_BROADCAST_OP(name, kernel, type)                 \
-  Status name(const vm::ref<Buffer>& lhs, int32_t rhs,                    \
-              const vm::ref<Buffer>& dst) {                               \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                         \
-    return kernel::Execute<type>(lhs->As<type>(), static_cast<type>(rhs), \
-                                 dst->As<type>());                        \
-  }
-
-#define IREE_VMLA_TERNARY_OP(name, kernel, type)                              \
-  Status name(const vm::ref<Buffer>& a, const vm::ref<Buffer>& b,             \
-              const vm::ref<Buffer>& c, const vm::ref<Buffer>& dst) {         \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                             \
-    return kernel::Execute<type>(a->As<type>(), b->As<type>(), c->As<type>(), \
-                                 dst->As<type>());                            \
-  }
-
-  //===--------------------------------------------------------------------===//
-  // VMLA Ops: comparison
-  //===--------------------------------------------------------------------===//
-
-  enum class CmpPredicate : uint32_t {
-    kEQ = 0,
-    kNE = 1,
-    kLT = 2,
-    kLE = 3,
-    kGT = 4,
-    kGE = 5,
-  };
-
-#define IREE_VMLA_COMPARE_OP(name, type)                                \
-  Status name(int32_t predicate, const vm::ref<Buffer>& lhs,            \
-              const vm::ref<Buffer>& rhs, const vm::ref<Buffer>& dst) { \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                       \
-    switch (static_cast<CmpPredicate>(predicate)) {                     \
-      case CmpPredicate::kEQ:                                           \
-        return kernels::CompareEQ::Execute<type>(                       \
-            lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>());      \
-      case CmpPredicate::kNE:                                           \
-        return kernels::CompareNE::Execute<type>(                       \
-            lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>());      \
-      case CmpPredicate::kLT:                                           \
-        return kernels::CompareLT::Execute<type>(                       \
-            lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>());      \
-      case CmpPredicate::kLE:                                           \
-        return kernels::CompareLE::Execute<type>(                       \
-            lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>());      \
-      case CmpPredicate::kGT:                                           \
-        return kernels::CompareGT::Execute<type>(                       \
-            lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>());      \
-      case CmpPredicate::kGE:                                           \
-        return kernels::CompareGE::Execute<type>(                       \
-            lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>());      \
-      default:                                                          \
-        return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,           \
-                                "unsupported predicate %d", predicate); \
-    }                                                                   \
-  }
-  IREE_VMLA_COMPARE_OP(CmpI8, int8_t);
-  IREE_VMLA_COMPARE_OP(CmpI16, int16_t);
-  IREE_VMLA_COMPARE_OP(CmpI32, int32_t);
-  IREE_VMLA_COMPARE_OP(CmpF32, float);
-
-#define IREE_VMLA_SELECT_OP(name, type)                                     \
-  Status name(const vm::ref<Buffer>& cond, const vm::ref<Buffer>& lhs,      \
-              const vm::ref<Buffer>& rhs, const vm::ref<Buffer>& dst) {     \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                           \
-    return kernels::Select::Execute<type>(cond->As<uint8_t>(),              \
-                                          lhs->As<type>(), rhs->As<type>(), \
-                                          dst->As<type>());                 \
-  }
-  IREE_VMLA_SELECT_OP(SelectX8, uint8_t);
-  IREE_VMLA_SELECT_OP(SelectX16, uint16_t);
-  IREE_VMLA_SELECT_OP(SelectX32, uint32_t);
-
-#define IREE_VMLA_UNARY_PREDICATE_OP(name, kernel, type)                \
-  Status name(const vm::ref<Buffer>& src, const vm::ref<Buffer>& dst) { \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                       \
-    return kernel::Execute<type>(src->As<type>(), dst->As<bool>());     \
-  }
-  IREE_VMLA_UNARY_PREDICATE_OP(FiniteF32, kernels::Finite, float);
-
-  //===--------------------------------------------------------------------===//
-  // VMLA Ops: shape/structure
-  //===--------------------------------------------------------------------===//
-
-#define IREE_VMLA_COPY_OP(name, size)                                     \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,    \
-              absl::Span<const int32_t> src_indices,                      \
-              const vm::ref<Buffer>& dst, iree_vmla_shape_t dst_shape,    \
-              absl::Span<const int32_t> dst_indices,                      \
-              absl::Span<const int32_t> lengths) {                        \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                         \
-    return kernels::Copy::Execute<size>(src->As<uint8_t>(), src_shape,    \
-                                        src_indices, dst->As<uint8_t>(),  \
-                                        dst_shape, dst_indices, lengths); \
-  }
-  IREE_VMLA_COPY_OP(CopyX8, sizeof(uint8_t));
-  IREE_VMLA_COPY_OP(CopyX16, sizeof(uint16_t));
-  IREE_VMLA_COPY_OP(CopyX32, sizeof(uint32_t));
-
-#define IREE_VMLA_TRANSPOSE_OP(name, type)                                     \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,         \
-              absl::Span<const int32_t> permutation,                           \
-              const vm::ref<Buffer>& dst, iree_vmla_shape_t dst_shape) {       \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                              \
-    return kernels::Transpose::Execute<type>(src->As<type>(), dst->As<type>(), \
-                                             src_shape, permutation);          \
-  }
-  IREE_VMLA_TRANSPOSE_OP(TransposeX8, uint8_t);
-  IREE_VMLA_TRANSPOSE_OP(TransposeX16, uint16_t);
-  IREE_VMLA_TRANSPOSE_OP(TransposeX32, uint32_t);
-
-#define IREE_VMLA_REVERSE_OP(name, type)                                     \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,       \
-              absl::Span<const int32_t> dims, const vm::ref<Buffer>& dst,    \
-              iree_vmla_shape_t dst_shape) {                                 \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                            \
-    return kernels::Reverse::Execute<type>(src->As<type>(), dst->As<type>(), \
-                                           src_shape, dims);                 \
-  }
-  IREE_VMLA_REVERSE_OP(ReverseX8, uint8_t);
-  IREE_VMLA_REVERSE_OP(ReverseX16, uint16_t);
-  IREE_VMLA_REVERSE_OP(ReverseX32, uint32_t);
-
-#define IREE_VMLA_PAD_OP(name, type)                                       \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,     \
-              const vm::ref<Buffer>& value, iree_vmla_shape_t value_shape, \
-              const vm::ref<Buffer>& dst, iree_vmla_shape_t dst_shape,     \
-              absl::Span<const int32_t> edge_padding_low,                  \
-              absl::Span<const int32_t> edge_padding_high,                 \
-              absl::Span<const int32_t> interior_padding) {                \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                          \
-    return kernels::Pad::Execute<type>(                                    \
-        src->As<type>(), value->As<type>(), dst->As<type>(), src_shape,    \
-        dst_shape, edge_padding_low, edge_padding_high, interior_padding); \
-  }
-  IREE_VMLA_PAD_OP(PadX8, uint8_t);
-  IREE_VMLA_PAD_OP(PadX16, uint16_t);
-  IREE_VMLA_PAD_OP(PadX32, uint32_t);
-
-#define IREE_VMLA_GATHER_OP(name, type)                                        \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,         \
-              const vm::ref<Buffer>& indices, iree_vmla_shape_t indices_shape, \
-              const vm::ref<Buffer>& dst, iree_vmla_shape_t dst_shape,         \
-              const int32_t dim, const int32_t batch_dims) {                   \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                              \
-    return kernels::Gather::Execute<type>(                                     \
-        src->As<type>(), indices->As<int>(), dst->As<type>(), src_shape,       \
-        indices_shape, dst_shape, dim, batch_dims);                            \
-  }
-  IREE_VMLA_GATHER_OP(GatherX8, uint8_t);
-  IREE_VMLA_GATHER_OP(GatherX16, uint16_t);
-  IREE_VMLA_GATHER_OP(GatherX32, uint32_t);
-
-#define IREE_VMLA_SCATTER_OP(name, type)                                       \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,         \
-              const vm::ref<Buffer>& indices, iree_vmla_shape_t indices_shape, \
-              const vm::ref<Buffer>& dst, iree_vmla_shape_t dst_shape) {       \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                              \
-    return kernels::Scatter::Execute<type>(                                    \
-        src->As<type>(), indices->As<int>(), dst->As<type>(), src_shape,       \
-        indices_shape, dst_shape);                                             \
-  }
-  IREE_VMLA_SCATTER_OP(ScatterX8, uint8_t);
-  IREE_VMLA_SCATTER_OP(ScatterX16, uint16_t);
-  IREE_VMLA_SCATTER_OP(ScatterX32, uint32_t);
-
-#define IREE_VMLA_BROADCAST_OP(name, type)                               \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,   \
-              const vm::ref<Buffer>& dst, iree_vmla_shape_t dst_shape) { \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                        \
-    return kernels::Broadcast::Execute<type>(src->As<type>(),            \
-                                             dst->As<type>());           \
-  }
-  IREE_VMLA_BROADCAST_OP(BroadcastX8, uint8_t);
-  IREE_VMLA_BROADCAST_OP(BroadcastX16, uint16_t);
-  IREE_VMLA_BROADCAST_OP(BroadcastX32, uint32_t);
-
-  IREE_VMLA_NONARY_OP(IotaI8, kernels::Iota, int8_t);
-  IREE_VMLA_NONARY_OP(IotaI16, kernels::Iota, int16_t);
-  IREE_VMLA_NONARY_OP(IotaI32, kernels::Iota, int32_t);
-  IREE_VMLA_NONARY_OP(IotaF32, kernels::Iota, float_t);
-
-#define IREE_VMLA_TILE_OP(name, type)                                     \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,    \
-              const vm::ref<Buffer>& dst, iree_vmla_shape_t dst_shape) {  \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                         \
-    return kernels::Tile::Execute<type>(src->As<type>(), dst->As<type>(), \
-                                        src_shape, dst_shape);            \
-  }
-  IREE_VMLA_TILE_OP(TileX8, uint8_t);
-  IREE_VMLA_TILE_OP(TileX16, uint16_t);
-  IREE_VMLA_TILE_OP(TileX32, uint32_t);
-
-  //===--------------------------------------------------------------------===//
-  // VMLA Ops: bit manipulation
-  //===--------------------------------------------------------------------===//
-
-  IREE_VMLA_UNARY_OP(NotX8, kernels::Not, uint8_t);
-  IREE_VMLA_UNARY_OP(NotX16, kernels::Not, uint16_t);
-  IREE_VMLA_UNARY_OP(NotX32, kernels::Not, uint32_t);
-  IREE_VMLA_BINARY_OP(AndX8, kernels::And, uint8_t);
-  IREE_VMLA_BINARY_OP(AndX16, kernels::And, uint16_t);
-  IREE_VMLA_BINARY_OP(AndX32, kernels::And, uint32_t);
-  IREE_VMLA_BINARY_BROADCAST_OP(AndBroadcastX8, kernels::And, uint8_t);
-  IREE_VMLA_BINARY_BROADCAST_OP(AndBroadcastX16, kernels::And, uint16_t);
-  IREE_VMLA_BINARY_BROADCAST_OP(AndBroadcastX32, kernels::And, uint32_t);
-  IREE_VMLA_BINARY_OP(OrX8, kernels::Or, uint8_t);
-  IREE_VMLA_BINARY_OP(OrX16, kernels::Or, uint16_t);
-  IREE_VMLA_BINARY_OP(OrX32, kernels::Or, uint32_t);
-  IREE_VMLA_BINARY_OP(XorX8, kernels::Xor, uint8_t);
-  IREE_VMLA_BINARY_OP(XorX16, kernels::Xor, uint16_t);
-  IREE_VMLA_BINARY_OP(XorX32, kernels::Xor, uint32_t);
-  IREE_VMLA_BINARY_BROADCAST_OP(XorBroadcastX8, kernels::Xor, uint8_t);
-  IREE_VMLA_BINARY_BROADCAST_OP(XorBroadcastX16, kernels::Xor, uint16_t);
-  IREE_VMLA_BINARY_BROADCAST_OP(XorBroadcastX32, kernels::Xor, uint32_t);
-  IREE_VMLA_BINARY_OP(ShlX8, kernels::ShiftLeft, uint8_t);
-  IREE_VMLA_BINARY_OP(ShlX16, kernels::ShiftLeft, uint16_t);
-  IREE_VMLA_BINARY_OP(ShlX32, kernels::ShiftLeft, uint32_t);
-  IREE_VMLA_BINARY_OP(ShrU8, kernels::ShiftRight, uint8_t);
-  IREE_VMLA_BINARY_OP(ShrU16, kernels::ShiftRight, uint16_t);
-  IREE_VMLA_BINARY_OP(ShrU32, kernels::ShiftRight, uint32_t);
-  IREE_VMLA_BINARY_OP(ShrI8, kernels::ShiftRight, int8_t);
-  IREE_VMLA_BINARY_OP(ShrI16, kernels::ShiftRight, int16_t);
-  IREE_VMLA_BINARY_OP(ShrI32, kernels::ShiftRight, int32_t);
-
-  //===--------------------------------------------------------------------===//
-  // VMLA Ops: arithmetic
-  //===--------------------------------------------------------------------===//
-
-  IREE_VMLA_BINARY_OP(AddI8, kernels::Add, int8_t);
-  IREE_VMLA_BINARY_OP(AddI16, kernels::Add, int16_t);
-  IREE_VMLA_BINARY_OP(AddI32, kernels::Add, int32_t);
-  IREE_VMLA_BINARY_OP(AddF32, kernels::Add, float);
-  IREE_VMLA_BINARY_OP(SubI8, kernels::Sub, int8_t);
-  IREE_VMLA_BINARY_OP(SubI16, kernels::Sub, int16_t);
-  IREE_VMLA_BINARY_OP(SubI32, kernels::Sub, int32_t);
-  IREE_VMLA_BINARY_OP(SubF32, kernels::Sub, float);
-  IREE_VMLA_UNARY_OP(AbsI8, kernels::Abs, int8_t);
-  IREE_VMLA_UNARY_OP(AbsI16, kernels::Abs, int16_t);
-  IREE_VMLA_UNARY_OP(AbsI32, kernels::Abs, int32_t);
-  IREE_VMLA_UNARY_OP(AbsF32, kernels::Abs, float);
-  IREE_VMLA_UNARY_OP(NegI8, kernels::Neg, int8_t);
-  IREE_VMLA_UNARY_OP(NegI16, kernels::Neg, int16_t);
-  IREE_VMLA_UNARY_OP(NegI32, kernels::Neg, int32_t);
-  IREE_VMLA_UNARY_OP(NegF32, kernels::Neg, float);
-  IREE_VMLA_BINARY_OP(MulI8, kernels::Mul, int8_t);
-  IREE_VMLA_BINARY_OP(MulI16, kernels::Mul, int16_t);
-  IREE_VMLA_BINARY_OP(MulI32, kernels::Mul, int32_t);
-  IREE_VMLA_BINARY_OP(MulF32, kernels::Mul, float);
-  IREE_VMLA_BINARY_OP(DivI8, kernels::Div, int8_t);
-  IREE_VMLA_BINARY_OP(DivI16, kernels::Div, int16_t);
-  IREE_VMLA_BINARY_OP(DivI32, kernels::Div, int32_t);
-  IREE_VMLA_BINARY_OP(DivU8, kernels::Div, uint8_t);
-  IREE_VMLA_BINARY_OP(DivU16, kernels::Div, uint16_t);
-  IREE_VMLA_BINARY_OP(DivU32, kernels::Div, uint32_t);
-  IREE_VMLA_BINARY_OP(DivF32, kernels::Div, float);
-  IREE_VMLA_BINARY_OP(RemI8, kernels::Rem, int8_t);
-  IREE_VMLA_BINARY_OP(RemI16, kernels::Rem, int16_t);
-  IREE_VMLA_BINARY_OP(RemI32, kernels::Rem, int32_t);
-  IREE_VMLA_BINARY_OP(RemU8, kernels::Rem, uint8_t);
-  IREE_VMLA_BINARY_OP(RemU16, kernels::Rem, uint16_t);
-  IREE_VMLA_BINARY_OP(RemU32, kernels::Rem, uint32_t);
-  IREE_VMLA_BINARY_OP(RemF32, kernels::Rem, float);
-  IREE_VMLA_BINARY_OP(PowF32, kernels::Pow, float);
-  IREE_VMLA_UNARY_OP(ExpF32, kernels::Exp, float);
-  IREE_VMLA_UNARY_OP(LogF32, kernels::Log, float);
-  IREE_VMLA_UNARY_OP(RsqrtF32, kernels::Rsqrt, float);
-  IREE_VMLA_UNARY_OP(SqrtF32, kernels::Sqrt, float);
-  IREE_VMLA_UNARY_OP(CosF32, kernels::Cos, float);
-  IREE_VMLA_UNARY_OP(SinF32, kernels::Sin, float);
-  IREE_VMLA_UNARY_OP(TanhF32, kernels::Tanh, float);
-  IREE_VMLA_BINARY_OP(Atan2F32, kernels::Atan2, float);
-
-  IREE_VMLA_BINARY_OP(MinI8, kernels::Min, int8_t);
-  IREE_VMLA_BINARY_OP(MinI16, kernels::Min, int16_t);
-  IREE_VMLA_BINARY_OP(MinI32, kernels::Min, int32_t);
-  IREE_VMLA_BINARY_OP(MinF32, kernels::Min, float);
-  IREE_VMLA_BINARY_OP(MaxI8, kernels::Max, int8_t);
-  IREE_VMLA_BINARY_OP(MaxI16, kernels::Max, int16_t);
-  IREE_VMLA_BINARY_OP(MaxI32, kernels::Max, int32_t);
-  IREE_VMLA_BINARY_OP(MaxF32, kernels::Max, float);
-  IREE_VMLA_TERNARY_OP(ClampI8, kernels::Clamp, int8_t);
-  IREE_VMLA_TERNARY_OP(ClampI16, kernels::Clamp, int16_t);
-  IREE_VMLA_TERNARY_OP(ClampI32, kernels::Clamp, int32_t);
-  IREE_VMLA_TERNARY_OP(ClampF32, kernels::Clamp, float);
-  IREE_VMLA_UNARY_OP(FloorF32, kernels::Floor, float);
-  IREE_VMLA_UNARY_OP(CeilF32, kernels::Ceil, float);
-  IREE_VMLA_UNARY_OP(RoundF32, kernels::Round, float);
-
-#define IREE_VMLA_SORT_OP(name, type)                                        \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,       \
-              const vm::ref<Buffer>& dst) {                                  \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                            \
-    return kernels::Sort::Execute<type>(src->As<type>(), dst->As<int32_t>(), \
-                                        src_shape);                          \
-  }
-
-  IREE_VMLA_SORT_OP(SortI8, int8_t);
-  IREE_VMLA_SORT_OP(SortI16, int16_t);
-  IREE_VMLA_SORT_OP(SortI32, int32_t);
-  IREE_VMLA_SORT_OP(SortF32, float);
-
-#define IREE_VMLA_COMPLEX_INPUT_FFT_OP(name, op)                            \
-  Status name(                                                              \
-      const vm::ref<Buffer>& real_src, iree_vmla_shape_t real_src_shape,    \
-      const vm::ref<Buffer>& imag_src, iree_vmla_shape_t imag_src_shape,    \
-      const vm::ref<Buffer>& real_dst, const vm::ref<Buffer>& imag_dst) {   \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                           \
-    return op::Execute<float>(real_src->As<float>(), imag_src->As<float>(), \
-                              real_dst->As<float>(), imag_dst->As<float>(), \
-                              real_src_shape, imag_src_shape);              \
-  }
-
-  IREE_VMLA_COMPLEX_INPUT_FFT_OP(FftF32, kernels::Fft);
-  IREE_VMLA_COMPLEX_INPUT_FFT_OP(IfftF32, kernels::Ifft);
-
-  Status RfftF32(const vm::ref<Buffer>& real_src,
-                 iree_vmla_shape_t real_src_shape,
-                 const vm::ref<Buffer>& real_dst,
-                 const vm::ref<Buffer>& imag_dst) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::RfftF32");
-    IREE_RETURN_IF_ERROR(kernels::Rfft::Execute<float>(
-        real_src->As<float>(), real_dst->As<float>(), imag_dst->As<float>(),
-        real_src_shape));
-    return OkStatus();
-  }
-
-  Status IrfftF32(const vm::ref<Buffer>& real_src,
-                  iree_vmla_shape_t real_src_shape,
-                  const vm::ref<Buffer>& imag_src,
-                  iree_vmla_shape_t imag_src_shape,
-                  const vm::ref<Buffer>& real_dst) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::IrfftF32");
-    IREE_RETURN_IF_ERROR(kernels::Irfft::Execute<float>(
-        real_src->As<float>(), imag_src->As<float>(), real_dst->As<float>(),
-        real_src_shape, imag_src_shape));
-    return OkStatus();
-  }
-
-  //===--------------------------------------------------------------------===//
-  // VMLA Ops: conversion
-  //===--------------------------------------------------------------------===//
-
-#define IREE_VMLA_CONVERSION_OP(name, src_type, dst_type)                      \
-  Status name(const vm::ref<Buffer>& src, const vm::ref<Buffer>& dst) {        \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                              \
-    return kernels::Convert::Execute<src_type, dst_type>(src->As<src_type>(),  \
-                                                         dst->As<dst_type>()); \
-  }
-  IREE_VMLA_CONVERSION_OP(ConvertI8I16, int8_t, int16_t);
-  IREE_VMLA_CONVERSION_OP(ConvertI8I32, int8_t, int32_t);
-  IREE_VMLA_CONVERSION_OP(ConvertI8F32, int8_t, float);
-  IREE_VMLA_CONVERSION_OP(ConvertI16I8, int16_t, int8_t);
-  IREE_VMLA_CONVERSION_OP(ConvertI16I32, int16_t, int32_t);
-  IREE_VMLA_CONVERSION_OP(ConvertI16F32, int16_t, float);
-  IREE_VMLA_CONVERSION_OP(ConvertI32I8, int32_t, int8_t);
-  IREE_VMLA_CONVERSION_OP(ConvertI32I16, int32_t, int16_t);
-  IREE_VMLA_CONVERSION_OP(ConvertI32F32, int32_t, float);
-  IREE_VMLA_CONVERSION_OP(ConvertF32I8, float, int8_t);
-  IREE_VMLA_CONVERSION_OP(ConvertF32I16, float, int16_t);
-  IREE_VMLA_CONVERSION_OP(ConvertF32I32, float, int32_t);
-
-  //===--------------------------------------------------------------------===//
-  // VMLA Ops: Convolution
-  //===--------------------------------------------------------------------===//
-
-  Status ConvF32F32F32(
-      const vm::ref<Buffer>& input, iree_vmla_shape_t input_shape,
-      const vm::ref<Buffer>& filter, iree_vmla_shape_t filter_shape,
-      const vm::ref<Buffer>& dst, iree_vmla_shape_t dst_shape,
-      absl::Span<const int32_t> window_strides,
-      absl::Span<const int32_t> padding, absl::Span<const int32_t> lhs_dilation,
-      absl::Span<const int32_t> rhs_dilation, const int32_t feature_group_count,
-      const int32_t batch_group_count) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::ConvF32F32F32");
-    if (input_shape.size() != 4 || filter_shape.size() != 4 ||
-        dst_shape.size() != 4) {
-      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                              "expecting 4-d tensors for Conv2D kernel");
-    }
-
-    const int32_t batch_size = input_shape[0];
-    const auto input_example_shape = input_shape.subspan(1, 3);
-    const auto output_example_shape = dst_shape.subspan(1, 3);
-    const auto filter_shape_4d = filter_shape.subspan(0, 4);
-    const auto pad_h = padding.subspan(0, 2);
-    const auto pad_w = padding.subspan(2, 2);
-    const auto window_strides_2d = window_strides.subspan(0, 2);
-
-    const float* raw_inputs_data = input->As<float>().data();
-    const float* raw_filter_data = filter->As<float>().data();
-    float* raw_dst_data = dst->As<float>().data();
-    auto filter_buffer = absl::MakeConstSpan(
-        raw_filter_data, kernels::GetElementCount(filter_shape_4d));
-
-    const size_t input_stride = kernels::GetElementCount(input_example_shape);
-    const size_t output_stride = kernels::GetElementCount(output_example_shape);
-
-    for (int i = 0; i < batch_size; ++i) {
-      auto input_example =
-          absl::MakeConstSpan(raw_inputs_data + i * input_stride, input_stride);
-      auto output_example =
-          absl::MakeSpan(raw_dst_data + i * output_stride, output_stride);
-      IREE_RETURN_IF_ERROR(kernels::Conv2D::Execute(
-          input_example, input_example_shape, filter_buffer, filter_shape_4d,
-          output_example, output_example_shape, window_strides_2d, pad_h, pad_w,
-          lhs_dilation.subspan(0, 2), rhs_dilation.subspan(0, 2),
-          feature_group_count));
-    }
-    return OkStatus();
-  }
-
-  //===--------------------------------------------------------------------===//
-  // VMLA Ops: GEMM/GEMV
-  //===--------------------------------------------------------------------===//
-
-  template <typename LhsEl, typename RhsEl, typename AccumEl, typename DstEl>
-  Status BatchMatMul(const vm::ref<Buffer>& lhs, iree_vmla_shape_t lhs_shape,
-                     const vm::ref<Buffer>& rhs, iree_vmla_shape_t rhs_shape,
-                     const vm::ref<Buffer>& dst, iree_vmla_shape_t dst_shape) {
-    IREE_TRACE_SCOPE0("VMLAModuleState::BatchMatMul");
-    // Compiler guarantees. Here for documentation purposes.
-    assert(lhs_shape.size() == 3 && rhs_shape.size() == 3 &&
-           dst_shape.size() == 3);
-    assert(lhs_shape[0] == rhs_shape[0] && rhs_shape[0] == dst_shape[0]);
-
-    iree_vmla_shape_t lhs_batch_element_shape = lhs_shape.subspan(1);
-    iree_vmla_shape_t rhs_batch_element_shape = rhs_shape.subspan(1);
-    iree_vmla_shape_t dst_batch_element_shape = dst_shape.subspan(1);
-    auto lhs_batch_element_shape2 = lhs_batch_element_shape.subspan(0, 2);
-    auto rhs_batch_element_shape2 = rhs_batch_element_shape.subspan(0, 2);
-    auto dst_batch_element_shape2 = dst_batch_element_shape.subspan(0, 2);
-    size_t lhs_batch_stride = kernels::GetElementCount(lhs_batch_element_shape);
-    size_t rhs_batch_stride = kernels::GetElementCount(rhs_batch_element_shape);
-    size_t dst_batch_stride = kernels::GetElementCount(dst_batch_element_shape);
-    LhsEl* lhs_batch_base = lhs->As<LhsEl>().data();
-    RhsEl* rhs_batch_base = rhs->As<RhsEl>().data();
-    DstEl* dst_batch_base = dst->As<DstEl>().data();
-    int32_t batch_dim = lhs_shape[0];
-    for (int i = 0; i < batch_dim; i++) {
-      kernels::MatMul::Buffers<LhsEl, RhsEl, AccumEl, DstEl> buffers;
-      buffers.lhs_buffer = absl::MakeSpan(lhs_batch_base + i * lhs_batch_stride,
-                                          lhs_batch_stride);
-      buffers.lhs_shape = lhs_batch_element_shape2;
-      buffers.rhs_buffer = absl::MakeSpan(rhs_batch_base + i * rhs_batch_stride,
-                                          rhs_batch_stride);
-      buffers.rhs_shape = rhs_batch_element_shape2;
-      buffers.dst_buffer = absl::MakeSpan(dst_batch_base + i * dst_batch_stride,
-                                          dst_batch_stride);
-      buffers.dst_shape = dst_batch_element_shape2;
-
-      IREE_RETURN_IF_ERROR(kernels::MatMul::Execute(
-          kernel_state_->mat_mul_state.get(), buffers));
-    }
-    return OkStatus();
-  }
-
-  //===--------------------------------------------------------------------===//
-  // VMLA Ops: reduction
-  //===--------------------------------------------------------------------===//
-
-#define IREE_VMLA_REDUCTION_OP(name, kernel, type)                       \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,   \
-              const vm::ref<Buffer>& init, iree_vmla_shape_t init_shape, \
-              int32_t dimension, const vm::ref<Buffer>& dst,             \
-              iree_vmla_shape_t dst_shape) {                             \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                        \
-    return kernel::Execute<type>(src->As<type>(), init->As<type>(),      \
-                                 dst->As<type>(), dimension, src_shape,  \
-                                 dst_shape);                             \
-  }
-  IREE_VMLA_REDUCTION_OP(ReduceSumI8, kernels::ReduceSum, int8_t);
-  IREE_VMLA_REDUCTION_OP(ReduceSumI16, kernels::ReduceSum, int16_t);
-  IREE_VMLA_REDUCTION_OP(ReduceSumI32, kernels::ReduceSum, int32_t);
-  IREE_VMLA_REDUCTION_OP(ReduceSumF32, kernels::ReduceSum, float);
-  IREE_VMLA_REDUCTION_OP(ReduceMinI8, kernels::ReduceMin, int8_t);
-  IREE_VMLA_REDUCTION_OP(ReduceMinI16, kernels::ReduceMin, int16_t);
-  IREE_VMLA_REDUCTION_OP(ReduceMinI32, kernels::ReduceMin, int32_t);
-  IREE_VMLA_REDUCTION_OP(ReduceMinF32, kernels::ReduceMin, float);
-  IREE_VMLA_REDUCTION_OP(ReduceMaxI8, kernels::ReduceMax, int8_t);
-  IREE_VMLA_REDUCTION_OP(ReduceMaxI16, kernels::ReduceMax, int16_t);
-  IREE_VMLA_REDUCTION_OP(ReduceMaxI32, kernels::ReduceMax, int32_t);
-  IREE_VMLA_REDUCTION_OP(ReduceMaxF32, kernels::ReduceMax, float);
-  IREE_VMLA_REDUCTION_OP(ReduceAndI8, kernels::ReduceAnd, int8_t);
-  IREE_VMLA_REDUCTION_OP(ReduceOrI8, kernels::ReduceOr, int8_t);
-
-#define IREE_VMLA_POOLING_OP(name, kernel, type)                              \
-  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,        \
-              const vm::ref<Buffer>& init, iree_vmla_shape_t init_shape,      \
-              const vm::ref<Buffer>& dst, iree_vmla_shape_t dst_shape,        \
-              iree_vmla_shape_t window_dimensions, iree_vmla_shape_t strides, \
-              iree_vmla_shape_t pad_low) {                                    \
-    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                             \
-    return kernel::Execute<type>(src->As<type>(), init->As<type>(),           \
-                                 dst->As<type>(), src_shape, dst_shape,       \
-                                 window_dimensions, strides, pad_low);        \
-  }
-  IREE_VMLA_POOLING_OP(PoolingSumI8, kernels::PoolingSum, int8_t);
-  IREE_VMLA_POOLING_OP(PoolingSumI16, kernels::PoolingSum, int16_t);
-  IREE_VMLA_POOLING_OP(PoolingSumI32, kernels::PoolingSum, int32_t);
-  IREE_VMLA_POOLING_OP(PoolingSumF32, kernels::PoolingSum, float);
-  IREE_VMLA_POOLING_OP(PoolingMinI8, kernels::PoolingMin, int8_t);
-  IREE_VMLA_POOLING_OP(PoolingMinI16, kernels::PoolingMin, int16_t);
-  IREE_VMLA_POOLING_OP(PoolingMinI32, kernels::PoolingMin, int32_t);
-  IREE_VMLA_POOLING_OP(PoolingMinF32, kernels::PoolingMin, float);
-  IREE_VMLA_POOLING_OP(PoolingMaxI8, kernels::PoolingMax, int8_t);
-  IREE_VMLA_POOLING_OP(PoolingMaxI16, kernels::PoolingMax, int16_t);
-  IREE_VMLA_POOLING_OP(PoolingMaxI32, kernels::PoolingMax, int32_t);
-  IREE_VMLA_POOLING_OP(PoolingMaxF32, kernels::PoolingMax, float);
-
- private:
-  iree_allocator_t allocator_;
-
-  // NOTE: kernel state must be externally synchronized as it is shared across
-  // all contexts using the VMLA module. This is fine in our current design as
-  // we only ever execute a single context at a time but if we start to allow
-  // concurrency across contexts we'll need to introduce locks.
-  kernels::RuntimeState* kernel_state_ = nullptr;
-};
-
-//===----------------------------------------------------------------------===//
-// VM module interface implementation
-//===----------------------------------------------------------------------===//
-
-static const vm::NativeFunction<VMLAModuleState> kVMLAModuleFunctions[] = {
-    vm::MakeNativeFunction("interface.const", &VMLAModuleState::InterfaceConst),
-    vm::MakeNativeFunction("interface.binding",
-                           &VMLAModuleState::InterfaceBinding),
-
-    vm::MakeNativeFunction("buffer.const", &VMLAModuleState::BufferConst),
-    vm::MakeNativeFunction("buffer.alloc", &VMLAModuleState::BufferAlloc),
-    vm::MakeNativeFunction("buffer.clone", &VMLAModuleState::BufferClone),
-    vm::MakeNativeFunction("buffer.view", &VMLAModuleState::BufferView),
-    vm::MakeNativeFunction("buffer.copy", &VMLAModuleState::BufferCopy),
-    vm::MakeNativeFunction("buffer.fill", &VMLAModuleState::BufferFill),
-    vm::MakeNativeFunction("buffer.load.i32", &VMLAModuleState::BufferLoadI32),
-
-    vm::MakeNativeFunction("cmp.i8", &VMLAModuleState::CmpI8),
-    vm::MakeNativeFunction("cmp.i16", &VMLAModuleState::CmpI16),
-    vm::MakeNativeFunction("cmp.i32", &VMLAModuleState::CmpI32),
-    vm::MakeNativeFunction("cmp.f32", &VMLAModuleState::CmpF32),
-    vm::MakeNativeFunction("select.x8", &VMLAModuleState::SelectX8),
-    vm::MakeNativeFunction("select.x16", &VMLAModuleState::SelectX16),
-    vm::MakeNativeFunction("select.x32", &VMLAModuleState::SelectX32),
-
-    vm::MakeNativeFunction("broadcast.x8", &VMLAModuleState::BroadcastX8),
-    vm::MakeNativeFunction("broadcast.x16", &VMLAModuleState::BroadcastX16),
-    vm::MakeNativeFunction("broadcast.x32", &VMLAModuleState::BroadcastX32),
-    vm::MakeNativeFunction("copy.x8", &VMLAModuleState::CopyX8),
-    vm::MakeNativeFunction("copy.x16", &VMLAModuleState::CopyX16),
-    vm::MakeNativeFunction("copy.x32", &VMLAModuleState::CopyX32),
-    vm::MakeNativeFunction("transpose.x8", &VMLAModuleState::TransposeX8),
-    vm::MakeNativeFunction("transpose.x16", &VMLAModuleState::TransposeX16),
-    vm::MakeNativeFunction("transpose.x32", &VMLAModuleState::TransposeX32),
-    vm::MakeNativeFunction("reverse.x8", &VMLAModuleState::ReverseX8),
-    vm::MakeNativeFunction("reverse.x16", &VMLAModuleState::ReverseX16),
-    vm::MakeNativeFunction("reverse.x32", &VMLAModuleState::ReverseX32),
-    vm::MakeNativeFunction("pad.x8", &VMLAModuleState::PadX8),
-    vm::MakeNativeFunction("pad.x16", &VMLAModuleState::PadX16),
-    vm::MakeNativeFunction("pad.x32", &VMLAModuleState::PadX32),
-    vm::MakeNativeFunction("gather.x8", &VMLAModuleState::GatherX8),
-    vm::MakeNativeFunction("gather.x16", &VMLAModuleState::GatherX16),
-    vm::MakeNativeFunction("gather.x32", &VMLAModuleState::GatherX32),
-    vm::MakeNativeFunction("scatter.x8", &VMLAModuleState::ScatterX8),
-    vm::MakeNativeFunction("scatter.x16", &VMLAModuleState::ScatterX16),
-    vm::MakeNativeFunction("scatter.x32", &VMLAModuleState::ScatterX32),
-    vm::MakeNativeFunction("iota.i8", &VMLAModuleState::IotaI8),
-    vm::MakeNativeFunction("iota.i16", &VMLAModuleState::IotaI16),
-    vm::MakeNativeFunction("iota.i32", &VMLAModuleState::IotaI32),
-    vm::MakeNativeFunction("iota.f32", &VMLAModuleState::IotaF32),
-    vm::MakeNativeFunction("tile.x8", &VMLAModuleState::TileX8),
-    vm::MakeNativeFunction("tile.x16", &VMLAModuleState::TileX16),
-    vm::MakeNativeFunction("tile.x32", &VMLAModuleState::TileX32),
-
-    vm::MakeNativeFunction("not.x8", &VMLAModuleState::NotX8),
-    vm::MakeNativeFunction("not.x16", &VMLAModuleState::NotX16),
-    vm::MakeNativeFunction("not.x32", &VMLAModuleState::NotX32),
-    vm::MakeNativeFunction("and.x8", &VMLAModuleState::AndX8),
-    vm::MakeNativeFunction("and.x16", &VMLAModuleState::AndX16),
-    vm::MakeNativeFunction("and.x32", &VMLAModuleState::AndX32),
-    vm::MakeNativeFunction("and.broadcast.x8",
-                           &VMLAModuleState::AndBroadcastX8),
-    vm::MakeNativeFunction("and.broadcast.x16",
-                           &VMLAModuleState::AndBroadcastX16),
-    vm::MakeNativeFunction("and.broadcast.x32",
-                           &VMLAModuleState::AndBroadcastX32),
-    vm::MakeNativeFunction("or.x8", &VMLAModuleState::OrX8),
-    vm::MakeNativeFunction("or.x16", &VMLAModuleState::OrX16),
-    vm::MakeNativeFunction("or.x32", &VMLAModuleState::OrX32),
-    vm::MakeNativeFunction("xor.x8", &VMLAModuleState::XorX8),
-    vm::MakeNativeFunction("xor.x16", &VMLAModuleState::XorX16),
-    vm::MakeNativeFunction("xor.x32", &VMLAModuleState::XorX32),
-    vm::MakeNativeFunction("xor.broadcast.x8",
-                           &VMLAModuleState::XorBroadcastX8),
-    vm::MakeNativeFunction("xor.broadcast.x16",
-                           &VMLAModuleState::XorBroadcastX16),
-    vm::MakeNativeFunction("xor.broadcast.x32",
-                           &VMLAModuleState::XorBroadcastX32),
-    vm::MakeNativeFunction("shl.x8", &VMLAModuleState::ShlX8),
-    vm::MakeNativeFunction("shl.x16", &VMLAModuleState::ShlX16),
-    vm::MakeNativeFunction("shl.x32", &VMLAModuleState::ShlX32),
-    vm::MakeNativeFunction("shr.u8", &VMLAModuleState::ShrU8),
-    vm::MakeNativeFunction("shr.u16", &VMLAModuleState::ShrU16),
-    vm::MakeNativeFunction("shr.u32", &VMLAModuleState::ShrU32),
-    vm::MakeNativeFunction("shr.i8", &VMLAModuleState::ShrI8),
-    vm::MakeNativeFunction("shr.i16", &VMLAModuleState::ShrI16),
-    vm::MakeNativeFunction("shr.i32", &VMLAModuleState::ShrI32),
-
-    vm::MakeNativeFunction("add.i8", &VMLAModuleState::AddI8),
-    vm::MakeNativeFunction("add.i16", &VMLAModuleState::AddI16),
-    vm::MakeNativeFunction("add.i32", &VMLAModuleState::AddI32),
-    vm::MakeNativeFunction("add.f32", &VMLAModuleState::AddF32),
-    vm::MakeNativeFunction("sub.i8", &VMLAModuleState::SubI8),
-    vm::MakeNativeFunction("sub.i16", &VMLAModuleState::SubI16),
-    vm::MakeNativeFunction("sub.i32", &VMLAModuleState::SubI32),
-    vm::MakeNativeFunction("sub.f32", &VMLAModuleState::SubF32),
-    vm::MakeNativeFunction("abs.i8", &VMLAModuleState::AbsI8),
-    vm::MakeNativeFunction("abs.i16", &VMLAModuleState::AbsI16),
-    vm::MakeNativeFunction("abs.i32", &VMLAModuleState::AbsI32),
-    vm::MakeNativeFunction("abs.f32", &VMLAModuleState::AbsF32),
-    vm::MakeNativeFunction("neg.i8", &VMLAModuleState::NegI8),
-    vm::MakeNativeFunction("neg.i16", &VMLAModuleState::NegI16),
-    vm::MakeNativeFunction("neg.i32", &VMLAModuleState::NegI32),
-    vm::MakeNativeFunction("neg.f32", &VMLAModuleState::NegF32),
-    vm::MakeNativeFunction("mul.i8", &VMLAModuleState::MulI8),
-    vm::MakeNativeFunction("mul.i16", &VMLAModuleState::MulI16),
-    vm::MakeNativeFunction("mul.i32", &VMLAModuleState::MulI32),
-    vm::MakeNativeFunction("mul.f32", &VMLAModuleState::MulF32),
-    vm::MakeNativeFunction("div.i8", &VMLAModuleState::DivI8),
-    vm::MakeNativeFunction("div.i16", &VMLAModuleState::DivI16),
-    vm::MakeNativeFunction("div.i32", &VMLAModuleState::DivI32),
-    vm::MakeNativeFunction("div.u8", &VMLAModuleState::DivU8),
-    vm::MakeNativeFunction("div.u16", &VMLAModuleState::DivU16),
-    vm::MakeNativeFunction("div.u32", &VMLAModuleState::DivU32),
-    vm::MakeNativeFunction("div.f32", &VMLAModuleState::DivF32),
-    vm::MakeNativeFunction("rem.i8", &VMLAModuleState::RemI8),
-    vm::MakeNativeFunction("rem.i16", &VMLAModuleState::RemI16),
-    vm::MakeNativeFunction("rem.i32", &VMLAModuleState::RemI32),
-    vm::MakeNativeFunction("rem.u8", &VMLAModuleState::RemU8),
-    vm::MakeNativeFunction("rem.u16", &VMLAModuleState::RemU16),
-    vm::MakeNativeFunction("rem.u32", &VMLAModuleState::RemU32),
-    vm::MakeNativeFunction("rem.f32", &VMLAModuleState::RemF32),
-    vm::MakeNativeFunction("pow.f32", &VMLAModuleState::PowF32),
-    vm::MakeNativeFunction("exp.f32", &VMLAModuleState::ExpF32),
-    vm::MakeNativeFunction("log.f32", &VMLAModuleState::LogF32),
-    vm::MakeNativeFunction("rsqrt.f32", &VMLAModuleState::RsqrtF32),
-    vm::MakeNativeFunction("sqrt.f32", &VMLAModuleState::SqrtF32),
-    vm::MakeNativeFunction("cos.f32", &VMLAModuleState::CosF32),
-    vm::MakeNativeFunction("sin.f32", &VMLAModuleState::SinF32),
-    vm::MakeNativeFunction("tanh.f32", &VMLAModuleState::TanhF32),
-    vm::MakeNativeFunction("atan2.f32", &VMLAModuleState::Atan2F32),
-
-    vm::MakeNativeFunction("min.i8", &VMLAModuleState::MinI8),
-    vm::MakeNativeFunction("min.i16", &VMLAModuleState::MinI16),
-    vm::MakeNativeFunction("min.i32", &VMLAModuleState::MinI32),
-    vm::MakeNativeFunction("min.f32", &VMLAModuleState::MinF32),
-    vm::MakeNativeFunction("max.i8", &VMLAModuleState::MaxI8),
-    vm::MakeNativeFunction("max.i16", &VMLAModuleState::MaxI16),
-    vm::MakeNativeFunction("max.i32", &VMLAModuleState::MaxI32),
-    vm::MakeNativeFunction("max.f32", &VMLAModuleState::MaxF32),
-    vm::MakeNativeFunction("clamp.i8", &VMLAModuleState::ClampI8),
-    vm::MakeNativeFunction("clamp.i16", &VMLAModuleState::ClampI16),
-    vm::MakeNativeFunction("clamp.i32", &VMLAModuleState::ClampI32),
-    vm::MakeNativeFunction("clamp.f32", &VMLAModuleState::ClampF32),
-    vm::MakeNativeFunction("floor.f32", &VMLAModuleState::FloorF32),
-    vm::MakeNativeFunction("ceil.f32", &VMLAModuleState::CeilF32),
-    vm::MakeNativeFunction("round.f32", &VMLAModuleState::RoundF32),
-    vm::MakeNativeFunction("sort.i8", &VMLAModuleState::SortI8),
-    vm::MakeNativeFunction("sort.i16", &VMLAModuleState::SortI16),
-    vm::MakeNativeFunction("sort.i32", &VMLAModuleState::SortI32),
-    vm::MakeNativeFunction("sort.f32", &VMLAModuleState::SortF32),
-    vm::MakeNativeFunction("fft.f32", &VMLAModuleState::FftF32),
-    vm::MakeNativeFunction("ifft.f32", &VMLAModuleState::IfftF32),
-    vm::MakeNativeFunction("rfft.f32", &VMLAModuleState::RfftF32),
-    vm::MakeNativeFunction("irfft.f32", &VMLAModuleState::IrfftF32),
-    vm::MakeNativeFunction("finite.f32", &VMLAModuleState::FiniteF32),
-
-    vm::MakeNativeFunction("convert.i8.i16", &VMLAModuleState::ConvertI8I16),
-    vm::MakeNativeFunction("convert.i8.i32", &VMLAModuleState::ConvertI8I32),
-    vm::MakeNativeFunction("convert.i8.f32", &VMLAModuleState::ConvertI8F32),
-    vm::MakeNativeFunction("convert.i16.i8", &VMLAModuleState::ConvertI16I8),
-    vm::MakeNativeFunction("convert.i16.i32", &VMLAModuleState::ConvertI16I32),
-    vm::MakeNativeFunction("convert.i16.f32", &VMLAModuleState::ConvertI16F32),
-    vm::MakeNativeFunction("convert.i32.i8", &VMLAModuleState::ConvertI32I8),
-    vm::MakeNativeFunction("convert.i32.i16", &VMLAModuleState::ConvertI32I16),
-    vm::MakeNativeFunction("convert.i32.f32", &VMLAModuleState::ConvertI32F32),
-    vm::MakeNativeFunction("convert.f32.i8", &VMLAModuleState::ConvertF32I8),
-    vm::MakeNativeFunction("convert.f32.i16", &VMLAModuleState::ConvertF32I16),
-    vm::MakeNativeFunction("convert.f32.i32", &VMLAModuleState::ConvertF32I32),
-
-    vm::MakeNativeFunction("reduce.sum.i8", &VMLAModuleState::ReduceSumI8),
-    vm::MakeNativeFunction("reduce.sum.i16", &VMLAModuleState::ReduceSumI16),
-    vm::MakeNativeFunction("reduce.sum.i32", &VMLAModuleState::ReduceSumI32),
-    vm::MakeNativeFunction("reduce.sum.f32", &VMLAModuleState::ReduceSumF32),
-    vm::MakeNativeFunction("reduce.min.i8", &VMLAModuleState::ReduceMinI8),
-    vm::MakeNativeFunction("reduce.min.i16", &VMLAModuleState::ReduceMinI16),
-    vm::MakeNativeFunction("reduce.min.i32", &VMLAModuleState::ReduceMinI32),
-    vm::MakeNativeFunction("reduce.min.f32", &VMLAModuleState::ReduceMinF32),
-    vm::MakeNativeFunction("reduce.max.i8", &VMLAModuleState::ReduceMaxI8),
-    vm::MakeNativeFunction("reduce.max.i16", &VMLAModuleState::ReduceMaxI16),
-    vm::MakeNativeFunction("reduce.max.i32", &VMLAModuleState::ReduceMaxI32),
-    vm::MakeNativeFunction("reduce.max.f32", &VMLAModuleState::ReduceMaxF32),
-    vm::MakeNativeFunction("reduce.and.i8", &VMLAModuleState::ReduceAndI8),
-    vm::MakeNativeFunction("reduce.or.i8", &VMLAModuleState::ReduceOrI8),
-
-    vm::MakeNativeFunction("pooling.sum.i8", &VMLAModuleState::PoolingSumI8),
-    vm::MakeNativeFunction("pooling.sum.i16", &VMLAModuleState::PoolingSumI16),
-    vm::MakeNativeFunction("pooling.sum.i32", &VMLAModuleState::PoolingSumI32),
-    vm::MakeNativeFunction("pooling.sum.f32", &VMLAModuleState::PoolingSumF32),
-    vm::MakeNativeFunction("pooling.min.i8", &VMLAModuleState::PoolingMinI8),
-    vm::MakeNativeFunction("pooling.min.i16", &VMLAModuleState::PoolingMinI16),
-    vm::MakeNativeFunction("pooling.min.i32", &VMLAModuleState::PoolingMinI32),
-    vm::MakeNativeFunction("pooling.min.f32", &VMLAModuleState::PoolingMinF32),
-    vm::MakeNativeFunction("pooling.max.i8", &VMLAModuleState::PoolingMaxI8),
-    vm::MakeNativeFunction("pooling.max.i16", &VMLAModuleState::PoolingMaxI16),
-    vm::MakeNativeFunction("pooling.max.i32", &VMLAModuleState::PoolingMaxI32),
-    vm::MakeNativeFunction("pooling.max.f32", &VMLAModuleState::PoolingMaxF32),
-
-    vm::MakeNativeFunction(
-        "batch.matmul.f32f32.f32",
-        &VMLAModuleState::BatchMatMul<float, float, float, float>),
-
-    vm::MakeNativeFunction(
-        "batch.matmul.i32i32.i32",
-        &VMLAModuleState::BatchMatMul<int32_t, int32_t, int32_t, int32_t>),
-
-    vm::MakeNativeFunction(
-        "batch.matmul.i8i8.i32",
-        &VMLAModuleState::BatchMatMul<int8_t, int8_t, int32_t, int32_t>),
-
-    vm::MakeNativeFunction(
-        "batch.matmul.i16i16.i32",
-        &VMLAModuleState::BatchMatMul<int16_t, int16_t, int32_t, int32_t>),
-
-    vm::MakeNativeFunction("conv.f32f32.f32", &VMLAModuleState::ConvF32F32F32)};
-
-// Per-device VMLA module.
-// One of these will be created per device and be shared across all executables
-// that are created within that device. Large shared kernel state can go here
-// (such as thread pools/caches/etc), though note that they must be either
-// thread-safe or internally synchronized.
-//
-// Thread-safe.
-class VMLAModule final : public vm::NativeModule<VMLAModuleState> {
- public:
-  explicit VMLAModule(iree_allocator_t allocator)
-      : vm::NativeModule<VMLAModuleState>(
-            "vmla", allocator, absl::MakeConstSpan(kVMLAModuleFunctions)) {}
-  ~VMLAModule() = default;
-
-  Status Initialize() {
-    IREE_TRACE_SCOPE0("VMLAModule::Initialize");
-    return OkStatus();
-  }
-
-  StatusOr<std::unique_ptr<VMLAModuleState>> CreateState(
-      iree_allocator_t allocator) override {
-    IREE_TRACE_SCOPE0("VMLAModule::CreateState");
-    auto state = std::make_unique<VMLAModuleState>(allocator, &kernel_state_);
-    return state;
-  }
-
- private:
-  // NOTE: shared across all contexts with the VMLA module loaded. See
-  // VMLAModuleState::kernel_state_ for more information.
-  kernels::RuntimeState kernel_state_;
-};
-
-}  // namespace
-
-Status ModuleCreate(iree_allocator_t allocator, iree_vm_module_t** out_module) {
-  if (!out_module) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "out_module must not be null");
-  }
-  *out_module = nullptr;
-  auto module = std::make_unique<VMLAModule>(allocator);
-  IREE_RETURN_IF_ERROR(module->Initialize());
-  *out_module = module.release()->interface();
-  return OkStatus();
-}
-
-}  // namespace vmla
-}  // namespace hal
-}  // namespace iree
diff --git a/iree/modules/vmla/op_module.h b/iree/modules/vmla/op_module.h
deleted file mode 100644
index 5c1a4e1..0000000
--- a/iree/modules/vmla/op_module.h
+++ /dev/null
@@ -1,146 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// NOTE: unlike most VM modules we are only ever created and used from C++ code
-// linked into the same library, because of this we can avoid the C shims and
-// directly use C++ types.
-
-#ifndef IREE_MODULES_VMLA_OP_MODULE_H_
-#define IREE_MODULES_VMLA_OP_MODULE_H_
-
-#include <cstdint>
-
-#include "absl/types/span.h"
-#include "iree/base/api.h"
-#include "iree/base/status.h"
-#include "iree/vm/api.h"
-#include "iree/vm/native_module_cc.h"
-#include "iree/vm/ref_cc.h"
-
-namespace iree {
-namespace hal {
-namespace vmla {
-
-using iree_vmla_size_t = uint32_t;
-using iree_vmla_shape_t = absl::Span<const int32_t>;
-
-// Sentinel indicating that the remaining buffer after any offset has been
-// applied should be used as the length.
-constexpr iree_vmla_size_t kVMLAWholeBuffer = -1;
-
-// A lightweight buffer lifetime management type.
-// This is exported to modules as `vmla.buffer`. It can be used to provide
-// views into existing rdata buffers (by specifying iree_allocator_null()),
-// views into parent buffers (parents retained via a reference), or dedicated
-// allocations from an allocator.
-//
-// The provided data pointer and length is always for the buffer itself; it'll
-// already be offset/clamped to parent buffer bounds when a view.
-class Buffer final : public iree::vm::RefObject<Buffer> {
- public:
-  static StatusOr<vm::ref<Buffer>> Allocate(size_t byte_length,
-                                            iree_allocator_t allocator);
-
-  static StatusOr<vm::ref<Buffer>> Wrap(const void* data, size_t data_length,
-                                        iree_allocator_t allocator);
-
-  static StatusOr<vm::ref<Buffer>> WrapMutable(void* data, size_t data_length,
-                                               iree_allocator_t allocator);
-
-  ~Buffer();
-
-  constexpr const void* data() const { return data_; }
-  constexpr void* data() { return data_; }
-  constexpr size_t size() const { return data_length_; }
-
-  template <typename T>
-  absl::Span<const T> As() const {
-    return absl::MakeConstSpan(reinterpret_cast<const T*>(data_),
-                               data_length_ / sizeof(T));
-  }
-
-  template <typename T>
-  absl::Span<T> As() {
-    return absl::MakeSpan(reinterpret_cast<T*>(data_),
-                          data_length_ / sizeof(T));
-  }
-
-  template <typename T>
-  StatusOr<absl::Span<T>> RangeAs(iree_vmla_size_t byte_offset,
-                                  iree_vmla_size_t byte_length) {
-    absl::Span<uint8_t> byte_range;
-    IREE_RETURN_IF_ERROR(MakeRange(byte_offset, byte_length, &byte_range));
-    return ReinterpretSpan<T>(byte_range);
-  }
-
- private:
-  // reinterpret_cast for Spans, preserving byte size.
-  template <typename T, typename U>
-  static constexpr absl::Span<T> ReinterpretSpan(absl::Span<U> value) {
-    return absl::MakeSpan(reinterpret_cast<T*>(value.data()),
-                          (value.size() * sizeof(U)) / sizeof(T));
-  }
-
-  Status MakeRange(iree_vmla_size_t byte_offset, iree_vmla_size_t byte_length,
-                   absl::Span<uint8_t>* out_range) const;
-
-  vm::ref<Buffer> parent_;
-  void* data_ = nullptr;
-  size_t data_length_ = 0;
-  iree_allocator_t allocator_;
-};
-
-class Interface final : public iree::vm::RefObject<Interface> {
- public:
-  static constexpr int kMaxConstants = 32;
-  static constexpr int kMaxSets = 4;
-  static constexpr int kMaxBindings = 32;
-
-  struct Binding {
-    vm::ref<Buffer> buffer;
-    // TODO(benvanik): other descriptor set information.
-  };
-
-  // Resets all bindings on the interface.
-  void Reset();
-
-  // Gets the value from the push constants block at the given element offset.
-  StatusOr<uint32_t> GetConstant(uint32_t offset) const;
-
-  // Sets the push constant block contents to the given values.
-  Status SetConstants(const uint32_t* values, iree_host_size_t value_count);
-
-  // Gets the binding within a set. Note that the buffer may be null.
-  StatusOr<const Binding> GetBinding(uint32_t set, uint32_t binding) const;
-
-  // Sets a binding within a set to the given buffer value (possibly null).
-  Status SetBinding(uint32_t set, uint32_t binding, Binding value);
-
- private:
-  std::array<uint32_t, kMaxConstants> constants_;
-  std::array<std::array<Binding, kMaxBindings>, kMaxSets> bindings_;
-};
-
-Status ModuleRegisterTypes();
-
-Status ModuleCreate(iree_allocator_t allocator, iree_vm_module_t** out_module);
-
-}  // namespace vmla
-}  // namespace hal
-}  // namespace iree
-
-IREE_VM_DECLARE_TYPE_ADAPTERS(Buffer, iree::hal::vmla::Buffer);
-IREE_VM_DECLARE_TYPE_ADAPTERS(Interface, iree::hal::vmla::Interface);
-
-#endif  // IREE_MODULES_VMLA_OP_MODULE_H_
diff --git a/iree/samples/custom_modules/BUILD b/iree/samples/custom_modules/BUILD
index 6694077..fe28780 100644
--- a/iree/samples/custom_modules/BUILD
+++ b/iree/samples/custom_modules/BUILD
@@ -23,8 +23,8 @@
 
 iree_cmake_extra_content(
     content = """
-if(NOT "${IREE_TARGET_BACKEND_VMLA}" OR
-   NOT "${IREE_HAL_DRIVER_VMLA}")
+if(NOT "${IREE_TARGET_BACKEND_VMVX}" OR
+   NOT "${IREE_HAL_DRIVER_VMVX}")
   return()
 endif()
 """,
diff --git a/iree/samples/custom_modules/CMakeLists.txt b/iree/samples/custom_modules/CMakeLists.txt
index 605ebe0..22ffc26 100644
--- a/iree/samples/custom_modules/CMakeLists.txt
+++ b/iree/samples/custom_modules/CMakeLists.txt
@@ -8,8 +8,8 @@
 # To disable autogeneration for this file entirely, delete this header.        #
 ################################################################################
 
-if(NOT "${IREE_TARGET_BACKEND_VMLA}" OR
-   NOT "${IREE_HAL_DRIVER_VMLA}")
+if(NOT "${IREE_TARGET_BACKEND_VMVX}" OR
+   NOT "${IREE_HAL_DRIVER_VMVX}")
   return()
 endif()