blob: cb674a0fd478f51783efaa2fb49d996d30747d02 [file] [log] [blame]
// An example based on iree/samples/simple_embedding.
#include "samples/util/util.h"
#include <springbok.h>
#include <stdio.h>
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "samples/device/device.h"
extern const MlModel kModel;
// Function to create bytecode or C module.
static iree_status_t create_module(iree_vm_module_t **module) {
#if !defined(BUILD_EMITC_STATIC)
const iree_const_byte_span_t module_data = load_bytecode_module_data();
return iree_vm_bytecode_module_create(module_data, iree_allocator_null(),
iree_allocator_system(), module);
#else
return create_c_module(module);
#endif
}
// Prepare the input buffers and buffer_views based on the data type. They must
// be released by the caller.
static iree_status_t prepare_input_hal_buffer_views(
const MlModel *model, iree_hal_device_t *device, void **arg_buffers,
iree_hal_buffer_view_t **arg_buffer_views) {
iree_status_t result = iree_ok_status();
// Prepare the input buffer, and populate the initial value.
// The input buffer must be released by the caller.
result = load_input_data(model, arg_buffers);
// Wrap buffers in shaped buffer views.
// The buffers can be mapped on the CPU and that can also be used
// on the device. Not all devices support this, but the ones we have now do.
iree_hal_memory_type_t input_memory_type =
IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
for (int i = 0; i < model->num_input; ++i) {
if (iree_status_is_ok(result)) {
result = iree_hal_buffer_view_wrap_or_clone_heap_buffer(
iree_hal_device_allocator(device), model->input_shape[i],
model->num_input_dim[i], model->hal_element_type,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, input_memory_type,
IREE_HAL_MEMORY_ACCESS_READ, IREE_HAL_BUFFER_USAGE_ALL,
iree_make_byte_span(arg_buffers[i], model->input_size_bytes[i] *
model->input_length[i]),
iree_allocator_null(), &(arg_buffer_views[i]));
}
}
return result;
}
iree_status_t run(const MlModel *model) {
IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
iree_vm_instance_t *instance = NULL;
iree_status_t result =
iree_vm_instance_create(iree_allocator_system(), &instance);
iree_hal_device_t *device = NULL;
if (iree_status_is_ok(result)) {
result = create_sample_device(iree_allocator_system(), &device);
}
iree_vm_module_t *hal_module = NULL;
if (iree_status_is_ok(result)) {
result =
iree_hal_module_create(device, iree_allocator_system(), &hal_module);
}
// Load bytecode or C module.
iree_vm_module_t *module = NULL;
if (iree_status_is_ok(result)) {
result = create_module(&module);
}
// Allocate a context that will hold the module state across invocations.
iree_vm_context_t *context = NULL;
iree_vm_module_t *modules[] = {hal_module, module};
if (iree_status_is_ok(result)) {
result = iree_vm_context_create_with_modules(
instance, IREE_VM_CONTEXT_FLAG_NONE, &modules[0],
IREE_ARRAYSIZE(modules), iree_allocator_system(), &context);
}
iree_vm_module_release(hal_module);
iree_vm_module_release(module);
// Lookup the entry point function.
// Note that we use the synchronous variant which operates on pure type/shape
// erased buffers.
iree_vm_function_t main_function;
if (iree_status_is_ok(result)) {
result = (iree_vm_context_resolve_function(
context, iree_make_cstring_view(model->entry_func), &main_function));
}
// Prepare the input buffers.
void *arg_buffers[MAX_MODEL_INPUT_NUM] = {NULL};
iree_hal_buffer_view_t *arg_buffer_views[MAX_MODEL_INPUT_NUM] = {NULL};
if (iree_status_is_ok(result)) {
result = prepare_input_hal_buffer_views(model, device, arg_buffers,
arg_buffer_views);
}
// Setup call inputs with our buffers.
iree_vm_list_t *inputs = NULL;
if (iree_status_is_ok(result)) {
result = iree_vm_list_create(
/*element_type=*/NULL, /*capacity=*/model->num_input,
iree_allocator_system(), &inputs);
}
iree_vm_ref_t arg_buffer_view_ref;
for (int i = 0; i < model->num_input; ++i) {
arg_buffer_view_ref = iree_hal_buffer_view_move_ref(arg_buffer_views[i]);
if (iree_status_is_ok(result)) {
result = iree_vm_list_push_ref_move(inputs, &arg_buffer_view_ref);
}
}
// Prepare outputs list to accept the results from the invocation.
// The output vm list is allocated statically.
iree_vm_list_t *outputs = NULL;
if (iree_status_is_ok(result)) {
result = iree_vm_list_create(
/*element_type=*/NULL,
/*capacity=*/1, iree_allocator_system(), &outputs);
}
// Invoke the function.
if (iree_status_is_ok(result)) {
result = iree_vm_invoke(context, main_function, IREE_VM_CONTEXT_FLAG_NONE,
/*policy=*/NULL, inputs, outputs,
iree_allocator_system());
}
for (int index_output = 0; index_output < model->num_output; index_output++) {
iree_hal_buffer_view_t *ret_buffer_view = NULL;
if (iree_status_is_ok(result)) {
// Get the result buffers from the invocation.
ret_buffer_view = (iree_hal_buffer_view_t *)iree_vm_list_get_ref_deref(
outputs, index_output, iree_hal_buffer_view_get_descriptor());
if (ret_buffer_view == NULL) {
result = iree_make_status(IREE_STATUS_NOT_FOUND,
"can't find return buffer view");
}
}
// Read back the results and ensure we got the right values.
iree_hal_buffer_mapping_t mapped_memory;
if (iree_status_is_ok(result)) {
result = iree_hal_buffer_map_range(
iree_hal_buffer_view_buffer(ret_buffer_view),
IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_READ, 0,
IREE_WHOLE_BUFFER, &mapped_memory);
}
if (iree_status_is_ok(result)) {
result = check_output_data(model, &mapped_memory, index_output);
iree_hal_buffer_unmap_range(&mapped_memory);
}
}
iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
for (int i = 0; i < model->num_input; ++i) {
if (arg_buffers[i] != NULL) {
iree_aligned_free(arg_buffers[i]);
}
}
iree_vm_context_release(context);
IREE_IGNORE_ERROR(iree_hal_allocator_statistics_fprint(
stdout, iree_hal_device_allocator(device)));
iree_hal_device_release(device);
iree_vm_instance_release(instance);
return result;
}
int main() {
const MlModel *model_ptr = &kModel;
const iree_status_t result = run(model_ptr);
int ret = (int)iree_status_code(result);
if (!iree_status_is_ok(result)) {
iree_status_fprint(stderr, result);
iree_status_free(result);
} else {
LOG_INFO("%s finished successfully", model_ptr->model_name);
}
return ret;
}