Merge pull request #4767 from google/benvanik-hal-c
A 30KB drop, removes all allocs from the common command buffer recording path, and almost all operations are faster to boot (no extraneous ref counting). There's a lot less C++ magic too: 1k loc of template magic replaced with 140 loc of stupid simple macros and autogeneratable boilerplate.
Fixes #4678.
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
index eaa23c2..a6f1e2e 100644
--- a/bindings/python/pyiree/rt/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -591,7 +591,7 @@
switch (desc.type) {
case RawSignatureParser::Type::kBuffer: {
iree_hal_buffer_view_t* buffer_view =
- iree_hal_buffer_view_deref(&f_result.ref);
+ iree_hal_buffer_view_deref(f_result.ref);
if (!buffer_view) {
throw RaiseValueError(
"Could not deref result buffer view (wrong type?)");
@@ -793,8 +793,8 @@
if (iree_vm_variant_is_value(variant)) {
results.push_back("i32=" + std::to_string(variant.i32));
} else if (iree_vm_variant_is_ref(variant) &&
- iree_hal_buffer_view_isa(&variant.ref)) {
- auto buffer_view = iree_hal_buffer_view_deref(&variant.ref);
+ iree_hal_buffer_view_isa(variant.ref)) {
+ auto buffer_view = iree_hal_buffer_view_deref(variant.ref);
std::string result_str(4096, '\0');
iree_status_t status;
diff --git a/bindings/python/pyiree/rt/vm.cc b/bindings/python/pyiree/rt/vm.cc
index 8091ee2..af719f4 100644
--- a/bindings/python/pyiree/rt/vm.cc
+++ b/bindings/python/pyiree/rt/vm.cc
@@ -213,13 +213,13 @@
absl::StrAppend(&s, variant.i32);
} else if (iree_vm_variant_is_ref(variant)) {
// Pretty print a subset of ABI impacting known types.
- if (iree_hal_buffer_isa(&variant.ref)) {
- auto* hal_buffer = iree_hal_buffer_deref(&variant.ref);
+ if (iree_hal_buffer_isa(variant.ref)) {
+ auto* hal_buffer = iree_hal_buffer_deref(variant.ref);
assert(hal_buffer);
absl::StrAppend(&s, "HalBuffer(",
iree_hal_buffer_byte_length(hal_buffer), ")");
- } else if (iree_hal_buffer_view_isa(&variant.ref)) {
- auto hal_bv = iree_hal_buffer_view_deref(&variant.ref);
+ } else if (iree_hal_buffer_view_isa(variant.ref)) {
+ auto hal_bv = iree_hal_buffer_view_deref(variant.ref);
absl::StrAppend(&s, "HalBufferView(");
absl::InlinedVector<int32_t, 5> shape(
iree_hal_buffer_view_shape_rank(hal_bv));
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 9c40cdb..fa8f56f 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -317,11 +317,14 @@
"/Gy"
"/DNDEBUG"
"/DIREE_STATUS_MODE=0"
+ "/PDB"
+ "/Os"
+ "/Oy"
)
iree_select_compiler_opts(IREE_SIZE_OPTIMIZED_DEFAULT_LINKOPTS
MSVC_OR_CLANG_CL
- "/LTCG"
- "/opt:ref,icf"
+ "-LTCG"
+ "-opt:ref,icf"
)
# TODO(#898): make this only impact the runtime (IREE_RUNTIME_DEFAULT_...).
set(IREE_DEFAULT_COPTS
diff --git a/iree/base/api.c b/iree/base/api.c
index f7c641b..e118a5c 100644
--- a/iree/base/api.c
+++ b/iree/base/api.c
@@ -595,7 +595,7 @@
#if IREE_STATUS_FEATURES == 0
// More advanced status code features like source location and messages are
// disabled. All statuses are just the codes.
- return (iree_status_t)(code & IREE_STATUS_CODE_MASK);
+ return iree_status_from_code(code);
#else
// No-op for OK statuses; we won't get these from the macros but may be called
// with this from marshaling code.
diff --git a/iree/base/api.h b/iree/base/api.h
index d41ecc4..bdc2b3f 100644
--- a/iree/base/api.h
+++ b/iree/base/api.h
@@ -204,6 +204,7 @@
// Size, in bytes, of a buffer on the host.
typedef size_t iree_host_size_t;
+#define IREE_MAX_HOST_SIZE SIZE_MAX
// Size, in bytes, of a buffer on devices.
typedef uint64_t iree_device_size_t;
@@ -470,7 +471,8 @@
((iree_status_code_t)(((uintptr_t)(value)) & IREE_STATUS_CODE_MASK))
// Macros to check the value of a status code.
-#define iree_status_is_ok(value) ((uintptr_t)(value) == IREE_STATUS_OK)
+#define iree_status_is_ok(value) \
+ IREE_LIKELY((uintptr_t)(value) == IREE_STATUS_OK)
#define iree_status_is_cancelled(value) \
(iree_status_code(value) == IREE_STATUS_CANCELLED)
#define iree_status_is_unknown(value) \
@@ -578,13 +580,14 @@
#define IREE_STATUS_IMPL_MAKE_(code, ...) \
(iree_status_t)(uintptr_t)((code)&IREE_STATUS_CODE_MASK)
#undef IREE_STATUS_IMPL_RETURN_IF_API_ERROR_
-#define IREE_STATUS_IMPL_RETURN_IF_API_ERROR_(var, expr, ...) \
- iree_status_t var = (expr); \
+#define IREE_STATUS_IMPL_RETURN_IF_API_ERROR_(var, ...) \
+ iree_status_t var = (IREE_STATUS_IMPL_IDENTITY_( \
+ IREE_STATUS_IMPL_IDENTITY_(IREE_STATUS_IMPL_GET_EXPR_)(__VA_ARGS__))); \
if (IREE_UNLIKELY(var)) return var;
#undef IREE_STATUS_IMPL_RETURN_AND_EVAL_IF_API_ERROR_
-#define IREE_STATUS_IMPL_RETURN_AND_EVAL_IF_API_ERROR_(tail_expr, var, expr, \
- ...) \
- iree_status_t var = (expr); \
+#define IREE_STATUS_IMPL_RETURN_AND_EVAL_IF_API_ERROR_(tail_expr, var, ...) \
+ iree_status_t var = (IREE_STATUS_IMPL_IDENTITY_( \
+ IREE_STATUS_IMPL_IDENTITY_(IREE_STATUS_IMPL_GET_EXPR_)(__VA_ARGS__))); \
if (IREE_UNLIKELY(var)) { \
(tail_expr); \
return var; \
diff --git a/iree/base/attributes.h b/iree/base/attributes.h
index 9482781..178030f 100644
--- a/iree/base/attributes.h
+++ b/iree/base/attributes.h
@@ -149,4 +149,14 @@
#define IREE_UNLIKELY(x) (x)
#endif // IREE_HAVE_ATTRIBUTE(likely)
+//===----------------------------------------------------------------------===//
+// IREE_ATTRIBUTE_PACKED
+//===----------------------------------------------------------------------===//
+
+#if IREE_HAVE_ATTRIBUTE(packed) || (defined(__GNUC__) && !defined(__clang__))
+#define IREE_ATTRIBUTE_PACKED __attribute__((__packed__))
+#else
+#define IREE_ATTRIBUTE_PACKED
+#endif // IREE_HAVE_ATTRIBUTE(packed)
+
#endif // IREE_BASE_ATTRIBUTES_H_
diff --git a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
index ca6a8a9..7c6be4e 100644
--- a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
+++ b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
@@ -429,13 +429,7 @@
output << "#include \"" << include << "\"\n";
};
- printInclude("iree/vm/context.h");
- printInclude("iree/vm/instance.h");
- printInclude("iree/vm/native_module.h");
- printInclude("iree/vm/ops.h");
- printInclude("iree/vm/ref.h");
- printInclude("iree/vm/shims.h");
- printInclude("iree/vm/stack.h");
+ printInclude("iree/vm/api.h");
output << "\n";
printModuleComment(moduleOp, output);
diff --git a/iree/modules/hal/BUILD b/iree/modules/hal/BUILD
index 6c6d538..8e30060 100644
--- a/iree/modules/hal/BUILD
+++ b/iree/modules/hal/BUILD
@@ -20,17 +20,21 @@
cc_library(
name = "hal",
- srcs = ["hal_module.cc"],
- hdrs = ["hal_module.h"],
+ srcs = [
+ "hal_module.c",
+ "shims.c",
+ "shims.h",
+ ],
+ hdrs = [
+ "hal_module.h",
+ ],
+ textual_hdrs = [
+ "exports.inl",
+ ],
deps = [
"//iree/base:api",
"//iree/base:tracing",
"//iree/hal:api",
"//iree/vm",
- "//iree/vm:cc",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/types:span",
],
)
diff --git a/iree/modules/hal/CMakeLists.txt b/iree/modules/hal/CMakeLists.txt
index 36800aa..6e602e2 100644
--- a/iree/modules/hal/CMakeLists.txt
+++ b/iree/modules/hal/CMakeLists.txt
@@ -19,17 +19,16 @@
hal
HDRS
"hal_module.h"
+ TEXTUAL_HDRS
+ "exports.inl"
SRCS
- "hal_module.cc"
+ "hal_module.c"
+ "shims.c"
+ "shims.h"
DEPS
- absl::core_headers
- absl::inlined_vector
- absl::memory
- absl::span
iree::base::api
iree::base::tracing
iree::hal::api
iree::vm
- iree::vm::cc
PUBLIC
)
diff --git a/iree/modules/hal/exports.inl b/iree/modules/hal/exports.inl
new file mode 100644
index 0000000..881977b
--- /dev/null
+++ b/iree/modules/hal/exports.inl
@@ -0,0 +1,84 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===----------------------------------------------------------------------===//
+//
+// ██ ██ █████ ██████ ███ ██ ██ ███ ██ ██████
+// ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
+// ██ █ ██ ███████ ██████ ██ ██ ██ ██ ██ ██ ██ ██ ███
+// ██ ███ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
+// ███ ███ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
+//
+//===----------------------------------------------------------------------===//
+//
+// This file will be auto generated from hal.imports.mlir in the future; for
+// now it's modified by hand but with strict alphabetical sorting required.
+// The order of these functions must be sorted ascending by name in a way
+// compatible with iree_string_view_compare.
+//
+// Users are meant to `#define EXPORT_FN` to be able to access the information.
+// #define EXPORT_FN(name, arg_type, ret_type, target_fn)
+
+// clang-format off
+
+EXPORT_FN("allocator.allocate", iree_hal_module_allocator_allocate, riii, r)
+EXPORT_FN("allocator.wrap.byte_buffer", iree_hal_module_allocator_wrap_byte_buffer, riirii, r)
+
+EXPORT_FN("buffer.allocator", iree_hal_module_buffer_allocator, r, r)
+EXPORT_FN("buffer.fill", iree_hal_module_buffer_fill, riii, v)
+EXPORT_FN("buffer.load", iree_hal_module_buffer_load, rii, i)
+EXPORT_FN("buffer.store", iree_hal_module_buffer_store, irii, v)
+EXPORT_FN("buffer.subspan", iree_hal_module_buffer_subspan, rii, r)
+
+EXPORT_FN("buffer_view.buffer", iree_hal_module_buffer_view_buffer, r, r)
+EXPORT_FN("buffer_view.byte_length", iree_hal_module_buffer_view_byte_length, r, i)
+EXPORT_FN("buffer_view.create", iree_hal_module_buffer_view_create, riCiD, r)
+EXPORT_FN("buffer_view.dim", iree_hal_module_buffer_view_dim, ri, i)
+EXPORT_FN("buffer_view.element_type", iree_hal_module_buffer_view_element_type, r, i)
+EXPORT_FN("buffer_view.rank", iree_hal_module_buffer_view_rank, r, i)
+EXPORT_FN("buffer_view.trace", iree_hal_module_buffer_view_trace, rCrD, v)
+
+EXPORT_FN("command_buffer.begin", iree_hal_module_command_buffer_begin, r, v)
+EXPORT_FN("command_buffer.bind_descriptor_set", iree_hal_module_command_buffer_bind_descriptor_set, rrirCiD, v)
+EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, rririi, v)
+EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, rii, r)
+EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiii, v)
+EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriri, v)
+EXPORT_FN("command_buffer.end", iree_hal_module_command_buffer_end, r, v)
+EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v)
+EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rriii, v)
+EXPORT_FN("command_buffer.push_constants", iree_hal_module_command_buffer_push_constants, rriCiD, v)
+EXPORT_FN("command_buffer.push_descriptor_set", iree_hal_module_command_buffer_push_descriptor_set, rriCiriiD, v)
+
+EXPORT_FN("descriptor_set.create", iree_hal_module_descriptor_set_create, rrCiriiD, r)
+
+EXPORT_FN("descriptor_set_layout.create", iree_hal_module_descriptor_set_layout_create, riCiiiD, r)
+
+EXPORT_FN("device.allocator", iree_hal_module_device_allocator, r, r)
+EXPORT_FN("device.match.id", iree_hal_module_device_match_id, rr, i)
+
+EXPORT_FN("ex.shared_device", iree_hal_module_ex_shared_device, v, r)
+EXPORT_FN("ex.submit_and_wait", iree_hal_module_ex_submit_and_wait, rr, v)
+
+EXPORT_FN("executable.create", iree_hal_module_executable_create, rirCrD, r)
+
+EXPORT_FN("executable_layout.create", iree_hal_module_executable_layout_create, riCrD, r)
+
+EXPORT_FN("semaphore.await", iree_hal_module_semaphore_await, ri, i)
+EXPORT_FN("semaphore.create", iree_hal_module_semaphore_create, ri, r)
+EXPORT_FN("semaphore.fail", iree_hal_module_semaphore_fail, r, i)
+EXPORT_FN("semaphore.query", iree_hal_module_semaphore_query, r, ii)
+EXPORT_FN("semaphore.signal", iree_hal_module_semaphore_signal, ri, v)
+
+// clang-format on
diff --git a/iree/modules/hal/hal_module.c b/iree/modules/hal/hal_module.c
new file mode 100644
index 0000000..06c0baf
--- /dev/null
+++ b/iree/modules/hal/hal_module.c
@@ -0,0 +1,992 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/modules/hal/hal_module.h"
+
+#include <inttypes.h>
+#include <stdio.h>
+
+#include "iree/base/api.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/api.h"
+#include "iree/modules/hal/shims.h"
+#include "iree/vm/api.h"
+
+// Limit the number of bindings we pass down through the HAL. This can be tuned
+// in the future but right now guards the stack from blowing up during calls.
+#define IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT ((iree_host_size_t)32)
+
+//===----------------------------------------------------------------------===//
+// Type registration
+//===----------------------------------------------------------------------===//
+
+static iree_vm_ref_type_descriptor_t iree_hal_allocator_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_buffer_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_buffer_view_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_command_buffer_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_layout_descriptor =
+ {0};
+static iree_vm_ref_type_descriptor_t iree_hal_device_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_event_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_executable_descriptor = {0};
+static iree_vm_ref_type_descriptor_t iree_hal_executable_layout_descriptor = {
+ 0};
+static iree_vm_ref_type_descriptor_t iree_hal_semaphore_descriptor = {0};
+
+#define IREE_VM_REGISTER_HAL_C_TYPE(type, name, destroy_fn, descriptor) \
+ descriptor.type_name = iree_make_cstring_view(name); \
+ descriptor.offsetof_counter = offsetof(iree_hal_resource_t, ref_count); \
+ descriptor.destroy = (iree_vm_ref_destroy_t)destroy_fn; \
+ IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor));
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_module_register_types() {
+ static bool has_registered = false;
+ if (has_registered) return iree_ok_status();
+
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_allocator_t, "hal.allocator",
+ iree_hal_allocator_destroy,
+ iree_hal_allocator_descriptor);
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_buffer_t, "hal.buffer",
+ iree_hal_buffer_destroy,
+ iree_hal_buffer_descriptor);
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_buffer_view_t, "hal.buffer_view",
+ iree_hal_buffer_view_destroy,
+ iree_hal_buffer_view_descriptor);
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_command_buffer_t, "hal.command_buffer",
+ iree_hal_command_buffer_destroy,
+ iree_hal_command_buffer_descriptor);
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_descriptor_set_t, "hal.descriptor_set",
+ iree_hal_descriptor_set_destroy,
+ iree_hal_descriptor_set_descriptor);
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_descriptor_set_layout_t,
+ "hal.descriptor_set_layout",
+ iree_hal_descriptor_set_layout_destroy,
+ iree_hal_descriptor_set_layout_descriptor);
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_device_t, "hal.device",
+ iree_hal_device_destroy,
+ iree_hal_device_descriptor);
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_event_t, "hal.event",
+ iree_hal_event_destroy,
+ iree_hal_event_descriptor);
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_executable_t, "hal.executable",
+ iree_hal_executable_destroy,
+ iree_hal_executable_descriptor);
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_executable_layout_t,
+ "hal.executable_layout",
+ iree_hal_executable_layout_destroy,
+ iree_hal_executable_layout_descriptor);
+ IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_semaphore_t, "hal.semaphore",
+ iree_hal_semaphore_destroy,
+ iree_hal_semaphore_descriptor);
+
+ has_registered = true;
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// Type wrappers
+//===----------------------------------------------------------------------===//
+
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_allocator, iree_hal_allocator_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_buffer, iree_hal_buffer_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_buffer_view, iree_hal_buffer_view_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_command_buffer,
+ iree_hal_command_buffer_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set,
+ iree_hal_descriptor_set_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set_layout,
+ iree_hal_descriptor_set_layout_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_device, iree_hal_device_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_event, iree_hal_event_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable, iree_hal_executable_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable_layout,
+ iree_hal_executable_layout_t);
+IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_semaphore, iree_hal_semaphore_t);
+
+//===----------------------------------------------------------------------===//
+// Module type definitions
+//===----------------------------------------------------------------------===//
+
+typedef struct {
+ iree_allocator_t host_allocator;
+ iree_hal_device_t* shared_device;
+ // TODO(benvanik): types.
+} iree_hal_module_t;
+
+#define IREE_HAL_MODULE_CAST(module) \
+ (iree_hal_module_t*)((uint8_t*)(module) + iree_vm_native_module_size());
+
+typedef struct {
+ iree_allocator_t host_allocator;
+ iree_hal_device_t* shared_device;
+ iree_hal_executable_cache_t* executable_cache;
+
+ iree_vm_list_t* deferred_releases;
+} iree_hal_module_state_t;
+
+static void IREE_API_PTR iree_hal_module_destroy(void* base_module) {
+ iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
+ iree_hal_device_release(module->shared_device);
+}
+
+static iree_status_t IREE_API_PTR
+iree_hal_module_alloc_state(void* self, iree_allocator_t host_allocator,
+ iree_vm_module_state_t** out_module_state) {
+ iree_hal_module_t* module = IREE_HAL_MODULE_CAST(self);
+ iree_hal_module_state_t* state = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_allocator_malloc(host_allocator, sizeof(*state), (void**)&state));
+ memset(state, 0, sizeof(*state));
+ state->host_allocator = host_allocator;
+ state->shared_device = module->shared_device;
+ iree_hal_device_retain(state->shared_device);
+
+ IREE_RETURN_IF_ERROR(iree_vm_list_create(
+ /*element_type=*/NULL, /*initial_capacity=*/512, state->host_allocator,
+ &state->deferred_releases));
+
+ IREE_RETURN_IF_ERROR(iree_hal_executable_cache_create(
+ state->shared_device, iree_string_view_empty(),
+ &state->executable_cache));
+
+ *out_module_state = (iree_vm_module_state_t*)state;
+ return iree_ok_status();
+}
+
+static void IREE_API_PTR
+iree_hal_module_free_state(void* self, iree_vm_module_state_t* module_state) {
+ iree_hal_module_state_t* state = (iree_hal_module_state_t*)module_state;
+ iree_vm_list_release(state->deferred_releases);
+ iree_hal_executable_cache_release(state->executable_cache);
+ iree_hal_device_release(state->shared_device);
+ iree_allocator_free(state->host_allocator, state);
+}
+
+//===----------------------------------------------------------------------===//
+// Experimental APIs
+//===----------------------------------------------------------------------===//
+// NOTE: Ex* APIs are experimental and likely to be removed soon. Modules
+// using these APIs are not forward compatible.
+
+IREE_VM_ABI_EXPORT(iree_hal_module_ex_shared_device, v, r) {
+ rets->r0 = iree_hal_device_retain_ref(state->shared_device);
+ return iree_ok_status();
+}
+
+void iree_hal_module_ex_defer_release(iree_hal_module_state_t* state,
+ const iree_vm_ref_t value) {
+ IREE_IGNORE_ERROR(
+ iree_vm_list_push_ref_retain(state->deferred_releases, &value));
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_ex_submit_and_wait, rr, v) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r1, &command_buffer));
+
+ // Temporary semaphore we'll signal from 0->1.
+ iree_hal_semaphore_t* semaphore = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(device, 0ull, &semaphore));
+
+ // Batch with our single command buffer.
+ iree_hal_submission_batch_t batch;
+ memset(&batch, 0, sizeof(batch));
+
+ iree_hal_command_buffer_t* command_buffer_ptrs[] = {command_buffer};
+ batch.command_buffer_count = IREE_ARRAYSIZE(command_buffer_ptrs);
+ batch.command_buffers = command_buffer_ptrs;
+
+ iree_hal_semaphore_t* signal_semaphore_ptrs[] = {semaphore};
+ uint64_t signal_semaphore_values[] = {1ull};
+ batch.signal_semaphores.count = IREE_ARRAYSIZE(signal_semaphore_ptrs);
+ batch.signal_semaphores.semaphores = signal_semaphore_ptrs;
+ batch.signal_semaphores.payload_values = signal_semaphore_values;
+
+ iree_status_t status = iree_hal_device_queue_submit(
+ device, IREE_HAL_COMMAND_CATEGORY_ANY, 0, 1, &batch);
+ if (!iree_status_is_ok(status)) {
+ iree_hal_semaphore_release(semaphore);
+ return status;
+ }
+
+ // Block and wait for the semaphore to be signaled (or fail).
+ status = iree_hal_semaphore_wait_with_deadline(semaphore, 1ull,
+ IREE_TIME_INFINITE_FUTURE);
+ if (!iree_status_is_ok(status)) {
+ iree_hal_semaphore_release(semaphore);
+ return status;
+ }
+
+ // Drop all pending deferred releases (references to everything in flight).
+ // This will be replaced with resource sets in the future that are attached to
+ // each command buffer.
+ IREE_RETURN_IF_ERROR(iree_vm_list_resize(state->deferred_releases, 0));
+
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_allocator_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_allocator_allocate, riii, r) {
+ iree_hal_allocator_t* allocator = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r0, &allocator));
+ iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i1;
+ iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i2;
+ iree_vm_size_t allocation_size = (iree_vm_size_t)args->i3;
+
+ iree_hal_buffer_t* buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
+ allocator, memory_types, buffer_usage, allocation_size, &buffer));
+ rets->r0 = iree_hal_buffer_move_ref(buffer);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_allocator_wrap_byte_buffer, riirii, r) {
+ iree_hal_allocator_t* allocator = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r0, &allocator));
+ iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i1;
+ iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i2;
+ iree_vm_ro_byte_buffer_t* source = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_ro_byte_buffer_check_deref(args->r3, &source));
+ iree_vm_size_t offset = (iree_vm_size_t)args->i4;
+ iree_vm_size_t length = (iree_vm_size_t)args->i5;
+
+ // TODO(benvanik): wrap when supported.
+
+ iree_host_size_t buffer_length = source->data.data_length;
+ if (length == -1) {
+ length = buffer_length;
+ }
+ if (length < 0 || offset < 0 || offset > buffer_length ||
+ offset + length > buffer_length) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "byte range out of bounds (requested %d-%d of available %zu)", offset,
+ (offset + length - 1), buffer_length);
+ }
+
+ iree_hal_buffer_t* buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_allocator_allocate_buffer(allocator, memory_types, buffer_usage,
+ length, &buffer),
+ "failed to allocate buffer of length %d", length);
+
+ iree_status_t status =
+ iree_hal_buffer_write_data(buffer, 0, source->data.data + offset, length);
+ if (iree_status_is_ok(status)) {
+ rets->r0 = iree_hal_buffer_move_ref(buffer);
+ } else {
+ iree_hal_buffer_release(buffer);
+ }
+ return status;
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_buffer_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_allocator, r, r) {
+ iree_hal_buffer_t* buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &buffer));
+ rets->r0 = iree_hal_allocator_retain_ref(iree_hal_buffer_allocator(buffer));
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_subspan, rii, r) {
+ iree_hal_buffer_t* source_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
+ iree_vm_size_t source_offset = (iree_vm_size_t)args->i1;
+ iree_vm_size_t length = (iree_vm_size_t)args->i2;
+
+ iree_hal_buffer_t* subspan_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_subspan(source_buffer, source_offset, length,
+ &subspan_buffer),
+ "invalid subspan of an existing buffer (source_offset=%d, length=%d)",
+ source_offset, length);
+ rets->r0 = iree_hal_buffer_move_ref(subspan_buffer);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_fill, riii, v) {
+ // DEPRECATED: will be removed in future versions. Use command buffers.
+ iree_hal_buffer_t* target_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &target_buffer));
+ iree_vm_size_t target_offset = (iree_vm_size_t)args->i1;
+ iree_vm_size_t length = (iree_vm_size_t)args->i2;
+ uint32_t pattern = (uint32_t)args->i3;
+
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_fill(target_buffer, target_offset,
+ length, &pattern, sizeof(pattern)),
+ "fill range failed (target_offset=%d, length=%d)",
+ target_offset, length);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_load, rii, i) {
+ iree_hal_buffer_t* source_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
+ iree_vm_size_t source_offset = (iree_vm_size_t)args->i1;
+ iree_vm_size_t length = (iree_vm_size_t)args->i2;
+
+ uint32_t target_buffer = 0;
+ if (length > sizeof(target_buffer)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "load length byte count %d exceeds max", length);
+ }
+
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_read_data(source_buffer, source_offset,
+ &target_buffer, length));
+
+ rets->i0 = target_buffer;
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_store, irii, v) {
+ int32_t value = args->i0;
+ iree_hal_buffer_t* target_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &target_buffer));
+ iree_vm_size_t target_offset = (iree_vm_size_t)args->i2;
+ iree_vm_size_t length = (iree_vm_size_t)args->i3;
+
+ if (length > sizeof(value)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "store length byte count %d exceeds max", length);
+ } else if (target_offset + length >
+ iree_hal_buffer_byte_length(target_buffer)) {
+ return iree_make_status(
+ IREE_STATUS_OUT_OF_RANGE,
+ "store out of bounds (target_offset=%d, length=%d into max %" PRIu64
+ ")",
+ target_offset, length, iree_hal_buffer_byte_length(target_buffer));
+ }
+
+ return iree_hal_buffer_write_data(target_buffer, target_offset, &value,
+ length);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_buffer_view_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_create, riCiD, r) {
+ iree_hal_buffer_t* source_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
+ iree_hal_element_type_t element_type = (iree_hal_element_type_t)args->i1;
+ iree_host_size_t shape_rank = 0;
+ iree_hal_dim_t* shape_dims = NULL;
+ IREE_VM_ABI_VLA_STACK_CAST(args, a2_count, a2, iree_hal_dim_t, 128,
+ &shape_rank, &shape_dims);
+
+ iree_hal_buffer_view_t* buffer_view = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+ source_buffer, element_type, shape_dims, shape_rank, &buffer_view));
+ rets->r0 = iree_hal_buffer_view_move_ref(buffer_view);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_buffer, r, r) {
+ iree_hal_buffer_view_t* buffer_view = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
+ rets->r0 =
+ iree_hal_buffer_retain_ref(iree_hal_buffer_view_buffer(buffer_view));
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_byte_length, r, i) {
+ iree_hal_buffer_view_t* buffer_view = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
+ rets->i0 = (iree_vm_size_t)iree_hal_buffer_view_byte_length(buffer_view);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_element_type, r, i) {
+ iree_hal_buffer_view_t* buffer_view = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
+ rets->i0 = (uint32_t)iree_hal_buffer_view_element_type(buffer_view);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_rank, r, i) {
+ iree_hal_buffer_view_t* buffer_view = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
+ rets->i0 = (iree_vm_size_t)iree_hal_buffer_view_shape_rank(buffer_view);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_dim, ri, i) {
+ iree_hal_buffer_view_t* buffer_view = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
+ iree_vm_size_t index = (iree_vm_size_t)args->i1;
+ rets->i0 = (iree_vm_size_t)iree_hal_buffer_view_shape_dim(buffer_view, index);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_trace, rCrD, v) {
+ iree_vm_ro_byte_buffer_t* key = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_ro_byte_buffer_check_deref(args->r0, &key));
+ iree_string_view_t key_str = iree_vm_ro_byte_buffer_as_string(key);
+
+ fprintf(stderr, "=== %.*s ===\n", (int)key_str.size, key_str.data);
+ for (iree_host_size_t i = 0; i < args->a1_count; ++i) {
+ iree_hal_buffer_view_t* buffer_view = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_view_check_deref(args->a1[i].r0, &buffer_view));
+
+ // NOTE: this export is for debugging only and a no-op in min-size builds.
+ // We heap-alloc here because at the point this export is used performance
+ // is not a concern.
+
+ // Query total length (excluding NUL terminator).
+ iree_host_size_t result_length = 0;
+ iree_status_t status = iree_hal_buffer_view_format(buffer_view, SIZE_MAX, 0,
+ NULL, &result_length);
+ if (!iree_status_is_out_of_range(status)) {
+ return status;
+ }
+ ++result_length; // include NUL
+
+ // Allocate scratch heap memory to contain the result and format into it.
+ char* result_str = NULL;
+ IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+ state->host_allocator, result_length, (void**)&result_str));
+ status = iree_hal_buffer_view_format(buffer_view, SIZE_MAX, result_length,
+ result_str, &result_length);
+ if (iree_status_is_ok(status)) {
+ fprintf(stderr, "%.*s\n", (int)result_length, result_str);
+ }
+ iree_allocator_free(state->host_allocator, result_str);
+ IREE_RETURN_IF_ERROR(status);
+ }
+ fprintf(stderr, "\n");
+
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_command_buffer_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_create, rii, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_command_buffer_mode_t modes =
+ (iree_hal_command_buffer_mode_t)args->i1;
+ iree_hal_command_category_t command_categories =
+ (iree_hal_command_category_t)args->i2;
+
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
+ device, modes, command_categories, &command_buffer));
+ rets->r0 = iree_hal_command_buffer_move_ref(command_buffer);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_begin, r, v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+
+ return iree_hal_command_buffer_begin(command_buffer);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_end, r, v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+
+ return iree_hal_command_buffer_end(command_buffer);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_execution_barrier, riii, v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+ iree_hal_execution_stage_t source_stage_mask =
+ (iree_hal_execution_stage_t)args->i1;
+ iree_hal_execution_stage_t target_stage_mask =
+ (iree_hal_execution_stage_t)args->i2;
+ iree_hal_execution_barrier_flags_t flags =
+ (iree_hal_execution_barrier_flags_t)args->i3;
+
+ // TODO(benvanik): decode barriers.
+ iree_hal_memory_barrier_t global_barrier;
+ global_barrier.source_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE;
+ global_barrier.target_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_READ;
+
+ return iree_hal_command_buffer_execution_barrier(
+ command_buffer, source_stage_mask, target_stage_mask, flags, 1,
+ &global_barrier, 0, NULL);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_fill_buffer, rriii, v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+ iree_hal_buffer_t* target_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &target_buffer));
+ iree_vm_size_t target_offset = (iree_vm_size_t)args->i2;
+ iree_vm_size_t length = (iree_vm_size_t)args->i3;
+ uint32_t pattern = (uint32_t)args->i4;
+
+ iree_hal_module_ex_defer_release(state, args->r1);
+
+ return iree_hal_command_buffer_fill_buffer(command_buffer, target_buffer,
+ target_offset, length, &pattern,
+ sizeof(pattern));
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_copy_buffer, rririi, v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+ iree_hal_buffer_t* source_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &source_buffer));
+ iree_vm_size_t source_offset = (iree_vm_size_t)args->i2;
+ iree_hal_buffer_t* target_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r3, &target_buffer));
+ iree_vm_size_t target_offset = (iree_vm_size_t)args->i4;
+ iree_vm_size_t length = (iree_vm_size_t)args->i5;
+
+ iree_hal_module_ex_defer_release(state, args->r1);
+ iree_hal_module_ex_defer_release(state, args->r3);
+
+ return iree_hal_command_buffer_copy_buffer(command_buffer, source_buffer,
+ source_offset, target_buffer,
+ target_offset, length);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_push_constants, rriCiD, v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+ iree_hal_executable_layout_t* executable_layout = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_executable_layout_check_deref(args->r1, &executable_layout));
+ iree_vm_size_t offset = (iree_vm_size_t)args->i2;
+ iree_host_size_t value_count = args->a3_count;
+ const uint32_t* values = (const uint32_t*)&args->a3[0].i0;
+
+ return iree_hal_command_buffer_push_constants(
+ command_buffer, executable_layout, offset * sizeof(uint32_t), values,
+ value_count * sizeof(uint32_t));
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_push_descriptor_set,
+ rriCiriiD, v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+ iree_hal_executable_layout_t* executable_layout = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_executable_layout_check_deref(args->r1, &executable_layout));
+ iree_vm_size_t set = args->i2;
+
+ iree_host_size_t binding_count = args->a3_count;
+ if (IREE_UNLIKELY(binding_count >
+ IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "binding count %zu > %zu",
+ binding_count,
+ IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT);
+ }
+ iree_hal_descriptor_set_binding_t* bindings =
+ (iree_hal_descriptor_set_binding_t*)iree_alloca(
+ binding_count * sizeof(iree_hal_descriptor_set_binding_t));
+ for (iree_host_size_t i = 0; i < binding_count; ++i) {
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref(args->a3[i].r1, &bindings[i].buffer));
+ bindings[i].binding = (uint32_t)args->a3[i].i0;
+ bindings[i].offset = (iree_device_size_t)args->a3[i].i2;
+ bindings[i].length = (iree_device_size_t)args->a3[i].i3;
+ iree_hal_module_ex_defer_release(state, args->a3[i].r1);
+ }
+
+ return iree_hal_command_buffer_push_descriptor_set(
+ command_buffer, executable_layout, set, binding_count, bindings);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_bind_descriptor_set, rrirCiD,
+ v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+ iree_hal_executable_layout_t* executable_layout = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_executable_layout_check_deref(args->r1, &executable_layout));
+ int32_t set = args->i2;
+ iree_hal_descriptor_set_t* descriptor_set = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_descriptor_set_check_deref(args->r3, &descriptor_set));
+ iree_host_size_t dynamic_offset_count = 0;
+ iree_device_size_t* dynamic_offsets = NULL;
+ IREE_VM_ABI_VLA_STACK_CAST(args, a4_count, a4, iree_device_size_t, 64,
+ &dynamic_offset_count, &dynamic_offsets);
+
+ iree_hal_module_ex_defer_release(state, args->r3);
+
+ return iree_hal_command_buffer_bind_descriptor_set(
+ command_buffer, executable_layout, set, descriptor_set,
+ dynamic_offset_count, dynamic_offsets);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch, rriiii, v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+ iree_hal_executable_t* executable = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_executable_check_deref(args->r1, &executable));
+ uint32_t entry_point = (uint32_t)args->i2;
+ uint32_t workgroup_x = (uint32_t)args->i3;
+ uint32_t workgroup_y = (uint32_t)args->i4;
+ uint32_t workgroup_z = (uint32_t)args->i5;
+
+ iree_hal_module_ex_defer_release(state, args->r1);
+
+ return iree_hal_command_buffer_dispatch(command_buffer, executable,
+ entry_point, workgroup_x, workgroup_y,
+ workgroup_z);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, rriri, v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+ iree_hal_executable_t* executable = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_executable_check_deref(args->r1, &executable));
+ uint32_t entry_point = (uint32_t)args->i2;
+ iree_hal_buffer_t* workgroups_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref(args->r3, &workgroups_buffer));
+ iree_vm_size_t workgroups_offset = (iree_vm_size_t)args->i4;
+
+ iree_hal_module_ex_defer_release(state, args->r1);
+ iree_hal_module_ex_defer_release(state, args->r3);
+
+ return iree_hal_command_buffer_dispatch_indirect(
+ command_buffer, executable, entry_point, workgroups_buffer,
+ workgroups_offset);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_descriptor_set_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_descriptor_set_create, rrCiriiD, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_descriptor_set_layout_t* set_layout = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_descriptor_set_layout_check_deref(args->r1, &set_layout));
+
+ iree_host_size_t binding_count = args->a2_count;
+ if (IREE_UNLIKELY(binding_count >
+ IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "binding count %zu > %zu",
+ binding_count,
+ IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT);
+ }
+ iree_hal_descriptor_set_binding_t* bindings =
+ (iree_hal_descriptor_set_binding_t*)iree_alloca(
+ binding_count * sizeof(iree_hal_descriptor_set_binding_t));
+ for (iree_host_size_t i = 0; i < binding_count; ++i) {
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref(args->a2[i].r1, &bindings[i].buffer));
+ bindings[i].binding = (uint32_t)args->a2[i].i0;
+ bindings[i].offset = (iree_device_size_t)args->a2[i].i2;
+ bindings[i].length = (iree_device_size_t)args->a2[i].i3;
+ }
+
+ iree_hal_descriptor_set_t* descriptor_set = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_create(
+ device, set_layout, binding_count, bindings, &descriptor_set));
+ rets->r0 = iree_hal_descriptor_set_move_ref(descriptor_set);
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_descriptor_set_layout
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_descriptor_set_layout_create, riCiiiD, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_descriptor_set_layout_usage_type_t usage_type =
+ (iree_hal_descriptor_set_layout_usage_type_t)args->i1;
+
+ iree_host_size_t binding_count = args->a2_count;
+ if (IREE_UNLIKELY(binding_count >
+ IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "binding count %zu > %zu",
+ binding_count,
+ IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT);
+ }
+ iree_hal_descriptor_set_layout_binding_t* bindings =
+ (iree_hal_descriptor_set_layout_binding_t*)iree_alloca(
+ binding_count * sizeof(iree_hal_descriptor_set_layout_binding_t));
+ for (iree_host_size_t i = 0; i < binding_count; ++i) {
+ bindings[i].binding = (uint32_t)args->a2[i].i0;
+ bindings[i].type = (iree_hal_descriptor_type_t)args->a2[i].i1;
+ bindings[i].access = (iree_hal_memory_access_t)args->a2[i].i2;
+ }
+
+ iree_hal_descriptor_set_layout_t* descriptor_set_layout = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_layout_create(
+ device, usage_type, binding_count, bindings, &descriptor_set_layout));
+ rets->r0 = iree_hal_descriptor_set_layout_move_ref(descriptor_set_layout);
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_device_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_device_allocator, r, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ rets->r0 = iree_hal_allocator_retain_ref(iree_hal_device_allocator(device));
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_device_match_id, rr, i) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_vm_ro_byte_buffer_t* pattern = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_ro_byte_buffer_check_deref(args->r1, &pattern));
+ iree_string_view_t pattern_str = iree_vm_ro_byte_buffer_as_string(pattern);
+
+ iree_string_view_t device_id = iree_hal_device_id(device);
+ rets->i0 = iree_string_view_match_pattern(device_id, pattern_str) ? 1 : 0;
+ return iree_ok_status();
+}
+
+//===--------------------------------------------------------------------===//
+// iree_hal_executable_t
+//===--------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_executable_create, rirCrD, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_executable_format_t executable_format =
+ (iree_hal_executable_format_t)args->i1;
+ iree_vm_ro_byte_buffer_t* executable_data = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_ro_byte_buffer_check_deref(args->r2, &executable_data));
+ iree_host_size_t executable_layout_count = args->a3_count;
+ iree_hal_executable_layout_t** executable_layouts = NULL;
+ IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+ state->host_allocator,
+ executable_layout_count * sizeof(executable_layouts[0]),
+ (void**)&executable_layouts));
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < executable_layout_count; ++i) {
+ status = iree_hal_executable_layout_check_deref(args->a3[i].r0,
+ &executable_layouts[i]);
+ if (!iree_status_is_ok(status)) break;
+ }
+
+ iree_hal_executable_t* executable = NULL;
+ if (iree_status_is_ok(status)) {
+ iree_hal_executable_spec_t spec;
+ iree_hal_executable_spec_initialize(&spec);
+ spec.executable_format = executable_format;
+ spec.executable_data = executable_data->data;
+ spec.executable_layout_count = executable_layout_count;
+ spec.executable_layouts = executable_layouts;
+ status = iree_hal_executable_cache_prepare_executable(
+ state->executable_cache, &spec, &executable);
+ }
+
+ iree_allocator_free(state->host_allocator, executable_layouts);
+ rets->r0 = iree_hal_executable_move_ref(executable);
+ return status;
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_executable_layout_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_executable_layout_create, riCrD, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ int32_t push_constants = (int32_t)args->i1;
+ iree_host_size_t set_layout_count = 0;
+ iree_hal_descriptor_set_layout_t** set_layouts = NULL;
+ IREE_VM_ABI_VLA_STACK_DEREF(args, a2_count, a2,
+ iree_hal_descriptor_set_layout, 32,
+ &set_layout_count, &set_layouts);
+
+ iree_hal_executable_layout_t* executable_layout = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_executable_layout_create(
+ device, push_constants, set_layout_count, set_layouts,
+ &executable_layout));
+ rets->r0 = iree_hal_executable_layout_move_ref(executable_layout);
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hal_semaphore_t
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_create, ri, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ uint32_t initial_value = (uint32_t)args->i1;
+
+ iree_hal_semaphore_t* semaphore = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_semaphore_create(device, initial_value, &semaphore));
+ rets->r0 = iree_hal_semaphore_move_ref(semaphore);
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_query, r, ii) {
+ iree_hal_semaphore_t* semaphore = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
+
+ uint64_t value = 0;
+ iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value);
+ rets->i0 = iree_status_consume_code(query_status);
+ rets->i1 = (uint32_t)value;
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_signal, ri, v) {
+ iree_hal_semaphore_t* semaphore = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
+ uint32_t new_value = (uint32_t)args->i1;
+
+ return iree_hal_semaphore_signal(semaphore, new_value);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_fail, ri, v) {
+ iree_hal_semaphore_t* semaphore = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
+ iree_status_code_t status_code =
+ (iree_status_code_t)(args->i1 & IREE_STATUS_CODE_MASK);
+
+ iree_hal_semaphore_fail(semaphore, iree_make_status(status_code));
+ return iree_ok_status();
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_await, ri, i) {
+ iree_hal_semaphore_t* semaphore = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
+ uint64_t new_value = (uint32_t)args->i1;
+
+ // TODO(benvanik): coroutine magic.
+ iree_status_t status = iree_hal_semaphore_wait_with_deadline(
+ semaphore, new_value, IREE_TIME_INFINITE_FUTURE);
+ if (iree_status_is_ok(status)) {
+ rets->i0 = 0;
+ } else if (iree_status_is_deadline_exceeded(status)) {
+ // Propagate deadline exceeded back to the VM.
+ rets->i0 = (int32_t)iree_status_consume_code(status);
+ }
+ return status;
+}
+
+//===----------------------------------------------------------------------===//
+// VM module interface implementation
+//===----------------------------------------------------------------------===//
+
+// NOTE: this must match the ordering of the iree_hal_module_exports_ table.
+static const iree_vm_native_function_ptr_t iree_hal_module_funcs_[] = {
+#define EXPORT_FN(name, target_fn, arg_types, ret_types) \
+ { \
+ .shim = (iree_vm_native_function_shim_t) \
+ iree_vm_shim_##arg_types##_##ret_types, \
+ .target = (iree_vm_native_function_target_t)(target_fn), \
+ },
+#include "iree/modules/hal/exports.inl"
+#undef EXPORT_FN
+};
+
+// NOTE: 0 length, but can't express that in C.
+static const iree_vm_native_import_descriptor_t iree_hal_module_imports_[1];
+
+static const iree_vm_native_export_descriptor_t iree_hal_module_exports_[] = {
+#define EXPORT_FN(name, target_fn, arg_types, ret_types) \
+ { \
+ .local_name = iree_string_view_literal(name), \
+ .calling_convention = \
+ iree_string_view_literal("0" #arg_types "_" #ret_types), \
+ .reflection_attr_count = 0, \
+ .reflection_attrs = NULL, \
+ },
+#include "iree/modules/hal/exports.inl"
+#undef EXPORT_FN
+};
+static_assert(IREE_ARRAYSIZE(iree_hal_module_funcs_) ==
+ IREE_ARRAYSIZE(iree_hal_module_exports_),
+ "function pointer table must be 1:1 with exports");
+
+static const iree_vm_native_module_descriptor_t iree_hal_module_descriptor_ = {
+ .module_name = iree_string_view_literal("hal"),
+ .import_count = 0, // workaround for 0-length C struct
+ .imports = iree_hal_module_imports_,
+ .export_count = IREE_ARRAYSIZE(iree_hal_module_exports_),
+ .exports = iree_hal_module_exports_,
+ .function_count = IREE_ARRAYSIZE(iree_hal_module_funcs_),
+ .functions = iree_hal_module_funcs_,
+ .reflection_attr_count = 0,
+ .reflection_attrs = NULL,
+};
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_module_create(iree_hal_device_t* device, iree_allocator_t allocator,
+ iree_vm_module_t** out_module) {
+ IREE_ASSERT_ARGUMENT(device);
+ IREE_ASSERT_ARGUMENT(out_module);
+ *out_module = NULL;
+
+ // Setup the interface with the functions we implement ourselves. Any function
+ // we omit will be handled by the base native module.
+ static const iree_vm_module_t interface = {
+ .destroy = iree_hal_module_destroy,
+ .alloc_state = iree_hal_module_alloc_state,
+ .free_state = iree_hal_module_free_state,
+ };
+
+ // Allocate shared module state.
+ iree_host_size_t total_size =
+ iree_vm_native_module_size() + sizeof(iree_hal_module_t);
+ iree_vm_module_t* base_module = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_allocator_malloc(allocator, total_size, (void**)&base_module));
+ memset(base_module, 0, total_size);
+ iree_status_t status = iree_vm_native_module_initialize(
+ &interface, &iree_hal_module_descriptor_, allocator, base_module);
+ if (!iree_status_is_ok(status)) {
+ iree_allocator_free(allocator, base_module);
+ return status;
+ }
+
+ iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
+ module->host_allocator = allocator;
+ module->shared_device = device;
+ iree_hal_device_retain(module->shared_device);
+
+ *out_module = base_module;
+ return iree_ok_status();
+}
diff --git a/iree/modules/hal/hal_module.cc b/iree/modules/hal/hal_module.cc
deleted file mode 100644
index 9f3c164..0000000
--- a/iree/modules/hal/hal_module.cc
+++ /dev/null
@@ -1,833 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/modules/hal/hal_module.h"
-
-#include <inttypes.h>
-
-#include "absl/base/macros.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/memory/memory.h"
-#include "absl/types/span.h"
-#include "iree/base/api.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/api.h"
-#include "iree/vm/native_module_cc.h"
-
-//===----------------------------------------------------------------------===//
-// Type registration
-//===----------------------------------------------------------------------===//
-
-static iree_vm_ref_type_descriptor_t iree_hal_allocator_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_buffer_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_buffer_view_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_command_buffer_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_layout_descriptor =
- {0};
-static iree_vm_ref_type_descriptor_t iree_hal_device_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_event_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_executable_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_executable_cache_descriptor = {0};
-static iree_vm_ref_type_descriptor_t iree_hal_executable_layout_descriptor = {
- 0};
-static iree_vm_ref_type_descriptor_t iree_hal_semaphore_descriptor = {0};
-
-#define IREE_VM_REGISTER_HAL_C_TYPE(type, name, destroy_fn, descriptor) \
- descriptor.type_name = iree_make_cstring_view(name); \
- descriptor.offsetof_counter = offsetof(iree_hal_resource_t, ref_count); \
- descriptor.destroy = (iree_vm_ref_destroy_t)destroy_fn; \
- IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor));
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_module_register_types() {
- static bool has_registered = false;
- if (has_registered) return iree_ok_status();
-
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_allocator_t, "hal.allocator",
- iree_hal_allocator_destroy,
- iree_hal_allocator_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_buffer_t, "hal.buffer",
- iree_hal_buffer_destroy,
- iree_hal_buffer_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_buffer_view_t, "hal.buffer_view",
- iree_hal_buffer_view_destroy,
- iree_hal_buffer_view_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_command_buffer_t, "hal.command_buffer",
- iree_hal_command_buffer_destroy,
- iree_hal_command_buffer_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_descriptor_set_t, "hal.descriptor_set",
- iree_hal_descriptor_set_destroy,
- iree_hal_descriptor_set_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_descriptor_set_layout_t,
- "hal.descriptor_set_layout",
- iree_hal_descriptor_set_layout_destroy,
- iree_hal_descriptor_set_layout_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_device_t, "hal.device",
- iree_hal_device_destroy,
- iree_hal_device_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_event_t, "hal.event",
- iree_hal_event_destroy,
- iree_hal_event_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_executable_t, "hal.executable",
- iree_hal_executable_destroy,
- iree_hal_executable_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(
- iree_hal_executable_cache_t, "hal.executable_cache",
- iree_hal_executable_cache_destroy, iree_hal_executable_cache_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_executable_layout_t,
- "hal.executable_layout",
- iree_hal_executable_layout_destroy,
- iree_hal_executable_layout_descriptor);
- IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_semaphore_t, "hal.semaphore",
- iree_hal_semaphore_destroy,
- iree_hal_semaphore_descriptor);
-
- has_registered = true;
- return iree_ok_status();
-}
-
-//===----------------------------------------------------------------------===//
-// Type wrappers
-//===----------------------------------------------------------------------===//
-
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_allocator, iree_hal_allocator_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_buffer, iree_hal_buffer_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_buffer_view, iree_hal_buffer_view_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_command_buffer,
- iree_hal_command_buffer_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set,
- iree_hal_descriptor_set_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set_layout,
- iree_hal_descriptor_set_layout_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_device, iree_hal_device_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_event, iree_hal_event_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable, iree_hal_executable_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable_cache,
- iree_hal_executable_cache_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable_layout,
- iree_hal_executable_layout_t);
-IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_semaphore, iree_hal_semaphore_t);
-
-namespace iree {
-namespace hal {
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Module type definitions
-//===----------------------------------------------------------------------===//
-
-class HALModuleState final {
- public:
- HALModuleState(iree_allocator_t allocator, iree_hal_device_t* shared_device)
- : allocator_(allocator), shared_device_(shared_device) {
- iree_hal_device_retain(shared_device_);
- }
-
- ~HALModuleState() {
- for (auto& ref : deferred_releases_) {
- iree_vm_ref_release(&ref);
- }
- deferred_releases_.clear();
- iree_hal_executable_cache_release(executable_cache_);
- iree_hal_device_release(shared_device_);
- }
-
- Status Initialize() {
- IREE_TRACE_SCOPE0("HALModuleState::Initialize");
-
- IREE_RETURN_IF_ERROR(iree_hal_executable_cache_create(
- shared_device_, iree_string_view_empty(), &executable_cache_));
-
- return OkStatus();
- }
-
- //===--------------------------------------------------------------------===//
- // Experimental APIs
- //===--------------------------------------------------------------------===//
- // NOTE: Ex* APIs are experimental and likely to be removed soon. Modules
- // using these APIs are not forward compatible.
-
- StatusOr<vm::ref<iree_hal_device_t>> ExSharedDevice() {
- return vm::retain_ref(shared_device_);
- }
-
- template <typename T>
- void ExDeferRelease(const vm::ref<T>& value) {
- deferred_releases_.push_back({0});
- iree_vm_ref_retain((iree_vm_ref_t*)&value, &deferred_releases_.back());
- }
-
- Status ExSubmitAndWait(
- const vm::ref<iree_hal_device_t>& device,
- const vm::ref<iree_hal_command_buffer_t>& command_buffer) {
- IREE_TRACE_SCOPE0("HALModuleState::ExSubmitAndWait");
-
- vm::ref<iree_hal_semaphore_t> semaphore;
- IREE_RETURN_IF_ERROR(
- iree_hal_semaphore_create(device.get(), 0ull, &semaphore));
-
- iree_hal_submission_batch_t batch;
- memset(&batch, 0, sizeof(batch));
- batch.command_buffer_count = 1;
- iree_hal_command_buffer_t* command_buffer_ptrs[] = {command_buffer.get()};
- batch.command_buffers = command_buffer_ptrs;
- batch.signal_semaphores.count = 1;
- iree_hal_semaphore_t* semaphore_ptrs[] = {semaphore.get()};
- batch.signal_semaphores.semaphores = semaphore_ptrs;
- uint64_t signal_value = 1ull;
- batch.signal_semaphores.payload_values = &signal_value;
- IREE_RETURN_IF_ERROR(iree_hal_device_queue_submit(
- device.get(), IREE_HAL_COMMAND_CATEGORY_ANY, 0, 1, &batch));
-
- IREE_RETURN_IF_ERROR(iree_hal_semaphore_wait_with_deadline(
- semaphore.get(), 1ull, IREE_TIME_INFINITE_FUTURE));
-
- {
- IREE_TRACE_SCOPE0("HALModuleState::DeferredReleases");
- for (auto& ref : deferred_releases_) {
- iree_vm_ref_release(&ref);
- }
- deferred_releases_.clear();
- }
-
- return OkStatus();
- }
-
- //===--------------------------------------------------------------------===//
- // iree_hal_allocator_t
- //===--------------------------------------------------------------------===//
-
- StatusOr<vm::ref<iree_hal_buffer_t>> AllocatorAllocate(
- const vm::ref<iree_hal_allocator_t>& allocator,
- iree_hal_memory_type_t memory_types, iree_hal_buffer_usage_t buffer_usage,
- int32_t allocation_size) {
- IREE_TRACE_SCOPE0("HALModuleState::AllocatorAllocate");
- vm::ref<iree_hal_buffer_t> buffer;
- IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
- allocator.get(), memory_types, buffer_usage, allocation_size, &buffer));
- return std::move(buffer);
- }
-
- StatusOr<vm::ref<iree_hal_buffer_t>> AllocatorWrapByteBuffer(
- const vm::ref<iree_hal_allocator_t>& allocator,
- iree_hal_memory_type_t memory_types, iree_hal_buffer_usage_t buffer_usage,
- const vm::ref<iree_vm_ro_byte_buffer_t>& source, int32_t offset,
- int32_t length) {
- IREE_TRACE_SCOPE0("HALModuleState::AllocatorWrapByteBuffer");
-
- // TODO(benvanik): wrap when supported.
-
- buffer_usage |= IREE_HAL_BUFFER_USAGE_MAPPING;
-
- size_t buffer_length = source->data.data_length;
- if (length == -1) {
- length = static_cast<size_t>(buffer_length);
- }
- if (length < 0 || offset < 0 || offset > buffer_length ||
- offset + length > buffer_length) {
- return iree_make_status(
- IREE_STATUS_INVALID_ARGUMENT,
- "byte range out of bounds (requested %d-%d of available %zu" PRIu64
- ")",
- offset, (offset + length - 1), buffer_length);
- }
-
- vm::ref<iree_hal_buffer_t> buffer;
- IREE_RETURN_IF_ERROR(
- iree_hal_allocator_allocate_buffer(allocator.get(), memory_types,
- buffer_usage, length, &buffer),
- "failed to allocate buffer");
-
- IREE_RETURN_IF_ERROR(
- iree_hal_buffer_write_data(buffer.get(), 0, source->data.data + offset,
- length),
- "writing constant data");
-
- return buffer;
- }
-
- //===--------------------------------------------------------------------===//
- // iree_hal_buffer_t
- //===--------------------------------------------------------------------===//
-
- StatusOr<vm::ref<iree_hal_allocator_t>> BufferAllocator(
- const vm::ref<iree_hal_buffer_t>& buffer) {
- return vm::retain_ref(iree_hal_buffer_allocator(buffer.get()));
- }
-
- StatusOr<vm::ref<iree_hal_buffer_t>> BufferSubspan(
- const vm::ref<iree_hal_buffer_t>& source_buffer, int32_t source_offset,
- int32_t length) {
- IREE_TRACE_SCOPE0("HALModuleState::BufferSubspan");
- vm::ref<iree_hal_buffer_t> target_buffer;
- IREE_RETURN_IF_ERROR(
- iree_hal_buffer_subspan(source_buffer.get(), source_offset, length,
- &target_buffer),
- "subspan of an existing buffer (source_offset=%u, length=%u)",
- source_offset, length);
- return target_buffer;
- }
-
- Status BufferFill(const vm::ref<iree_hal_buffer_t>& target_buffer,
- int32_t target_offset, int32_t length, int32_t pattern) {
- IREE_TRACE_SCOPE0("HALModuleState::BufferFill");
- IREE_RETURN_IF_ERROR(
- iree_hal_buffer_fill(target_buffer.get(), target_offset, length,
- &pattern, sizeof(pattern)),
- "fill range failed (target_offset=%u, length=%u)", target_offset,
- length);
- return OkStatus();
- }
-
- StatusOr<int32_t> BufferLoad(const vm::ref<iree_hal_buffer_t>& source_buffer,
- int32_t source_offset, int32_t length) {
- IREE_TRACE_SCOPE0("HALModuleState::BufferLoad");
-
- uint32_t target_buffer = 0;
- if (length > sizeof(target_buffer)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "length %d exceeds max", length);
- }
-
- IREE_RETURN_IF_ERROR(
- iree_hal_buffer_read_data(source_buffer.get(), source_offset,
- &target_buffer, length),
- "read failed");
- return target_buffer;
- }
-
- Status BufferStore(int32_t value,
- const vm::ref<iree_hal_buffer_t>& target_buffer,
- int32_t target_offset, int32_t length) {
- IREE_TRACE_SCOPE0("HALModuleState::BufferStore");
-
- if (target_offset + length >
- iree_hal_buffer_byte_length(target_buffer.get())) {
- return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "out of bounds store");
- } else if (length > sizeof(value)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "length %d exceeds max", length);
- }
-
- IREE_RETURN_IF_ERROR(
- iree_hal_buffer_write_data(target_buffer.get(), target_offset, &value,
- length),
- "write failed");
- return OkStatus();
- }
-
- //===--------------------------------------------------------------------===//
- // iree_hal_buffer_view_t
- //===--------------------------------------------------------------------===//
-
- StatusOr<vm::ref<iree_hal_buffer_view_t>> BufferViewCreate(
- const vm::ref<iree_hal_buffer_t>& buffer,
- iree_hal_element_type_t element_type, absl::Span<const int32_t> shape) {
- vm::ref<iree_hal_buffer_view_t> buffer_view;
- IREE_RETURN_IF_ERROR(
- iree_hal_buffer_view_create(buffer.get(), element_type, shape.data(),
- shape.size(), &buffer_view),
- "failed to create buffer view");
- return std::move(buffer_view);
- }
-
- StatusOr<vm::ref<iree_hal_buffer_t>> BufferViewBuffer(
- const vm::ref<iree_hal_buffer_view_t>& buffer_view) {
- return vm::retain_ref(iree_hal_buffer_view_buffer(buffer_view.get()));
- }
-
- StatusOr<int32_t> BufferViewByteLength(
- const vm::ref<iree_hal_buffer_view_t>& buffer_view) {
- return iree_hal_buffer_view_byte_length(buffer_view.get());
- }
-
- StatusOr<int32_t> BufferViewElementType(
- const vm::ref<iree_hal_buffer_view_t>& buffer_view) {
- return static_cast<int32_t>(
- iree_hal_buffer_view_element_type(buffer_view.get()));
- }
-
- StatusOr<int32_t> BufferViewRank(
- const vm::ref<iree_hal_buffer_view_t>& buffer_view) {
- return static_cast<int32_t>(
- iree_hal_buffer_view_shape_rank(buffer_view.get()));
- }
-
- StatusOr<int32_t> BufferViewDim(
- const vm::ref<iree_hal_buffer_view_t>& buffer_view, int32_t index) {
- return static_cast<int32_t>(
- iree_hal_buffer_view_shape_dim(buffer_view.get(), index));
- }
-
- Status BufferViewTrace(
- absl::string_view key,
- absl::Span<const vm::ref<iree_hal_buffer_view_t>> buffer_views) {
- fprintf(stderr, "=== %s ===\n", std::string(key).c_str());
- for (auto& view : buffer_views) {
- std::string result_str(4096, '\0');
- iree_status_t status;
- auto max_element_count = std::numeric_limits<iree_host_size_t>::max();
- do {
- iree_host_size_t actual_length = 0;
- status = iree_hal_buffer_view_format(view.get(), max_element_count,
- result_str.size() + 1,
- &result_str[0], &actual_length);
- result_str.resize(actual_length);
- } while (iree_status_is_out_of_range(status));
- IREE_RETURN_IF_ERROR(status);
- fprintf(stderr, "%s\n", result_str.c_str());
- }
- fprintf(stderr, "\n");
- return OkStatus();
- }
-
- //===--------------------------------------------------------------------===//
- // iree_hal_command_buffer_t
- //===--------------------------------------------------------------------===//
-
- StatusOr<vm::ref<iree_hal_command_buffer_t>> CommandBufferCreate(
- const vm::ref<iree_hal_device_t>& device,
- iree_hal_command_buffer_mode_t modes,
- iree_hal_command_category_t command_categories) {
- vm::ref<iree_hal_command_buffer_t> command_buffer;
- IREE_RETURN_IF_ERROR(
- iree_hal_command_buffer_create(device.get(), modes, command_categories,
- &command_buffer),
- "failed to create command buffer");
- return command_buffer;
- }
-
- Status CommandBufferBegin(
- const vm::ref<iree_hal_command_buffer_t>& command_buffer) {
- return iree_hal_command_buffer_begin(command_buffer.get());
- }
-
- Status CommandBufferEnd(
- const vm::ref<iree_hal_command_buffer_t>& command_buffer) {
- return iree_hal_command_buffer_end(command_buffer.get());
- }
-
- Status CommandBufferExecutionBarrier(
- const vm::ref<iree_hal_command_buffer_t>& command_buffer,
- iree_hal_execution_stage_t source_stage_mask,
- iree_hal_execution_stage_t target_stage_mask,
- iree_hal_execution_barrier_flags_t flags) {
- iree_hal_memory_barrier_t global_barrier;
- global_barrier.source_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE;
- global_barrier.target_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_READ;
- return iree_hal_command_buffer_execution_barrier(
- command_buffer.get(), source_stage_mask, target_stage_mask, flags, 1,
- &global_barrier, 0, nullptr);
- }
-
- Status CommandBufferFillBuffer(
- const vm::ref<iree_hal_command_buffer_t>& command_buffer,
- const vm::ref<iree_hal_buffer_t>& target_buffer, int32_t target_offset,
- int32_t length, uint32_t pattern) {
- ExDeferRelease(target_buffer);
- return iree_hal_command_buffer_fill_buffer(
- command_buffer.get(), target_buffer.get(), target_offset, length,
- &pattern, sizeof(pattern));
- }
-
- Status CommandBufferCopyBuffer(
- const vm::ref<iree_hal_command_buffer_t>& command_buffer,
- const vm::ref<iree_hal_buffer_t>& source_buffer, int32_t source_offset,
- const vm::ref<iree_hal_buffer_t>& target_buffer, int32_t target_offset,
- int32_t length) {
- ExDeferRelease(source_buffer);
- ExDeferRelease(target_buffer);
- return iree_hal_command_buffer_copy_buffer(
- command_buffer.get(), source_buffer.get(), source_offset,
- target_buffer.get(), target_offset, length);
- }
-
- Status CommandBufferPushConstants(
- const vm::ref<iree_hal_command_buffer_t>& command_buffer,
- const vm::ref<iree_hal_executable_layout_t>& executable_layout,
- uint32_t offset, absl::Span<const uint32_t> values) {
- ExDeferRelease(executable_layout);
- return iree_hal_command_buffer_push_constants(
- command_buffer.get(), executable_layout.get(),
- offset * sizeof(uint32_t), values.data(),
- values.size() * sizeof(uint32_t));
- }
-
- Status CommandBufferPushDescriptorSet(
- const vm::ref<iree_hal_command_buffer_t>& command_buffer,
- const vm::ref<iree_hal_executable_layout_t>& executable_layout,
- uint32_t set,
- absl::Span<const std::tuple<uint32_t, vm::ref<iree_hal_buffer_t>, int32_t,
- int32_t>>
- bindings) {
- ExDeferRelease(executable_layout);
- absl::InlinedVector<iree_hal_descriptor_set_binding_t, 16> binding_structs(
- bindings.size());
- for (int i = 0; i < bindings.size(); ++i) {
- binding_structs[i] = {
- std::get<0>(bindings[i]), std::get<1>(bindings[i]).get(),
- static_cast<iree_device_size_t>(std::get<2>(bindings[i])),
- static_cast<iree_device_size_t>(std::get<3>(bindings[i]))};
- ExDeferRelease(std::get<1>(bindings[i]));
- }
- return iree_hal_command_buffer_push_descriptor_set(
- command_buffer.get(), executable_layout.get(), set,
- binding_structs.size(), binding_structs.data());
- }
-
- Status CommandBufferBindDescriptorSet(
- const vm::ref<iree_hal_command_buffer_t>& command_buffer,
- const vm::ref<iree_hal_executable_layout_t>& executable_layout,
- uint32_t set, const vm::ref<iree_hal_descriptor_set_t>& descriptor_set,
- absl::Span<const int32_t> dynamic_offsets) {
- ExDeferRelease(executable_layout);
- ExDeferRelease(descriptor_set);
- absl::InlinedVector<iree_device_size_t, 4> dynamic_offset_values(
- dynamic_offsets.size());
- for (int i = 0; i < dynamic_offsets.size(); ++i) {
- dynamic_offset_values[i] =
- static_cast<iree_device_size_t>(dynamic_offsets[i]);
- }
- return iree_hal_command_buffer_bind_descriptor_set(
- command_buffer.get(), executable_layout.get(), set,
- descriptor_set.get(), dynamic_offset_values.size(),
- dynamic_offset_values.data());
- }
-
- Status CommandBufferDispatch(
- const vm::ref<iree_hal_command_buffer_t>& command_buffer,
- const vm::ref<iree_hal_executable_t>& executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
- ExDeferRelease(executable);
- return iree_hal_command_buffer_dispatch(
- command_buffer.get(), executable.get(), entry_point, workgroup_x,
- workgroup_y, workgroup_z);
- }
-
- Status CommandBufferDispatchIndirect(
- const vm::ref<iree_hal_command_buffer_t>& command_buffer,
- const vm::ref<iree_hal_executable_t>& executable, int32_t entry_point,
- const vm::ref<iree_hal_buffer_t>& workgroups_buffer,
- int32_t workgroups_offset) {
- ExDeferRelease(executable);
- ExDeferRelease(workgroups_buffer);
- return iree_hal_command_buffer_dispatch_indirect(
- command_buffer.get(), executable.get(), entry_point,
- workgroups_buffer.get(), workgroups_offset);
- }
-
- //===--------------------------------------------------------------------===//
- // iree_hal_descriptor_set_t
- //===--------------------------------------------------------------------===//
-
- StatusOr<vm::ref<iree_hal_descriptor_set_t>> DescriptorSetCreate(
- const vm::ref<iree_hal_device_t>& device,
- const vm::ref<iree_hal_descriptor_set_layout_t>& set_layout,
- absl::Span<const std::tuple<uint32_t, vm::ref<iree_hal_buffer_t>, int32_t,
- int32_t>>
- bindings) {
- absl::InlinedVector<iree_hal_descriptor_set_binding_t, 4> binding_structs(
- bindings.size());
- for (int i = 0; i < bindings.size(); ++i) {
- binding_structs[i] = {
- /*ordinal=*/std::get<0>(bindings[i]),
- /*buffer=*/std::get<1>(bindings[i]).get(),
- /*offset=*/static_cast<iree_device_size_t>(std::get<2>(bindings[i])),
- /*length=*/static_cast<iree_device_size_t>(std::get<3>(bindings[i]))};
- }
- vm::ref<iree_hal_descriptor_set_t> descriptor_set;
- IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_create(
- device.get(), set_layout.get(), binding_structs.size(),
- binding_structs.data(), &descriptor_set));
- return std::move(descriptor_set);
- }
-
- //===--------------------------------------------------------------------===//
- // iree_hal_descriptor_set_layout_t
- //===--------------------------------------------------------------------===//
-
- StatusOr<vm::ref<iree_hal_descriptor_set_layout_t>> DescriptorSetLayoutCreate(
- const vm::ref<iree_hal_device_t>& device,
- iree_hal_descriptor_set_layout_usage_type_t usage_type,
- absl::Span<const std::tuple<uint32_t, iree_hal_descriptor_type_t,
- iree_hal_memory_access_t>>
- bindings) {
- // TODO(benvanik): custom marshaling for the structs.
- absl::InlinedVector<iree_hal_descriptor_set_layout_binding_t, 4>
- binding_structs(bindings.size());
- for (int i = 0; i < bindings.size(); ++i) {
- binding_structs[i] = {std::get<0>(bindings[i]), std::get<1>(bindings[i]),
- std::get<2>(bindings[i])};
- }
- vm::ref<iree_hal_descriptor_set_layout_t> descriptor_set_layout;
- IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_layout_create(
- device.get(), usage_type, binding_structs.size(),
- binding_structs.data(), &descriptor_set_layout));
- return std::move(descriptor_set_layout);
- }
-
- //===--------------------------------------------------------------------===//
- // iree_hal_device_t
- //===--------------------------------------------------------------------===//
-
- StatusOr<vm::ref<iree_hal_allocator_t>> DeviceAllocator(
- const vm::ref<iree_hal_device_t>& device) {
- return vm::retain_ref(iree_hal_device_allocator(device.get()));
- }
-
- StatusOr<int32_t> DeviceMatchID(const vm::ref<iree_hal_device_t>& device,
- absl::string_view pattern) {
- iree_string_view_t device_id = iree_hal_device_id(device.get());
- return iree_string_view_match_pattern(
- device_id, iree_string_view_t{pattern.data(), pattern.size()})
- ? 1
- : 0;
- }
-
- //===--------------------------------------------------------------------===//
- // iree_hal_event_t
- //===--------------------------------------------------------------------===//
-
- //===--------------------------------------------------------------------===//
- // iree_hal_executable_t
- //===--------------------------------------------------------------------===//
-
- StatusOr<vm::ref<iree_hal_executable_t>> ExecutableCreate(
- const vm::ref<iree_hal_device_t>& device,
- iree_hal_executable_format_t executable_format,
- const vm::ref<iree_vm_ro_byte_buffer_t>& executable_data,
- absl::Span<const vm::ref<iree_hal_executable_layout_t>>
- executable_layouts) {
- iree_hal_executable_spec_t spec;
- iree_hal_executable_spec_initialize(&spec);
-
- spec.executable_format = executable_format;
- spec.executable_data = executable_data->data;
-
- spec.executable_layout_count = executable_layouts.size();
- iree_hal_executable_layout_t** executable_layouts_ptr =
- (iree_hal_executable_layout_t**)iree_alloca(
- sizeof(executable_layouts_ptr[0]) * executable_layouts.size());
- for (size_t i = 0; i < executable_layouts.size(); ++i) {
- executable_layouts_ptr[i] = executable_layouts[i].get();
- }
- spec.executable_layouts = executable_layouts_ptr;
-
- vm::ref<iree_hal_executable_t> executable;
- IREE_RETURN_IF_ERROR(iree_hal_executable_cache_prepare_executable(
- executable_cache_, &spec, &executable));
- return std::move(executable);
- }
-
- //===--------------------------------------------------------------------===//
- // iree_hal_executable_layout_t
- //===--------------------------------------------------------------------===//
-
- StatusOr<vm::ref<iree_hal_executable_layout_t>> ExecutableLayoutCreate(
- const vm::ref<iree_hal_device_t>& device, int32_t push_constants,
- absl::Span<const vm::ref<iree_hal_descriptor_set_layout_t>> set_layouts) {
- iree_hal_descriptor_set_layout_t** set_layouts_ptr =
- (iree_hal_descriptor_set_layout_t**)iree_alloca(
- sizeof(set_layouts_ptr[0]) * set_layouts.size());
- for (size_t i = 0; i < set_layouts.size(); ++i) {
- set_layouts_ptr[i] = set_layouts[i].get();
- }
-
- vm::ref<iree_hal_executable_layout_t> executable_layout;
- IREE_RETURN_IF_ERROR(iree_hal_executable_layout_create(
- device.get(), push_constants, set_layouts.size(), set_layouts_ptr,
- &executable_layout));
- return std::move(executable_layout);
- }
-
- //===--------------------------------------------------------------------===//
- // iree_hal_semaphore_t
- //===--------------------------------------------------------------------===//
-
- StatusOr<vm::ref<iree_hal_semaphore_t>> SemaphoreCreate(
- const vm::ref<iree_hal_device_t>& device, uint32_t initial_value) {
- vm::ref<iree_hal_semaphore_t> semaphore;
- IREE_RETURN_IF_ERROR(
- iree_hal_semaphore_create(device.get(), initial_value, &semaphore));
- return std::move(semaphore);
- }
-
- StatusOr<std::tuple<int32_t, uint32_t>> SemaphoreQuery(
- const vm::ref<iree_hal_semaphore_t>& semaphore) {
- uint64_t value = 0;
- iree_status_t query_status =
- iree_hal_semaphore_query(semaphore.get(), &value);
- return std::make_tuple<int32_t, uint32_t>(iree_status_code(query_status),
- static_cast<uint32_t>(value));
- }
-
- Status SemaphoreSignal(const vm::ref<iree_hal_semaphore_t>& semaphore,
- uint32_t new_value) {
- return iree_hal_semaphore_signal(semaphore.get(), new_value);
- }
-
- Status SemaphoreFail(const vm::ref<iree_hal_semaphore_t>& semaphore,
- int32_t status_code) {
- iree_status_t status = iree_make_status(
- static_cast<iree_status_code_t>(status_code & IREE_STATUS_CODE_MASK));
- iree_hal_semaphore_fail(semaphore.get(), status);
- return OkStatus();
- }
-
- StatusOr<int32_t> SemaphoreAwait(
- const vm::ref<iree_hal_semaphore_t>& semaphore, uint32_t new_value) {
- // TODO(benvanik): coroutine magic.
- iree_status_t status = iree_hal_semaphore_wait_with_deadline(
- semaphore.get(), new_value, IREE_TIME_INFINITE_FUTURE);
- if (iree_status_is_ok(status)) {
- return 0;
- } else if (iree_status_is_deadline_exceeded(status)) {
- // Propagate deadline exceeded back to the VM.
- return static_cast<int32_t>(iree_status_consume_code(status));
- }
- return Status(std::move(status));
- }
-
- private:
- iree_allocator_t allocator_;
- iree_hal_device_t* shared_device_ = NULL;
- iree_hal_executable_cache_t* executable_cache_ = NULL;
-
- std::vector<iree_vm_ref_t> deferred_releases_;
-};
-
-//===----------------------------------------------------------------------===//
-// VM module interface implementation
-//===----------------------------------------------------------------------===//
-
-static const vm::NativeFunction<HALModuleState> kHALModuleFunctions[] = {
- vm::MakeNativeFunction("ex.shared_device", &HALModuleState::ExSharedDevice),
- vm::MakeNativeFunction("ex.submit_and_wait",
- &HALModuleState::ExSubmitAndWait),
-
- vm::MakeNativeFunction("allocator.allocate",
- &HALModuleState::AllocatorAllocate),
- vm::MakeNativeFunction("allocator.wrap.byte_buffer",
- &HALModuleState::AllocatorWrapByteBuffer),
-
- vm::MakeNativeFunction("buffer.allocator",
- &HALModuleState::BufferAllocator),
- vm::MakeNativeFunction("buffer.subspan", &HALModuleState::BufferSubspan),
- vm::MakeNativeFunction("buffer.fill", &HALModuleState::BufferFill),
- vm::MakeNativeFunction("buffer.load", &HALModuleState::BufferLoad),
- vm::MakeNativeFunction("buffer.store", &HALModuleState::BufferStore),
-
- vm::MakeNativeFunction("buffer_view.create",
- &HALModuleState::BufferViewCreate),
- vm::MakeNativeFunction("buffer_view.buffer",
- &HALModuleState::BufferViewBuffer),
- vm::MakeNativeFunction("buffer_view.byte_length",
- &HALModuleState::BufferViewByteLength),
- vm::MakeNativeFunction("buffer_view.element_type",
- &HALModuleState::BufferViewElementType),
- vm::MakeNativeFunction("buffer_view.rank", &HALModuleState::BufferViewRank),
- vm::MakeNativeFunction("buffer_view.dim", &HALModuleState::BufferViewDim),
- vm::MakeNativeFunction("buffer_view.trace",
- &HALModuleState::BufferViewTrace),
-
- vm::MakeNativeFunction("command_buffer.create",
- &HALModuleState::CommandBufferCreate),
- vm::MakeNativeFunction("command_buffer.begin",
- &HALModuleState::CommandBufferBegin),
- vm::MakeNativeFunction("command_buffer.end",
- &HALModuleState::CommandBufferEnd),
- vm::MakeNativeFunction("command_buffer.execution_barrier",
- &HALModuleState::CommandBufferExecutionBarrier),
- vm::MakeNativeFunction("command_buffer.fill_buffer",
- &HALModuleState::CommandBufferFillBuffer),
- vm::MakeNativeFunction("command_buffer.copy_buffer",
- &HALModuleState::CommandBufferCopyBuffer),
- vm::MakeNativeFunction("command_buffer.push_constants",
- &HALModuleState::CommandBufferPushConstants),
- vm::MakeNativeFunction("command_buffer.push_descriptor_set",
- &HALModuleState::CommandBufferPushDescriptorSet),
- vm::MakeNativeFunction("command_buffer.bind_descriptor_set",
- &HALModuleState::CommandBufferBindDescriptorSet),
- vm::MakeNativeFunction("command_buffer.dispatch",
- &HALModuleState::CommandBufferDispatch),
- vm::MakeNativeFunction("command_buffer.dispatch.indirect",
- &HALModuleState::CommandBufferDispatchIndirect),
-
- vm::MakeNativeFunction("descriptor_set.create",
- &HALModuleState::DescriptorSetCreate),
- vm::MakeNativeFunction("descriptor_set_layout.create",
- &HALModuleState::DescriptorSetLayoutCreate),
-
- vm::MakeNativeFunction("device.allocator",
- &HALModuleState::DeviceAllocator),
- vm::MakeNativeFunction("device.match.id", &HALModuleState::DeviceMatchID),
-
- vm::MakeNativeFunction("executable.create",
- &HALModuleState::ExecutableCreate),
-
- vm::MakeNativeFunction("executable_layout.create",
- &HALModuleState::ExecutableLayoutCreate),
-
- vm::MakeNativeFunction("semaphore.create",
- &HALModuleState::SemaphoreCreate),
- vm::MakeNativeFunction("semaphore.query", &HALModuleState::SemaphoreQuery),
- vm::MakeNativeFunction("semaphore.signal",
- &HALModuleState::SemaphoreSignal),
- vm::MakeNativeFunction("semaphore.fail", &HALModuleState::SemaphoreFail),
- vm::MakeNativeFunction("semaphore.await", &HALModuleState::SemaphoreAwait),
-};
-
-class HALModule final : public vm::NativeModule<HALModuleState> {
- public:
- HALModule(iree_allocator_t allocator, iree_hal_device_t* shared_device)
- : vm::NativeModule<HALModuleState>(
- "hal", allocator, absl::MakeConstSpan(kHALModuleFunctions)),
- shared_device_(shared_device) {
- iree_hal_device_retain(shared_device_);
- }
-
- ~HALModule() { iree_hal_device_release(shared_device_); }
-
- Status Initialize() {
- IREE_TRACE_SCOPE0("HALModule::Initialize");
- return OkStatus();
- }
-
- StatusOr<std::unique_ptr<HALModuleState>> CreateState(
- iree_allocator_t allocator) override {
- IREE_TRACE_SCOPE0("HALModule::CreateState");
- auto state = std::make_unique<HALModuleState>(allocator, shared_device_);
- IREE_RETURN_IF_ERROR(state->Initialize());
- return state;
- }
-
- private:
- iree_hal_device_t* shared_device_ = NULL;
-};
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_module_create(iree_hal_device_t* device, iree_allocator_t allocator,
- iree_vm_module_t** out_module) {
- IREE_ASSERT_ARGUMENT(device);
- IREE_ASSERT_ARGUMENT(out_module);
- *out_module = nullptr;
- auto module = std::make_unique<HALModule>(allocator, device);
- IREE_RETURN_IF_ERROR(module->Initialize());
- *out_module = module.release()->interface();
- return iree_ok_status();
-}
-
-} // namespace
-} // namespace hal
-} // namespace iree
diff --git a/iree/modules/hal/shims.c b/iree/modules/hal/shims.c
new file mode 100644
index 0000000..28fa9ee
--- /dev/null
+++ b/iree/modules/hal/shims.c
@@ -0,0 +1,52 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/modules/hal/shims.h"
+
+IREE_VM_ABI_DEFINE_SHIM(irii, v);
+IREE_VM_ABI_DEFINE_SHIM(r, i);
+IREE_VM_ABI_DEFINE_SHIM(r, ii);
+IREE_VM_ABI_DEFINE_SHIM(r, iii);
+IREE_VM_ABI_DEFINE_SHIM(r, iiii);
+IREE_VM_ABI_DEFINE_SHIM(r, r);
+IREE_VM_ABI_DEFINE_SHIM(r, v);
+IREE_VM_ABI_DEFINE_SHIM(rCiD, i);
+IREE_VM_ABI_DEFINE_SHIM(rCrD, v);
+IREE_VM_ABI_DEFINE_SHIM(ri, i);
+IREE_VM_ABI_DEFINE_SHIM(ri, r);
+IREE_VM_ABI_DEFINE_SHIM(ri, v);
+IREE_VM_ABI_DEFINE_SHIM(riCiD, r);
+IREE_VM_ABI_DEFINE_SHIM(riCiiiD, r);
+IREE_VM_ABI_DEFINE_SHIM(riCrD, r);
+IREE_VM_ABI_DEFINE_SHIM(rii, i);
+IREE_VM_ABI_DEFINE_SHIM(rii, r);
+IREE_VM_ABI_DEFINE_SHIM(riii, r);
+IREE_VM_ABI_DEFINE_SHIM(riii, v);
+IREE_VM_ABI_DEFINE_SHIM(riirii, r);
+IREE_VM_ABI_DEFINE_SHIM(rirCrD, r);
+IREE_VM_ABI_DEFINE_SHIM(ririi, v);
+IREE_VM_ABI_DEFINE_SHIM(rr, i);
+IREE_VM_ABI_DEFINE_SHIM(rr, r);
+IREE_VM_ABI_DEFINE_SHIM(rr, v);
+IREE_VM_ABI_DEFINE_SHIM(rrCiriiD, r);
+IREE_VM_ABI_DEFINE_SHIM(rriCiD, v);
+IREE_VM_ABI_DEFINE_SHIM(rriCiriiD, v);
+IREE_VM_ABI_DEFINE_SHIM(rriii, v);
+IREE_VM_ABI_DEFINE_SHIM(rriiii, v);
+IREE_VM_ABI_DEFINE_SHIM(rrirCiD, v);
+IREE_VM_ABI_DEFINE_SHIM(rriri, v);
+IREE_VM_ABI_DEFINE_SHIM(rririi, v);
+IREE_VM_ABI_DEFINE_SHIM(v, i);
+IREE_VM_ABI_DEFINE_SHIM(v, r);
+IREE_VM_ABI_DEFINE_SHIM(v, v);
diff --git a/iree/modules/hal/shims.h b/iree/modules/hal/shims.h
new file mode 100644
index 0000000..31a0bb5
--- /dev/null
+++ b/iree/modules/hal/shims.h
@@ -0,0 +1,394 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_MODULES_HAL_SHIMS_H_
+#define IREE_MODULES_HAL_SHIMS_H_
+
+#include "iree/base/api.h"
+#include "iree/base/attributes.h"
+#include "iree/vm/api.h"
+
+//===----------------------------------------------------------------------===//
+// Argument/result struct utilities
+//===----------------------------------------------------------------------===//
+
+#define IREE_VM_ABI_TYPE_NAME(types) iree_vm_abi_##types##_t
+
+#define IREE_VM_ABI_FIXED_STRUCT(types, body) \
+ IREE_VM_ABI_FIXED_STRUCT_IMPL(types, IREE_VM_ABI_TYPE_NAME(types), body)
+
+#define IREE_VM_ABI_VLA_STRUCT(types, vla_count, vla_field, body) \
+ IREE_VM_ABI_VLA_STRUCT_IMPL(types, vla_count, vla_field, \
+ IREE_VM_ABI_TYPE_NAME(types), body)
+
+#define IREE_VM_ABI_FIXED_STRUCT_IMPL(types, struct_type, body) \
+ typedef struct iree_vm_abi_##types##_s body IREE_ATTRIBUTE_PACKED \
+ struct_type; \
+ static inline struct_type* iree_vm_abi_##types##_checked_deref( \
+ iree_byte_span_t buffer) { \
+ return IREE_LIKELY(buffer.data_length == sizeof(struct_type)) \
+ ? (struct_type*)buffer.data \
+ : NULL; \
+ } \
+ static inline void iree_vm_abi_##types##_reset(struct_type* value) { \
+ memset(value, 0, sizeof(struct_type)); \
+ }
+
+#define IREE_VM_ABI_FIELD_SIZE(type, member) sizeof(((type*)NULL)->member)
+#define IREE_VM_ABI_VLA_STRUCT_IMPL(types, vla_count, vla_field, struct_type, \
+ body) \
+ typedef struct iree_vm_abi_##types##_s body IREE_ATTRIBUTE_PACKED \
+ struct_type; \
+ static inline struct_type* iree_vm_abi_##types##_checked_deref( \
+ iree_byte_span_t buffer) { \
+ return IREE_LIKELY(buffer.data_length >= sizeof(struct_type)) && \
+ IREE_LIKELY( \
+ buffer.data_length == \
+ sizeof(struct_type) + \
+ ((const struct_type*)buffer.data)->vla_count * \
+ IREE_VM_ABI_FIELD_SIZE(struct_type, \
+ vla_field[0])) \
+ ? (struct_type*)buffer.data \
+ : NULL; \
+ }
+
+//===----------------------------------------------------------------------===//
+// Shim function declaration/definition and accessor utilities
+//===----------------------------------------------------------------------===//
+
+typedef iree_status_t(IREE_API_PTR* iree_vm_native_function_target2_t)(
+ iree_vm_stack_t* IREE_RESTRICT stack, void* IREE_RESTRICT module,
+ void* IREE_RESTRICT module_state, const void* IREE_RESTRICT args,
+ void* IREE_RESTRICT rets);
+
+#define IREE_VM_ABI_DECLARE_SHIM(arg_types, ret_types) \
+ iree_status_t iree_vm_shim_##arg_types##_##ret_types( \
+ iree_vm_stack_t* IREE_RESTRICT stack, \
+ const iree_vm_function_call_t* IREE_RESTRICT call, \
+ iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module, \
+ void* IREE_RESTRICT module_state, \
+ iree_vm_execution_result_t* IREE_RESTRICT out_result);
+
+#define IREE_VM_ABI_DEFINE_SHIM(arg_types, ret_types) \
+ iree_status_t iree_vm_shim_##arg_types##_##ret_types( \
+ iree_vm_stack_t* IREE_RESTRICT stack, \
+ const iree_vm_function_call_t* IREE_RESTRICT call, \
+ iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module, \
+ void* IREE_RESTRICT module_state, \
+ iree_vm_execution_result_t* IREE_RESTRICT out_result) { \
+ const IREE_VM_ABI_TYPE_NAME(arg_types)* args = \
+ iree_vm_abi_##arg_types##_checked_deref(call->arguments); \
+ IREE_VM_ABI_TYPE_NAME(ret_types)* rets = \
+ iree_vm_abi_##ret_types##_checked_deref(call->results); \
+ if (IREE_UNLIKELY(!args || !rets)) { \
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, \
+ "argument/result signature mismatch"); \
+ } \
+ iree_vm_abi_##ret_types##_reset(rets); \
+ return target_fn(stack, module, module_state, args, rets); \
+ }
+
+#define IREE_VM_ABI_EXPORT(function_name, arg_types, ret_types) \
+ static iree_status_t function_name( \
+ iree_vm_stack_t* IREE_RESTRICT stack, void* IREE_RESTRICT module, \
+ iree_hal_module_state_t* IREE_RESTRICT state, \
+ IREE_VM_ABI_TYPE_NAME(arg_types) * IREE_RESTRICT args, \
+ IREE_VM_ABI_TYPE_NAME(ret_types) * IREE_RESTRICT rets)
+
+// TODO(benvanik): special case when source type and target type match.
+#define IREE_VM_ABI_VLA_STACK_CAST(args, vla_count, vla_field, target_type, \
+ max_count, out_count, out_ptrs) \
+ *(out_count) = (args)->vla_count; \
+ if (IREE_UNLIKELY((args)->vla_count > (max_count))) { \
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "count %u > %u", \
+ (args)->vla_count, (uint32_t)(max_count)); \
+ } \
+ *(out_ptrs) = \
+ (target_type*)iree_alloca((args)->vla_count * sizeof(target_type)); \
+ for (iree_host_size_t i = 0; i < (args)->vla_count; ++i) { \
+ (*(out_ptrs))[i] = (target_type)((args)->vla_field[i].i0); \
+ }
+
+#define IREE_VM_ABI_VLA_STACK_DEREF(args, vla_count, vla_field, ref_type, \
+ max_count, out_count, out_ptrs) \
+ *(out_count) = (args)->vla_count; \
+ if (IREE_UNLIKELY((args)->vla_count > (max_count))) { \
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE, \
+ "count %u of " #ref_type " > %u", \
+ (args)->vla_count, (uint32_t)(max_count)); \
+ } \
+ *(out_ptrs) = \
+ (ref_type##_t**)iree_alloca((args)->vla_count * sizeof(ref_type##_t*)); \
+ for (iree_host_size_t i = 0; i < (args)->vla_count; ++i) { \
+ IREE_RETURN_IF_ERROR( \
+ ref_type##_check_deref((args)->vla_field[i].r0, &(*(out_ptrs))[i])); \
+ }
+
+#define IREE_VM_ABI_VLA_HEAP_DEREF(args, vla_count, vla_field, ref_type, \
+ host_allocator, out_count, out_ptrs) \
+ *(out_count) = (args)->vla_count; \
+ IREE_RETURN_IF_ERROR(iree_alloca((args)->vla_count * sizeof(ref_type##_t*)); \
+ for (iree_host_size_t i = 0; i < (args)->vla_count; ++i) { \
+ IREE_RETURN_IF_ERROR( \
+ ref_type##_check_deref((args)->vla_field[i].r0, &(*(out_ptrs))[i])); \
+ }
+
+//===----------------------------------------------------------------------===//
+// Structures used for arguments and results.
+//===----------------------------------------------------------------------===//
+
+#if defined(IREE_COMPILER_MSVC)
+#pragma pack(push, 1)
+#endif // IREE_COMPILER_MSVC
+
+// Special case for void (empty args/rets) as C structs can't have a 0 length.
+typedef struct {
+ int unused;
+} iree_vm_abi_v_t;
+static inline iree_vm_abi_v_t* iree_vm_abi_v_checked_deref(
+ iree_byte_span_t buffer) {
+ return (iree_vm_abi_v_t*)buffer.data;
+}
+static inline void iree_vm_abi_v_reset(iree_vm_abi_v_t* value) {}
+
+IREE_VM_ABI_FIXED_STRUCT(i, { int32_t i0; });
+
+IREE_VM_ABI_FIXED_STRUCT(ii, {
+ int32_t i0;
+ int32_t i1;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(iii, {
+ int32_t i0;
+ int32_t i1;
+ int32_t i2;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(iiii, {
+ int32_t i0;
+ int32_t i1;
+ int32_t i2;
+ int32_t i3;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(irii, {
+ int32_t i0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ int32_t i3;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(r, { iree_vm_ref_t r0; });
+
+IREE_VM_ABI_FIXED_STRUCT(rr, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(ri, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(ririi, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ iree_vm_ref_t r2;
+ int32_t i3;
+ int32_t i4;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rii, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ int32_t i2;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(riii, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ int32_t i2;
+ int32_t i3;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(riirii, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ int32_t i2;
+ iree_vm_ref_t r3;
+ int32_t i4;
+ int32_t i5;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rriii, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ int32_t i3;
+ int32_t i4;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rriiii, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ int32_t i3;
+ int32_t i4;
+ int32_t i5;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rriri, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ iree_vm_ref_t r3;
+ int32_t i4;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rririi, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ iree_vm_ref_t r3;
+ int32_t i4;
+ int32_t i5;
+});
+
+IREE_VM_ABI_VLA_STRUCT(rCiD, a1_count, a1, {
+ iree_vm_ref_t r0;
+ iree_vm_size_t a1_count;
+ iree_vm_abi_i_t a1[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rCrD, a1_count, a1, {
+ iree_vm_ref_t r0;
+ iree_vm_size_t a1_count;
+ iree_vm_abi_r_t a1[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(riCiD, a2_count, a2, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ iree_vm_size_t a2_count;
+ iree_vm_abi_i_t a2[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(riCrD, a2_count, a2, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ iree_vm_size_t a2_count;
+ iree_vm_abi_r_t a2[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(riiCriD, a3_count, a3, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ int32_t i2;
+ iree_vm_size_t a3_count;
+ iree_vm_abi_ri_t a3[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rirCrD, a3_count, a3, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ iree_vm_ref_t r2;
+ iree_vm_size_t a3_count;
+ iree_vm_abi_r_t a3[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rriCiD, a3_count, a3, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ iree_vm_size_t a3_count;
+ iree_vm_abi_i_t a3[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rrirCiD, a4_count, a4, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ iree_vm_ref_t r3;
+ iree_vm_size_t a4_count;
+ iree_vm_abi_i_t a4[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(riCiiiD, a2_count, a2, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ iree_vm_size_t a2_count;
+ iree_vm_abi_iii_t a2[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rrCiriiD, a2_count, a2, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ iree_vm_size_t a2_count;
+ iree_vm_abi_irii_t a2[0];
+});
+
+IREE_VM_ABI_VLA_STRUCT(rriCiriiD, a3_count, a3, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ iree_vm_size_t a3_count;
+ iree_vm_abi_irii_t a3[0];
+});
+
+#if defined(IREE_COMPILER_MSVC)
+#pragma pack(pop)
+#endif // IREE_COMPILER_MSVC
+
+//===----------------------------------------------------------------------===//
+// Shims for marshaling arguments and results
+//===----------------------------------------------------------------------===//
+
+IREE_VM_ABI_DECLARE_SHIM(irii, v);
+IREE_VM_ABI_DECLARE_SHIM(r, i);
+IREE_VM_ABI_DECLARE_SHIM(r, ii);
+IREE_VM_ABI_DECLARE_SHIM(r, iii);
+IREE_VM_ABI_DECLARE_SHIM(r, iiii);
+IREE_VM_ABI_DECLARE_SHIM(r, r);
+IREE_VM_ABI_DECLARE_SHIM(r, v);
+IREE_VM_ABI_DECLARE_SHIM(rCiD, i);
+IREE_VM_ABI_DECLARE_SHIM(rCrD, v);
+IREE_VM_ABI_DECLARE_SHIM(ri, i);
+IREE_VM_ABI_DECLARE_SHIM(ri, r);
+IREE_VM_ABI_DECLARE_SHIM(ri, v);
+IREE_VM_ABI_DECLARE_SHIM(riCiD, r);
+IREE_VM_ABI_DECLARE_SHIM(riCiiiD, r);
+IREE_VM_ABI_DECLARE_SHIM(riCrD, r);
+IREE_VM_ABI_DECLARE_SHIM(rii, i);
+IREE_VM_ABI_DECLARE_SHIM(rii, r);
+IREE_VM_ABI_DECLARE_SHIM(riii, r);
+IREE_VM_ABI_DECLARE_SHIM(riii, v);
+IREE_VM_ABI_DECLARE_SHIM(riirii, r);
+IREE_VM_ABI_DECLARE_SHIM(rirCrD, r);
+IREE_VM_ABI_DECLARE_SHIM(ririi, v);
+IREE_VM_ABI_DECLARE_SHIM(rr, i);
+IREE_VM_ABI_DECLARE_SHIM(rr, r);
+IREE_VM_ABI_DECLARE_SHIM(rr, v);
+IREE_VM_ABI_DECLARE_SHIM(rrCiriiD, r);
+IREE_VM_ABI_DECLARE_SHIM(rriCiD, v);
+IREE_VM_ABI_DECLARE_SHIM(rriCiriiD, v);
+IREE_VM_ABI_DECLARE_SHIM(rriii, v);
+IREE_VM_ABI_DECLARE_SHIM(rriiii, v);
+IREE_VM_ABI_DECLARE_SHIM(rrirCiD, v);
+IREE_VM_ABI_DECLARE_SHIM(rriri, v);
+IREE_VM_ABI_DECLARE_SHIM(rririi, v);
+IREE_VM_ABI_DECLARE_SHIM(v, i);
+IREE_VM_ABI_DECLARE_SHIM(v, r);
+IREE_VM_ABI_DECLARE_SHIM(v, v);
+
+#endif // IREE_MODULES_HAL_SHIMS_H_
diff --git a/iree/tools/utils/vm_util.cc b/iree/tools/utils/vm_util.cc
index 03ff229..b758d8d 100644
--- a/iree/tools/utils/vm_util.cc
+++ b/iree/tools/utils/vm_util.cc
@@ -212,7 +212,7 @@
"variant %d has value type %d but descriptor information %s", i,
(int)variant.type.value_type, desc_str.c_str());
}
- auto* buffer_view = iree_hal_buffer_view_deref(&variant.ref);
+ auto* buffer_view = iree_hal_buffer_view_deref(variant.ref);
if (!buffer_view) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"failed dereferencing variant %d", i);
diff --git a/iree/vm/builtin_types.h b/iree/vm/builtin_types.h
index 4361a9f..c8144b5 100644
--- a/iree/vm/builtin_types.h
+++ b/iree/vm/builtin_types.h
@@ -31,6 +31,14 @@
iree_vm_ref_destroy_t destroy;
} iree_vm_ro_byte_buffer_t;
+// Returns the a string view referencing the given |value| buffer.
+static inline iree_string_view_t iree_vm_ro_byte_buffer_as_string(
+ const iree_vm_ro_byte_buffer_t* value) {
+ return value ? iree_make_string_view((const char*)value->data.data,
+ value->data.data_length)
+ : iree_string_view_empty();
+}
+
// The built-in mutable buffer type.
// This simply points at a span of memory. The memory could be owned (in which
// case a destroy function must be provided) or unowned (NULL destroy function).
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index 32594c2..86c6509 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -827,7 +827,7 @@
DISPATCH_OP(CORE, ListReserve, {
bool list_is_move;
iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
- iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+ iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
if (IREE_UNLIKELY(!list)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
}
@@ -838,7 +838,7 @@
DISPATCH_OP(CORE, ListSize, {
bool list_is_move;
iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
- iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+ iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
if (IREE_UNLIKELY(!list)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
}
@@ -849,7 +849,7 @@
DISPATCH_OP(CORE, ListResize, {
bool list_is_move;
iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
- iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+ iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
if (IREE_UNLIKELY(!list)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
}
@@ -860,7 +860,7 @@
DISPATCH_OP(CORE, ListGetI32, {
bool list_is_move;
iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
- iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+ iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
if (IREE_UNLIKELY(!list)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
}
@@ -875,7 +875,7 @@
DISPATCH_OP(CORE, ListSetI32, {
bool list_is_move;
iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
- iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+ iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
if (!list) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
}
@@ -1328,7 +1328,7 @@
DISPATCH_OP(EXT_I64, ListGetI64, {
bool list_is_move;
iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
- iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+ iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
if (IREE_UNLIKELY(!list)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
}
@@ -1343,7 +1343,7 @@
DISPATCH_OP(EXT_I64, ListSetI64, {
bool list_is_move;
iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move);
- iree_vm_list_t* list = iree_vm_list_deref(list_ref);
+ iree_vm_list_t* list = iree_vm_list_deref(*list_ref);
if (IREE_UNLIKELY(!list)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null");
}
diff --git a/iree/vm/list.c b/iree/vm/list.c
index 17569d1..64fbdc8 100644
--- a/iree/vm/list.c
+++ b/iree/vm/list.c
@@ -494,7 +494,7 @@
if (!iree_status_is_ok(iree_status_consume_code(status))) {
return NULL;
}
- status = iree_vm_ref_check(&value, type_descriptor->type);
+ status = iree_vm_ref_check(value, type_descriptor->type);
if (!iree_status_is_ok(iree_status_consume_code(status))) {
return NULL;
}
diff --git a/iree/vm/list_test.cc b/iree/vm/list_test.cc
index e17dd9d..4423cbc 100644
--- a/iree/vm/list_test.cc
+++ b/iree/vm/list_test.cc
@@ -146,8 +146,8 @@
for (iree_host_size_t i = 0; i < 5; ++i) {
iree_vm_ref_t ref_a{0};
IREE_ASSERT_OK(iree_vm_list_get_ref_retain(list, i, &ref_a));
- EXPECT_TRUE(test_a_isa(&ref_a));
- auto* a = test_a_deref(&ref_a);
+ EXPECT_TRUE(test_a_isa(ref_a));
+ auto* a = test_a_deref(ref_a);
EXPECT_EQ(i, a->data());
iree_vm_ref_release(&ref_a);
}
@@ -192,8 +192,8 @@
for (iree_host_size_t i = 5; i < 10; ++i) {
iree_vm_ref_t ref_a{0};
IREE_ASSERT_OK(iree_vm_list_get_ref_retain(list, i, &ref_a));
- EXPECT_TRUE(test_a_isa(&ref_a));
- auto* a = test_a_deref(&ref_a);
+ EXPECT_TRUE(test_a_isa(ref_a));
+ auto* a = test_a_deref(ref_a);
EXPECT_EQ(i, a->data());
iree_vm_ref_release(&ref_a);
}
diff --git a/iree/vm/native_module.c b/iree/vm/native_module.c
index 8e4b11d..5ad605d 100644
--- a/iree/vm/native_module.c
+++ b/iree/vm/native_module.c
@@ -26,10 +26,13 @@
iree_vm_module_t base_interface;
// Interface with optional user-provided function pointers.
- // user_interface.self will contain the user's module pointer that must be
- // passed to all functions.
iree_vm_module_t user_interface;
+ // The self passed to user_interface functions. Will either be the value of
+ // user_interface.self when initialized and the base pointer of the base
+ // native module otherwise.
+ void* self;
+
// Allocator this module was allocated with and must be freed with.
iree_allocator_t allocator;
@@ -37,6 +40,10 @@
const iree_vm_native_module_descriptor_t* descriptor;
} iree_vm_native_module_t;
+IREE_API_EXPORT iree_host_size_t iree_vm_native_module_size() {
+ return sizeof(iree_vm_native_module_t);
+}
+
#if defined(NDEBUG)
static iree_status_t iree_vm_native_module_verify_descriptor(
const iree_vm_native_module_descriptor_t* module_descriptor) {
@@ -69,17 +76,23 @@
iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
// Destroy the optional user-provided self.
- if (module->user_interface.destroy) {
- module->user_interface.destroy(module->user_interface.self);
+ if (module->self == module) {
+ iree_allocator_t allocator = module->allocator;
+ if (module->user_interface.destroy) {
+ module->user_interface.destroy(module->self);
+ }
+ iree_allocator_free(allocator, module);
+ } else {
+ if (module->user_interface.destroy) {
+ module->user_interface.destroy(module->self);
+ }
}
-
- iree_allocator_free(module->allocator, module);
}
static iree_string_view_t IREE_API_PTR iree_vm_native_module_name(void* self) {
iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
if (module->user_interface.name) {
- return module->user_interface.name(module->user_interface.self);
+ return module->user_interface.name(module->self);
}
return module->descriptor->module_name;
}
@@ -88,7 +101,7 @@
iree_vm_native_module_signature(void* self) {
iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
if (module->user_interface.signature) {
- return module->user_interface.signature(module->user_interface.self);
+ return module->user_interface.signature(module->self);
}
iree_vm_module_signature_t signature;
memset(&signature, 0, sizeof(signature));
@@ -155,9 +168,8 @@
if (out_name) memset(out_name, 0, sizeof(*out_name));
if (out_signature) memset(out_signature, 0, sizeof(*out_signature));
if (module->user_interface.get_function) {
- return module->user_interface.get_function(module->user_interface.self,
- linkage, ordinal, out_function,
- out_name, out_signature);
+ return module->user_interface.get_function(
+ module->self, linkage, ordinal, out_function, out_name, out_signature);
}
switch (linkage) {
case IREE_VM_FUNCTION_LINKAGE_IMPORT:
@@ -181,7 +193,7 @@
iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
if (module->user_interface.get_function_reflection_attr) {
return module->user_interface.get_function_reflection_attr(
- module->user_interface.self, linkage, ordinal, index, key, value);
+ module->self, linkage, ordinal, index, key, value);
}
// TODO(benvanik): implement native module reflection.
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
@@ -194,8 +206,8 @@
iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
memset(out_function, 0, sizeof(*out_function));
if (module->user_interface.lookup_function) {
- return module->user_interface.lookup_function(module->user_interface.self,
- linkage, name, out_function);
+ return module->user_interface.lookup_function(module->self, linkage, name,
+ out_function);
}
if (IREE_UNLIKELY(linkage != IREE_VM_FUNCTION_LINKAGE_EXPORT)) {
@@ -234,8 +246,8 @@
iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
*out_module_state = NULL;
if (module->user_interface.alloc_state) {
- return module->user_interface.alloc_state(module->user_interface.self,
- allocator, out_module_state);
+ return module->user_interface.alloc_state(module->self, allocator,
+ out_module_state);
}
// Default to no state.
return iree_ok_status();
@@ -245,8 +257,7 @@
void* self, iree_vm_module_state_t* module_state) {
iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
if (module->user_interface.free_state) {
- module->user_interface.free_state(module->user_interface.self,
- module_state);
+ module->user_interface.free_state(module->self, module_state);
return;
}
// No-op in the default implementation.
@@ -260,9 +271,8 @@
const iree_vm_function_signature_t* signature) {
iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
if (module->user_interface.resolve_import) {
- return module->user_interface.resolve_import(module->user_interface.self,
- module_state, ordinal,
- function, signature);
+ return module->user_interface.resolve_import(module->self, module_state,
+ ordinal, function, signature);
}
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"native module does not support imports");
@@ -282,8 +292,8 @@
module->descriptor->export_count);
}
if (module->user_interface.begin_call) {
- return module->user_interface.begin_call(module->user_interface.self, stack,
- call, out_result);
+ return module->user_interface.begin_call(module->self, stack, call,
+ out_result);
}
// NOTE: VM stack is currently unused. We could stash things here for the
@@ -320,8 +330,7 @@
iree_vm_execution_result_t* out_result) {
iree_vm_native_module_t* module = (iree_vm_native_module_t*)self;
if (module->user_interface.resume_call) {
- return module->user_interface.resume_call(module->user_interface.self,
- stack, out_result);
+ return module->user_interface.resume_call(module->self, stack, out_result);
}
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"native module does not support resume");
@@ -361,11 +370,52 @@
iree_vm_native_module_t* module = NULL;
IREE_RETURN_IF_ERROR(iree_allocator_malloc(
allocator, sizeof(iree_vm_native_module_t), (void**)&module));
+
+ iree_status_t status = iree_vm_native_module_initialize(
+ interface, module_descriptor, allocator, (iree_vm_module_t*)module);
+ if (!iree_status_is_ok(status)) {
+ iree_allocator_free(allocator, module);
+ return status;
+ }
+
+ *out_module = &module->base_interface;
+ return iree_ok_status();
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_vm_native_module_initialize(
+ const iree_vm_module_t* interface,
+ const iree_vm_native_module_descriptor_t* module_descriptor,
+ iree_allocator_t allocator, iree_vm_module_t* base_module) {
+ IREE_ASSERT_ARGUMENT(interface);
+ IREE_ASSERT_ARGUMENT(module_descriptor);
+ IREE_ASSERT_ARGUMENT(base_module);
+ iree_vm_native_module_t* module = (iree_vm_native_module_t*)base_module;
+
+ if (IREE_UNLIKELY(!interface->begin_call) &&
+ IREE_UNLIKELY(!module_descriptor->functions)) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "native modules must provide call support or function pointers");
+ } else if (IREE_UNLIKELY(!interface->begin_call) &&
+ IREE_UNLIKELY(module_descriptor->export_count !=
+ module_descriptor->function_count)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "native modules using the default call support "
+ "must have 1:1 exports:function pointers");
+ }
+
+ // Perform some optional debug-only verification of the descriptor.
+ // Since native modules are designed to be compiled in we don't need to do
+ // this in release builds.
+ IREE_RETURN_IF_ERROR(
+ iree_vm_native_module_verify_descriptor(module_descriptor));
module->allocator = allocator;
module->descriptor = module_descriptor;
// TODO(benvanik): version interface and copy only valid bytes.
memcpy(&module->user_interface, interface, sizeof(*interface));
+ module->self =
+ module->user_interface.self ? module->user_interface.self : module;
// Base interface that routes through our thunks.
iree_vm_module_initialize(&module->base_interface, module);
@@ -383,6 +433,5 @@
module->base_interface.begin_call = iree_vm_native_module_begin_call;
module->base_interface.resume_call = iree_vm_native_module_resume_call;
- *out_module = &module->base_interface;
return iree_ok_status();
}
diff --git a/iree/vm/native_module.h b/iree/vm/native_module.h
index 448abff..a1dd52e 100644
--- a/iree/vm/native_module.h
+++ b/iree/vm/native_module.h
@@ -102,6 +102,10 @@
const iree_vm_reflection_attr_t* reflection_attrs;
} iree_vm_native_module_descriptor_t;
+// Returns the size, in bytes, of the allocation required for native modules.
+// Callers may allocate more memory if they need additional storage.
+IREE_API_EXPORT iree_host_size_t iree_vm_native_module_size();
+
// Creates a new native module with the metadata tables in |descriptor|.
// These tables will be used for reflection and function lookup, and the
// provided function pointers will be called when state needs to be managed or
@@ -119,6 +123,11 @@
const iree_vm_native_module_descriptor_t* module_descriptor,
iree_allocator_t allocator, iree_vm_module_t** out_module);
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_vm_native_module_initialize(
+ const iree_vm_module_t* interface,
+ const iree_vm_native_module_descriptor_t* module_descriptor,
+ iree_allocator_t allocator, iree_vm_module_t* module);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/iree/vm/ref.c b/iree/vm/ref.c
index c8a69e5..58613a1 100644
--- a/iree/vm/ref.c
+++ b/iree/vm/ref.c
@@ -149,12 +149,6 @@
return iree_ok_status();
}
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_vm_ref_check(iree_vm_ref_t* ref, iree_vm_ref_type_t type) {
- return ref->type == type ? iree_ok_status()
- : iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
-}
-
IREE_API_EXPORT void IREE_API_CALL iree_vm_ref_retain(iree_vm_ref_t* ref,
iree_vm_ref_t* out_ref) {
if (ref != out_ref && ref->ptr != out_ref->ptr) {
diff --git a/iree/vm/ref.h b/iree/vm/ref.h
index 1ef7fd4..5e05f67 100644
--- a/iree/vm/ref.h
+++ b/iree/vm/ref.h
@@ -175,12 +175,11 @@
void* ptr, iree_vm_ref_type_t type, iree_vm_ref_t* out_ref);
// Checks that the given reference-counted pointer |ref| is of |type|.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_vm_ref_check(iree_vm_ref_t* ref, iree_vm_ref_type_t type);
-
-#define IREE_VM_DEREF_OR_RETURN(value_type, value, ref, type) \
- IREE_RETURN_IF_ERROR(iree_vm_ref_check(ref, type)); \
- value_type* value = (value_type*)(ref)->ptr;
+static inline iree_status_t iree_vm_ref_check(const iree_vm_ref_t ref,
+ iree_vm_ref_type_t type) {
+ return ref.type == type ? iree_ok_status()
+ : iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
+}
// Retains the reference-counted pointer |ref|.
// |out_ref| will be released if it already contains a reference.
@@ -267,13 +266,13 @@
#define IREE_VM_DECLARE_TYPE_ADAPTERS(name, T) \
IREE_API_EXPORT iree_vm_ref_t IREE_API_CALL name##_retain_ref(T* value); \
IREE_API_EXPORT iree_vm_ref_t IREE_API_CALL name##_move_ref(T* value); \
- IREE_API_EXPORT T* IREE_API_CALL name##_deref(iree_vm_ref_t* ref); \
+ IREE_API_EXPORT T* IREE_API_CALL name##_deref(const iree_vm_ref_t ref); \
IREE_API_EXPORT iree_status_t IREE_API_CALL name##_check_deref( \
- iree_vm_ref_t* ref, T** out_ptr); \
+ const iree_vm_ref_t ref, T** out_ptr); \
IREE_API_EXPORT const iree_vm_ref_type_descriptor_t* IREE_API_CALL \
name##_get_descriptor(); \
- inline bool name##_isa(iree_vm_ref_t* ref) { \
- return name##_get_descriptor()->type == ref->type; \
+ static inline bool name##_isa(const iree_vm_ref_t ref) { \
+ return name##_get_descriptor()->type == ref.type; \
} \
IREE_API_EXPORT iree_vm_ref_type_t IREE_API_CALL name##_type_id(); \
IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T)
@@ -290,18 +289,18 @@
iree_vm_ref_wrap_assign(value, name##_descriptor.type, &ref); \
return ref; \
} \
- IREE_API_EXPORT T* IREE_API_CALL name##_deref(iree_vm_ref_t* ref) { \
+ IREE_API_EXPORT T* IREE_API_CALL name##_deref(const iree_vm_ref_t ref) { \
iree_status_t status = iree_vm_ref_check(ref, name##_descriptor.type); \
- if (!iree_status_is_ok(status)) { \
- iree_status_ignore(status); \
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) { \
+ IREE_IGNORE_ERROR(status); \
return NULL; \
} \
- return (T*)ref->ptr; \
+ return (T*)ref.ptr; \
} \
IREE_API_EXPORT iree_status_t IREE_API_CALL name##_check_deref( \
- iree_vm_ref_t* ref, T** out_ptr) { \
+ const iree_vm_ref_t ref, T** out_ptr) { \
IREE_RETURN_IF_ERROR(iree_vm_ref_check(ref, name##_descriptor.type)); \
- *out_ptr = (T*)ref->ptr; \
+ *out_ptr = (T*)ref.ptr; \
return iree_ok_status(); \
} \
IREE_API_EXPORT const iree_vm_ref_type_descriptor_t* IREE_API_CALL \
diff --git a/iree/vm/ref_test.cc b/iree/vm/ref_test.cc
index 5be26b5..ac47cf3 100644
--- a/iree/vm/ref_test.cc
+++ b/iree/vm/ref_test.cc
@@ -161,18 +161,18 @@
// Checking null refs is fine.
TEST(VMRefTest, CheckNull) {
iree_vm_ref_t null_ref = {0};
- IREE_EXPECT_OK(iree_vm_ref_check(&null_ref, IREE_VM_REF_TYPE_NULL));
+ IREE_EXPECT_OK(iree_vm_ref_check(null_ref, IREE_VM_REF_TYPE_NULL));
IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
::iree::Status(iree_vm_ref_check(
- &null_ref, static_cast<iree_vm_ref_type_t>(1234))));
+ null_ref, static_cast<iree_vm_ref_type_t>(1234))));
}
// Tests type checks.
TEST(VMRefTest, Check) {
iree_vm_ref_t a_ref = MakeRef<A>("AType");
- IREE_EXPECT_OK(iree_vm_ref_check(&a_ref, A::kTypeID));
+ IREE_EXPECT_OK(iree_vm_ref_check(a_ref, A::kTypeID));
IREE_EXPECT_STATUS_IS(IREE_STATUS_INVALID_ARGUMENT,
- ::iree::Status(iree_vm_ref_check(&a_ref, B::kTypeID)));
+ ::iree::Status(iree_vm_ref_check(a_ref, B::kTypeID)));
iree_vm_ref_release(&a_ref);
}
@@ -248,7 +248,7 @@
iree_vm_ref_t a_ref_0 = MakeRef<A>("AType");
iree_vm_ref_t a_ref_1 = {0};
iree_vm_ref_retain_or_move(/*is_move=*/1, &a_ref_0, &a_ref_1);
- IREE_EXPECT_OK(iree_vm_ref_check(&a_ref_0, IREE_VM_REF_TYPE_NULL));
+ IREE_EXPECT_OK(iree_vm_ref_check(a_ref_0, IREE_VM_REF_TYPE_NULL));
iree_vm_ref_release(&a_ref_1);
}
@@ -269,7 +269,7 @@
TEST(VMRefTest, RetainOrMoveMovingIntoSelf) {
iree_vm_ref_t a_ref = MakeRef<A>("AType");
iree_vm_ref_retain_or_move(/*is_move=*/1, &a_ref, &a_ref);
- IREE_EXPECT_OK(iree_vm_ref_check(&a_ref, A::kTypeID));
+ IREE_EXPECT_OK(iree_vm_ref_check(a_ref, A::kTypeID));
iree_vm_ref_release(&a_ref);
}
@@ -413,7 +413,7 @@
iree_vm_ref_t a_ref_0 = MakeRef<A>("AType");
iree_vm_ref_t a_ref_1 = {0};
iree_vm_ref_move(&a_ref_0, &a_ref_1);
- IREE_EXPECT_OK(iree_vm_ref_check(&a_ref_0, IREE_VM_REF_TYPE_NULL));
+ IREE_EXPECT_OK(iree_vm_ref_check(a_ref_0, IREE_VM_REF_TYPE_NULL));
iree_vm_ref_release(&a_ref_1);
}
@@ -421,7 +421,7 @@
TEST(VMRefTest, MovingIntoSelf) {
iree_vm_ref_t a_ref = MakeRef<A>("AType");
iree_vm_ref_move(&a_ref, &a_ref);
- IREE_EXPECT_OK(iree_vm_ref_check(&a_ref, A::kTypeID));
+ IREE_EXPECT_OK(iree_vm_ref_check(a_ref, A::kTypeID));
iree_vm_ref_release(&a_ref);
}
diff --git a/iree/vm/value.h b/iree/vm/value.h
index d5aa167..5ac1b8e 100644
--- a/iree/vm/value.h
+++ b/iree/vm/value.h
@@ -21,6 +21,11 @@
extern "C" {
#endif // __cplusplus
+// TODO(benvanik): support variable size in modules. vm.imports would need index
+// type and we'd have to make sure all native modules used this size type. It
+// would be a compiler runtime flag and runtime compile flag.
+typedef int32_t iree_vm_size_t;
+
// Defines the type of a primitive value.
typedef enum {
// Not a value type.