blob: 4f94022ec3e665b04b5961797363c80f2b84051e [file] [log] [blame]
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Demonstrates using external transient buffers with the IREE C API.
//
// This sample shows how to:
// 1. Query the transient buffer size needed for a function using reflection
// attributes (iree.abi.transients.size.constant or iree.abi.transients.size)
// 2. Allocate the transient buffer from the device
// 3. Pass it to the function invocation along with inputs and outputs
//
// NOTE: this file does not properly handle error cases and will leak on
// failure. Applications that are just going to exit()/abort() on failure can
// probably get away with the same thing but really should prefer not to.
#include <stdio.h>
#include <stdlib.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode/module.h"
// A function to create the HAL device from the different backend targets.
// The HAL device is returned based on the implementation, and it must be
// released by the caller.
extern iree_status_t create_sample_device(iree_allocator_t host_allocator,
iree_hal_device_t** out_device);
// A function to load the vm bytecode module from the different backend targets.
// The bytecode module is generated for the specific backend and platform.
extern const iree_const_byte_span_t load_bytecode_module_data();
// Queries the size needed for transient storage for the given function.
// This checks for the iree.abi.transients.size.constant attribute first,
// and if not present, calls the function referenced by
// iree.abi.transients.size.
static iree_status_t query_transient_size(
iree_vm_context_t* context, const iree_vm_function_t* main_function,
iree_vm_list_t* main_inputs, iree_allocator_t host_allocator,
iree_host_size_t* out_size) {
// First check for a constant size attribute.
iree_string_view_t size_constant_attr = iree_vm_function_lookup_attr_by_name(
main_function, IREE_SV("iree.abi.transients.size.constant"));
if (!iree_string_view_is_empty(size_constant_attr)) {
// Constant size is specified - parse it directly.
if (!iree_string_view_atoi_uint64(size_constant_attr, out_size)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"failed to parse integer attribute");
}
return iree_ok_status();
}
// No constant size - need to call the size query function.
iree_string_view_t size_func_name = iree_vm_function_lookup_attr_by_name(
main_function, IREE_SV("iree.abi.transients.size"));
if (iree_string_view_is_empty(size_func_name)) {
return iree_make_status(
IREE_STATUS_NOT_FOUND,
"function has no transient size information (missing both "
"iree.abi.transients.size.constant and iree.abi.transients.size)");
}
// Resolve the size query function.
iree_vm_function_t size_function;
IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function(context, size_func_name,
&size_function));
// Call the size function with the same inputs as the main function.
iree_vm_list_t* size_outputs = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(),
/*capacity=*/1, host_allocator, &size_outputs),
"can't allocate size query output list");
iree_status_t status = iree_vm_invoke(
context, size_function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/NULL, main_inputs, size_outputs, host_allocator);
if (iree_status_is_ok(status)) {
// Extract the size from the output (should be an i64).
iree_vm_value_t size_value;
status = iree_vm_list_get_value_as(size_outputs, 0, IREE_VM_VALUE_TYPE_I64,
&size_value);
if (iree_status_is_ok(status)) {
*out_size = (iree_host_size_t)size_value.i64;
}
}
iree_vm_list_release(size_outputs);
return status;
}
#define INPUT_SIZE 64
iree_status_t Run() {
iree_allocator_t host_allocator = iree_allocator_system();
iree_vm_instance_t* instance = NULL;
IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT,
host_allocator, &instance));
IREE_CHECK_OK(iree_hal_module_register_all_types(instance));
iree_hal_device_t* device = NULL;
IREE_CHECK_OK(create_sample_device(host_allocator, &device), "create device");
iree_vm_module_t* hal_module = NULL;
IREE_CHECK_OK(iree_hal_module_create(
instance, iree_hal_module_device_policy_default(), /*device_count=*/1,
&device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
iree_hal_module_debug_sink_stdio(stderr), host_allocator, &hal_module));
// Load bytecode module from the embedded data.
const iree_const_byte_span_t module_data = load_bytecode_module_data();
iree_vm_module_t* bytecode_module = NULL;
IREE_CHECK_OK(iree_vm_bytecode_module_create(
instance, module_data, iree_allocator_null(), host_allocator,
&bytecode_module));
// Allocate a context that will hold the module state across invocations.
iree_vm_context_t* context = NULL;
iree_vm_module_t* modules[] = {hal_module, bytecode_module};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0],
host_allocator, &context));
iree_vm_module_release(hal_module);
iree_vm_module_release(bytecode_module);
// Lookup the entry point function.
const char kMainFunctionName[] = "module.in_place_computation";
iree_vm_function_t main_function;
IREE_CHECK_OK(iree_vm_context_resolve_function(
context, iree_make_cstring_view(kMainFunctionName), &main_function));
// Prepare input buffer: 64xf32 filled with 1.0.
float input_data[INPUT_SIZE];
for (iree_host_size_t i = 0; i < INPUT_SIZE; ++i) {
input_data[i] = 1.0f;
}
iree_hal_dim_t input_shape[1] = {INPUT_SIZE};
iree_hal_buffer_view_t* input_buffer_view = NULL;
IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device), IREE_ARRAYSIZE(input_shape),
input_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
},
iree_make_const_byte_span(input_data, sizeof(input_data)),
&input_buffer_view));
// Prepare output buffer: 64xf32 of zeros (storage for the result).
float output_data[INPUT_SIZE];
for (iree_host_size_t i = 0; i < INPUT_SIZE; ++i) {
output_data[i] = 0.0f;
}
iree_hal_buffer_view_t* output_buffer_view = NULL;
IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device), IREE_ARRAYSIZE(input_shape),
input_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
},
iree_make_const_byte_span(output_data, sizeof(output_data)),
&output_buffer_view));
// Setup inputs list for size query (input + output, but no transient yet).
iree_vm_list_t* size_query_inputs = NULL;
IREE_CHECK_OK(
iree_vm_list_create(iree_vm_make_undefined_type_def(),
/*capacity=*/3, host_allocator, &size_query_inputs),
"can't allocate size query input list");
iree_vm_ref_t input_ref = iree_hal_buffer_view_move_ref(input_buffer_view);
iree_vm_ref_t output_ref = iree_hal_buffer_view_move_ref(output_buffer_view);
IREE_CHECK_OK(iree_vm_list_push_ref_retain(size_query_inputs, &input_ref));
IREE_CHECK_OK(iree_vm_list_push_ref_retain(size_query_inputs, &output_ref));
// Query the transient buffer size.
iree_host_size_t transient_size = 0;
IREE_CHECK_OK(query_transient_size(context, &main_function, size_query_inputs,
host_allocator, &transient_size),
"failed to query transient size");
fprintf(stdout, "Transient buffer size needed: %zu bytes\n", transient_size);
// Allocate the transient buffer.
iree_hal_buffer_t* transient_buffer = NULL;
IREE_CHECK_OK(iree_hal_allocator_allocate_buffer(
iree_hal_device_allocator(device),
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
},
transient_size, &transient_buffer),
"failed to allocate transient buffer");
// Setup call inputs: input, output, transient.
iree_vm_list_t* inputs = NULL;
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(),
/*capacity=*/3, host_allocator, &inputs),
"can't allocate input vm list");
iree_vm_ref_t transient_ref = iree_hal_buffer_move_ref(transient_buffer);
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &input_ref));
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &output_ref));
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &transient_ref));
// Prepare outputs list to accept the results from the invocation.
iree_vm_list_t* outputs = NULL;
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(),
/*capacity=*/1, host_allocator, &outputs),
"can't allocate output vm list");
// Synchronously invoke the function.
IREE_CHECK_OK(
iree_vm_invoke(context, main_function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/NULL, inputs, outputs, host_allocator));
// Get the result buffer from the invocation.
iree_hal_buffer_view_t* ret_buffer_view =
iree_vm_list_get_buffer_view_assign(outputs, 0);
if (ret_buffer_view == NULL) {
return iree_make_status(IREE_STATUS_NOT_FOUND,
"can't find return buffer view");
}
// Read back the results and ensure we got the right values.
// Expected: (input + 1.0) * 2.0 + input = (1.0 + 1.0) * 2.0 + 1.0 = 5.0
float results[INPUT_SIZE];
IREE_CHECK_OK(iree_hal_device_transfer_d2h(
device, iree_hal_buffer_view_buffer(ret_buffer_view), 0, results,
sizeof(results), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout()));
fprintf(stdout, "Results: ");
for (iree_host_size_t i = 0; i < INPUT_SIZE; ++i) {
if (i > 0) fprintf(stdout, ", ");
fprintf(stdout, "%.1f", results[i]);
if (i >= 5) {
fprintf(stdout, ", ...");
break;
}
}
fprintf(stdout, "\n");
// Verify the results.
for (iree_host_size_t i = 0; i < INPUT_SIZE; ++i) {
if (results[i] != 5.0f) {
return iree_make_status(
IREE_STATUS_UNKNOWN,
"result mismatches at index %zu: expected 5.0, got %.1f", i,
results[i]);
}
}
iree_vm_list_release(size_query_inputs);
iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
iree_hal_device_release(device);
iree_vm_context_release(context);
iree_vm_instance_release(instance);
return iree_ok_status();
}
int main() {
const iree_status_t result = Run();
int ret = (int)iree_status_code(result);
if (!iree_status_is_ok(result)) {
iree_status_fprint(stderr, result);
iree_status_free(result);
} else {
fprintf(stdout, "external_transients sample completed successfully\n");
}
return ret;
}