blob: 655289aae5de2cb92dca3d24fa6f52fc637c3e70 [file] [log] [blame]
Lun Dong9ee45102021-08-30 10:02:50 -07001// An example based on iree/samples/simple_embedding.
2
3#include <springbok.h>
4#include <stdio.h>
5
6#include "iree/base/api.h"
7#include "iree/hal/api.h"
8#include "iree/modules/hal/module.h"
9#include "iree/vm/api.h"
10#include "iree/vm/bytecode_module.h"
11#include "samples/util/util.h"
12
13// A function to create the HAL device from the different backend targets.
14// The HAL device is returned based on the implementation, and it must be
15// released by the caller
16extern iree_status_t create_sample_device(iree_hal_device_t **device);
17
18extern const iree_const_byte_span_t load_bytecode_module_data();
19
20extern iree_status_t load_input_data(const MlModel *model, void **buffer);
21
22extern iree_status_t check_output_data(const MlModel *model,
23 iree_hal_buffer_mapping_t *mapped_memory,
24 int index_output);
25
26extern const MlModel kModel;
27
28// Prepare the input buffers and buffer_views based on the data type. They must
29// be released by the caller.
30static iree_status_t prepare_input_hal_buffer_views(
31 const MlModel *model, iree_hal_device_t *device, void **arg0_buffer,
32 iree_hal_buffer_view_t **arg0_buffer_view) {
33 iree_status_t result = iree_ok_status();
34 *arg0_buffer = iree_aligned_alloc(
35 sizeof(uint32_t), model->input_size_bytes * model->input_length);
36 if (*arg0_buffer == NULL) {
37 result = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED);
38 }
39
40 // Populate initial value
41 load_input_data(model, arg0_buffer);
42
43 // Wrap buffers in shaped buffer views.
44 // The buffers can be mapped on the CPU and that can also be used
45 // on the device. Not all devices support this, but the ones we have now do.
46
47 iree_hal_memory_type_t input_memory_type =
48 IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
49
50 result = iree_hal_buffer_view_wrap_or_clone_heap_buffer(
51 iree_hal_device_allocator(device), model->input_shape,
52 model->num_input_dim, model->hal_element_type, input_memory_type,
53 IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, IREE_HAL_MEMORY_ACCESS_READ,
54 IREE_HAL_BUFFER_USAGE_ALL,
55 iree_make_byte_span(*arg0_buffer,
56 model->input_size_bytes * model->input_length),
57 iree_allocator_null(), arg0_buffer_view);
58 return result;
59}
60
61iree_status_t run(const MlModel *model) {
62 IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
63
64 iree_vm_instance_t *instance = NULL;
65 iree_status_t result =
66 iree_vm_instance_create(iree_allocator_system(), &instance);
67
68 iree_hal_device_t *device = NULL;
69 if (iree_status_is_ok(result)) {
70 result = create_sample_device(&device);
71 }
72 iree_vm_module_t *hal_module = NULL;
73 if (iree_status_is_ok(result)) {
74 result =
75 iree_hal_module_create(device, iree_allocator_system(), &hal_module);
76 }
77 // Load bytecode module from the embedded data.
78 const iree_const_byte_span_t module_data = load_bytecode_module_data();
79
80 iree_vm_module_t *bytecode_module = NULL;
81 if (iree_status_is_ok(result)) {
82 result = iree_vm_bytecode_module_create(module_data, iree_allocator_null(),
83 iree_allocator_system(),
84 &bytecode_module);
85 }
86
87 // Allocate a context that will hold the module state across invocations.
88 iree_vm_context_t *context = NULL;
89 iree_vm_module_t *modules[] = {hal_module, bytecode_module};
90 if (iree_status_is_ok(result)) {
91 result = iree_vm_context_create_with_modules(
92 instance, &modules[0], IREE_ARRAYSIZE(modules), iree_allocator_system(),
93 &context);
94 }
95 iree_vm_module_release(hal_module);
96 iree_vm_module_release(bytecode_module);
97
98 // Lookup the entry point function.
99 // Note that we use the synchronous variant which operates on pure type/shape
100 // erased buffers.
Lun Dong9ee45102021-08-30 10:02:50 -0700101 iree_vm_function_t main_function;
102 if (iree_status_is_ok(result)) {
103 result = (iree_vm_context_resolve_function(
Lun Dongb7990a22021-09-17 21:46:30 +0000104 context, iree_make_cstring_view(model->entry_func), &main_function));
Lun Dong9ee45102021-08-30 10:02:50 -0700105 }
106
107 // Prepare the input buffers.
108 void *arg0_buffer = NULL;
109 iree_hal_buffer_view_t *arg0_buffer_view = NULL;
110 if (iree_status_is_ok(result)) {
111 result = prepare_input_hal_buffer_views(model, device, &arg0_buffer,
112 &arg0_buffer_view);
113 }
114
115 // Setup call inputs with our buffers.
116 iree_vm_list_t *inputs = NULL;
117 if (iree_status_is_ok(result)) {
118 result = iree_vm_list_create(
119 /*element_type=*/NULL,
120 /*capacity=*/1, iree_allocator_system(), &inputs);
121 }
122 iree_vm_ref_t arg0_buffer_view_ref =
123 iree_hal_buffer_view_move_ref(arg0_buffer_view);
124 if (iree_status_is_ok(result)) {
125 result = iree_vm_list_push_ref_move(inputs, &arg0_buffer_view_ref);
126 }
127
128 // Prepare outputs list to accept the results from the invocation.
129 // The output vm list is allocated statically.
130 iree_vm_list_t *outputs = NULL;
131 if (iree_status_is_ok(result)) {
132 result = iree_vm_list_create(
133 /*element_type=*/NULL,
134 /*capacity=*/1, iree_allocator_system(), &outputs);
135 }
136
137 // Invoke the function.
138 if (iree_status_is_ok(result)) {
139 result = iree_vm_invoke(context, main_function,
140 /*policy=*/NULL, inputs, outputs,
141 iree_allocator_system());
142 }
143
144 for (int index_output = 0; index_output < model->num_output; index_output++) {
145 iree_hal_buffer_view_t *ret_buffer_view = NULL;
146 if (iree_status_is_ok(result)) {
147 // Get the result buffers from the invocation.
148 ret_buffer_view = (iree_hal_buffer_view_t *)iree_vm_list_get_ref_deref(
149 outputs, index_output, iree_hal_buffer_view_get_descriptor());
150 if (ret_buffer_view == NULL) {
151 result = iree_make_status(IREE_STATUS_NOT_FOUND,
152 "can't find return buffer view");
153 }
154 }
155 // Read back the results and ensure we got the right values.
156 iree_hal_buffer_mapping_t mapped_memory;
157 if (iree_status_is_ok(result)) {
158 result = iree_hal_buffer_map_range(
159 iree_hal_buffer_view_buffer(ret_buffer_view),
160 IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &mapped_memory);
161 }
162 if (iree_status_is_ok(result)) {
163 result = check_output_data(model, &mapped_memory, index_output);
164 iree_hal_buffer_unmap_range(&mapped_memory);
165 }
166 }
167
168 iree_vm_list_release(inputs);
169 iree_vm_list_release(outputs);
170 iree_aligned_free(arg0_buffer);
Lun Dong9ee45102021-08-30 10:02:50 -0700171 iree_vm_context_release(context);
Cindy Liu44b65402021-09-03 00:48:35 -0700172 IREE_IGNORE_ERROR(iree_hal_allocator_statistics_fprint(
173 stdout, iree_hal_device_allocator(device)));
174 iree_hal_device_release(device);
Lun Dong9ee45102021-08-30 10:02:50 -0700175 iree_vm_instance_release(instance);
176 return result;
177}
178
179int main() {
180 const MlModel *model_ptr = &kModel;
181 const iree_status_t result = run(model_ptr);
182 int ret = (int)iree_status_code(result);
183 if (!iree_status_is_ok(result)) {
184 iree_status_fprint(stderr, result);
185 iree_status_free(result);
186 } else {
187 LOG_INFO("%s finished successfully", model_ptr->model_name);
188 }
189
190 return ret;
191}