| // Copyright 2025 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| // Demonstrates using external transient buffers with the IREE C API. |
| // |
| // This sample shows how to: |
| // 1. Query the transient buffer size needed for a function using reflection |
| // attributes (iree.abi.transients.size.constant or iree.abi.transients.size) |
| // 2. Allocate the transient buffer from the device |
| // 3. Pass it to the function invocation along with inputs and outputs |
| // |
| // NOTE: this file does not properly handle error cases and will leak on |
| // failure. Applications that are just going to exit()/abort() on failure can |
| // probably get away with the same thing but really should prefer not to. |
| |
| #include <stdio.h> |
| #include <stdlib.h> |
| |
| #include "iree/base/api.h" |
| #include "iree/hal/api.h" |
| #include "iree/modules/hal/module.h" |
| #include "iree/vm/api.h" |
| #include "iree/vm/bytecode/module.h" |
| |
| // A function to create the HAL device from the different backend targets. |
| // The HAL device is returned based on the implementation, and it must be |
| // released by the caller. |
| extern iree_status_t create_sample_device(iree_allocator_t host_allocator, |
| iree_hal_device_t** out_device); |
| |
| // A function to load the vm bytecode module from the different backend targets. |
| // The bytecode module is generated for the specific backend and platform. |
| extern const iree_const_byte_span_t load_bytecode_module_data(); |
| |
| // Queries the size needed for transient storage for the given function. |
| // This checks for the iree.abi.transients.size.constant attribute first, |
| // and if not present, calls the function referenced by |
| // iree.abi.transients.size. |
| static iree_status_t query_transient_size( |
| iree_vm_context_t* context, const iree_vm_function_t* main_function, |
| iree_vm_list_t* main_inputs, iree_allocator_t host_allocator, |
| iree_host_size_t* out_size) { |
| // First check for a constant size attribute. |
| iree_string_view_t size_constant_attr = iree_vm_function_lookup_attr_by_name( |
| main_function, IREE_SV("iree.abi.transients.size.constant")); |
| if (!iree_string_view_is_empty(size_constant_attr)) { |
| // Constant size is specified - parse it directly. |
| if (!iree_string_view_atoi_uint64(size_constant_attr, out_size)) { |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "failed to parse integer attribute"); |
| } |
| return iree_ok_status(); |
| } |
| |
| // No constant size - need to call the size query function. |
| iree_string_view_t size_func_name = iree_vm_function_lookup_attr_by_name( |
| main_function, IREE_SV("iree.abi.transients.size")); |
| if (iree_string_view_is_empty(size_func_name)) { |
| return iree_make_status( |
| IREE_STATUS_NOT_FOUND, |
| "function has no transient size information (missing both " |
| "iree.abi.transients.size.constant and iree.abi.transients.size)"); |
| } |
| |
| // Resolve the size query function. |
| iree_vm_function_t size_function; |
| IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function(context, size_func_name, |
| &size_function)); |
| |
| // Call the size function with the same inputs as the main function. |
| iree_vm_list_t* size_outputs = NULL; |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_create(iree_vm_make_undefined_type_def(), |
| /*capacity=*/1, host_allocator, &size_outputs), |
| "can't allocate size query output list"); |
| iree_status_t status = iree_vm_invoke( |
| context, size_function, IREE_VM_INVOCATION_FLAG_NONE, |
| /*policy=*/NULL, main_inputs, size_outputs, host_allocator); |
| if (iree_status_is_ok(status)) { |
| // Extract the size from the output (should be an i64). |
| iree_vm_value_t size_value; |
| status = iree_vm_list_get_value_as(size_outputs, 0, IREE_VM_VALUE_TYPE_I64, |
| &size_value); |
| if (iree_status_is_ok(status)) { |
| *out_size = (iree_host_size_t)size_value.i64; |
| } |
| } |
| iree_vm_list_release(size_outputs); |
| return status; |
| } |
| |
| #define INPUT_SIZE 64 |
| |
| iree_status_t Run() { |
| iree_allocator_t host_allocator = iree_allocator_system(); |
| |
| iree_vm_instance_t* instance = NULL; |
| IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, |
| host_allocator, &instance)); |
| IREE_CHECK_OK(iree_hal_module_register_all_types(instance)); |
| |
| iree_hal_device_t* device = NULL; |
| IREE_CHECK_OK(create_sample_device(host_allocator, &device), "create device"); |
| iree_vm_module_t* hal_module = NULL; |
| IREE_CHECK_OK(iree_hal_module_create( |
| instance, iree_hal_module_device_policy_default(), /*device_count=*/1, |
| &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS, |
| iree_hal_module_debug_sink_stdio(stderr), host_allocator, &hal_module)); |
| |
| // Load bytecode module from the embedded data. |
| const iree_const_byte_span_t module_data = load_bytecode_module_data(); |
| iree_vm_module_t* bytecode_module = NULL; |
| IREE_CHECK_OK(iree_vm_bytecode_module_create( |
| instance, module_data, iree_allocator_null(), host_allocator, |
| &bytecode_module)); |
| |
| // Allocate a context that will hold the module state across invocations. |
| iree_vm_context_t* context = NULL; |
| iree_vm_module_t* modules[] = {hal_module, bytecode_module}; |
| IREE_CHECK_OK(iree_vm_context_create_with_modules( |
| instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0], |
| host_allocator, &context)); |
| iree_vm_module_release(hal_module); |
| iree_vm_module_release(bytecode_module); |
| |
| // Lookup the entry point function. |
| const char kMainFunctionName[] = "module.in_place_computation"; |
| iree_vm_function_t main_function; |
| IREE_CHECK_OK(iree_vm_context_resolve_function( |
| context, iree_make_cstring_view(kMainFunctionName), &main_function)); |
| |
| // Prepare input buffer: 64xf32 filled with 1.0. |
| float input_data[INPUT_SIZE]; |
| for (iree_host_size_t i = 0; i < INPUT_SIZE; ++i) { |
| input_data[i] = 1.0f; |
| } |
| iree_hal_dim_t input_shape[1] = {INPUT_SIZE}; |
| iree_hal_buffer_view_t* input_buffer_view = NULL; |
| IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer_copy( |
| device, iree_hal_device_allocator(device), IREE_ARRAYSIZE(input_shape), |
| input_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, |
| IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, |
| (iree_hal_buffer_params_t){ |
| .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, |
| .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, |
| }, |
| iree_make_const_byte_span(input_data, sizeof(input_data)), |
| &input_buffer_view)); |
| |
| // Prepare output buffer: 64xf32 of zeros (storage for the result). |
| float output_data[INPUT_SIZE]; |
| for (iree_host_size_t i = 0; i < INPUT_SIZE; ++i) { |
| output_data[i] = 0.0f; |
| } |
| iree_hal_buffer_view_t* output_buffer_view = NULL; |
| IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer_copy( |
| device, iree_hal_device_allocator(device), IREE_ARRAYSIZE(input_shape), |
| input_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, |
| IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, |
| (iree_hal_buffer_params_t){ |
| .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, |
| .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, |
| }, |
| iree_make_const_byte_span(output_data, sizeof(output_data)), |
| &output_buffer_view)); |
| |
| // Setup inputs list for size query (input + output, but no transient yet). |
| iree_vm_list_t* size_query_inputs = NULL; |
| IREE_CHECK_OK( |
| iree_vm_list_create(iree_vm_make_undefined_type_def(), |
| /*capacity=*/3, host_allocator, &size_query_inputs), |
| "can't allocate size query input list"); |
| iree_vm_ref_t input_ref = iree_hal_buffer_view_move_ref(input_buffer_view); |
| iree_vm_ref_t output_ref = iree_hal_buffer_view_move_ref(output_buffer_view); |
| IREE_CHECK_OK(iree_vm_list_push_ref_retain(size_query_inputs, &input_ref)); |
| IREE_CHECK_OK(iree_vm_list_push_ref_retain(size_query_inputs, &output_ref)); |
| |
| // Query the transient buffer size. |
| iree_host_size_t transient_size = 0; |
| IREE_CHECK_OK(query_transient_size(context, &main_function, size_query_inputs, |
| host_allocator, &transient_size), |
| "failed to query transient size"); |
| fprintf(stdout, "Transient buffer size needed: %zu bytes\n", transient_size); |
| |
| // Allocate the transient buffer. |
| iree_hal_buffer_t* transient_buffer = NULL; |
| IREE_CHECK_OK(iree_hal_allocator_allocate_buffer( |
| iree_hal_device_allocator(device), |
| (iree_hal_buffer_params_t){ |
| .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, |
| .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, |
| }, |
| transient_size, &transient_buffer), |
| "failed to allocate transient buffer"); |
| |
| // Setup call inputs: input, output, transient. |
| iree_vm_list_t* inputs = NULL; |
| IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), |
| /*capacity=*/3, host_allocator, &inputs), |
| "can't allocate input vm list"); |
| |
| iree_vm_ref_t transient_ref = iree_hal_buffer_move_ref(transient_buffer); |
| IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &input_ref)); |
| IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &output_ref)); |
| IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &transient_ref)); |
| |
| // Prepare outputs list to accept the results from the invocation. |
| iree_vm_list_t* outputs = NULL; |
| IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), |
| /*capacity=*/1, host_allocator, &outputs), |
| "can't allocate output vm list"); |
| |
| // Synchronously invoke the function. |
| IREE_CHECK_OK( |
| iree_vm_invoke(context, main_function, IREE_VM_INVOCATION_FLAG_NONE, |
| /*policy=*/NULL, inputs, outputs, host_allocator)); |
| |
| // Get the result buffer from the invocation. |
| iree_hal_buffer_view_t* ret_buffer_view = |
| iree_vm_list_get_buffer_view_assign(outputs, 0); |
| if (ret_buffer_view == NULL) { |
| return iree_make_status(IREE_STATUS_NOT_FOUND, |
| "can't find return buffer view"); |
| } |
| |
| // Read back the results and ensure we got the right values. |
| // Expected: (input + 1.0) * 2.0 + input = (1.0 + 1.0) * 2.0 + 1.0 = 5.0 |
| float results[INPUT_SIZE]; |
| IREE_CHECK_OK(iree_hal_device_transfer_d2h( |
| device, iree_hal_buffer_view_buffer(ret_buffer_view), 0, results, |
| sizeof(results), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, |
| iree_infinite_timeout())); |
| fprintf(stdout, "Results: "); |
| for (iree_host_size_t i = 0; i < INPUT_SIZE; ++i) { |
| if (i > 0) fprintf(stdout, ", "); |
| fprintf(stdout, "%.1f", results[i]); |
| if (i >= 5) { |
| fprintf(stdout, ", ..."); |
| break; |
| } |
| } |
| fprintf(stdout, "\n"); |
| |
| // Verify the results. |
| for (iree_host_size_t i = 0; i < INPUT_SIZE; ++i) { |
| if (results[i] != 5.0f) { |
| return iree_make_status( |
| IREE_STATUS_UNKNOWN, |
| "result mismatches at index %zu: expected 5.0, got %.1f", i, |
| results[i]); |
| } |
| } |
| |
| iree_vm_list_release(size_query_inputs); |
| iree_vm_list_release(inputs); |
| iree_vm_list_release(outputs); |
| iree_hal_device_release(device); |
| iree_vm_context_release(context); |
| iree_vm_instance_release(instance); |
| return iree_ok_status(); |
| } |
| |
| int main() { |
| const iree_status_t result = Run(); |
| int ret = (int)iree_status_code(result); |
| if (!iree_status_is_ok(result)) { |
| iree_status_fprint(stderr, result); |
| iree_status_free(result); |
| } else { |
| fprintf(stdout, "external_transients sample completed successfully\n"); |
| } |
| return ret; |
| } |