blob: 913335de42d007652314f1b878eed7b4108ea81a [file] [log] [blame]
Adam Jesionowski6e273a72022-04-14 12:20:20 -07001/*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Cindy Liu986ee242021-07-29 23:23:17 -070017// Float simple_mul bytecode loading and input/output processes
18
Cindy Liu3c4d6272022-08-04 18:54:12 -070019#include "model_util/util.h"
Cindy Liu986ee242021-07-29 23:23:17 -070020
21// Compiled module embedded here to avoid file IO:
Lun Dong6243e9e2022-09-06 15:25:07 -070022#if defined(BUILD_VMVX)
23#if !defined(BUILD_EMITC)
24#include "samples/simple_vec_mul/simple_float_mul_bytecode_module_vmvx_c.h"
25#else
26#include "samples/simple_vec_mul/simple_float_mul_c_module_vmvx_emitc.h"
27#endif // !defined(BUILD_EMITC)
28#else
Lun Dong53c187c2022-01-24 18:48:13 +000029#if !defined(BUILD_EMITC)
30#include "samples/simple_vec_mul/simple_float_mul_bytecode_module_static.h"
31#include "samples/simple_vec_mul/simple_float_mul_bytecode_module_static_c.h"
Lun Dong72890d02021-12-03 18:51:03 -080032#else
33#include "samples/simple_vec_mul/simple_float_mul_c_module_static_c.h"
34#include "samples/simple_vec_mul/simple_float_mul_c_module_static_emitc.h"
Lun Dong6243e9e2022-09-06 15:25:07 -070035#endif // #if !defined(BUILD_EMITC)
36#endif // #if defined(BUILD_VMVX)
Cindy Liu986ee242021-07-29 23:23:17 -070037
Cindy Liue3240e22021-10-07 17:24:21 -070038const MlModel kModel = {
39 .num_input = 2,
40 .num_input_dim = {1, 1},
41 .input_shape = {{1024}, {1024}},
42 .input_length = {1024, 1024},
43 .input_size_bytes = {sizeof(float), sizeof(float)},
44 .num_output = 1,
45 .output_length = {1024},
46 .output_size_bytes = sizeof(float),
47 .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32,
48 .entry_func = "module.simple_mul",
49 .model_name = "simple_float_vec_mul",
50};
51
Cindy Liucaad9952022-08-10 10:55:34 -070052iree_status_t create_module(iree_vm_instance_t *instance,
53 iree_vm_module_t **module) {
Lun Dong53c187c2022-01-24 18:48:13 +000054#if !defined(BUILD_EMITC)
Lun Dong6243e9e2022-09-06 15:25:07 -070055#if defined(BUILD_VMVX)
56 const struct iree_file_toc_t *module_file_toc =
57 samples_simple_vec_mul_simple_float_mul_bytecode_module_vmvx_create();
58#else
Cindy Liu986ee242021-07-29 23:23:17 -070059 const struct iree_file_toc_t *module_file_toc =
Lun Dong53c187c2022-01-24 18:48:13 +000060 samples_simple_vec_mul_simple_float_mul_bytecode_module_static_create();
Lun Dong6243e9e2022-09-06 15:25:07 -070061#endif // #if defined(BUILD_VMVX)
Lun Dong53c187c2022-01-24 18:48:13 +000062 return iree_vm_bytecode_module_create(
Cindy Liucaad9952022-08-10 10:55:34 -070063 instance,
Lun Dong53c187c2022-01-24 18:48:13 +000064 iree_make_const_byte_span(module_file_toc->data, module_file_toc->size),
65 iree_allocator_null(), iree_allocator_system(), module);
Lun Dong72890d02021-12-03 18:51:03 -080066#else
Cindy Liucaad9952022-08-10 10:55:34 -070067 return module_create(instance, iree_allocator_system(), module);
Lun Dong6243e9e2022-09-06 15:25:07 -070068#endif // #if !defined(BUILD_EMITC)
Lun Dong72890d02021-12-03 18:51:03 -080069}
70
Lun Dong6243e9e2022-09-06 15:25:07 -070071#if !defined(BUILD_VMVX)
Lun Dong96ec2e02022-03-08 23:30:36 +000072iree_hal_executable_library_query_fn_t library_query(void) {
73 return &simple_mul_dispatch_0_library_query;
Lun Dong72890d02021-12-03 18:51:03 -080074}
Lun Dong6243e9e2022-09-06 15:25:07 -070075#endif
Cindy Liu986ee242021-07-29 23:23:17 -070076
Lun Dongdbc0ab82022-01-07 18:18:10 +000077iree_status_t load_input_data(const MlModel *model, void **buffer,
Lun Dong96ec2e02022-03-08 23:30:36 +000078 iree_const_byte_span_t **byte_span) {
Cindy Liue3240e22021-10-07 17:24:21 -070079 iree_status_t result = alloc_input_buffer(model, buffer);
Cindy Liu986ee242021-07-29 23:23:17 -070080 // Populate initial values
81 // arg0 = 0, 1/4, 1/2, 3/4... 1023/4
82 // arg1 = 0, 1/2, 1, 3/2... 1023/2
Cindy Liu986ee242021-07-29 23:23:17 -070083 if (iree_status_is_ok(result)) {
Cindy Liue3240e22021-10-07 17:24:21 -070084 for (int i = 0; i < model->input_length[0]; ++i) {
85 ((float *)buffer[0])[i] = i / 4.0f;
86 ((float *)buffer[1])[i] = i / 2.0f;
87 }
Cindy Liu986ee242021-07-29 23:23:17 -070088 }
Lun Dongdbc0ab82022-01-07 18:18:10 +000089 for (int i = 0; i < model->num_input; ++i) {
Lun Dong96ec2e02022-03-08 23:30:36 +000090 byte_span[i] = malloc(sizeof(iree_const_byte_span_t));
91 *byte_span[i] = iree_make_const_byte_span(
Lun Dongdbc0ab82022-01-07 18:18:10 +000092 buffer[i], model->input_size_bytes[i] * model->input_length[i]);
93 }
Cindy Liu986ee242021-07-29 23:23:17 -070094 return result;
95}
96
Adam Jesionowskida880932021-12-20 11:24:07 -080097iree_status_t process_output(const MlModel *model,
Lun Dong53c187c2022-01-24 18:48:13 +000098 iree_hal_buffer_mapping_t *buffers,
Adam Jesionowski1a82e442022-09-16 10:28:59 -070099 uint32_t *output_length) {
Cindy Liu986ee242021-07-29 23:23:17 -0700100 iree_status_t result = iree_ok_status();
Lun Dong53c187c2022-01-24 18:48:13 +0000101 for (int i = 0; i < buffers[0].contents.data_length / sizeof(float); ++i) {
Adam Jesionowskida880932021-12-20 11:24:07 -0800102 if (((const float *)buffers[0].contents.data)[i] != i * i / 8.0f) {
Cindy Liu986ee242021-07-29 23:23:17 -0700103 result = iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches");
104 break;
105 }
106 }
107 return result;
108}