Plumbing through IREE_HAL_MODULE_FLAG_SYNCHRONOUS. This will be used to avoid yielding when coroutines are not supported. We could remove this and reduce code size after #9612.
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp index 14265be..b56c736 100644 --- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp +++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
@@ -272,8 +272,8 @@ iree_hal_driver_release(driver); // Create hal module. - IREE_CHECK_OK( - iree_hal_module_create(device, iree_allocator_system(), &hal_module)); + IREE_CHECK_OK(iree_hal_module_create(device, IREE_HAL_MODULE_FLAG_NONE, + iree_allocator_system(), &hal_module)); // Bytecode module. IREE_CHECK_OK(iree_vm_bytecode_module_create(
diff --git a/docs/website/docs/bindings/c-api.md b/docs/website/docs/bindings/c-api.md index cd3e515..e79143a 100644 --- a/docs/website/docs/bindings/c-api.md +++ b/docs/website/docs/bindings/c-api.md
@@ -128,7 +128,8 @@ // We'll load this module into a VM context later. iree_vm_module_t* hal_module = NULL; IREE_CHECK_OK( - iree_hal_module_create(device, iree_allocator_system(), &hal_module)); + iree_hal_module_create(device, IREE_HAL_MODULE_FLAG_NONE, + iree_allocator_system(), &hal_module)); // The reference to the driver can be released now. iree_hal_driver_release(driver); ```
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc index ecd4a27..9103341 100644 --- a/runtime/bindings/python/vm.cc +++ b/runtime/bindings/python/vm.cc
@@ -22,9 +22,10 @@ VmModule CreateHalModule(HalDevice* device) { iree_vm_module_t* module; - CheckApiStatus(iree_hal_module_create(device->raw_ptr(), - iree_allocator_system(), &module), - "Error creating hal module"); + CheckApiStatus( + iree_hal_module_create(device->raw_ptr(), IREE_HAL_MODULE_FLAG_NONE, + iree_allocator_system(), &module), + "Error creating hal module"); return VmModule::StealFromRawPtr(module); }
diff --git a/runtime/bindings/tflite/interpreter.c b/runtime/bindings/tflite/interpreter.c index f7c47cd..f12177b 100644 --- a/runtime/bindings/tflite/interpreter.c +++ b/runtime/bindings/tflite/interpreter.c
@@ -61,8 +61,9 @@ "failed creating the default device for driver '%.*s'", (int)driver_name.size, driver_name.data); - IREE_RETURN_IF_ERROR(iree_hal_module_create( - interpreter->device, interpreter->allocator, &interpreter->hal_module)); + IREE_RETURN_IF_ERROR( + iree_hal_module_create(interpreter->device, IREE_HAL_MODULE_FLAG_NONE, + interpreter->allocator, &interpreter->hal_module)); return iree_ok_status(); }
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc index c925057..49e3d26 100644 --- a/runtime/src/iree/modules/check/check_test.cc +++ b/runtime/src/iree/modules/check/check_test.cc
@@ -45,8 +45,9 @@ } IREE_ASSERT_OK(iree_hal_driver_create_default_device( hal_driver, iree_allocator_system(), &device_)); - IREE_ASSERT_OK( - iree_hal_module_create(device_, iree_allocator_system(), &hal_module_)); + IREE_ASSERT_OK(iree_hal_module_create(device_, IREE_HAL_MODULE_FLAG_NONE, + iree_allocator_system(), + &hal_module_)); iree_hal_driver_release(hal_driver); IREE_ASSERT_OK(
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index bf83c61..eada4af 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c
@@ -121,6 +121,7 @@ typedef struct iree_hal_module_t { iree_allocator_t host_allocator; + iree_hal_module_flags_t flags; iree_hal_device_t* shared_device; // TODO(benvanik): types. } iree_hal_module_t; @@ -130,6 +131,7 @@ typedef struct iree_hal_module_state_t { iree_allocator_t host_allocator; + iree_hal_module_flags_t flags; iree_hal_device_t* shared_device; iree_status_t loop_status; iree_hal_executable_cache_t* executable_cache; @@ -155,6 +157,7 @@ iree_allocator_malloc(host_allocator, sizeof(*state), (void**)&state)); memset(state, 0, sizeof(*state)); state->host_allocator = host_allocator; + state->flags = module->flags; state->shared_device = module->shared_device; iree_hal_device_retain(state->shared_device); @@ -1538,9 +1541,9 @@ .functions = iree_hal_module_funcs_, }; -IREE_API_EXPORT iree_status_t -iree_hal_module_create(iree_hal_device_t* device, iree_allocator_t allocator, - iree_vm_module_t** out_module) { +IREE_API_EXPORT iree_status_t iree_hal_module_create( + iree_hal_device_t* device, iree_hal_module_flags_t flags, + iree_allocator_t host_allocator, iree_vm_module_t** out_module) { IREE_ASSERT_ARGUMENT(device); IREE_ASSERT_ARGUMENT(out_module); *out_module = NULL; @@ -1559,17 +1562,18 @@ iree_vm_native_module_size() + sizeof(iree_hal_module_t); iree_vm_module_t* base_module = NULL; IREE_RETURN_IF_ERROR( - iree_allocator_malloc(allocator, total_size, (void**)&base_module)); + iree_allocator_malloc(host_allocator, total_size, (void**)&base_module)); memset(base_module, 0, total_size); iree_status_t status = iree_vm_native_module_initialize( - &interface, &iree_hal_module_descriptor_, allocator, base_module); + &interface, &iree_hal_module_descriptor_, host_allocator, base_module); if (!iree_status_is_ok(status)) { - iree_allocator_free(allocator, base_module); + iree_allocator_free(host_allocator, base_module); return status; } iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module); - module->host_allocator = allocator; + module->host_allocator = host_allocator; + module->flags = flags; module->shared_device = device; iree_hal_device_retain(module->shared_device);
diff --git a/runtime/src/iree/modules/hal/module.h b/runtime/src/iree/modules/hal/module.h index 86ded4e..68a5408 100644 --- a/runtime/src/iree/modules/hal/module.h +++ b/runtime/src/iree/modules/hal/module.h
@@ -36,6 +36,14 @@ extern "C" { #endif // __cplusplus +enum iree_hal_module_flag_bits_t { + IREE_HAL_MODULE_FLAG_NONE = 0u, + + // Forces HAL methods to block instead of yielding as a coroutine. + IREE_HAL_MODULE_FLAG_SYNCHRONOUS = 1u << 0, +}; +typedef uint32_t iree_hal_module_flags_t; + // Registers the custom types used by the HAL module. // WARNING: not thread-safe; call at startup before using. IREE_API_EXPORT iree_status_t iree_hal_module_register_types(void); @@ -43,9 +51,9 @@ // Creates the HAL module initialized to use a specific |device|. // Each context using this module will share the device and have compatible // allocations. -IREE_API_EXPORT iree_status_t -iree_hal_module_create(iree_hal_device_t* device, iree_allocator_t allocator, - iree_vm_module_t** out_module); +IREE_API_EXPORT iree_status_t iree_hal_module_create( + iree_hal_device_t* device, iree_hal_module_flags_t flags, + iree_allocator_t host_allocator, iree_vm_module_t** out_module); // Returns the device currently in use by the HAL module. // Returns NULL if no device has been initialized yet.
diff --git a/runtime/src/iree/runtime/session.c b/runtime/src/iree/runtime/session.c index 29abb28..7ae2548 100644 --- a/runtime/src/iree/runtime/session.c +++ b/runtime/src/iree/runtime/session.c
@@ -94,7 +94,8 @@ // Lower-level usage of the VM can avoid the HAL if it's not required. iree_vm_module_t* hal_module = NULL; if (iree_status_is_ok(status)) { - status = iree_hal_module_create(device, host_allocator, &hal_module); + status = iree_hal_module_create(device, IREE_HAL_MODULE_FLAG_NONE, + host_allocator, &hal_module); } if (iree_status_is_ok(status)) { status = iree_vm_context_register_modules(
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c index 1b97685..cd74c5d 100644 --- a/runtime/src/iree/tooling/trace_replay.c +++ b/runtime/src/iree/tooling/trace_replay.c
@@ -109,8 +109,9 @@ document, module_node, iree_make_cstring_view("driver"), &driver_node)); IREE_RETURN_IF_ERROR(iree_trace_replay_create_device( replay, driver_node, replay->host_allocator, &replay->device)); - IREE_RETURN_IF_ERROR(iree_hal_module_create( - replay->device, replay->host_allocator, &module)); + IREE_RETURN_IF_ERROR( + iree_hal_module_create(replay->device, IREE_HAL_MODULE_FLAG_NONE, + replay->host_allocator, &module)); } if (!module) { return iree_make_status(
diff --git a/samples/simple_embedding/simple_embedding.c b/samples/simple_embedding/simple_embedding.c index 19c6437..c51cc6f 100644 --- a/samples/simple_embedding/simple_embedding.c +++ b/samples/simple_embedding/simple_embedding.c
@@ -43,7 +43,8 @@ "create device"); iree_vm_module_t* hal_module = NULL; IREE_RETURN_IF_ERROR( - iree_hal_module_create(device, iree_allocator_system(), &hal_module)); + iree_hal_module_create(device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS, + iree_allocator_system(), &hal_module)); // Load bytecode module from the embedded data. const iree_const_byte_span_t module_data = load_bytecode_module_data();
diff --git a/tools/android/run_module_app/src/main.cc b/tools/android/run_module_app/src/main.cc index 703b339..5a18516 100644 --- a/tools/android/run_module_app/src/main.cc +++ b/tools/android/run_module_app/src/main.cc
@@ -108,8 +108,8 @@ iree_make_string_view(invocation.device.data(), invocation.device.size()), iree_allocator_system(), &device)); iree_vm_module_t* hal_module = nullptr; - IREE_RETURN_IF_ERROR( - iree_hal_module_create(device, iree_allocator_system(), &hal_module)); + IREE_RETURN_IF_ERROR(iree_hal_module_create( + device, IREE_HAL_MODULE_FLAG_NONE, iree_allocator_system(), &hal_module)); iree_vm_context_t* context = nullptr; // Order matters. The input module will likely be dependent on the hal module.
diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc index 946b3a0..0650f63 100644 --- a/tools/iree-benchmark-module-main.cc +++ b/tools/iree-benchmark-module-main.cc
@@ -318,7 +318,8 @@ IREE_RETURN_IF_ERROR(iree_hal_create_device_from_flags( iree_hal_default_device_uri(), iree_allocator_system(), &device_)); IREE_RETURN_IF_ERROR( - iree_hal_module_create(device_, iree_allocator_system(), &hal_module_)); + iree_hal_module_create(device_, IREE_HAL_MODULE_FLAG_NONE, + iree_allocator_system(), &hal_module_)); IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create( flatbuffer_contents->const_buffer, iree_file_contents_deallocator(flatbuffer_contents),
diff --git a/tools/iree-check-module-main.cc b/tools/iree-check-module-main.cc index 4dda0fb..6d0bdf8 100644 --- a/tools/iree-check-module-main.cc +++ b/tools/iree-check-module-main.cc
@@ -113,8 +113,8 @@ IREE_RETURN_IF_ERROR(iree_hal_create_device_from_flags( iree_hal_default_device_uri(), iree_allocator_system(), &device)); iree_vm_module_t* hal_module = nullptr; - IREE_RETURN_IF_ERROR( - iree_hal_module_create(device, iree_allocator_system(), &hal_module)); + IREE_RETURN_IF_ERROR(iree_hal_module_create( + device, IREE_HAL_MODULE_FLAG_NONE, iree_allocator_system(), &hal_module)); iree_vm_module_t* check_module = nullptr; IREE_RETURN_IF_ERROR( iree_check_module_create(iree_allocator_system(), &check_module));
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc index 2e6a713..e049f76 100644 --- a/tools/iree-run-mlir-main.cc +++ b/tools/iree-run-mlir-main.cc
@@ -367,8 +367,8 @@ iree_allocator_system(), &device)); iree_vm_module_t* hal_module = nullptr; - IREE_RETURN_IF_ERROR( - iree_hal_module_create(device, iree_allocator_system(), &hal_module)); + IREE_RETURN_IF_ERROR(iree_hal_module_create( + device, IREE_HAL_MODULE_FLAG_NONE, iree_allocator_system(), &hal_module)); // Evaluate all exported functions. auto run_function = [&](int ordinal) -> Status {
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc index 42ab36f..ab9bdfd 100644 --- a/tools/iree-run-module-main.cc +++ b/tools/iree-run-module-main.cc
@@ -114,8 +114,8 @@ IREE_RETURN_IF_ERROR(iree_hal_create_device_from_flags( iree_hal_default_device_uri(), iree_allocator_system(), &device)); iree_vm_module_t* hal_module = nullptr; - IREE_RETURN_IF_ERROR( - iree_hal_module_create(device, iree_allocator_system(), &hal_module)); + IREE_RETURN_IF_ERROR(iree_hal_module_create( + device, IREE_HAL_MODULE_FLAG_NONE, iree_allocator_system(), &hal_module)); iree_vm_context_t* context = nullptr; // Order matters. The input module will likely be dependent on the hal module.