Add support for mnist float model Add mnist float model. As the entry function is different from others, a field entry_func is added for the MlModel struct. Change-Id: I485bf10f0756e2805636b77c4f9ba5bd1899c22e
diff --git a/samples/float_model_embedding/CMakeLists.txt b/samples/float_model_embedding/CMakeLists.txt index 919aa1b..4cc4955 100644 --- a/samples/float_model_embedding/CMakeLists.txt +++ b/samples/float_model_embedding/CMakeLists.txt
@@ -18,6 +18,21 @@ PUBLIC ) +springbok_bytecode_module( + NAME + mnist_bytecode_module_dylib + SRC + "$ENV{ROOTDIR}/toolchain/iree/iree/samples/models/mnist.mlir" + C_IDENTIFIER + "samples_float_model_embedding_mnist_bytecode_module_dylib" + FLAGS + "-iree-input-type=mhlo" + "-riscv-v-vector-bits-min=512" + "-riscv-v-fixed-length-vector-lmul-max=8" + "-riscv-v-fixed-length-vector-elen-max=32" + PUBLIC +) + if(${BUILD_INTERNAL_MODELS}) if(NOT ${BUILD_WITH_SPRINGBOK}) @@ -155,6 +170,18 @@ "LINKER:--defsym=__stack_size__=100k" ) +iree_cc_binary( + NAME + mnist_embedded_sync + SRCS + "mnist.c" + DEPS + ::mnist_bytecode_module_dylib_c + samples::util::util + LINKOPTS + "LINKER:--defsym=__stack_size__=100k" +) + if(NOT ${BUILD_INTERNAL_MODELS}) return() endif() @@ -227,6 +254,8 @@ DEPS ::semantic_lift_bytecode_module_dylib_c samples::util::util + LINKOPTS + "LINKER:--defsym=__stack_size__=100k" ) iree_cc_binary( @@ -237,4 +266,6 @@ DEPS ::voice_commands_bytecode_module_dylib_c samples::util::util + LINKOPTS + "LINKER:--defsym=__stack_size__=100k" )
diff --git a/samples/float_model_embedding/barcode.c b/samples/float_model_embedding/barcode.c index 3d047f5..362e339 100644 --- a/samples/float_model_embedding/barcode.c +++ b/samples/float_model_embedding/barcode.c
@@ -22,6 +22,7 @@ 2 * 2 * 72, 2 * 2 * 18, 72, 18}, .output_size_bytes = sizeof(float), .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32, + .entry_func = "module.main", .model_name = "barcode_float", }; @@ -51,7 +52,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/float_model_embedding/daredevil.c b/samples/float_model_embedding/daredevil.c index b74e425..488760e 100644 --- a/samples/float_model_embedding/daredevil.c +++ b/samples/float_model_embedding/daredevil.c
@@ -20,6 +20,7 @@ .output_length = {527}, .output_size_bytes = sizeof(float), .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32, + .entry_func = "module.main", .model_name = "daredevil_float", }; @@ -49,7 +50,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/float_model_embedding/fssd_25_8bit_v2.c b/samples/float_model_embedding/fssd_25_8bit_v2.c index e7f660a..c02fa9a 100644 --- a/samples/float_model_embedding/fssd_25_8bit_v2.c +++ b/samples/float_model_embedding/fssd_25_8bit_v2.c
@@ -20,6 +20,7 @@ .output_length = {1602, 1602 * 16}, .output_size_bytes = sizeof(float), .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32, + .entry_func = "module.main", .model_name = "fssd_25_8bit_v2_float", }; @@ -49,7 +50,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/float_model_embedding/mnist.c b/samples/float_model_embedding/mnist.c new file mode 100644 index 0000000..1a680b9 --- /dev/null +++ b/samples/float_model_embedding/mnist.c
@@ -0,0 +1,56 @@ +// mnist float model +// MlModel struct initialization to include model I/O info. +// Bytecode loading, input/output processes. + +#include <springbok.h> + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "samples/util/util.h" + +// Compiled module embedded here to avoid file IO: +#include "samples/float_model_embedding/mnist_bytecode_module_dylib_c.h" + +const MlModel kModel = { + .num_input_dim = 4, + .input_shape = {1, 28, 28, 1}, + .input_length = 28 * 28 * 1, + .input_size_bytes = sizeof(float), + .num_output = 1, + .output_length = {10}, + .output_size_bytes = sizeof(float), + .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32, + .entry_func = "module.predict", + .model_name = "mnist", +}; + +const iree_const_byte_span_t load_bytecode_module_data() { + const struct iree_file_toc_t *module_file_toc = + samples_float_model_embedding_mnist_bytecode_module_dylib_create(); + return iree_make_const_byte_span(module_file_toc->data, + module_file_toc->size); +} + +iree_status_t load_input_data(const MlModel *model, void **buffer) { + // Populate initial value + srand(74886738); + for (int i = 0; i < model->input_length; ++i) { + int x = rand(); + ((float *)*buffer)[i] = (float)x / (float)RAND_MAX; + } + return iree_ok_status(); +} + +iree_status_t check_output_data(const MlModel *model, + iree_hal_buffer_mapping_t *mapped_memory, + int index_output) { + iree_status_t result = iree_ok_status(); + if (index_output > model->num_output || + mapped_memory->contents.data_length / model->output_size_bytes != + model->output_length[index_output]) { + result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); + } + LOG_INFO("Output #%d data length: %d\n", index_output, + mapped_memory->contents.data_length / model->output_size_bytes); + return result; +}
diff --git a/samples/float_model_embedding/mobilenet_v1.c b/samples/float_model_embedding/mobilenet_v1.c index fa15389..c8e83b1 100644 --- a/samples/float_model_embedding/mobilenet_v1.c +++ b/samples/float_model_embedding/mobilenet_v1.c
@@ -20,6 +20,7 @@ .output_length = {1001}, .output_size_bytes = sizeof(float), .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32, + .entry_func = "module.main", .model_name = "mobilenet_v1_0.25_224_float", }; @@ -49,7 +50,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/float_model_embedding/person_detection.c b/samples/float_model_embedding/person_detection.c index ad945e1..c4c8430 100644 --- a/samples/float_model_embedding/person_detection.c +++ b/samples/float_model_embedding/person_detection.c
@@ -20,6 +20,7 @@ .output_length = {2}, .output_size_bytes = sizeof(float), .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32, + .entry_func = "module.main", .model_name = "person_detection_float", }; @@ -49,7 +50,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/float_model_embedding/scenenet_v2.c b/samples/float_model_embedding/scenenet_v2.c index 1d6662d..e820152 100644 --- a/samples/float_model_embedding/scenenet_v2.c +++ b/samples/float_model_embedding/scenenet_v2.c
@@ -20,6 +20,7 @@ .output_length = {170}, .output_size_bytes = sizeof(float), .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32, + .entry_func = "module.main", .model_name = "scenenet_v2_float", }; @@ -49,7 +50,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/float_model_embedding/semantic_lift.c b/samples/float_model_embedding/semantic_lift.c index c037f24..8bf1744 100644 --- a/samples/float_model_embedding/semantic_lift.c +++ b/samples/float_model_embedding/semantic_lift.c
@@ -20,6 +20,7 @@ .output_length = {2, 2}, .output_size_bytes = sizeof(float), .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32, + .entry_func = "module.main", .model_name = "semantic_lift_float", }; @@ -49,7 +50,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/float_model_embedding/voice_commands.c b/samples/float_model_embedding/voice_commands.c index 29aa61e..d2b587b 100644 --- a/samples/float_model_embedding/voice_commands.c +++ b/samples/float_model_embedding/voice_commands.c
@@ -20,6 +20,7 @@ .output_length = {63}, .output_size_bytes = sizeof(float), .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32, + .entry_func = "module.main", .model_name = "voice_commands_float", }; @@ -49,7 +50,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/quant_model_embedding/barcode.c b/samples/quant_model_embedding/barcode.c index 60b3e01..a8119cb 100644 --- a/samples/quant_model_embedding/barcode.c +++ b/samples/quant_model_embedding/barcode.c
@@ -22,6 +22,7 @@ 2 * 2 * 72, 2 * 2 * 18, 72, 18}, .output_size_bytes = sizeof(uint8_t), .hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8, + .entry_func = "module.main", .model_name = "barcode_quant", }; @@ -50,7 +51,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/quant_model_embedding/daredevil.c b/samples/quant_model_embedding/daredevil.c index 63aea14..97cc434 100644 --- a/samples/quant_model_embedding/daredevil.c +++ b/samples/quant_model_embedding/daredevil.c
@@ -20,6 +20,7 @@ .output_length = {527}, .output_size_bytes = sizeof(uint8_t), .hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8, + .entry_func = "module.main", .model_name = "daredevil_quant", }; @@ -48,7 +49,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/quant_model_embedding/fssd_25_8bit_v2.c b/samples/quant_model_embedding/fssd_25_8bit_v2.c index 3631f90..3b38938 100644 --- a/samples/quant_model_embedding/fssd_25_8bit_v2.c +++ b/samples/quant_model_embedding/fssd_25_8bit_v2.c
@@ -21,6 +21,7 @@ 5 * 5 * 48, 5 * 5 * 3, 3 * 3 * 48, 3 * 3 * 3}, .output_size_bytes = sizeof(uint8_t), .hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8, + .entry_func = "module.main", .model_name = "fssd_25_8bit_v2_quant", }; @@ -49,7 +50,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/quant_model_embedding/mobilenet_v1.c b/samples/quant_model_embedding/mobilenet_v1.c index 780d143..778781e 100644 --- a/samples/quant_model_embedding/mobilenet_v1.c +++ b/samples/quant_model_embedding/mobilenet_v1.c
@@ -20,6 +20,7 @@ .output_length = {1001}, .output_size_bytes = sizeof(uint8_t), .hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8, + .entry_func = "module.main", .model_name = "mobilenet_v1_0.25_224_quant", }; @@ -48,7 +49,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/quant_model_embedding/mobilenet_v2.c b/samples/quant_model_embedding/mobilenet_v2.c index ec733e8..0f0f63c 100644 --- a/samples/quant_model_embedding/mobilenet_v2.c +++ b/samples/quant_model_embedding/mobilenet_v2.c
@@ -20,6 +20,7 @@ .output_length = {1001}, .output_size_bytes = sizeof(uint8_t), .hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8, + .entry_func = "module.main", .model_name = "mobilenet_v2_1.0_224_quant", }; @@ -48,7 +49,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/quant_model_embedding/person_detection.c b/samples/quant_model_embedding/person_detection.c index a1ae000..1a7d4b2 100644 --- a/samples/quant_model_embedding/person_detection.c +++ b/samples/quant_model_embedding/person_detection.c
@@ -20,6 +20,7 @@ .output_length = {2}, .output_size_bytes = sizeof(int8_t), .hal_element_type = IREE_HAL_ELEMENT_TYPE_SINT_8, + .entry_func = "module.main", .model_name = "person_detection_quant", }; @@ -48,7 +49,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/quant_model_embedding/scenenet_v2.c b/samples/quant_model_embedding/scenenet_v2.c index 0e7edc7..ddc5a73 100644 --- a/samples/quant_model_embedding/scenenet_v2.c +++ b/samples/quant_model_embedding/scenenet_v2.c
@@ -18,8 +18,9 @@ .input_size_bytes = sizeof(uint8_t), .num_output = 1, .output_length = {170}, - .output_size_bytes = sizeof(float), + .output_size_bytes = sizeof(uint8_t), .hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8, + .entry_func = "module.main", .model_name = "scenenet_v2_quant", }; @@ -48,7 +49,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/quant_model_embedding/semantic_lift.c b/samples/quant_model_embedding/semantic_lift.c index eb29774..727fcf0 100644 --- a/samples/quant_model_embedding/semantic_lift.c +++ b/samples/quant_model_embedding/semantic_lift.c
@@ -20,6 +20,7 @@ .output_length = {2, 2, 2}, .output_size_bytes = sizeof(uint8_t), .hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8, + .entry_func = "module.main", .model_name = "semantic_lift_quant", }; @@ -48,7 +49,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/quant_model_embedding/voice_commands.c b/samples/quant_model_embedding/voice_commands.c index da4a9e8..dab6188 100644 --- a/samples/quant_model_embedding/voice_commands.c +++ b/samples/quant_model_embedding/voice_commands.c
@@ -18,8 +18,9 @@ .input_size_bytes = sizeof(uint8_t), .num_output = 1, .output_length = {63}, - .output_size_bytes = sizeof(float), + .output_size_bytes = sizeof(uint8_t), .hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8, + .entry_func = "module.main", .model_name = "voice_commands_quant", }; @@ -48,7 +49,7 @@ model->output_length[index_output]) { result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); } - LOG_INFO("Output #%d data length: %d \n", index_output, + LOG_INFO("Output #%d data length: %d\n", index_output, mapped_memory->contents.data_length / model->output_size_bytes); return result; }
diff --git a/samples/util/util.c b/samples/util/util.c index d618534..655289a 100644 --- a/samples/util/util.c +++ b/samples/util/util.c
@@ -98,11 +98,10 @@ // Lookup the entry point function. // Note that we use the synchronous variant which operates on pure type/shape // erased buffers. - const char kMainFunctionName[] = "module.main"; iree_vm_function_t main_function; if (iree_status_is_ok(result)) { result = (iree_vm_context_resolve_function( - context, iree_make_cstring_view(kMainFunctionName), &main_function)); + context, iree_make_cstring_view(model->entry_func), &main_function)); } // Prepare the input buffers.
diff --git a/samples/util/util.h b/samples/util/util.h index 6cd2e38..6ad7e66 100644 --- a/samples/util/util.h +++ b/samples/util/util.h
@@ -9,6 +9,8 @@ #define MAX_MODEL_INPUT_DIM 4 #define MAX_MODEL_OUTPUTS 12 +#define MAX_ENTRY_FUNC_NAME 20 + typedef struct { int num_input_dim; @@ -19,6 +21,7 @@ int output_length[MAX_MODEL_OUTPUTS]; int output_size_bytes; enum iree_hal_element_types_t hal_element_type; + char entry_func[MAX_ENTRY_FUNC_NAME]; char model_name[]; } MlModel;