blob: 61c86a9a0b813ef76fb5ff7c9eecf2148fb75d54 [file] [log] [blame]
// An example based on iree/samples/simple_embedding. Test a 1024-element int32
// multiplication with vector extension ISA.
#include <springbok.h>
#include <stdio.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
// A function to create the HAL device from the different backend targets.
// The HAL device is returned based on the implementation, and it must be
// released by the caller.
extern iree_status_t create_sample_device(iree_hal_device_t **device);
extern const iree_const_byte_span_t load_bytecode_module_data();
// Prepare the input buffers and buffer_views based on the data type. They must
// be released by the caller.
extern iree_status_t prepare_input_hal_buffer_views(
iree_hal_device_t *device, const int buffer_length, void **arg0_buffer,
void **arg1_buffer, iree_hal_buffer_view_t **arg0_buffer_view,
iree_hal_buffer_view_t **arg1_buffer_view);
extern iree_status_t check_output_data(
iree_hal_buffer_mapping_t *mapped_memory);
iree_status_t run() {
IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
iree_vm_instance_t *instance = NULL;
iree_status_t result =
iree_vm_instance_create(iree_allocator_system(), &instance);
iree_hal_device_t *device = NULL;
if (iree_status_is_ok(result)) {
result = create_sample_device(&device);
}
iree_vm_module_t *hal_module = NULL;
if (iree_status_is_ok(result)) {
result =
iree_hal_module_create(device, iree_allocator_system(), &hal_module);
}
// Load bytecode module from the embedded data.
const iree_const_byte_span_t module_data = load_bytecode_module_data();
iree_vm_module_t *bytecode_module = NULL;
if (iree_status_is_ok(result)) {
result = iree_vm_bytecode_module_create(module_data, iree_allocator_null(),
iree_allocator_system(),
&bytecode_module);
}
// Allocate a context that will hold the module state across invocations.
iree_vm_context_t *context = NULL;
iree_vm_module_t *modules[] = {hal_module, bytecode_module};
if (iree_status_is_ok(result)) {
result = iree_vm_context_create_with_modules(
instance, IREE_VM_CONTEXT_FLAG_NONE, &modules[0],
IREE_ARRAYSIZE(modules), iree_allocator_system(), &context);
}
iree_vm_module_release(hal_module);
iree_vm_module_release(bytecode_module);
// Lookup the entry point function.
// Note that we use the synchronous variant which operates on pure type/shape
// erased buffers.
const char kMainFunctionName[] = "module.simple_mul";
iree_vm_function_t main_function;
if (iree_status_is_ok(result)) {
result = (iree_vm_context_resolve_function(
context, iree_make_cstring_view(kMainFunctionName), &main_function));
}
// Prepare the input buffers.
void *arg0_buffer = NULL;
void *arg1_buffer = NULL;
iree_hal_buffer_view_t *arg0_buffer_view = NULL;
iree_hal_buffer_view_t *arg1_buffer_view = NULL;
const int kElementCount = 1024;
if (iree_status_is_ok(result)) {
result = prepare_input_hal_buffer_views(device, kElementCount, &arg0_buffer,
&arg1_buffer, &arg0_buffer_view,
&arg1_buffer_view);
}
// Setup call inputs with our buffers.
iree_vm_list_t *inputs = NULL;
if (iree_status_is_ok(result)) {
result = iree_vm_list_create(
/*element_type=*/NULL,
/*capacity=*/2, iree_allocator_system(), &inputs);
}
iree_vm_ref_t arg0_buffer_view_ref =
iree_hal_buffer_view_move_ref(arg0_buffer_view);
iree_vm_ref_t arg1_buffer_view_ref =
iree_hal_buffer_view_move_ref(arg1_buffer_view);
if (iree_status_is_ok(result)) {
result = iree_vm_list_push_ref_move(inputs, &arg0_buffer_view_ref);
}
if (iree_status_is_ok(result)) {
result = iree_vm_list_push_ref_move(inputs, &arg1_buffer_view_ref);
}
// Prepare outputs list to accept the results from the invocation.
// The output vm list is allocated statically.
iree_vm_list_t *outputs = NULL;
if (iree_status_is_ok(result)) {
result = iree_vm_list_create(
/*element_type=*/NULL,
/*capacity=*/1, iree_allocator_system(), &outputs);
}
// Invoke the function.
if (iree_status_is_ok(result)) {
result = iree_vm_invoke(context, main_function, IREE_VM_CONTEXT_FLAG_NONE,
/*policy=*/NULL, inputs, outputs,
iree_allocator_system());
}
iree_hal_buffer_view_t *ret_buffer_view = NULL;
if (iree_status_is_ok(result)) {
// Get the result buffers from the invocation.
ret_buffer_view = (iree_hal_buffer_view_t *)iree_vm_list_get_ref_deref(
outputs, 0, iree_hal_buffer_view_get_descriptor());
if (ret_buffer_view == NULL) {
result = iree_make_status(IREE_STATUS_NOT_FOUND,
"can't find return buffer view");
}
}
// Read back the results and ensure we got the right values.
iree_hal_buffer_mapping_t mapped_memory;
if (iree_status_is_ok(result)) {
result = iree_hal_buffer_map_range(
iree_hal_buffer_view_buffer(ret_buffer_view),
IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &mapped_memory);
}
if (iree_status_is_ok(result)) {
result = check_output_data(&mapped_memory);
iree_hal_buffer_unmap_range(&mapped_memory);
}
iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
iree_aligned_free(arg0_buffer);
iree_aligned_free(arg1_buffer);
iree_vm_context_release(context);
IREE_IGNORE_ERROR(iree_hal_allocator_statistics_fprint(
stdout, iree_hal_device_allocator(device)));
iree_hal_device_release(device);
iree_vm_instance_release(instance);
return result;
}
int main() {
const iree_status_t result = run();
int ret = (int)iree_status_code(result);
if (!iree_status_is_ok(result)) {
iree_status_fprint(stderr, result);
iree_status_free(result);
} else {
LOG_INFO("simple_vec_mul finished successfully");
}
return ret;
}