Merge pull request #4767 from google/benvanik-hal-c

A 30KB drop, removes all allocs from the common command buffer recording path, and almost all operations are faster to boot (no extraneous ref counting). There's a lot less C++ magic too: 1k loc of template magic replaced with 140 loc of stupid simple macros and autogeneratable boilerplate.

Fixes #4678.
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
index eaa23c2..a6f1e2e 100644
--- a/bindings/python/pyiree/rt/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -591,7 +591,7 @@
     switch (desc.type) {
       case RawSignatureParser::Type::kBuffer: {
         iree_hal_buffer_view_t* buffer_view =
-            iree_hal_buffer_view_deref(&f_result.ref);
+            iree_hal_buffer_view_deref(f_result.ref);
         if (!buffer_view) {
           throw RaiseValueError(
               "Could not deref result buffer view (wrong type?)");
@@ -793,8 +793,8 @@
     if (iree_vm_variant_is_value(variant)) {
       results.push_back("i32=" + std::to_string(variant.i32));
     } else if (iree_vm_variant_is_ref(variant) &&
-               iree_hal_buffer_view_isa(&variant.ref)) {
-      auto buffer_view = iree_hal_buffer_view_deref(&variant.ref);
+               iree_hal_buffer_view_isa(variant.ref)) {
+      auto buffer_view = iree_hal_buffer_view_deref(variant.ref);
 
       std::string result_str(4096, '\0');
       iree_status_t status;
diff --git a/bindings/python/pyiree/rt/vm.cc b/bindings/python/pyiree/rt/vm.cc
index 8091ee2..af719f4 100644
--- a/bindings/python/pyiree/rt/vm.cc
+++ b/bindings/python/pyiree/rt/vm.cc
@@ -213,13 +213,13 @@
       absl::StrAppend(&s, variant.i32);
     } else if (iree_vm_variant_is_ref(variant)) {
       // Pretty print a subset of ABI impacting known types.
-      if (iree_hal_buffer_isa(&variant.ref)) {
-        auto* hal_buffer = iree_hal_buffer_deref(&variant.ref);
+      if (iree_hal_buffer_isa(variant.ref)) {
+        auto* hal_buffer = iree_hal_buffer_deref(variant.ref);
         assert(hal_buffer);
         absl::StrAppend(&s, "HalBuffer(",
                         iree_hal_buffer_byte_length(hal_buffer), ")");
-      } else if (iree_hal_buffer_view_isa(&variant.ref)) {
-        auto hal_bv = iree_hal_buffer_view_deref(&variant.ref);
+      } else if (iree_hal_buffer_view_isa(variant.ref)) {
+        auto hal_bv = iree_hal_buffer_view_deref(variant.ref);
         absl::StrAppend(&s, "HalBufferView(");
         absl::InlinedVector<int32_t, 5> shape(
             iree_hal_buffer_view_shape_rank(hal_bv));
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 9c40cdb..fa8f56f 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -317,11 +317,14 @@
       "/Gy"
       "/DNDEBUG"
       "/DIREE_STATUS_MODE=0"
+      "/PDB"
+      "/Os"
+      "/Oy"
   )
   iree_select_compiler_opts(IREE_SIZE_OPTIMIZED_DEFAULT_LINKOPTS
     MSVC_OR_CLANG_CL
-      "/LTCG"
-      "/opt:ref,icf"
+      "-LTCG"
+      "-opt:ref,icf"
   )
   # TODO(#898): make this only impact the runtime (IREE_RUNTIME_DEFAULT_...).
   set(IREE_DEFAULT_COPTS
diff --git a/iree/base/api.c b/iree/base/api.c
index f7c641b..e118a5c 100644
--- a/iree/base/api.c
+++ b/iree/base/api.c
@@ -595,7 +595,7 @@
 #if IREE_STATUS_FEATURES == 0
   // More advanced status code features like source location and messages are
   // disabled. All statuses are just the codes.
-  return (iree_status_t)(code & IREE_STATUS_CODE_MASK);
+  return iree_status_from_code(code);
 #else
   // No-op for OK statuses; we won't get these from the macros but may be called
   // with this from marshaling code.
diff --git a/iree/base/api.h b/iree/base/api.h
index d41ecc4..bdc2b3f 100644
--- a/iree/base/api.h
+++ b/iree/base/api.h
@@ -204,6 +204,7 @@
 
 // Size, in bytes, of a buffer on the host.
 typedef size_t iree_host_size_t;
+#define IREE_MAX_HOST_SIZE SIZE_MAX
 
 // Size, in bytes, of a buffer on devices.
 typedef uint64_t iree_device_size_t;
@@ -470,7 +471,8 @@
   ((iree_status_code_t)(((uintptr_t)(value)) & IREE_STATUS_CODE_MASK))
 
 // Macros to check the value of a status code.
-#define iree_status_is_ok(value) ((uintptr_t)(value) == IREE_STATUS_OK)
+#define iree_status_is_ok(value) \
+  IREE_LIKELY((uintptr_t)(value) == IREE_STATUS_OK)
 #define iree_status_is_cancelled(value) \
   (iree_status_code(value) == IREE_STATUS_CANCELLED)
 #define iree_status_is_unknown(value) \
@@ -578,13 +580,14 @@
 #define IREE_STATUS_IMPL_MAKE_(code, ...) \
   (iree_status_t)(uintptr_t)((code)&IREE_STATUS_CODE_MASK)
 #undef IREE_STATUS_IMPL_RETURN_IF_API_ERROR_
-#define IREE_STATUS_IMPL_RETURN_IF_API_ERROR_(var, expr, ...) \
-  iree_status_t var = (expr);                                 \
+#define IREE_STATUS_IMPL_RETURN_IF_API_ERROR_(var, ...)                      \
+  iree_status_t var = (IREE_STATUS_IMPL_IDENTITY_(                           \
+      IREE_STATUS_IMPL_IDENTITY_(IREE_STATUS_IMPL_GET_EXPR_)(__VA_ARGS__))); \
   if (IREE_UNLIKELY(var)) return var;
 #undef IREE_STATUS_IMPL_RETURN_AND_EVAL_IF_API_ERROR_
-#define IREE_STATUS_IMPL_RETURN_AND_EVAL_IF_API_ERROR_(tail_expr, var, expr, \
-                                                       ...)                  \
-  iree_status_t var = (expr);                                                \
+#define IREE_STATUS_IMPL_RETURN_AND_EVAL_IF_API_ERROR_(tail_expr, var, ...)  \
+  iree_status_t var = (IREE_STATUS_IMPL_IDENTITY_(                           \
+      IREE_STATUS_IMPL_IDENTITY_(IREE_STATUS_IMPL_GET_EXPR_)(__VA_ARGS__))); \
   if (IREE_UNLIKELY(var)) {                                                  \
     (tail_expr);                                                             \
     return var;                                                              \
diff --git a/iree/base/attributes.h b/iree/base/attributes.h
index 9482781..178030f 100644
--- a/iree/base/attributes.h
+++ b/iree/base/attributes.h
@@ -149,4 +149,14 @@
 #define IREE_UNLIKELY(x) (x)
 #endif  // IREE_HAVE_ATTRIBUTE(likely)
 
+//===----------------------------------------------------------------------===//
+// IREE_ATTRIBUTE_PACKED
+//===----------------------------------------------------------------------===//
+
+#if IREE_HAVE_ATTRIBUTE(packed) || (defined(__GNUC__) && !defined(__clang__))
+#define IREE_ATTRIBUTE_PACKED __attribute__((__packed__))
+#else
+#define IREE_ATTRIBUTE_PACKED
+#endif  // IREE_HAVE_ATTRIBUTE(packed)
+
 #endif  // IREE_BASE_ATTRIBUTES_H_
diff --git a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
index ca6a8a9..7c6be4e 100644
--- a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
+++ b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
@@ -429,13 +429,7 @@
     output << "#include \"" << include << "\"\n";
   };
 
-  printInclude("iree/vm/context.h");
-  printInclude("iree/vm/instance.h");
-  printInclude("iree/vm/native_module.h");
-  printInclude("iree/vm/ops.h");
-  printInclude("iree/vm/ref.h");
-  printInclude("iree/vm/shims.h");
-  printInclude("iree/vm/stack.h");
+  printInclude("iree/vm/api.h");
   output << "\n";
 
   printModuleComment(moduleOp, output);
diff --git a/iree/modules/hal/BUILD b/iree/modules/hal/BUILD
index 6c6d538..8e30060 100644
--- a/iree/modules/hal/BUILD
+++ b/iree/modules/hal/BUILD
@@ -20,17 +20,21 @@
 
 cc_library(
     name = "hal",
-    srcs = ["hal_module.cc"],
-    hdrs = ["hal_module.h"],
+    srcs = [
+        "hal_module.c",
+        "shims.c",
+        "shims.h",
+    ],
+    hdrs = [
+        "hal_module.h",
+    ],
+    textual_hdrs = [
+        "exports.inl",
+    ],
     deps = [
         "//iree/base:api",
         "//iree/base:tracing",
         "//iree/hal:api",
         "//iree/vm",
-        "//iree/vm:cc",
-        "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/types:span",
     ],
 )
diff --git a/iree/modules/hal/CMakeLists.txt b/iree/modules/hal/CMakeLists.txt
index 36800aa..6e602e2 100644
--- a/iree/modules/hal/CMakeLists.txt
+++ b/iree/modules/hal/CMakeLists.txt
@@ -19,17 +19,16 @@
     hal
   HDRS
     "hal_module.h"
+  TEXTUAL_HDRS
+    "exports.inl"
   SRCS
-    "hal_module.cc"
+    "hal_module.c"
+    "shims.c"
+    "shims.h"
   DEPS
-    absl::core_headers
-    absl::inlined_vector
-    absl::memory
-    absl::span
     iree::base::api
     iree::base::tracing
     iree::hal::api
     iree::vm
-    iree::vm::cc
   PUBLIC
 )
diff --git a/iree/modules/hal/exports.inl b/iree/modules/hal/exports.inl
new file mode 100644
index 0000000..881977b
--- /dev/null
+++ b/iree/modules/hal/exports.inl
@@ -0,0 +1,84 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===----------------------------------------------------------------------===//
+//
+//         ██     ██  █████  ██████  ███    ██ ██ ███    ██  ██████
+//         ██     ██ ██   ██ ██   ██ ████   ██ ██ ████   ██ ██
+//         ██  █  ██ ███████ ██████  ██ ██  ██ ██ ██ ██  ██ ██   ███
+//         ██ ███ ██ ██   ██ ██   ██ ██  ██ ██ ██ ██  ██ ██ ██    ██
+//          ███ ███  ██   ██ ██   ██ ██   ████ ██ ██   ████  ██████
+//
+//===----------------------------------------------------------------------===//
+//
+// This file will be auto generated from hal.imports.mlir in the future; for
+// now it's modified by hand but with strict alphabetical sorting required.
+// The order of these functions must be sorted ascending by name in a way
+// compatible with iree_string_view_compare.
+//
+// Users are meant to `#define EXPORT_FN` to be able to access the information.
+// #define EXPORT_FN(name, arg_type, ret_type, target_fn)
+
+// clang-format off
+
+EXPORT_FN("allocator.allocate", iree_hal_module_allocator_allocate, riii, r)
+EXPORT_FN("allocator.wrap.byte_buffer", iree_hal_module_allocator_wrap_byte_buffer, riirii, r)
+
+EXPORT_FN("buffer.allocator", iree_hal_module_buffer_allocator, r, r)
+EXPORT_FN("buffer.fill", iree_hal_module_buffer_fill, riii, v)
+EXPORT_FN("buffer.load", iree_hal_module_buffer_load, rii, i)
+EXPORT_FN("buffer.store", iree_hal_module_buffer_store, irii, v)
+EXPORT_FN("buffer.subspan", iree_hal_module_buffer_subspan, rii, r)
+
+EXPORT_FN("buffer_view.buffer", iree_hal_module_buffer_view_buffer, r, r)
+EXPORT_FN("buffer_view.byte_length", iree_hal_module_buffer_view_byte_length, r, i)
+EXPORT_FN("buffer_view.create", iree_hal_module_buffer_view_create, riCiD, r)
+EXPORT_FN("buffer_view.dim", iree_hal_module_buffer_view_dim, ri, i)
+EXPORT_FN("buffer_view.element_type", iree_hal_module_buffer_view_element_type, r, i)
+EXPORT_FN("buffer_view.rank", iree_hal_module_buffer_view_rank, r, i)
+EXPORT_FN("buffer_view.trace", iree_hal_module_buffer_view_trace, rCrD, v)
+
+EXPORT_FN("command_buffer.begin", iree_hal_module_command_buffer_begin, r, v)
+EXPORT_FN("command_buffer.bind_descriptor_set", iree_hal_module_command_buffer_bind_descriptor_set, rrirCiD, v)
+EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, rririi, v)
+EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, rii, r)
+EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiii, v)
+EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriri, v)
+EXPORT_FN("command_buffer.end", iree_hal_module_command_buffer_end, r, v)
+EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v)
+EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rriii, v)
+EXPORT_FN("command_buffer.push_constants", iree_hal_module_command_buffer_push_constants, rriCiD, v)
+EXPORT_FN("command_buffer.push_descriptor_set", iree_hal_module_command_buffer_push_descriptor_set, rriCiriiD, v)
+
+EXPORT_FN("descriptor_set.create", iree_hal_module_descriptor_set_create, rrCiriiD, r)
+
+EXPORT_FN("descriptor_set_layout.create", iree_hal_module_descriptor_set_layout_create, riCiiiD, r)
+
+EXPORT_FN("device.allocator", iree_hal_module_device_allocator, r, r)
+EXPORT_FN("device.match.id", iree_hal_module_device_match_id, rr, i)
+
+EXPORT_FN("ex.shared_device", iree_hal_module_ex_shared_device, v, r)
+EXPORT_FN("ex.submit_and_wait", iree_hal_module_ex_submit_and_wait, rr, v)
+
+EXPORT_FN("executable.create", iree_hal_module_executable_create, rirCrD, r)
+
+EXPORT_FN("executable_layout.create", iree_hal_module_executable_layout_create, riCrD, r)
+
+EXPORT_FN("semaphore.await", iree_hal_module_semaphore_await, ri, i)
+EXPORT_FN("semaphore.create", iree_hal_module_semaphore_create, ri, r)
+EXPORT_FN("semaphore.fail", iree_hal_module_semaphore_fail, r, i)
+EXPORT_FN("semaphore.query", iree_hal_module_semaphore_query, r, ii)
+EXPORT_FN("semaphore.signal", iree_hal_module_semaphore_signal, ri, v)
+
+// clang-format on
diff --git a/iree/modules/hal/hal_module.c b/iree/modules/hal/hal_module.c
new file mode 100644
index 0000000..06c0baf
--- /dev/null
+++ b/iree/modules/hal/hal_module.c
@@ -0,0 +1,992 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/modules/hal/hal_module.h"
+
+#include <inttypes.h>
+#include <stdio.h>
+
+#include "iree/base/api.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/api.h"
+#include "iree/modules/hal/shims.h"
+#include "iree/vm/api.h"
+
+// Limit the number of bindings we pass down through the HAL. This can be tuned
+// in the future but right now guards the stack from blowing up during calls.
+#define IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT ((iree_host_size_t)32)
+
+//===----------------------------------------------------------------------===//
+// Type registration
+//===----------------------------------------------------------------------===//
+
+static iree_vm_ref_type_descriptor_t iree_hal_allocator_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_buffer_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_buffer_view_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_command_buffer_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_layout_descriptor =
+    {0};
+static iree_vm_ref_type_descriptor_t iree_hal_device_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_event_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_executable_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_executable_layout_descriptor = {
+    0};
+static iree_vm_ref_type_descriptor_t iree_hal_semaphore_descriptor = {0};
+
+#define IREE_VM_REGISTER_HAL_C_TYPE(type, name, destroy_fn, descriptor)   \
+  descriptor.type_name = iree_make_cstring_view(name);                    \
+  descriptor.offsetof_counter = offsetof(iree_hal_resource_t, ref_count); \
+  descriptor.destroy = (iree_vm_ref_destroy_t)destroy_fn;                 \
+  IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor));
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_module_register_types() {
+  static bool has_registered = false;
+  if (has_registered) return iree_ok_status();
+
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_allocator_t, "hal.allocator",
+                              iree_hal_allocator_destroy,
+                              iree_hal_allocator_descriptor);
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_buffer_t, "hal.buffer",
+                              iree_hal_buffer_destroy,
+                              iree_hal_buffer_descriptor);
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_buffer_view_t, "hal.buffer_view",
+                              iree_hal_buffer_view_destroy,
+                              iree_hal_buffer_view_descriptor);
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_command_buffer_t, "hal.command_buffer",
+                              iree_hal_command_buffer_destroy,
+                              iree_hal_command_buffer_descriptor);
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_descriptor_set_t, "hal.descriptor_set",
+                              iree_hal_descriptor_set_destroy,
+                              iree_hal_descriptor_set_descriptor);
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_descriptor_set_layout_t,
+                              "hal.descriptor_set_layout",
+                              iree_hal_descriptor_set_layout_destroy,
+                              iree_hal_descriptor_set_layout_descriptor);
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_device_t, "hal.device",
+                              iree_hal_device_destroy,
+                              iree_hal_device_descriptor);
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_event_t, "hal.event",
+                              iree_hal_event_destroy,
+                              iree_hal_event_descriptor);
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_executable_t, "hal.executable",
+                              iree_hal_executable_destroy,
+                              iree_hal_executable_descriptor);
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_executable_layout_t,
+                              "hal.executable_layout",
+                              iree_hal_executable_layout_destroy,
+                              iree_hal_executable_layout_descriptor);
+  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_semaphore_t, "hal.semaphore",
+                              iree_hal_semaphore_destroy,
+                              iree_hal_semaphore_descriptor);
+
+  has_registered = true;
+  return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// Type wrappers
+//===----------------------------------------------------------------------===//
+
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_allocator, iree_hal_allocator_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_buffer, iree_hal_buffer_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_buffer_view, iree_hal_buffer_view_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_command_buffer,
+                             iree_hal_command_buffer_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set,
+                             iree_hal_descriptor_set_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set_layout,
+                             iree_hal_descriptor_set_layout_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_device, iree_hal_device_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_event, iree_hal_event_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable, iree_hal_executable_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable_layout,
+                             iree_hal_executable_layout_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_semaphore, iree_hal_semaphore_t);
+
+//===----------------------------------------------------------------------===//
+// Module type definitions
+//===----------------------------------------------------------------------===//
+
+typedef struct {
+  iree_allocator_t host_allocator;
+  iree_hal_device_t* shared_device;
+  // TODO(benvanik): types.
+} iree_hal_module_t;
+
+#define IREE_HAL_MODULE_CAST(module) \
+  (iree_hal_module_t*)((uint8_t*)(module) + iree_vm_native_module_size());
+
+typedef struct {
+  iree_allocator_t host_allocator;
+  iree_hal_device_t* shared_device;
+  iree_hal_executable_cache_t* executable_cache;
+
+  iree_vm_list_t* deferred_releases;
+} iree_hal_module_state_t;
+
+static void IREE_API_PTR iree_hal_module_destroy(void* base_module) {
+  iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
+  iree_hal_device_release(module->shared_device);
+}
+
+static iree_status_t IREE_API_PTR
+iree_hal_module_alloc_state(void* self, iree_allocator_t host_allocator,
+                            iree_vm_module_state_t** out_module_state) {
+  iree_hal_module_t* module = IREE_HAL_MODULE_CAST(self);
+  iree_hal_module_state_t* state = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_allocator_malloc(host_allocator, sizeof(*state), (void**)&state));
+  memset(state, 0, sizeof(*state));
+  state->host_allocator = host_allocator;
+  state->shared_device = module->shared_device;
+  iree_hal_device_retain(state->shared_device);
+
+  IREE_RETURN_IF_ERROR(iree_vm_list_create(
+      /*element_type=*/NULL, /*initial_capacity=*/512, state->host_allocator,
+      &state->deferred_releases));
+
+  IREE_RETURN_IF_ERROR(iree_hal_executable_cache_create(
+      state->shared_device, iree_string_view_empty(),
+      &state->executable_cache));
+
+  *out_module_state = (iree_vm_module_state_t*)state;
+  return iree_ok_status();
+}
+
+static void IREE_API_PTR
+iree_hal_module_free_state(void* self, iree_vm_module_state_t* module_state) {
+  iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
+  iree_vm_list_release(state->deferred_releases);
+  iree_hal_executable_cache_release(state->executable_cache);
+  iree_hal_device_release(state->shared_device);
+  iree_allocator_free(state->host_allocator, state);
+}
+
+//===----------------------------------------------------------------------===//
+// Experimental APIs
+//===----------------------------------------------------------------------===//
+// NOTE: Ex* APIs are experimental and likely to be removed soon. Modules
+// using these APIs are not forward compatible.
+
+IREE_VM_ABI_EXPORT(iree_hal_module_ex_shared_device, v, r) {
+  rets->r0 = iree_hal_device_retain_ref(state->shared_device);
+  return iree_ok_status();
+}
+
+void iree_hal_module_ex_defer_release(iree_hal_module_state_t* state,
+                                      const iree_vm_ref_t value) {
+  IREE_IGNORE_ERROR(
+      iree_vm_list_push_ref_retain(state->deferred_releases, &value));
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_ex_submit_and_wait, rr, v) {
+  iree_hal_device_t* device = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r1, &command_buffer));
+
+  // Temporary semaphore we'll signal from 0->1.
+  iree_hal_semaphore_t* semaphore = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(device, 0ull, &semaphore));
+
+  // Batch with our single command buffer.
+  iree_hal_submission_batch_t batch;
+  memset(&batch, 0, sizeof(batch));
+
+  iree_hal_command_buffer_t* command_buffer_ptrs[] = {command_buffer};
+  batch.command_buffer_count = IREE_ARRAYSIZE(command_buffer_ptrs);
+  batch.command_buffers = command_buffer_ptrs;
+
+  iree_hal_semaphore_t* signal_semaphore_ptrs[] = {semaphore};
+  uint64_t signal_semaphore_values[] = {1ull};
+  batch.signal_semaphores.count = IREE_ARRAYSIZE(signal_semaphore_ptrs);
+  batch.signal_semaphores.semaphores = signal_semaphore_ptrs;
+  batch.signal_semaphores.payload_values = signal_semaphore_values;
+
+  iree_status_t status = iree_hal_device_queue_submit(
+      device, IREE_HAL_COMMAND_CATEGORY_ANY, 0, 1, &batch);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_semaphore_release(semaphore);
+    return status;
+  }
+
+  // Block and wait for the semaphore to be signaled (or fail).
+  status = iree_hal_semaphore_wait_with_deadline(semaphore, 1ull,
+                                                 IREE_TIME_INFINITE_FUTURE);
+  if (!iree_status_is_ok(status)) {
+    iree_hal_semaphore_release(semaphore);
+    return status;
+  }
+
+  // Drop all pending deferred releases (references to everything in flight).
+  // This will be replaced with resource sets in the future that are attached to
+  // each command buffer.
+  IREE_RETURN_IF_ERROR(iree_vm_list_resize(state->deferred_releases, 0));
+
+  return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_allocator_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_allocator_allocate, riii, r) {
+  iree_hal_allocator_t* allocator = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r0, &allocator));
+  iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i1;
+  iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i2;
+  iree_vm_size_t allocation_size = (iree_vm_size_t)args->i3;
+
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
+      allocator, memory_types, buffer_usage, allocation_size, &buffer));
+  rets->r0 = iree_hal_buffer_move_ref(buffer);
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_allocator_wrap_byte_buffer, riirii, r) {
+  iree_hal_allocator_t* allocator = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r0, &allocator));
+  iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i1;
+  iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i2;
+  iree_vm_ro_byte_buffer_t* source = NULL;
+  IREE_RETURN_IF_ERROR(iree_vm_ro_byte_buffer_check_deref(args->r3, &source));
+  iree_vm_size_t offset = (iree_vm_size_t)args->i4;
+  iree_vm_size_t length = (iree_vm_size_t)args->i5;
+
+  // TODO(benvanik): wrap when supported.
+
+  iree_host_size_t buffer_length = source->data.data_length;
+  if (length == -1) {
+    length = buffer_length;
+  }
+  if (length < 0 || offset < 0 || offset > buffer_length ||
+      offset + length > buffer_length) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "byte range out of bounds (requested %d-%d of available %zu)", offset,
+        (offset + length - 1), buffer_length);
+  }
+
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_allocator_allocate_buffer(allocator, memory_types, buffer_usage,
+                                         length, &buffer),
+      "failed to allocate buffer of length %d", length);
+
+  iree_status_t status =
+      iree_hal_buffer_write_data(buffer, 0, source->data.data + offset, length);
+  if (iree_status_is_ok(status)) {
+    rets->r0 = iree_hal_buffer_move_ref(buffer);
+  } else {
+    iree_hal_buffer_release(buffer);
+  }
+  return status;
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_buffer_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_allocator, r, r) {
+  iree_hal_buffer_t* buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &buffer));
+  rets->r0 = iree_hal_allocator_retain_ref(iree_hal_buffer_allocator(buffer));
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_subspan, rii, r) {
+  iree_hal_buffer_t* source_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
+  iree_vm_size_t source_offset = (iree_vm_size_t)args->i1;
+  iree_vm_size_t length = (iree_vm_size_t)args->i2;
+
+  iree_hal_buffer_t* subspan_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_subspan(source_buffer, source_offset, length,
+                              &subspan_buffer),
+      "invalid subspan of an existing buffer (source_offset=%d, length=%d)",
+      source_offset, length);
+  rets->r0 = iree_hal_buffer_move_ref(subspan_buffer);
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_fill, riii, v) {
+  // DEPRECATED: will be removed in future versions. Use command buffers.
+  iree_hal_buffer_t* target_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &target_buffer));
+  iree_vm_size_t target_offset = (iree_vm_size_t)args->i1;
+  iree_vm_size_t length = (iree_vm_size_t)args->i2;
+  uint32_t pattern = (uint32_t)args->i3;
+
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_fill(target_buffer, target_offset,
+                                            length, &pattern, sizeof(pattern)),
+                       "fill range failed (target_offset=%d, length=%d)",
+                       target_offset, length);
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_load, rii, i) {
+  iree_hal_buffer_t* source_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
+  iree_vm_size_t source_offset = (iree_vm_size_t)args->i1;
+  iree_vm_size_t length = (iree_vm_size_t)args->i2;
+
+  uint32_t target_buffer = 0;
+  if (length > sizeof(target_buffer)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "load length byte count %d exceeds max", length);
+  }
+
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_read_data(source_buffer, source_offset,
+                                                 &target_buffer, length));
+
+  rets->i0 = target_buffer;
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_store, irii, v) {
+  int32_t value = args->i0;
+  iree_hal_buffer_t* target_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &target_buffer));
+  iree_vm_size_t target_offset = (iree_vm_size_t)args->i2;
+  iree_vm_size_t length = (iree_vm_size_t)args->i3;
+
+  if (length > sizeof(value)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "store length byte count %d exceeds max", length);
+  } else if (target_offset + length >
+             iree_hal_buffer_byte_length(target_buffer)) {
+    return iree_make_status(
+        IREE_STATUS_OUT_OF_RANGE,
+        "store out of bounds (target_offset=%d, length=%d into max %" PRIu64
+        ")",
+        target_offset, length, iree_hal_buffer_byte_length(target_buffer));
+  }
+
+  return iree_hal_buffer_write_data(target_buffer, target_offset, &value,
+                                    length);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_buffer_view_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_create, riCiD, r) {
+  iree_hal_buffer_t* source_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
+  iree_hal_element_type_t element_type = (iree_hal_element_type_t)args->i1;
+  iree_host_size_t shape_rank = 0;
+  iree_hal_dim_t* shape_dims = NULL;
+  IREE_VM_ABI_VLA_STACK_CAST(args, a2_count, a2, iree_hal_dim_t, 128,
+                             &shape_rank, &shape_dims);
+
+  iree_hal_buffer_view_t* buffer_view = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+      source_buffer, element_type, shape_dims, shape_rank, &buffer_view));
+  rets->r0 = iree_hal_buffer_view_move_ref(buffer_view);
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_buffer, r, r) {
+  iree_hal_buffer_view_t* buffer_view = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
+  rets->r0 =
+      iree_hal_buffer_retain_ref(iree_hal_buffer_view_buffer(buffer_view));
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_byte_length, r, i) {
+  iree_hal_buffer_view_t* buffer_view = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
+  rets->i0 = (iree_vm_size_t)iree_hal_buffer_view_byte_length(buffer_view);
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_element_type, r, i) {
+  iree_hal_buffer_view_t* buffer_view = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
+  rets->i0 = (uint32_t)iree_hal_buffer_view_element_type(buffer_view);
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_rank, r, i) {
+  iree_hal_buffer_view_t* buffer_view = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
+  rets->i0 = (iree_vm_size_t)iree_hal_buffer_view_shape_rank(buffer_view);
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_dim, ri, i) {
+  iree_hal_buffer_view_t* buffer_view = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
+  iree_vm_size_t index = (iree_vm_size_t)args->i1;
+  rets->i0 = (iree_vm_size_t)iree_hal_buffer_view_shape_dim(buffer_view, index);
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_trace, rCrD, v) {
+  iree_vm_ro_byte_buffer_t* key = NULL;
+  IREE_RETURN_IF_ERROR(iree_vm_ro_byte_buffer_check_deref(args->r0, &key));
+  iree_string_view_t key_str = iree_vm_ro_byte_buffer_as_string(key);
+
+  fprintf(stderr, "=== %.*s ===\n", (int)key_str.size, key_str.data);
+  for (iree_host_size_t i = 0; i < args->a1_count; ++i) {
+    iree_hal_buffer_view_t* buffer_view = NULL;
+    IREE_RETURN_IF_ERROR(
+        iree_hal_buffer_view_check_deref(args->a1[i].r0, &buffer_view));
+
+    // NOTE: this export is for debugging only and a no-op in min-size builds.
+    // We heap-alloc here because at the point this export is used performance
+    // is not a concern.
+
+    // Query total length (excluding NUL terminator).
+    iree_host_size_t result_length = 0;
+    iree_status_t status = iree_hal_buffer_view_format(buffer_view, SIZE_MAX, 0,
+                                                       NULL, &result_length);
+    if (!iree_status_is_out_of_range(status)) {
+      return status;
+    }
+    ++result_length;  // include NUL
+
+    // Allocate scratch heap memory to contain the result and format into it.
+    char* result_str = NULL;
+    IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+        state->host_allocator, result_length, (void**)&result_str));
+    status = iree_hal_buffer_view_format(buffer_view, SIZE_MAX, result_length,
+                                         result_str, &result_length);
+    if (iree_status_is_ok(status)) {
+      fprintf(stderr, "%.*s\n", (int)result_length, result_str);
+    }
+    iree_allocator_free(state->host_allocator, result_str);
+    IREE_RETURN_IF_ERROR(status);
+  }
+  fprintf(stderr, "\n");
+
+  return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_command_buffer_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_create, rii, r) {
+  iree_hal_device_t* device = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+  iree_hal_command_buffer_mode_t modes =
+      (iree_hal_command_buffer_mode_t)args->i1;
+  iree_hal_command_category_t command_categories =
+      (iree_hal_command_category_t)args->i2;
+
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
+      device, modes, command_categories, &command_buffer));
+  rets->r0 = iree_hal_command_buffer_move_ref(command_buffer);
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_begin, r, v) {
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+
+  return iree_hal_command_buffer_begin(command_buffer);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_end, r, v) {
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+
+  return iree_hal_command_buffer_end(command_buffer);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_execution_barrier, riii, v) {
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+  iree_hal_execution_stage_t source_stage_mask =
+      (iree_hal_execution_stage_t)args->i1;
+  iree_hal_execution_stage_t target_stage_mask =
+      (iree_hal_execution_stage_t)args->i2;
+  iree_hal_execution_barrier_flags_t flags =
+      (iree_hal_execution_barrier_flags_t)args->i3;
+
+  // TODO(benvanik): decode barriers.
+  iree_hal_memory_barrier_t global_barrier;
+  global_barrier.source_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE;
+  global_barrier.target_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_READ;
+
+  return iree_hal_command_buffer_execution_barrier(
+      command_buffer, source_stage_mask, target_stage_mask, flags, 1,
+      &global_barrier, 0, NULL);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_fill_buffer, rriii, v) {
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+  iree_hal_buffer_t* target_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &target_buffer));
+  iree_vm_size_t target_offset = (iree_vm_size_t)args->i2;
+  iree_vm_size_t length = (iree_vm_size_t)args->i3;
+  uint32_t pattern = (uint32_t)args->i4;
+
+  iree_hal_module_ex_defer_release(state, args->r1);
+
+  return iree_hal_command_buffer_fill_buffer(command_buffer, target_buffer,
+                                             target_offset, length, &pattern,
+                                             sizeof(pattern));
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_copy_buffer, rririi, v) {
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+  iree_hal_buffer_t* source_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &source_buffer));
+  iree_vm_size_t source_offset = (iree_vm_size_t)args->i2;
+  iree_hal_buffer_t* target_buffer = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r3, &target_buffer));
+  iree_vm_size_t target_offset = (iree_vm_size_t)args->i4;
+  iree_vm_size_t length = (iree_vm_size_t)args->i5;
+
+  iree_hal_module_ex_defer_release(state, args->r1);
+  iree_hal_module_ex_defer_release(state, args->r3);
+
+  return iree_hal_command_buffer_copy_buffer(command_buffer, source_buffer,
+                                             source_offset, target_buffer,
+                                             target_offset, length);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_push_constants, rriCiD, v) {
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+  iree_hal_executable_layout_t* executable_layout = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_executable_layout_check_deref(args->r1, &executable_layout));
+  iree_vm_size_t offset = (iree_vm_size_t)args->i2;
+  iree_host_size_t value_count = args->a3_count;
+  const uint32_t* values = (const uint32_t*)&args->a3[0].i0;
+
+  return iree_hal_command_buffer_push_constants(
+      command_buffer, executable_layout, offset * sizeof(uint32_t), values,
+      value_count * sizeof(uint32_t));
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_push_descriptor_set,
+                   rriCiriiD, v) {
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+  iree_hal_executable_layout_t* executable_layout = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_executable_layout_check_deref(args->r1, &executable_layout));
+  iree_vm_size_t set = args->i2;
+
+  iree_host_size_t binding_count = args->a3_count;
+  if (IREE_UNLIKELY(binding_count >
+                    IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "binding count %zu > %zu",
+                            binding_count,
+                            IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT);
+  }
+  iree_hal_descriptor_set_binding_t* bindings =
+      (iree_hal_descriptor_set_binding_t*)iree_alloca(
+          binding_count * sizeof(iree_hal_descriptor_set_binding_t));
+  for (iree_host_size_t i = 0; i < binding_count; ++i) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_buffer_check_deref(args->a3[i].r1, &bindings[i].buffer));
+    bindings[i].binding = (uint32_t)args->a3[i].i0;
+    bindings[i].offset = (iree_device_size_t)args->a3[i].i2;
+    bindings[i].length = (iree_device_size_t)args->a3[i].i3;
+    iree_hal_module_ex_defer_release(state, args->a3[i].r1);
+  }
+
+  return iree_hal_command_buffer_push_descriptor_set(
+      command_buffer, executable_layout, set, binding_count, bindings);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_bind_descriptor_set, rrirCiD,
+                   v) {
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+  iree_hal_executable_layout_t* executable_layout = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_executable_layout_check_deref(args->r1, &executable_layout));
+  int32_t set = args->i2;
+  iree_hal_descriptor_set_t* descriptor_set = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_descriptor_set_check_deref(args->r3, &descriptor_set));
+  iree_host_size_t dynamic_offset_count = 0;
+  iree_device_size_t* dynamic_offsets = NULL;
+  IREE_VM_ABI_VLA_STACK_CAST(args, a4_count, a4, iree_device_size_t, 64,
+                             &dynamic_offset_count, &dynamic_offsets);
+
+  iree_hal_module_ex_defer_release(state, args->r3);
+
+  return iree_hal_command_buffer_bind_descriptor_set(
+      command_buffer, executable_layout, set, descriptor_set,
+      dynamic_offset_count, dynamic_offsets);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch, rriiii, v) {
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+  iree_hal_executable_t* executable = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_executable_check_deref(args->r1, &executable));
+  uint32_t entry_point = (uint32_t)args->i2;
+  uint32_t workgroup_x = (uint32_t)args->i3;
+  uint32_t workgroup_y = (uint32_t)args->i4;
+  uint32_t workgroup_z = (uint32_t)args->i5;
+
+  iree_hal_module_ex_defer_release(state, args->r1);
+
+  return iree_hal_command_buffer_dispatch(command_buffer, executable,
+                                          entry_point, workgroup_x, workgroup_y,
+                                          workgroup_z);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, rriri, v) {
+  iree_hal_command_buffer_t* command_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+  iree_hal_executable_t* executable = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_executable_check_deref(args->r1, &executable));
+  uint32_t entry_point = (uint32_t)args->i2;
+  iree_hal_buffer_t* workgroups_buffer = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_buffer_check_deref(args->r3, &workgroups_buffer));
+  iree_vm_size_t workgroups_offset = (iree_vm_size_t)args->i4;
+
+  iree_hal_module_ex_defer_release(state, args->r1);
+  iree_hal_module_ex_defer_release(state, args->r3);
+
+  return iree_hal_command_buffer_dispatch_indirect(
+      command_buffer, executable, entry_point, workgroups_buffer,
+      workgroups_offset);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_descriptor_set_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_descriptor_set_create, rrCiriiD, r) {
+  iree_hal_device_t* device = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+  iree_hal_descriptor_set_layout_t* set_layout = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_descriptor_set_layout_check_deref(args->r1, &set_layout));
+
+  iree_host_size_t binding_count = args->a2_count;
+  if (IREE_UNLIKELY(binding_count >
+                    IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "binding count %zu > %zu",
+                            binding_count,
+                            IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT);
+  }
+  iree_hal_descriptor_set_binding_t* bindings =
+      (iree_hal_descriptor_set_binding_t*)iree_alloca(
+          binding_count * sizeof(iree_hal_descriptor_set_binding_t));
+  for (iree_host_size_t i = 0; i < binding_count; ++i) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_buffer_check_deref(args->a2[i].r1, &bindings[i].buffer));
+    bindings[i].binding = (uint32_t)args->a2[i].i0;
+    bindings[i].offset = (iree_device_size_t)args->a2[i].i2;
+    bindings[i].length = (iree_device_size_t)args->a2[i].i3;
+  }
+
+  iree_hal_descriptor_set_t* descriptor_set = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_create(
+      device, set_layout, binding_count, bindings, &descriptor_set));
+  rets->r0 = iree_hal_descriptor_set_move_ref(descriptor_set);
+  return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_descriptor_set_layout
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_descriptor_set_layout_create, riCiiiD, r) {
+  iree_hal_device_t* device = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+  iree_hal_descriptor_set_layout_usage_type_t usage_type =
+      (iree_hal_descriptor_set_layout_usage_type_t)args->i1;
+
+  iree_host_size_t binding_count = args->a2_count;
+  if (IREE_UNLIKELY(binding_count >
+                    IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) {
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "binding count %zu > %zu",
+                            binding_count,
+                            IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT);
+  }
+  iree_hal_descriptor_set_layout_binding_t* bindings =
+      (iree_hal_descriptor_set_layout_binding_t*)iree_alloca(
+          binding_count * sizeof(iree_hal_descriptor_set_layout_binding_t));
+  for (iree_host_size_t i = 0; i < binding_count; ++i) {
+    bindings[i].binding = (uint32_t)args->a2[i].i0;
+    bindings[i].type = (iree_hal_descriptor_type_t)args->a2[i].i1;
+    bindings[i].access = (iree_hal_memory_access_t)args->a2[i].i2;
+  }
+
+  iree_hal_descriptor_set_layout_t* descriptor_set_layout = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_layout_create(
+      device, usage_type, binding_count, bindings, &descriptor_set_layout));
+  rets->r0 = iree_hal_descriptor_set_layout_move_ref(descriptor_set_layout);
+  return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_device_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_device_allocator, r, r) {
+  iree_hal_device_t* device = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+  rets->r0 = iree_hal_allocator_retain_ref(iree_hal_device_allocator(device));
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_device_match_id, rr, i) {
+  iree_hal_device_t* device = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+  iree_vm_ro_byte_buffer_t* pattern = NULL;
+  IREE_RETURN_IF_ERROR(iree_vm_ro_byte_buffer_check_deref(args->r1, &pattern));
+  iree_string_view_t pattern_str = iree_vm_ro_byte_buffer_as_string(pattern);
+
+  iree_string_view_t device_id = iree_hal_device_id(device);
+  rets->i0 = iree_string_view_match_pattern(device_id, pattern_str) ? 1 : 0;
+  return iree_ok_status();
+}
+
+//===--------------------------------------------------------------------===//
+// iree_hal_executable_t
+//===--------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_executable_create, rirCrD, r) {
+  iree_hal_device_t* device = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+  iree_hal_executable_format_t executable_format =
+      (iree_hal_executable_format_t)args->i1;
+  iree_vm_ro_byte_buffer_t* executable_data = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_vm_ro_byte_buffer_check_deref(args->r2, &executable_data));
+  iree_host_size_t executable_layout_count = args->a3_count;
+  iree_hal_executable_layout_t** executable_layouts = NULL;
+  IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+      state->host_allocator,
+      executable_layout_count * sizeof(executable_layouts[0]),
+      (void**)&executable_layouts));
+  iree_status_t status = iree_ok_status();
+  for (iree_host_size_t i = 0; i < executable_layout_count; ++i) {
+    status = iree_hal_executable_layout_check_deref(args->a3[i].r0,
+                                                    &executable_layouts[i]);
+    if (!iree_status_is_ok(status)) break;
+  }
+
+  iree_hal_executable_t* executable = NULL;
+  if (iree_status_is_ok(status)) {
+    iree_hal_executable_spec_t spec;
+    iree_hal_executable_spec_initialize(&spec);
+    spec.executable_format = executable_format;
+    spec.executable_data = executable_data->data;
+    spec.executable_layout_count = executable_layout_count;
+    spec.executable_layouts = executable_layouts;
+    status = iree_hal_executable_cache_prepare_executable(
+        state->executable_cache, &spec, &executable);
+  }
+
+  iree_allocator_free(state->host_allocator, executable_layouts);
+  rets->r0 = iree_hal_executable_move_ref(executable);
+  return status;
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_executable_layout_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_executable_layout_create, riCrD, r) {
+  iree_hal_device_t* device = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+  int32_t push_constants = (int32_t)args->i1;
+  iree_host_size_t set_layout_count = 0;
+  iree_hal_descriptor_set_layout_t** set_layouts = NULL;
+  IREE_VM_ABI_VLA_STACK_DEREF(args, a2_count, a2,
+                              iree_hal_descriptor_set_layout, 32,
+                              &set_layout_count, &set_layouts);
+
+  iree_hal_executable_layout_t* executable_layout = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_executable_layout_create(
+      device, push_constants, set_layout_count, set_layouts,
+      &executable_layout));
+  rets->r0 = iree_hal_executable_layout_move_ref(executable_layout);
+  return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_semaphore_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_create, ri, r) {
+  iree_hal_device_t* device = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+  uint32_t initial_value = (uint32_t)args->i1;
+
+  iree_hal_semaphore_t* semaphore = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_hal_semaphore_create(device, initial_value, &semaphore));
+  rets->r0 = iree_hal_semaphore_move_ref(semaphore);
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_query, r, ii) {
+  iree_hal_semaphore_t* semaphore = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
+
+  uint64_t value = 0;
+  iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value);
+  rets->i0 = iree_status_consume_code(query_status);
+  rets->i1 = (uint32_t)value;
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_signal, ri, v) {
+  iree_hal_semaphore_t* semaphore = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
+  uint32_t new_value = (uint32_t)args->i1;
+
+  return iree_hal_semaphore_signal(semaphore, new_value);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_fail, ri, v) {
+  iree_hal_semaphore_t* semaphore = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
+  iree_status_code_t status_code =
+      (iree_status_code_t)(args->i1 & IREE_STATUS_CODE_MASK);
+
+  iree_hal_semaphore_fail(semaphore, iree_make_status(status_code));
+  return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_await, ri, i) {
+  iree_hal_semaphore_t* semaphore = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
+  uint64_t new_value = (uint32_t)args->i1;
+
+  // TODO(benvanik): coroutine magic.
+  iree_status_t status = iree_hal_semaphore_wait_with_deadline(
+      semaphore, new_value, IREE_TIME_INFINITE_FUTURE);
+  if (iree_status_is_ok(status)) {
+    rets->i0 = 0;
+  } else if (iree_status_is_deadline_exceeded(status)) {
+    // Propagate deadline exceeded back to the VM.
+    rets->i0 = (int32_t)iree_status_consume_code(status);
+  }
+  return status;
+}
+
+//===----------------------------------------------------------------------===//
+// VM module interface implementation
+//===----------------------------------------------------------------------===//
+
+// NOTE: this must match the ordering of the iree_hal_module_exports_ table.
+static const iree_vm_native_function_ptr_t iree_hal_module_funcs_[] = {
+#define EXPORT_FN(name, target_fn, arg_types, ret_types)       \
+  {                                                            \
+      .shim = (iree_vm_native_function_shim_t)                 \
+          iree_vm_shim_##arg_types##_##ret_types,              \
+      .target = (iree_vm_native_function_target_t)(target_fn), \
+  },
+#include "iree/modules/hal/exports.inl"
+#undef EXPORT_FN
+};
+
+// NOTE: 0 length, but can't express that in C.
+static const iree_vm_native_import_descriptor_t iree_hal_module_imports_[1];
+
+static const iree_vm_native_export_descriptor_t iree_hal_module_exports_[] = {
+#define EXPORT_FN(name, target_fn, arg_types, ret_types)           \
+  {                                                                \
+      .local_name = iree_string_view_literal(name),                \
+      .calling_convention =                                        \
+          iree_string_view_literal("0" #arg_types "_" #ret_types), \
+      .reflection_attr_count = 0,                                  \
+      .reflection_attrs = NULL,                                    \
+  },
+#include "iree/modules/hal/exports.inl"
+#undef EXPORT_FN
+};
+static_assert(IREE_ARRAYSIZE(iree_hal_module_funcs_) ==
+                  IREE_ARRAYSIZE(iree_hal_module_exports_),
+              "function pointer table must be 1:1 with exports");
+
+static const iree_vm_native_module_descriptor_t iree_hal_module_descriptor_ = {
+    .module_name = iree_string_view_literal("hal"),
+    .import_count = 0,  // workaround for 0-length C struct
+    .imports = iree_hal_module_imports_,
+    .export_count = IREE_ARRAYSIZE(iree_hal_module_exports_),
+    .exports = iree_hal_module_exports_,
+    .function_count = IREE_ARRAYSIZE(iree_hal_module_funcs_),
+    .functions = iree_hal_module_funcs_,
+    .reflection_attr_count = 0,
+    .reflection_attrs = NULL,
+};
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_module_create(iree_hal_device_t* device, iree_allocator_t allocator,
+                       iree_vm_module_t** out_module) {
+  IREE_ASSERT_ARGUMENT(device);
+  IREE_ASSERT_ARGUMENT(out_module);
+  *out_module = NULL;
+
+  // Setup the interface with the functions we implement ourselves. Any function
+  // we omit will be handled by the base native module.
+  static const iree_vm_module_t interface = {
+      .destroy = iree_hal_module_destroy,
+      .alloc_state = iree_hal_module_alloc_state,
+      .free_state = iree_hal_module_free_state,
+  };
+
+  // Allocate shared module state.
+  iree_host_size_t total_size =
+      iree_vm_native_module_size() + sizeof(iree_hal_module_t);
+  iree_vm_module_t* base_module = NULL;
+  IREE_RETURN_IF_ERROR(
+      iree_allocator_malloc(allocator, total_size, (void**)&base_module));
+  memset(base_module, 0, total_size);
+  iree_status_t status = iree_vm_native_module_initialize(
+      &interface, &iree_hal_module_descriptor_, allocator, base_module);
+  if (!iree_status_is_ok(status)) {
+    iree_allocator_free(allocator, base_module);
+    return status;
+  }
+
+  iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
+  module->host_allocator = allocator;
+  module->shared_device = device;
+  iree_hal_device_retain(module->shared_device);
+
+  *out_module = base_module;
+  return iree_ok_status();
+}
diff --git a/iree/modules/hal/hal_module.cc b/iree/modules/hal/hal_module.cc
deleted file mode 100644
index 9f3c164..0000000
--- a/iree/modules/hal/hal_module.cc
+++ /dev/null
@@ -1,833 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/modules/hal/hal_module.h"
-
-#include <inttypes.h>
-
-#include "absl/base/macros.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/memory/memory.h"
-#include "absl/types/span.h"
-#include "iree/base/api.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/api.h"
-#include "iree/vm/native_module_cc.h"
-
-//===----------------------------------------------------------------------===//
-// Type registration
-//===----------------------------------------------------------------------===//
-
-static iree_vm_ref_type_descriptor_t iree_hal_allocator_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_buffer_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_buffer_view_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_command_buffer_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_layout_descriptor =
-    {0};
-static iree_vm_ref_type_descriptor_t iree_hal_device_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_event_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_executable_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_executable_cache_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_executable_layout_descriptor = {
-    0};
-static iree_vm_ref_type_descriptor_t iree_hal_semaphore_descriptor = {0};
-
-#define IREE_VM_REGISTER_HAL_C_TYPE(type, name, destroy_fn, descriptor)   \
-  descriptor.type_name = iree_make_cstring_view(name);                    \
-  descriptor.offsetof_counter = offsetof(iree_hal_resource_t, ref_count); \
-  descriptor.destroy = (iree_vm_ref_destroy_t)destroy_fn;                 \
-  IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor));
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_module_register_types() {
-  static bool has_registered = false;
-  if (has_registered) return iree_ok_status();
-
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_allocator_t, "hal.allocator",
-                              iree_hal_allocator_destroy,
-                              iree_hal_allocator_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_buffer_t, "hal.buffer",
-                              iree_hal_buffer_destroy,
-                              iree_hal_buffer_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_buffer_view_t, "hal.buffer_view",
-                              iree_hal_buffer_view_destroy,
-                              iree_hal_buffer_view_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_command_buffer_t, "hal.command_buffer",
-                              iree_hal_command_buffer_destroy,
-                              iree_hal_command_buffer_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_descriptor_set_t, "hal.descriptor_set",
-                              iree_hal_descriptor_set_destroy,
-                              iree_hal_descriptor_set_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_descriptor_set_layout_t,
-                              "hal.descriptor_set_layout",
-                              iree_hal_descriptor_set_layout_destroy,
-                              iree_hal_descriptor_set_layout_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_device_t, "hal.device",
-                              iree_hal_device_destroy,
-                              iree_hal_device_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_event_t, "hal.event",
-                              iree_hal_event_destroy,
-                              iree_hal_event_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_executable_t, "hal.executable",
-                              iree_hal_executable_destroy,
-                              iree_hal_executable_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(
-      iree_hal_executable_cache_t, "hal.executable_cache",
-      iree_hal_executable_cache_destroy, iree_hal_executable_cache_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_executable_layout_t,
-                              "hal.executable_layout",
-                              iree_hal_executable_layout_destroy,
-                              iree_hal_executable_layout_descriptor);
-  IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_semaphore_t, "hal.semaphore",
-                              iree_hal_semaphore_destroy,
-                              iree_hal_semaphore_descriptor);
-
-  has_registered = true;
-  return iree_ok_status();
-}
-
-//===----------------------------------------------------------------------===//
-// Type wrappers
-//===----------------------------------------------------------------------===//
-
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_allocator, iree_hal_allocator_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_buffer, iree_hal_buffer_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_buffer_view, iree_hal_buffer_view_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_command_buffer,
-                             iree_hal_command_buffer_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set,
-                             iree_hal_descriptor_set_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set_layout,
-                             iree_hal_descriptor_set_layout_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_device, iree_hal_device_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_event, iree_hal_event_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable, iree_hal_executable_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable_cache,
-                             iree_hal_executable_cache_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable_layout,
-                             iree_hal_executable_layout_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_semaphore, iree_hal_semaphore_t);
-
-namespace iree {
-namespace hal {
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Module type definitions
-//===----------------------------------------------------------------------===//
-
-class HALModuleState final {
- public:
-  HALModuleState(iree_allocator_t allocator, iree_hal_device_t* shared_device)
-      : allocator_(allocator), shared_device_(shared_device) {
-    iree_hal_device_retain(shared_device_);
-  }
-
-  ~HALModuleState() {
-    for (auto& ref : deferred_releases_) {
-      iree_vm_ref_release(&ref);
-    }
-    deferred_releases_.clear();
-    iree_hal_executable_cache_release(executable_cache_);
-    iree_hal_device_release(shared_device_);
-  }
-
-  Status Initialize() {
-    IREE_TRACE_SCOPE0("HALModuleState::Initialize");
-
-    IREE_RETURN_IF_ERROR(iree_hal_executable_cache_create(
-        shared_device_, iree_string_view_empty(), &executable_cache_));
-
-    return OkStatus();
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Experimental APIs
-  //===--------------------------------------------------------------------===//
-  // NOTE: Ex* APIs are experimental and likely to be removed soon. Modules
-  // using these APIs are not forward compatible.
-
-  StatusOr<vm::ref<iree_hal_device_t>> ExSharedDevice() {
-    return vm::retain_ref(shared_device_);
-  }
-
-  template <typename T>
-  void ExDeferRelease(const vm::ref<T>& value) {
-    deferred_releases_.push_back({0});
-    iree_vm_ref_retain((iree_vm_ref_t*)&value, &deferred_releases_.back());
-  }
-
-  Status ExSubmitAndWait(
-      const vm::ref<iree_hal_device_t>& device,
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer) {
-    IREE_TRACE_SCOPE0("HALModuleState::ExSubmitAndWait");
-
-    vm::ref<iree_hal_semaphore_t> semaphore;
-    IREE_RETURN_IF_ERROR(
-        iree_hal_semaphore_create(device.get(), 0ull, &semaphore));
-
-    iree_hal_submission_batch_t batch;
-    memset(&batch, 0, sizeof(batch));
-    batch.command_buffer_count = 1;
-    iree_hal_command_buffer_t* command_buffer_ptrs[] = {command_buffer.get()};
-    batch.command_buffers = command_buffer_ptrs;
-    batch.signal_semaphores.count = 1;
-    iree_hal_semaphore_t* semaphore_ptrs[] = {semaphore.get()};
-    batch.signal_semaphores.semaphores = semaphore_ptrs;
-    uint64_t signal_value = 1ull;
-    batch.signal_semaphores.payload_values = &signal_value;
-    IREE_RETURN_IF_ERROR(iree_hal_device_queue_submit(
-        device.get(), IREE_HAL_COMMAND_CATEGORY_ANY, 0, 1, &batch));
-
-    IREE_RETURN_IF_ERROR(iree_hal_semaphore_wait_with_deadline(
-        semaphore.get(), 1ull, IREE_TIME_INFINITE_FUTURE));
-
-    {
-      IREE_TRACE_SCOPE0("HALModuleState::DeferredReleases");
-      for (auto& ref : deferred_releases_) {
-        iree_vm_ref_release(&ref);
-      }
-      deferred_releases_.clear();
-    }
-
-    return OkStatus();
-  }
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_allocator_t
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<iree_hal_buffer_t>> AllocatorAllocate(
-      const vm::ref<iree_hal_allocator_t>& allocator,
-      iree_hal_memory_type_t memory_types, iree_hal_buffer_usage_t buffer_usage,
-      int32_t allocation_size) {
-    IREE_TRACE_SCOPE0("HALModuleState::AllocatorAllocate");
-    vm::ref<iree_hal_buffer_t> buffer;
-    IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
-        allocator.get(), memory_types, buffer_usage, allocation_size, &buffer));
-    return std::move(buffer);
-  }
-
-  StatusOr<vm::ref<iree_hal_buffer_t>> AllocatorWrapByteBuffer(
-      const vm::ref<iree_hal_allocator_t>& allocator,
-      iree_hal_memory_type_t memory_types, iree_hal_buffer_usage_t buffer_usage,
-      const vm::ref<iree_vm_ro_byte_buffer_t>& source, int32_t offset,
-      int32_t length) {
-    IREE_TRACE_SCOPE0("HALModuleState::AllocatorWrapByteBuffer");
-
-    // TODO(benvanik): wrap when supported.
-
-    buffer_usage |= IREE_HAL_BUFFER_USAGE_MAPPING;
-
-    size_t buffer_length = source->data.data_length;
-    if (length == -1) {
-      length = static_cast<size_t>(buffer_length);
-    }
-    if (length < 0 || offset < 0 || offset > buffer_length ||
-        offset + length > buffer_length) {
-      return iree_make_status(
-          IREE_STATUS_INVALID_ARGUMENT,
-          "byte range out of bounds (requested %d-%d of available %zu" PRIu64
-          ")",
-          offset, (offset + length - 1), buffer_length);
-    }
-
-    vm::ref<iree_hal_buffer_t> buffer;
-    IREE_RETURN_IF_ERROR(
-        iree_hal_allocator_allocate_buffer(allocator.get(), memory_types,
-                                           buffer_usage, length, &buffer),
-        "failed to allocate buffer");
-
-    IREE_RETURN_IF_ERROR(
-        iree_hal_buffer_write_data(buffer.get(), 0, source->data.data + offset,
-                                   length),
-        "writing constant data");
-
-    return buffer;
-  }
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_buffer_t
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<iree_hal_allocator_t>> BufferAllocator(
-      const vm::ref<iree_hal_buffer_t>& buffer) {
-    return vm::retain_ref(iree_hal_buffer_allocator(buffer.get()));
-  }
-
-  StatusOr<vm::ref<iree_hal_buffer_t>> BufferSubspan(
-      const vm::ref<iree_hal_buffer_t>& source_buffer, int32_t source_offset,
-      int32_t length) {
-    IREE_TRACE_SCOPE0("HALModuleState::BufferSubspan");
-    vm::ref<iree_hal_buffer_t> target_buffer;
-    IREE_RETURN_IF_ERROR(
-        iree_hal_buffer_subspan(source_buffer.get(), source_offset, length,
-                                &target_buffer),
-        "subspan of an existing buffer (source_offset=%u, length=%u)",
-        source_offset, length);
-    return target_buffer;
-  }
-
-  Status BufferFill(const vm::ref<iree_hal_buffer_t>& target_buffer,
-                    int32_t target_offset, int32_t length, int32_t pattern) {
-    IREE_TRACE_SCOPE0("HALModuleState::BufferFill");
-    IREE_RETURN_IF_ERROR(
-        iree_hal_buffer_fill(target_buffer.get(), target_offset, length,
-                             &pattern, sizeof(pattern)),
-        "fill range failed (target_offset=%u, length=%u)", target_offset,
-        length);
-    return OkStatus();
-  }
-
-  StatusOr<int32_t> BufferLoad(const vm::ref<iree_hal_buffer_t>& source_buffer,
-                               int32_t source_offset, int32_t length) {
-    IREE_TRACE_SCOPE0("HALModuleState::BufferLoad");
-
-    uint32_t target_buffer = 0;
-    if (length > sizeof(target_buffer)) {
-      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                              "length %d exceeds max", length);
-    }
-
-    IREE_RETURN_IF_ERROR(
-        iree_hal_buffer_read_data(source_buffer.get(), source_offset,
-                                  &target_buffer, length),
-        "read failed");
-    return target_buffer;
-  }
-
-  Status BufferStore(int32_t value,
-                     const vm::ref<iree_hal_buffer_t>& target_buffer,
-                     int32_t target_offset, int32_t length) {
-    IREE_TRACE_SCOPE0("HALModuleState::BufferStore");
-
-    if (target_offset + length >
-        iree_hal_buffer_byte_length(target_buffer.get())) {
-      return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "out of bounds store");
-    } else if (length > sizeof(value)) {
-      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                              "length %d exceeds max", length);
-    }
-
-    IREE_RETURN_IF_ERROR(
-        iree_hal_buffer_write_data(target_buffer.get(), target_offset, &value,
-                                   length),
-        "write failed");
-    return OkStatus();
-  }
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_buffer_view_t
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<iree_hal_buffer_view_t>> BufferViewCreate(
-      const vm::ref<iree_hal_buffer_t>& buffer,
-      iree_hal_element_type_t element_type, absl::Span<const int32_t> shape) {
-    vm::ref<iree_hal_buffer_view_t> buffer_view;
-    IREE_RETURN_IF_ERROR(
-        iree_hal_buffer_view_create(buffer.get(), element_type, shape.data(),
-                                    shape.size(), &buffer_view),
-        "failed to create buffer view");
-    return std::move(buffer_view);
-  }
-
-  StatusOr<vm::ref<iree_hal_buffer_t>> BufferViewBuffer(
-      const vm::ref<iree_hal_buffer_view_t>& buffer_view) {
-    return vm::retain_ref(iree_hal_buffer_view_buffer(buffer_view.get()));
-  }
-
-  StatusOr<int32_t> BufferViewByteLength(
-      const vm::ref<iree_hal_buffer_view_t>& buffer_view) {
-    return iree_hal_buffer_view_byte_length(buffer_view.get());
-  }
-
-  StatusOr<int32_t> BufferViewElementType(
-      const vm::ref<iree_hal_buffer_view_t>& buffer_view) {
-    return static_cast<int32_t>(
-        iree_hal_buffer_view_element_type(buffer_view.get()));
-  }
-
-  StatusOr<int32_t> BufferViewRank(
-      const vm::ref<iree_hal_buffer_view_t>& buffer_view) {
-    return static_cast<int32_t>(
-        iree_hal_buffer_view_shape_rank(buffer_view.get()));
-  }
-
-  StatusOr<int32_t> BufferViewDim(
-      const vm::ref<iree_hal_buffer_view_t>& buffer_view, int32_t index) {
-    return static_cast<int32_t>(
-        iree_hal_buffer_view_shape_dim(buffer_view.get(), index));
-  }
-
-  Status BufferViewTrace(
-      absl::string_view key,
-      absl::Span<const vm::ref<iree_hal_buffer_view_t>> buffer_views) {
-    fprintf(stderr, "=== %s ===\n", std::string(key).c_str());
-    for (auto& view : buffer_views) {
-      std::string result_str(4096, '\0');
-      iree_status_t status;
-      auto max_element_count = std::numeric_limits<iree_host_size_t>::max();
-      do {
-        iree_host_size_t actual_length = 0;
-        status = iree_hal_buffer_view_format(view.get(), max_element_count,
-                                             result_str.size() + 1,
-                                             &result_str[0], &actual_length);
-        result_str.resize(actual_length);
-      } while (iree_status_is_out_of_range(status));
-      IREE_RETURN_IF_ERROR(status);
-      fprintf(stderr, "%s\n", result_str.c_str());
-    }
-    fprintf(stderr, "\n");
-    return OkStatus();
-  }
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_command_buffer_t
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<iree_hal_command_buffer_t>> CommandBufferCreate(
-      const vm::ref<iree_hal_device_t>& device,
-      iree_hal_command_buffer_mode_t modes,
-      iree_hal_command_category_t command_categories) {
-    vm::ref<iree_hal_command_buffer_t> command_buffer;
-    IREE_RETURN_IF_ERROR(
-        iree_hal_command_buffer_create(device.get(), modes, command_categories,
-                                       &command_buffer),
-        "failed to create command buffer");
-    return command_buffer;
-  }
-
-  Status CommandBufferBegin(
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer) {
-    return iree_hal_command_buffer_begin(command_buffer.get());
-  }
-
-  Status CommandBufferEnd(
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer) {
-    return iree_hal_command_buffer_end(command_buffer.get());
-  }
-
-  Status CommandBufferExecutionBarrier(
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer,
-      iree_hal_execution_stage_t source_stage_mask,
-      iree_hal_execution_stage_t target_stage_mask,
-      iree_hal_execution_barrier_flags_t flags) {
-    iree_hal_memory_barrier_t global_barrier;
-    global_barrier.source_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE;
-    global_barrier.target_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_READ;
-    return iree_hal_command_buffer_execution_barrier(
-        command_buffer.get(), source_stage_mask, target_stage_mask, flags, 1,
-        &global_barrier, 0, nullptr);
-  }
-
-  Status CommandBufferFillBuffer(
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer,
-      const vm::ref<iree_hal_buffer_t>& target_buffer, int32_t target_offset,
-      int32_t length, uint32_t pattern) {
-    ExDeferRelease(target_buffer);
-    return iree_hal_command_buffer_fill_buffer(
-        command_buffer.get(), target_buffer.get(), target_offset, length,
-        &pattern, sizeof(pattern));
-  }
-
-  Status CommandBufferCopyBuffer(
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer,
-      const vm::ref<iree_hal_buffer_t>& source_buffer, int32_t source_offset,
-      const vm::ref<iree_hal_buffer_t>& target_buffer, int32_t target_offset,
-      int32_t length) {
-    ExDeferRelease(source_buffer);
-    ExDeferRelease(target_buffer);
-    return iree_hal_command_buffer_copy_buffer(
-        command_buffer.get(), source_buffer.get(), source_offset,
-        target_buffer.get(), target_offset, length);
-  }
-
-  Status CommandBufferPushConstants(
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer,
-      const vm::ref<iree_hal_executable_layout_t>& executable_layout,
-      uint32_t offset, absl::Span<const uint32_t> values) {
-    ExDeferRelease(executable_layout);
-    return iree_hal_command_buffer_push_constants(
-        command_buffer.get(), executable_layout.get(),
-        offset * sizeof(uint32_t), values.data(),
-        values.size() * sizeof(uint32_t));
-  }
-
-  Status CommandBufferPushDescriptorSet(
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer,
-      const vm::ref<iree_hal_executable_layout_t>& executable_layout,
-      uint32_t set,
-      absl::Span<const std::tuple<uint32_t, vm::ref<iree_hal_buffer_t>, int32_t,
-                                  int32_t>>
-          bindings) {
-    ExDeferRelease(executable_layout);
-    absl::InlinedVector<iree_hal_descriptor_set_binding_t, 16> binding_structs(
-        bindings.size());
-    for (int i = 0; i < bindings.size(); ++i) {
-      binding_structs[i] = {
-          std::get<0>(bindings[i]), std::get<1>(bindings[i]).get(),
-          static_cast<iree_device_size_t>(std::get<2>(bindings[i])),
-          static_cast<iree_device_size_t>(std::get<3>(bindings[i]))};
-      ExDeferRelease(std::get<1>(bindings[i]));
-    }
-    return iree_hal_command_buffer_push_descriptor_set(
-        command_buffer.get(), executable_layout.get(), set,
-        binding_structs.size(), binding_structs.data());
-  }
-
-  Status CommandBufferBindDescriptorSet(
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer,
-      const vm::ref<iree_hal_executable_layout_t>& executable_layout,
-      uint32_t set, const vm::ref<iree_hal_descriptor_set_t>& descriptor_set,
-      absl::Span<const int32_t> dynamic_offsets) {
-    ExDeferRelease(executable_layout);
-    ExDeferRelease(descriptor_set);
-    absl::InlinedVector<iree_device_size_t, 4> dynamic_offset_values(
-        dynamic_offsets.size());
-    for (int i = 0; i < dynamic_offsets.size(); ++i) {
-      dynamic_offset_values[i] =
-          static_cast<iree_device_size_t>(dynamic_offsets[i]);
-    }
-    return iree_hal_command_buffer_bind_descriptor_set(
-        command_buffer.get(), executable_layout.get(), set,
-        descriptor_set.get(), dynamic_offset_values.size(),
-        dynamic_offset_values.data());
-  }
-
-  Status CommandBufferDispatch(
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer,
-      const vm::ref<iree_hal_executable_t>& executable, int32_t entry_point,
-      uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
-    ExDeferRelease(executable);
-    return iree_hal_command_buffer_dispatch(
-        command_buffer.get(), executable.get(), entry_point, workgroup_x,
-        workgroup_y, workgroup_z);
-  }
-
-  Status CommandBufferDispatchIndirect(
-      const vm::ref<iree_hal_command_buffer_t>& command_buffer,
-      const vm::ref<iree_hal_executable_t>& executable, int32_t entry_point,
-      const vm::ref<iree_hal_buffer_t>& workgroups_buffer,
-      int32_t workgroups_offset) {
-    ExDeferRelease(executable);
-    ExDeferRelease(workgroups_buffer);
-    return iree_hal_command_buffer_dispatch_indirect(
-        command_buffer.get(), executable.get(), entry_point,
-        workgroups_buffer.get(), workgroups_offset);
-  }
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_descriptor_set_t
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<iree_hal_descriptor_set_t>> DescriptorSetCreate(
-      const vm::ref<iree_hal_device_t>& device,
-      const vm::ref<iree_hal_descriptor_set_layout_t>& set_layout,
-      absl::Span<const std::tuple<uint32_t, vm::ref<iree_hal_buffer_t>, int32_t,
-                                  int32_t>>
-          bindings) {
-    absl::InlinedVector<iree_hal_descriptor_set_binding_t, 4> binding_structs(
-        bindings.size());
-    for (int i = 0; i < bindings.size(); ++i) {
-      binding_structs[i] = {
-          /*ordinal=*/std::get<0>(bindings[i]),
-          /*buffer=*/std::get<1>(bindings[i]).get(),
-          /*offset=*/static_cast<iree_device_size_t>(std::get<2>(bindings[i])),
-          /*length=*/static_cast<iree_device_size_t>(std::get<3>(bindings[i]))};
-    }
-    vm::ref<iree_hal_descriptor_set_t> descriptor_set;
-    IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_create(
-        device.get(), set_layout.get(), binding_structs.size(),
-        binding_structs.data(), &descriptor_set));
-    return std::move(descriptor_set);
-  }
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_descriptor_set_layout_t
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<iree_hal_descriptor_set_layout_t>> DescriptorSetLayoutCreate(
-      const vm::ref<iree_hal_device_t>& device,
-      iree_hal_descriptor_set_layout_usage_type_t usage_type,
-      absl::Span<const std::tuple<uint32_t, iree_hal_descriptor_type_t,
-                                  iree_hal_memory_access_t>>
-          bindings) {
-    // TODO(benvanik): custom marshaling for the structs.
-    absl::InlinedVector<iree_hal_descriptor_set_layout_binding_t, 4>
-        binding_structs(bindings.size());
-    for (int i = 0; i < bindings.size(); ++i) {
-      binding_structs[i] = {std::get<0>(bindings[i]), std::get<1>(bindings[i]),
-                            std::get<2>(bindings[i])};
-    }
-    vm::ref<iree_hal_descriptor_set_layout_t> descriptor_set_layout;
-    IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_layout_create(
-        device.get(), usage_type, binding_structs.size(),
-        binding_structs.data(), &descriptor_set_layout));
-    return std::move(descriptor_set_layout);
-  }
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_device_t
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<iree_hal_allocator_t>> DeviceAllocator(
-      const vm::ref<iree_hal_device_t>& device) {
-    return vm::retain_ref(iree_hal_device_allocator(device.get()));
-  }
-
-  StatusOr<int32_t> DeviceMatchID(const vm::ref<iree_hal_device_t>& device,
-                                  absl::string_view pattern) {
-    iree_string_view_t device_id = iree_hal_device_id(device.get());
-    return iree_string_view_match_pattern(
-               device_id, iree_string_view_t{pattern.data(), pattern.size()})
-               ? 1
-               : 0;
-  }
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_event_t
-  //===--------------------------------------------------------------------===//
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_executable_t
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<iree_hal_executable_t>> ExecutableCreate(
-      const vm::ref<iree_hal_device_t>& device,
-      iree_hal_executable_format_t executable_format,
-      const vm::ref<iree_vm_ro_byte_buffer_t>& executable_data,
-      absl::Span<const vm::ref<iree_hal_executable_layout_t>>
-          executable_layouts) {
-    iree_hal_executable_spec_t spec;
-    iree_hal_executable_spec_initialize(&spec);
-
-    spec.executable_format = executable_format;
-    spec.executable_data = executable_data->data;
-
-    spec.executable_layout_count = executable_layouts.size();
-    iree_hal_executable_layout_t** executable_layouts_ptr =
-        (iree_hal_executable_layout_t**)iree_alloca(
-            sizeof(executable_layouts_ptr[0]) * executable_layouts.size());
-    for (size_t i = 0; i < executable_layouts.size(); ++i) {
-      executable_layouts_ptr[i] = executable_layouts[i].get();
-    }
-    spec.executable_layouts = executable_layouts_ptr;
-
-    vm::ref<iree_hal_executable_t> executable;
-    IREE_RETURN_IF_ERROR(iree_hal_executable_cache_prepare_executable(
-        executable_cache_, &spec, &executable));
-    return std::move(executable);
-  }
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_executable_layout_t
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<iree_hal_executable_layout_t>> ExecutableLayoutCreate(
-      const vm::ref<iree_hal_device_t>& device, int32_t push_constants,
-      absl::Span<const vm::ref<iree_hal_descriptor_set_layout_t>> set_layouts) {
-    iree_hal_descriptor_set_layout_t** set_layouts_ptr =
-        (iree_hal_descriptor_set_layout_t**)iree_alloca(
-            sizeof(set_layouts_ptr[0]) * set_layouts.size());
-    for (size_t i = 0; i < set_layouts.size(); ++i) {
-      set_layouts_ptr[i] = set_layouts[i].get();
-    }
-
-    vm::ref<iree_hal_executable_layout_t> executable_layout;
-    IREE_RETURN_IF_ERROR(iree_hal_executable_layout_create(
-        device.get(), push_constants, set_layouts.size(), set_layouts_ptr,
-        &executable_layout));
-    return std::move(executable_layout);
-  }
-
-  //===--------------------------------------------------------------------===//
-  // iree_hal_semaphore_t
-  //===--------------------------------------------------------------------===//
-
-  StatusOr<vm::ref<iree_hal_semaphore_t>> SemaphoreCreate(
-      const vm::ref<iree_hal_device_t>& device, uint32_t initial_value) {
-    vm::ref<iree_hal_semaphore_t> semaphore;
-    IREE_RETURN_IF_ERROR(
-        iree_hal_semaphore_create(device.get(), initial_value, &semaphore));
-    return std::move(semaphore);
-  }
-
-  StatusOr<std::tuple<int32_t, uint32_t>> SemaphoreQuery(
-      const vm::ref<iree_hal_semaphore_t>& semaphore) {
-    uint64_t value = 0;
-    iree_status_t query_status =
-        iree_hal_semaphore_query(semaphore.get(), &value);
-    return std::make_tuple<int32_t, uint32_t>(iree_status_code(query_status),
-                                              static_cast<uint32_t>(value));
-  }
-
-  Status SemaphoreSignal(const vm::ref<iree_hal_semaphore_t>& semaphore,
-                         uint32_t new_value) {
-    return iree_hal_semaphore_signal(semaphore.get(), new_value);
-  }
-
-  Status SemaphoreFail(const vm::ref<iree_hal_semaphore_t>& semaphore,
-                       int32_t status_code) {
-    iree_status_t status = iree_make_status(
-        static_cast<iree_status_code_t>(status_code & IREE_STATUS_CODE_MASK));
-    iree_hal_semaphore_fail(semaphore.get(), status);
-    return OkStatus();
-  }
-
-  StatusOr<int32_t> SemaphoreAwait(
-      const vm::ref<iree_hal_semaphore_t>& semaphore, uint32_t new_value) {
-    // TODO(benvanik): coroutine magic.
-    iree_status_t status = iree_hal_semaphore_wait_with_deadline(
-        semaphore.get(), new_value, IREE_TIME_INFINITE_FUTURE);
-    if (iree_status_is_ok(status)) {
-      return 0;
-    } else if (iree_status_is_deadline_exceeded(status)) {
-      // Propagate deadline exceeded back to the VM.
-      return static_cast<int32_t>(iree_status_consume_code(status));
-    }
-    return Status(std::move(status));
-  }
-
- private:
-  iree_allocator_t allocator_;
-  iree_hal_device_t* shared_device_ = NULL;
-  iree_hal_executable_cache_t* executable_cache_ = NULL;
-
-  std::vector<iree_vm_ref_t> deferred_releases_;
-};
-
-//===----------------------------------------------------------------------===//
-// VM module interface implementation
-//===----------------------------------------------------------------------===//
-
-static const vm::NativeFunction<HALModuleState> kHALModuleFunctions[] = {
-    vm::MakeNativeFunction("ex.shared_device", &HALModuleState::ExSharedDevice),
-    vm::MakeNativeFunction("ex.submit_and_wait",
-                           &HALModuleState::ExSubmitAndWait),
-
-    vm::MakeNativeFunction("allocator.allocate",
-                           &HALModuleState::AllocatorAllocate),
-    vm::MakeNativeFunction("allocator.wrap.byte_buffer",
-                           &HALModuleState::AllocatorWrapByteBuffer),
-
-    vm::MakeNativeFunction("buffer.allocator",
-                           &HALModuleState::BufferAllocator),
-    vm::MakeNativeFunction("buffer.subspan", &HALModuleState::BufferSubspan),
-    vm::MakeNativeFunction("buffer.fill", &HALModuleState::BufferFill),
-    vm::MakeNativeFunction("buffer.load", &HALModuleState::BufferLoad),
-    vm::MakeNativeFunction("buffer.store", &HALModuleState::BufferStore),
-
-    vm::MakeNativeFunction("buffer_view.create",
-                           &HALModuleState::BufferViewCreate),
-    vm::MakeNativeFunction("buffer_view.buffer",
-                           &HALModuleState::BufferViewBuffer),
-    vm::MakeNativeFunction("buffer_view.byte_length",
-                           &HALModuleState::BufferViewByteLength),
-    vm::MakeNativeFunction("buffer_view.element_type",
-                           &HALModuleState::BufferViewElementType),
-    vm::MakeNativeFunction("buffer_view.rank", &HALModuleState::BufferViewRank),
-    vm::MakeNativeFunction("buffer_view.dim", &HALModuleState::BufferViewDim),
-    vm::MakeNativeFunction("buffer_view.trace",
-                           &HALModuleState::BufferViewTrace),
-
-    vm::MakeNativeFunction("command_buffer.create",
-                           &HALModuleState::CommandBufferCreate),
-    vm::MakeNativeFunction("command_buffer.begin",
-                           &HALModuleState::CommandBufferBegin),
-    vm::MakeNativeFunction("command_buffer.end",
-                           &HALModuleState::CommandBufferEnd),
-    vm::MakeNativeFunction("command_buffer.execution_barrier",
-                           &HALModuleState::CommandBufferExecutionBarrier),
-    vm::MakeNativeFunction("command_buffer.fill_buffer",
-                           &HALModuleState::CommandBufferFillBuffer),
-    vm::MakeNativeFunction("command_buffer.copy_buffer",
-                           &HALModuleState::CommandBufferCopyBuffer),
-    vm::MakeNativeFunction("command_buffer.push_constants",
-                           &HALModuleState::CommandBufferPushConstants),
-    vm::MakeNativeFunction("command_buffer.push_descriptor_set",
-                           &HALModuleState::CommandBufferPushDescriptorSet),
-    vm::MakeNativeFunction("command_buffer.bind_descriptor_set",
-                           &HALModuleState::CommandBufferBindDescriptorSet),
-    vm::MakeNativeFunction("command_buffer.dispatch",
-                           &HALModuleState::CommandBufferDispatch),
-    vm::MakeNativeFunction("command_buffer.dispatch.indirect",
-                           &HALModuleState::CommandBufferDispatchIndirect),
-
-    vm::MakeNativeFunction("descriptor_set.create",
-                           &HALModuleState::DescriptorSetCreate),
-    vm::MakeNativeFunction("descriptor_set_layout.create",
-                           &HALModuleState::DescriptorSetLayoutCreate),
-
-    vm::MakeNativeFunction("device.allocator",
-                           &HALModuleState::DeviceAllocator),
-    vm::MakeNativeFunction("device.match.id", &HALModuleState::DeviceMatchID),
-
-    vm::MakeNativeFunction("executable.create",
-                           &HALModuleState::ExecutableCreate),
-
-    vm::MakeNativeFunction("executable_layout.create",
-                           &HALModuleState::ExecutableLayoutCreate),
-
-    vm::MakeNativeFunction("semaphore.create",
-                           &HALModuleState::SemaphoreCreate),
-    vm::MakeNativeFunction("semaphore.query", &HALModuleState::SemaphoreQuery),
-    vm::MakeNativeFunction("semaphore.signal",
-                           &HALModuleState::SemaphoreSignal),
-    vm::MakeNativeFunction("semaphore.fail", &HALModuleState::SemaphoreFail),
-    vm::MakeNativeFunction("semaphore.await", &HALModuleState::SemaphoreAwait),
-};
-
-class HALModule final : public vm::NativeModule<HALModuleState> {
- public:
-  HALModule(iree_allocator_t allocator, iree_hal_device_t* shared_device)
-      : vm::NativeModule<HALModuleState>(
-            "hal", allocator, absl::MakeConstSpan(kHALModuleFunctions)),
-        shared_device_(shared_device) {
-    iree_hal_device_retain(shared_device_);
-  }
-
-  ~HALModule() { iree_hal_device_release(shared_device_); }
-
-  Status Initialize() {
-    IREE_TRACE_SCOPE0("HALModule::Initialize");
-    return OkStatus();
-  }
-
-  StatusOr<std::unique_ptr<HALModuleState>> CreateState(
-      iree_allocator_t allocator) override {
-    IREE_TRACE_SCOPE0("HALModule::CreateState");
-    auto state = std::make_unique<HALModuleState>(allocator, shared_device_);
-    IREE_RETURN_IF_ERROR(state->Initialize());
-    return state;
-  }
-
- private:
-  iree_hal_device_t* shared_device_ = NULL;
-};
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_module_create(iree_hal_device_t* device, iree_allocator_t allocator,
-                       iree_vm_module_t** out_module) {
-  IREE_ASSERT_ARGUMENT(device);
-  IREE_ASSERT_ARGUMENT(out_module);
-  *out_module = nullptr;
-  auto module = std::make_unique<HALModule>(allocator, device);
-  IREE_RETURN_IF_ERROR(module->Initialize());
-  *out_module = module.release()->interface();
-  return iree_ok_status();
-}
-
-}  // namespace
-}  // namespace hal
-}  // namespace iree
diff --git a/iree/modules/hal/shims.c b/iree/modules/hal/shims.c
new file mode 100644
index 0000000..28fa9ee
--- /dev/null
+++ b/iree/modules/hal/shims.c
@@ -0,0 +1,52 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/modules/hal/shims.h"
+
+IREE_VM_ABI_DEFINE_SHIM(irii, v);
+IREE_VM_ABI_DEFINE_SHIM(r, i);
+IREE_VM_ABI_DEFINE_SHIM(r, ii);
+IREE_VM_ABI_DEFINE_SHIM(r, iii);
+IREE_VM_ABI_DEFINE_SHIM(r, iiii);
+IREE_VM_ABI_DEFINE_SHIM(r, r);
+IREE_VM_ABI_DEFINE_SHIM(r, v);
+IREE_VM_ABI_DEFINE_SHIM(rCiD, i);
+IREE_VM_ABI_DEFINE_SHIM(rCrD, v);
+IREE_VM_ABI_DEFINE_SHIM(ri, i);
+IREE_VM_ABI_DEFINE_SHIM(ri, r);
+IREE_VM_ABI_DEFINE_SHIM(ri, v);
+IREE_VM_ABI_DEFINE_SHIM(riCiD, r);
+IREE_VM_ABI_DEFINE_SHIM(riCiiiD, r);
+IREE_VM_ABI_DEFINE_SHIM(riCrD, r);
+IREE_VM_ABI_DEFINE_SHIM(rii, i);
+IREE_VM_ABI_DEFINE_SHIM(rii, r);
+IREE_VM_ABI_DEFINE_SHIM(riii, r);
+IREE_VM_ABI_DEFINE_SHIM(riii, v);
+IREE_VM_ABI_DEFINE_SHIM(riirii, r);
+IREE_VM_ABI_DEFINE_SHIM(rirCrD, r);
+IREE_VM_ABI_DEFINE_SHIM(ririi, v);
+IREE_VM_ABI_DEFINE_SHIM(rr, i);
+IREE_VM_ABI_DEFINE_SHIM(rr, r);
+IREE_VM_ABI_DEFINE_SHIM(rr, v);
+IREE_VM_ABI_DEFINE_SHIM(rrCiriiD, r);
+IREE_VM_ABI_DEFINE_SHIM(rriCiD, v);
+IREE_VM_ABI_DEFINE_SHIM(rriCiriiD, v);
+IREE_VM_ABI_DEFINE_SHIM(rriii, v);
+IREE_VM_ABI_DEFINE_SHIM(rriiii, v);
+IREE_VM_ABI_DEFINE_SHIM(rrirCiD, v);
+IREE_VM_ABI_DEFINE_SHIM(rriri, v);
+IREE_VM_ABI_DEFINE_SHIM(rririi, v);
+IREE_VM_ABI_DEFINE_SHIM(v, i);
+IREE_VM_ABI_DEFINE_SHIM(v, r);
+IREE_VM_ABI_DEFINE_SHIM(v, v);
diff --git a/iree/modules/hal/shims.h b/iree/modules/hal/shims.h
new file mode 100644
index 0000000..31a0bb5
--- /dev/null
+++ b/iree/modules/hal/shims.h
@@ -0,0 +1,394 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_MODULES_HAL_SHIMS_H_
+#define IREE_MODULES_HAL_SHIMS_H_
+
+#include "iree/base/api.h"
+#include "iree/base/attributes.h"
+#include "iree/vm/api.h"
+
+//===----------------------------------------------------------------------===//
+// Argument/result struct utilities
+//===----------------------------------------------------------------------===//
+
+#define IREE_VM_ABI_TYPE_NAME(types) iree_vm_abi_##types##_t
+
+#define IREE_VM_ABI_FIXED_STRUCT(types, body) \
+  IREE_VM_ABI_FIXED_STRUCT_IMPL(types, IREE_VM_ABI_TYPE_NAME(types), body)
+
+#define IREE_VM_ABI_VLA_STRUCT(types, vla_count, vla_field, body) \
+  IREE_VM_ABI_VLA_STRUCT_IMPL(types, vla_count, vla_field,        \
+                              IREE_VM_ABI_TYPE_NAME(types), body)
+
+#define IREE_VM_ABI_FIXED_STRUCT_IMPL(types, struct_type, body)        \
+  typedef struct iree_vm_abi_##types##_s body IREE_ATTRIBUTE_PACKED    \
+      struct_type;                                                     \
+  static inline struct_type* iree_vm_abi_##types##_checked_deref(      \
+      iree_byte_span_t buffer) {                                       \
+    return IREE_LIKELY(buffer.data_length == sizeof(struct_type))      \
+               ? (struct_type*)buffer.data                             \
+               : NULL;                                                 \
+  }                                                                    \
+  static inline void iree_vm_abi_##types##_reset(struct_type* value) { \
+    memset(value, 0, sizeof(struct_type));                             \
+  }
+
+#define IREE_VM_ABI_FIELD_SIZE(type, member) sizeof(((type*)NULL)->member)
+#define IREE_VM_ABI_VLA_STRUCT_IMPL(types, vla_count, vla_field, struct_type, \
+                                    body)                                     \
+  typedef struct iree_vm_abi_##types##_s body IREE_ATTRIBUTE_PACKED           \
+      struct_type;                                                            \
+  static inline struct_type* iree_vm_abi_##types##_checked_deref(             \
+      iree_byte_span_t buffer) {                                              \
+    return IREE_LIKELY(buffer.data_length >= sizeof(struct_type)) &&          \
+                   IREE_LIKELY(                                               \
+                       buffer.data_length ==                                  \
+                       sizeof(struct_type) +                                  \
+                           ((const struct_type*)buffer.data)->vla_count *     \
+                               IREE_VM_ABI_FIELD_SIZE(struct_type,            \
+                                                      vla_field[0]))          \
+               ? (struct_type*)buffer.data                                    \
+               : NULL;                                                        \
+  }
+
+//===----------------------------------------------------------------------===//
+// Shim function declaration/definition and accessor utilities
+//===----------------------------------------------------------------------===//
+
+typedef iree_status_t(IREE_API_PTR* iree_vm_native_function_target2_t)(
+    iree_vm_stack_t* IREE_RESTRICT stack, void* IREE_RESTRICT module,
+    void* IREE_RESTRICT module_state, const void* IREE_RESTRICT args,
+    void* IREE_RESTRICT rets);
+
+#define IREE_VM_ABI_DECLARE_SHIM(arg_types, ret_types)                         \
+  iree_status_t iree_vm_shim_##arg_types##_##ret_types(                        \
+      iree_vm_stack_t* IREE_RESTRICT stack,                                    \
+      const iree_vm_function_call_t* IREE_RESTRICT call,                       \
+      iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module, \
+      void* IREE_RESTRICT module_state,                                        \
+      iree_vm_execution_result_t* IREE_RESTRICT out_result);
+
+#define IREE_VM_ABI_DEFINE_SHIM(arg_types, ret_types)                          \
+  iree_status_t iree_vm_shim_##arg_types##_##ret_types(                        \
+      iree_vm_stack_t* IREE_RESTRICT stack,                                    \
+      const iree_vm_function_call_t* IREE_RESTRICT call,                       \
+      iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module, \
+      void* IREE_RESTRICT module_state,                                        \
+      iree_vm_execution_result_t* IREE_RESTRICT out_result) {                  \
+    const IREE_VM_ABI_TYPE_NAME(arg_types)* args =                             \
+        iree_vm_abi_##arg_types##_checked_deref(call->arguments);              \
+    IREE_VM_ABI_TYPE_NAME(ret_types)* rets =                                   \
+        iree_vm_abi_##ret_types##_checked_deref(call->results);                \
+    if (IREE_UNLIKELY(!args || !rets)) {                                       \
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,                    \
+                              "argument/result signature mismatch");           \
+    }                                                                          \
+    iree_vm_abi_##ret_types##_reset(rets);                                     \
+    return target_fn(stack, module, module_state, args, rets);                 \
+  }
+
+#define IREE_VM_ABI_EXPORT(function_name, arg_types, ret_types)         \
+  static iree_status_t function_name(                                   \
+      iree_vm_stack_t* IREE_RESTRICT stack, void* IREE_RESTRICT module, \
+      iree_hal_module_state_t* IREE_RESTRICT state,                     \
+      IREE_VM_ABI_TYPE_NAME(arg_types) * IREE_RESTRICT args,            \
+      IREE_VM_ABI_TYPE_NAME(ret_types) * IREE_RESTRICT rets)
+
+// TODO(benvanik): special case when source type and target type match.
+#define IREE_VM_ABI_VLA_STACK_CAST(args, vla_count, vla_field, target_type, \
+                                   max_count, out_count, out_ptrs)          \
+  *(out_count) = (args)->vla_count;                                         \
+  if (IREE_UNLIKELY((args)->vla_count > (max_count))) {                     \
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "count %u > %u",      \
+                            (args)->vla_count, (uint32_t)(max_count));      \
+  }                                                                         \
+  *(out_ptrs) =                                                             \
+      (target_type*)iree_alloca((args)->vla_count * sizeof(target_type));   \
+  for (iree_host_size_t i = 0; i < (args)->vla_count; ++i) {                \
+    (*(out_ptrs))[i] = (target_type)((args)->vla_field[i].i0);              \
+  }
+
+#define IREE_VM_ABI_VLA_STACK_DEREF(args, vla_count, vla_field, ref_type,     \
+                                    max_count, out_count, out_ptrs)           \
+  *(out_count) = (args)->vla_count;                                           \
+  if (IREE_UNLIKELY((args)->vla_count > (max_count))) {                       \
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,                         \
+                            "count %u of " #ref_type " > %u",                 \
+                            (args)->vla_count, (uint32_t)(max_count));        \
+  }                                                                           \
+  *(out_ptrs) =                                                               \
+      (ref_type##_t**)iree_alloca((args)->vla_count * sizeof(ref_type##_t*)); \
+  for (iree_host_size_t i = 0; i < (args)->vla_count; ++i) {                  \
+    IREE_RETURN_IF_ERROR(                                                     \
+        ref_type##_check_deref((args)->vla_field[i].r0, &(*(out_ptrs))[i]));  \
+  }
+
+#define IREE_VM_ABI_VLA_HEAP_DEREF(args, vla_count, vla_field, ref_type,         \
+                                   host_allocator, out_count, out_ptrs)          \
+  *(out_count) = (args)->vla_count;                                              \
+  IREE_RETURN_IF_ERROR(iree_alloca((args)->vla_count * sizeof(ref_type##_t*));  \
+  for (iree_host_size_t i = 0; i < (args)->vla_count; ++i) {                   \
+    IREE_RETURN_IF_ERROR(                                                      \
+        ref_type##_check_deref((args)->vla_field[i].r0, &(*(out_ptrs))[i]));  \
+  }
+
+//===----------------------------------------------------------------------===//
+// Structures used for arguments and results.
+//===----------------------------------------------------------------------===//
+
+#if defined(IREE_COMPILER_MSVC)
+#pragma pack(push, 1)
+#endif  // IREE_COMPILER_MSVC
+
+// Special case for void (empty args/rets) as C structs can't have a 0 length.
+typedef struct {
+  int unused;
+} iree_vm_abi_v_t;
+static inline iree_vm_abi_v_t* iree_vm_abi_v_checked_deref(
+    iree_byte_span_t buffer) {
+  return (iree_vm_abi_v_t*)buffer.data;
+}
+static inline void iree_vm_abi_v_reset(iree_vm_abi_v_t* value) {}
+
+IREE_VM_ABI_FIXED_STRUCT(i, { int32_t i0; });
+
+IREE_VM_ABI_FIXED_STRUCT(ii, {
+  int32_t i0;
+  int32_t i1;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(iii, {
+  int32_t i0;
+  int32_t i1;
+  int32_t i2;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(iiii, {
+  int32_t i0;
+  int32_t i1;
+  int32_t i2;
+  int32_t i3;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(irii, {
+  int32_t i0;
+  iree_vm_ref_t r1;
+  int32_t i2;
+  int32_t i3;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(r, { iree_vm_ref_t r0; });
+
+IREE_VM_ABI_FIXED_STRUCT(rr, {
+  iree_vm_ref_t r0;
+  iree_vm_ref_t r1;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(ri, {
+  iree_vm_ref_t r0;
+  int32_t i1;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(ririi, {
+  iree_vm_ref_t r0;
+  int32_t i1;
+  iree_vm_ref_t r2;
+  int32_t i3;
+  int32_t i4;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rii, {
+  iree_vm_ref_t r0;
+  int32_t i1;
+  int32_t i2;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(riii, {
+  iree_vm_ref_t r0;
+  int32_t i1;
+  int32_t i2;
+  int32_t i3;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(riirii, {
+  iree_vm_ref_t r0;
+  int32_t i1;
+  int32_t i2;
+  iree_vm_ref_t r3;
+  int32_t i4;
+  int32_t i5;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rriii, {
+  iree_vm_ref_t r0;
+  iree_vm_ref_t r1;
+  int32_t i2;
+  int32_t i3;
+  int32_t i4;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rriiii, {
+  iree_vm_ref_t r0;
+  iree_vm_ref_t r1;
+  int32_t i2;
+  int32_t i3;
+  int32_t i4;
+  int32_t i5;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rriri, {
+  iree_vm_ref_t r0;
+  iree_vm_ref_t r1;
+  int32_t i2;
+  iree_vm_ref_t r3;
+  int32_t i4;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rririi, {
+  iree_vm_ref_t r0;
+  iree_vm_ref_t r1;
+  int32_t i2;
+  iree_vm_ref_t r3;
+  int32_t i4;
+  int32_t i5;
+});
+
+IREE_VM_ABI_VLA_STRUCT(rCiD, a1_count, a1, {
+  iree_vm_ref_t r0;
+  iree_vm_size_t a1_count;
+  iree_vm_abi_i_t a1[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rCrD, a1_count, a1, {
+  iree_vm_ref_t r0;
+  iree_vm_size_t a1_count;
+  iree_vm_abi_r_t a1[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(riCiD, a2_count, a2, {
+  iree_vm_ref_t r0;
+  int32_t i1;
+  iree_vm_size_t a2_count;
+  iree_vm_abi_i_t a2[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(riCrD, a2_count, a2, {
+  iree_vm_ref_t r0;
+  int32_t i1;
+  iree_vm_size_t a2_count;
+  iree_vm_abi_r_t a2[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(riiCriD, a3_count, a3, {
+  iree_vm_ref_t r0;
+  int32_t i1;
+  int32_t i2;
+  iree_vm_size_t a3_count;
+  iree_vm_abi_ri_t a3[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rirCrD, a3_count, a3, {
+  iree_vm_ref_t r0;
+  int32_t i1;
+  iree_vm_ref_t r2;
+  iree_vm_size_t a3_count;
+  iree_vm_abi_r_t a3[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rriCiD, a3_count, a3, {
+  iree_vm_ref_t r0;
+  iree_vm_ref_t r1;
+  int32_t i2;
+  iree_vm_size_t a3_count;
+  iree_vm_abi_i_t a3[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rrirCiD, a4_count, a4, {
+  iree_vm_ref_t r0;
+  iree_vm_ref_t r1;
+  int32_t i2;
+  iree_vm_ref_t r3;
+  iree_vm_size_t a4_count;
+  iree_vm_abi_i_t a4[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(riCiiiD, a2_count, a2, {
+  iree_vm_ref_t r0;
+  int32_t i1;
+  iree_vm_size_t a2_count;
+  iree_vm_abi_iii_t a2[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rrCiriiD, a2_count, a2, {
+  iree_vm_ref_t r0;
+  iree_vm_ref_t r1;
+  iree_vm_size_t a2_count;
+  iree_vm_abi_irii_t a2[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rriCiriiD, a3_count, a3, {
+  iree_vm_ref_t r0;
+  iree_vm_ref_t r1;
+  int32_t i2;
+  iree_vm_size_t a3_count;
+  iree_vm_abi_irii_t a3[0];
+});
+
+#if defined(IREE_COMPILER_MSVC)
+#pragma pack(pop)
+#endif  // IREE_COMPILER_MSVC
+
+//===----------------------------------------------------------------------===//
+// Shims for marshaling arguments and results
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_DECLARE_SHIM(irii, v);
+IREE_VM_ABI_DECLARE_SHIM(r, i);
+IREE_VM_ABI_DECLARE_SHIM(r, ii);
+IREE_VM_ABI_DECLARE_SHIM(r, iii);
+IREE_VM_ABI_DECLARE_SHIM(r, iiii);
+IREE_VM_ABI_DECLARE_SHIM(r, r);
+IREE_VM_ABI_DECLARE_SHIM(r, v);
+IREE_VM_ABI_DECLARE_SHIM(rCiD, i);
+IREE_VM_ABI_DECLARE_SHIM(rCrD, v);
+IREE_VM_ABI_DECLARE_SHIM(ri, i);
+IREE_VM_ABI_DECLARE_SHIM(ri, r);
+IREE_VM_ABI_DECLARE_SHIM(ri, v);
+IREE_VM_ABI_DECLARE_SHIM(riCiD, r);
+IREE_VM_ABI_DECLARE_SHIM(riCiiiD, r);
+IREE_VM_ABI_DECLARE_SHIM(riCrD, r);
+IREE_VM_ABI_DECLARE_SHIM(rii, i);
+IREE_VM_ABI_DECLARE_SHIM(rii, r);
+IREE_VM_ABI_DECLARE_SHIM(riii, r);
+IREE_VM_ABI_DECLARE_SHIM(riii, v);
+IREE_VM_ABI_DECLARE_SHIM(riirii, r);
+IREE_VM_ABI_DECLARE_SHIM(rirCrD, r);
+IREE_VM_ABI_DECLARE_SHIM(ririi, v);
+IREE_VM_ABI_DECLARE_SHIM(rr, i);
+IREE_VM_ABI_DECLARE_SHIM(rr, r);
+IREE_VM_ABI_DECLARE_SHIM(rr, v);
+IREE_VM_ABI_DECLARE_SHIM(rrCiriiD, r);
+IREE_VM_ABI_DECLARE_SHIM(rriCiD, v);
+IREE_VM_ABI_DECLARE_SHIM(rriCiriiD, v);
+IREE_VM_ABI_DECLARE_SHIM(rriii, v);
+IREE_VM_ABI_DECLARE_SHIM(rriiii, v);
+IREE_VM_ABI_DECLARE_SHIM(rrirCiD, v);
+IREE_VM_ABI_DECLARE_SHIM(rriri, v);
+IREE_VM_ABI_DECLARE_SHIM(rririi, v);
+IREE_VM_ABI_DECLARE_SHIM(v, i);
+IREE_VM_ABI_DECLARE_SHIM(v, r);
+IREE_VM_ABI_DECLARE_SHIM(v, v);
+
+#endif  // IREE_MODULES_HAL_SHIMS_H_
diff --git a/iree/tools/utils/vm_util.cc b/iree/tools/utils/vm_util.cc
index 03ff229..b758d8d 100644
--- a/iree/tools/utils/vm_util.cc
+++ b/iree/tools/utils/vm_util.cc
@@ -212,7 +212,7 @@
               "variant %d has value type %d but descriptor information %s", i,
               (int)variant.type.value_type, desc_str.c_str());
         }
-        auto* buffer_view = iree_hal_buffer_view_deref(&variant.ref);
+        auto* buffer_view = iree_hal_buffer_view_deref(variant.ref);
         if (!buffer_view) {
           return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                                   "failed dereferencing variant %d", i);
diff --git a/iree/vm/builtin_types.h b/iree/vm/builtin_types.h
index 4361a9f..c8144b5 100644
--- a/iree/vm/builtin_types.h
+++ b/iree/vm/builtin_types.h
@@ -31,6 +31,14 @@
   iree_vm_ref_destroy_t destroy;
 } iree_vm_ro_byte_buffer_t;
 
+// Returns the a string view referencing the given |value| buffer.
+static inline iree_string_view_t iree_vm_ro_byte_buffer_as_string(
+    const iree_vm_ro_byte_buffer_t* value) {
+  return value ? iree_make_string_view((const char*)value->data.data,
+                                       value->data.data_length)
+               : iree_string_view_empty();
+}
+
 // The built-in mutable buffer type.
 // This simply points at a span of memory. The memory could be owned (in which
 // case a destroy function must be provided) or unowned (NULL destroy function).
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index 32594c2..86c6509 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -827,7 +827,7 @@
     DISPATCH_OP(CORE, ListReserve, {
       bool list_is_move;
       iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
-      iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+      iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
       if (IREE_UNLIKELY(!list)) {
         return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
       }
@@ -838,7 +838,7 @@
     DISPATCH_OP(CORE, ListSize, {
       bool list_is_move;
       iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
-      iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+      iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
       if (IREE_UNLIKELY(!list)) {
         return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
       }
@@ -849,7 +849,7 @@
     DISPATCH_OP(CORE, ListResize, {
       bool list_is_move;
       iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
-      iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+      iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
       if (IREE_UNLIKELY(!list)) {
         return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
       }
@@ -860,7 +860,7 @@
     DISPATCH_OP(CORE, ListGetI32, {
       bool list_is_move;
       iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
-      iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+      iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
       if (IREE_UNLIKELY(!list)) {
         return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
       }
@@ -875,7 +875,7 @@
     DISPATCH_OP(CORE, ListSetI32, {
       bool list_is_move;
       iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
-      iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+      iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
       if (!list) {
         return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
       }
@@ -1328,7 +1328,7 @@
       DISPATCH_OP(EXT_I64, ListGetI64, {
         bool list_is_move;
         iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
-        iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+        iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
         if (IREE_UNLIKELY(!list)) {
           return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
         }
@@ -1343,7 +1343,7 @@
       DISPATCH_OP(EXT_I64, ListSetI64, {
         bool list_is_move;
         iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
-        iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+        iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
         if (IREE_UNLIKELY(!list)) {
           return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
         }
diff --git a/iree/vm/list.c b/iree/vm/list.c
index 17569d1..64fbdc8 100644
--- a/iree/vm/list.c
+++ b/iree/vm/list.c
@@ -494,7 +494,7 @@
   if (!iree_status_is_ok(iree_status_consume_code(status))) {
     return NULL;
   }
-  status = iree_vm_ref_check(&value, type_descriptor->type);
+  status = iree_vm_ref_check(value, type_descriptor->type);
   if (!iree_status_is_ok(iree_status_consume_code(status))) {
     return NULL;
   }
diff --git a/iree/vm/list_test.cc b/iree/vm/list_test.cc
index e17dd9d..4423cbc 100644
--- a/iree/vm/list_test.cc
+++ b/iree/vm/list_test.cc
@@ -146,8 +146,8 @@
   for (iree_host_size_t i = 0; i < 5; ++i) {
     iree_vm_ref_t ref_a{0};
     IREE_ASSERT_OK(iree_vm_list_get_ref_retain(list, i, &ref_a));
-    EXPECT_TRUE(test_a_isa(&ref_a));
-    auto* a = test_a_deref(&ref_a);
+    EXPECT_TRUE(test_a_isa(ref_a));
+    auto* a = test_a_deref(ref_a);
     EXPECT_EQ(i, a->data());
     iree_vm_ref_release(&ref_a);
   }
@@ -192,8 +192,8 @@
   for (iree_host_size_t i = 5; i < 10; ++i) {
     iree_vm_ref_t ref_a{0};
     IREE_ASSERT_OK(iree_vm_list_get_ref_retain(list, i, &ref_a));
-    EXPECT_TRUE(test_a_isa(&ref_a));
-    auto* a = test_a_deref(&ref_a);
+    EXPECT_TRUE(test_a_isa(ref_a));
+    auto* a = test_a_deref(ref_a);
     EXPECT_EQ(i, a->data());
     iree_vm_ref_release(&ref_a);
   }
diff --git a/iree/vm/native_module.c b/iree/vm/native_module.c
index 8e4b11d..5ad605d 100644
--- a/iree/vm/native_module.c
+++ b/iree/vm/native_module.c
@@ -26,10 +26,13 @@
   iree_vm_module_t base_interface;
 
   // Interface with optional user-provided function pointers.
-  // user_interface.self will contain the user's module pointer that must be
-  // passed to all functions.
   iree_vm_module_t user_interface;
 
+  // The self passed to user_interface functions. Will either be the value of
+  // user_interface.self when initialized and the base pointer of the base
+  // native module otherwise.
+  void* self;
+
   // Allocator this module was allocated with and must be freed with.
   iree_allocator_t allocator;
 
@@ -37,6 +40,10 @@
   const iree_vm_native_module_descriptor_t* descriptor;
 } iree_vm_native_module_t;
 
+IREE_API_EXPORT iree_host_size_t iree_vm_native_module_size() {
+  return sizeof(iree_vm_native_module_t);
+}
+
 #if defined(NDEBUG)
 static iree_status_t iree_vm_native_module_verify_descriptor(
     const iree_vm_native_module_descriptor_t* module_descriptor) {
@@ -69,17 +76,23 @@
   iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
 
   // Destroy the optional user-provided self.
-  if (module->user_interface.destroy) {
-    module->user_interface.destroy(module->user_interface.self);
+  if (module->self == module) {
+    iree_allocator_t allocator = module->allocator;
+    if (module->user_interface.destroy) {
+      module->user_interface.destroy(module->self);
+    }
+    iree_allocator_free(allocator, module);
+  } else {
+    if (module->user_interface.destroy) {
+      module->user_interface.destroy(module->self);
+    }
   }
-
-  iree_allocator_free(module->allocator, module);
 }
 
 static iree_string_view_t IREE_API_PTR iree_vm_native_module_name(void* self) {
   iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
   if (module->user_interface.name) {
-    return module->user_interface.name(module->user_interface.self);
+    return module->user_interface.name(module->self);
   }
   return module->descriptor->module_name;
 }
@@ -88,7 +101,7 @@
 iree_vm_native_module_signature(void* self) {
   iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
   if (module->user_interface.signature) {
-    return module->user_interface.signature(module->user_interface.self);
+    return module->user_interface.signature(module->self);
   }
   iree_vm_module_signature_t signature;
   memset(&signature, 0, sizeof(signature));
@@ -155,9 +168,8 @@
   if (out_name) memset(out_name, 0, sizeof(*out_name));
   if (out_signature) memset(out_signature, 0, sizeof(*out_signature));
   if (module->user_interface.get_function) {
-    return module->user_interface.get_function(module->user_interface.self,
-                                               linkage, ordinal, out_function,
-                                               out_name, out_signature);
+    return module->user_interface.get_function(
+        module->self, linkage, ordinal, out_function, out_name, out_signature);
   }
   switch (linkage) {
     case IREE_VM_FUNCTION_LINKAGE_IMPORT:
@@ -181,7 +193,7 @@
   iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
   if (module->user_interface.get_function_reflection_attr) {
     return module->user_interface.get_function_reflection_attr(
-        module->user_interface.self, linkage, ordinal, index, key, value);
+        module->self, linkage, ordinal, index, key, value);
   }
   // TODO(benvanik): implement native module reflection.
   return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
@@ -194,8 +206,8 @@
   iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
   memset(out_function, 0, sizeof(*out_function));
   if (module->user_interface.lookup_function) {
-    return module->user_interface.lookup_function(module->user_interface.self,
-                                                  linkage, name, out_function);
+    return module->user_interface.lookup_function(module->self, linkage, name,
+                                                  out_function);
   }
 
   if (IREE_UNLIKELY(linkage != IREE_VM_FUNCTION_LINKAGE_EXPORT)) {
@@ -234,8 +246,8 @@
   iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
   *out_module_state = NULL;
   if (module->user_interface.alloc_state) {
-    return module->user_interface.alloc_state(module->user_interface.self,
-                                              allocator, out_module_state);
+    return module->user_interface.alloc_state(module->self, allocator,
+                                              out_module_state);
   }
   // Default to no state.
   return iree_ok_status();
@@ -245,8 +257,7 @@
     void* self, iree_vm_module_state_t* module_state) {
   iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
   if (module->user_interface.free_state) {
-    module->user_interface.free_state(module->user_interface.self,
-                                      module_state);
+    module->user_interface.free_state(module->self, module_state);
     return;
   }
   // No-op in the default implementation.
@@ -260,9 +271,8 @@
     const iree_vm_function_signature_t* signature) {
   iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
   if (module->user_interface.resolve_import) {
-    return module->user_interface.resolve_import(module->user_interface.self,
-                                                 module_state, ordinal,
-                                                 function, signature);
+    return module->user_interface.resolve_import(module->self, module_state,
+                                                 ordinal, function, signature);
   }
   return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                           "native module does not support imports");
@@ -282,8 +292,8 @@
                             module->descriptor->export_count);
   }
   if (module->user_interface.begin_call) {
-    return module->user_interface.begin_call(module->user_interface.self, stack,
-                                             call, out_result);
+    return module->user_interface.begin_call(module->self, stack, call,
+                                             out_result);
   }
 
   // NOTE: VM stack is currently unused. We could stash things here for the
@@ -320,8 +330,7 @@
                                   iree_vm_execution_result_t* out_result) {
   iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
   if (module->user_interface.resume_call) {
-    return module->user_interface.resume_call(module->user_interface.self,
-                                              stack, out_result);
+    return module->user_interface.resume_call(module->self, stack, out_result);
   }
   return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                           "native module does not support resume");
@@ -361,11 +370,52 @@
   iree_vm_native_module_t* module = NULL;
   IREE_RETURN_IF_ERROR(iree_allocator_malloc(
       allocator, sizeof(iree_vm_native_module_t), (void**)&module));
+
+  iree_status_t status = iree_vm_native_module_initialize(
+      interface, module_descriptor, allocator, (iree_vm_module_t*)module);
+  if (!iree_status_is_ok(status)) {
+    iree_allocator_free(allocator, module);
+    return status;
+  }
+
+  *out_module = &module->base_interface;
+  return iree_ok_status();
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_vm_native_module_initialize(
+    const iree_vm_module_t* interface,
+    const iree_vm_native_module_descriptor_t* module_descriptor,
+    iree_allocator_t allocator, iree_vm_module_t* base_module) {
+  IREE_ASSERT_ARGUMENT(interface);
+  IREE_ASSERT_ARGUMENT(module_descriptor);
+  IREE_ASSERT_ARGUMENT(base_module);
+  iree_vm_native_module_t* module = (iree_vm_native_module_t*)base_module;
+
+  if (IREE_UNLIKELY(!interface->begin_call) &&
+      IREE_UNLIKELY(!module_descriptor->functions)) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "native modules must provide call support or function pointers");
+  } else if (IREE_UNLIKELY(!interface->begin_call) &&
+             IREE_UNLIKELY(module_descriptor->export_count !=
+                           module_descriptor->function_count)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "native modules using the default call support "
+                            "must have 1:1 exports:function pointers");
+  }
+
+  // Perform some optional debug-only verification of the descriptor.
+  // Since native modules are designed to be compiled in we don't need to do
+  // this in release builds.
+  IREE_RETURN_IF_ERROR(
+      iree_vm_native_module_verify_descriptor(module_descriptor));
   module->allocator = allocator;
   module->descriptor = module_descriptor;
 
   // TODO(benvanik): version interface and copy only valid bytes.
   memcpy(&module->user_interface, interface, sizeof(*interface));
+  module->self =
+      module->user_interface.self ? module->user_interface.self : module;
 
   // Base interface that routes through our thunks.
   iree_vm_module_initialize(&module->base_interface, module);
@@ -383,6 +433,5 @@
   module->base_interface.begin_call = iree_vm_native_module_begin_call;
   module->base_interface.resume_call = iree_vm_native_module_resume_call;
 
-  *out_module = &module->base_interface;
   return iree_ok_status();
 }
diff --git a/iree/vm/native_module.h b/iree/vm/native_module.h
index 448abff..a1dd52e 100644
--- a/iree/vm/native_module.h
+++ b/iree/vm/native_module.h
@@ -102,6 +102,10 @@
   const iree_vm_reflection_attr_t* reflection_attrs;
 } iree_vm_native_module_descriptor_t;
 
+// Returns the size, in bytes, of the allocation required for native modules.
+// Callers may allocate more memory if they need additional storage.
+IREE_API_EXPORT iree_host_size_t iree_vm_native_module_size();
+
 // Creates a new native module with the metadata tables in |descriptor|.
 // These tables will be used for reflection and function lookup, and the
 // provided function pointers will be called when state needs to be managed or
@@ -119,6 +123,11 @@
     const iree_vm_native_module_descriptor_t* module_descriptor,
     iree_allocator_t allocator, iree_vm_module_t** out_module);
 
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_vm_native_module_initialize(
+    const iree_vm_module_t* interface,
+    const iree_vm_native_module_descriptor_t* module_descriptor,
+    iree_allocator_t allocator, iree_vm_module_t* module);
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
diff --git a/iree/vm/ref.c b/iree/vm/ref.c
index c8a69e5..58613a1 100644
--- a/iree/vm/ref.c
+++ b/iree/vm/ref.c
@@ -149,12 +149,6 @@
   return iree_ok_status();
 }
 
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_vm_ref_check(iree_vm_ref_t* ref, iree_vm_ref_type_t type) {
-  return ref->type == type ? iree_ok_status()
-                           : iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
-}
-
 IREE_API_EXPORT void IREE_API_CALL iree_vm_ref_retain(iree_vm_ref_t* ref,
                                                       iree_vm_ref_t* out_ref) {
   if (ref != out_ref && ref->ptr != out_ref->ptr) {
diff --git a/iree/vm/ref.h b/iree/vm/ref.h
index 1ef7fd4..5e05f67 100644
--- a/iree/vm/ref.h
+++ b/iree/vm/ref.h
@@ -175,12 +175,11 @@
     void* ptr, iree_vm_ref_type_t type, iree_vm_ref_t* out_ref);
 
 // Checks that the given reference-counted pointer |ref| is of |type|.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_vm_ref_check(iree_vm_ref_t* ref, iree_vm_ref_type_t type);
-
-#define IREE_VM_DEREF_OR_RETURN(value_type, value, ref, type) \
-  IREE_RETURN_IF_ERROR(iree_vm_ref_check(ref, type));         \
-  value_type* value = (value_type*)(ref)->ptr;
+static inline iree_status_t iree_vm_ref_check(const iree_vm_ref_t ref,
+                                              iree_vm_ref_type_t type) {
+  return ref.type == type ? iree_ok_status()
+                          : iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
+}
 
 // Retains the reference-counted pointer |ref|.
 // |out_ref| will be released if it already contains a reference.
@@ -267,13 +266,13 @@
 #define IREE_VM_DECLARE_TYPE_ADAPTERS(name, T)                             \
   IREE_API_EXPORT iree_vm_ref_t IREE_API_CALL name##_retain_ref(T* value); \
   IREE_API_EXPORT iree_vm_ref_t IREE_API_CALL name##_move_ref(T* value);   \
-  IREE_API_EXPORT T* IREE_API_CALL name##_deref(iree_vm_ref_t* ref);       \
+  IREE_API_EXPORT T* IREE_API_CALL name##_deref(const iree_vm_ref_t ref);  \
   IREE_API_EXPORT iree_status_t IREE_API_CALL name##_check_deref(          \
-      iree_vm_ref_t* ref, T** out_ptr);                                    \
+      const iree_vm_ref_t ref, T** out_ptr);                               \
   IREE_API_EXPORT const iree_vm_ref_type_descriptor_t* IREE_API_CALL       \
       name##_get_descriptor();                                             \
-  inline bool name##_isa(iree_vm_ref_t* ref) {                             \
-    return name##_get_descriptor()->type == ref->type;                     \
+  static inline bool name##_isa(const iree_vm_ref_t ref) {                 \
+    return name##_get_descriptor()->type == ref.type;                      \
   }                                                                        \
   IREE_API_EXPORT iree_vm_ref_type_t IREE_API_CALL name##_type_id();       \
   IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T)
@@ -290,18 +289,18 @@
     iree_vm_ref_wrap_assign(value, name##_descriptor.type, &ref);           \
     return ref;                                                             \
   }                                                                         \
-  IREE_API_EXPORT T* IREE_API_CALL name##_deref(iree_vm_ref_t* ref) {       \
+  IREE_API_EXPORT T* IREE_API_CALL name##_deref(const iree_vm_ref_t ref) {  \
     iree_status_t status = iree_vm_ref_check(ref, name##_descriptor.type);  \
-    if (!iree_status_is_ok(status)) {                                       \
-      iree_status_ignore(status);                                           \
+    if (IREE_UNLIKELY(!iree_status_is_ok(status))) {                        \
+      IREE_IGNORE_ERROR(status);                                            \
       return NULL;                                                          \
     }                                                                       \
-    return (T*)ref->ptr;                                                    \
+    return (T*)ref.ptr;                                                     \
   }                                                                         \
   IREE_API_EXPORT iree_status_t IREE_API_CALL name##_check_deref(           \
-      iree_vm_ref_t* ref, T** out_ptr) {                                    \
+      const iree_vm_ref_t ref, T** out_ptr) {                               \
     IREE_RETURN_IF_ERROR(iree_vm_ref_check(ref, name##_descriptor.type));   \
-    *out_ptr = (T*)ref->ptr;                                                \
+    *out_ptr = (T*)ref.ptr;                                                 \
     return iree_ok_status();                                                \
   }                                                                         \
   IREE_API_EXPORT const iree_vm_ref_type_descriptor_t* IREE_API_CALL        \
diff --git a/iree/vm/ref_test.cc b/iree/vm/ref_test.cc
index 5be26b5..ac47cf3 100644
--- a/iree/vm/ref_test.cc
+++ b/iree/vm/ref_test.cc
@@ -161,18 +161,18 @@
 // Checking null refs is fine.
 TEST(VMRefTest, CheckNull) {
   iree_vm_ref_t null_ref = {0};
-  IREE_EXPECT_OK(iree_vm_ref_check(&null_ref, IREE_VM_REF_TYPE_NULL));
+  IREE_EXPECT_OK(iree_vm_ref_check(null_ref, IREE_VM_REF_TYPE_NULL));
   IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
                         ::iree::Status(iree_vm_ref_check(
-                            &null_ref, static_cast<iree_vm_ref_type_t>(1234))));
+                            null_ref, static_cast<iree_vm_ref_type_t>(1234))));
 }
 
 // Tests type checks.
 TEST(VMRefTest, Check) {
   iree_vm_ref_t a_ref = MakeRef<A>("AType");
-  IREE_EXPECT_OK(iree_vm_ref_check(&a_ref, A::kTypeID));
+  IREE_EXPECT_OK(iree_vm_ref_check(a_ref, A::kTypeID));
   IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
-                        ::iree::Status(iree_vm_ref_check(&a_ref, B::kTypeID)));
+                        ::iree::Status(iree_vm_ref_check(a_ref, B::kTypeID)));
   iree_vm_ref_release(&a_ref);
 }
 
@@ -248,7 +248,7 @@
   iree_vm_ref_t a_ref_0 = MakeRef<A>("AType");
   iree_vm_ref_t a_ref_1 = {0};
   iree_vm_ref_retain_or_move(/*is_move=*/1, &a_ref_0, &a_ref_1);
-  IREE_EXPECT_OK(iree_vm_ref_check(&a_ref_0, IREE_VM_REF_TYPE_NULL));
+  IREE_EXPECT_OK(iree_vm_ref_check(a_ref_0, IREE_VM_REF_TYPE_NULL));
   iree_vm_ref_release(&a_ref_1);
 }
 
@@ -269,7 +269,7 @@
 TEST(VMRefTest, RetainOrMoveMovingIntoSelf) {
   iree_vm_ref_t a_ref = MakeRef<A>("AType");
   iree_vm_ref_retain_or_move(/*is_move=*/1, &a_ref, &a_ref);
-  IREE_EXPECT_OK(iree_vm_ref_check(&a_ref, A::kTypeID));
+  IREE_EXPECT_OK(iree_vm_ref_check(a_ref, A::kTypeID));
   iree_vm_ref_release(&a_ref);
 }
 
@@ -413,7 +413,7 @@
   iree_vm_ref_t a_ref_0 = MakeRef<A>("AType");
   iree_vm_ref_t a_ref_1 = {0};
   iree_vm_ref_move(&a_ref_0, &a_ref_1);
-  IREE_EXPECT_OK(iree_vm_ref_check(&a_ref_0, IREE_VM_REF_TYPE_NULL));
+  IREE_EXPECT_OK(iree_vm_ref_check(a_ref_0, IREE_VM_REF_TYPE_NULL));
   iree_vm_ref_release(&a_ref_1);
 }
 
@@ -421,7 +421,7 @@
 TEST(VMRefTest, MovingIntoSelf) {
   iree_vm_ref_t a_ref = MakeRef<A>("AType");
   iree_vm_ref_move(&a_ref, &a_ref);
-  IREE_EXPECT_OK(iree_vm_ref_check(&a_ref, A::kTypeID));
+  IREE_EXPECT_OK(iree_vm_ref_check(a_ref, A::kTypeID));
   iree_vm_ref_release(&a_ref);
 }
 
diff --git a/iree/vm/value.h b/iree/vm/value.h
index d5aa167..5ac1b8e 100644
--- a/iree/vm/value.h
+++ b/iree/vm/value.h
@@ -21,6 +21,11 @@
 extern "C" {
 #endif  // __cplusplus
 
+// TODO(benvanik): support variable size in modules. vm.imports would need index
+// type and we'd have to make sure all native modules used this size type. It
+// would be a compiler runtime flag and runtime compile flag.
+typedef int32_t iree_vm_size_t;
+
 // Defines the type of a primitive value.
 typedef enum {
   // Not a value type.