Add support for mnist float model
Add mnist float model. As the entry function is different from others, a
field entry_func is added for the MlModel struct.
Change-Id: I485bf10f0756e2805636b77c4f9ba5bd1899c22e
diff --git a/samples/float_model_embedding/CMakeLists.txt b/samples/float_model_embedding/CMakeLists.txt
index 919aa1b..4cc4955 100644
--- a/samples/float_model_embedding/CMakeLists.txt
+++ b/samples/float_model_embedding/CMakeLists.txt
@@ -18,6 +18,21 @@
PUBLIC
)
+springbok_bytecode_module(
+ NAME
+ mnist_bytecode_module_dylib
+ SRC
+ "$ENV{ROOTDIR}/toolchain/iree/iree/samples/models/mnist.mlir"
+ C_IDENTIFIER
+ "samples_float_model_embedding_mnist_bytecode_module_dylib"
+ FLAGS
+ "-iree-input-type=mhlo"
+ "-riscv-v-vector-bits-min=512"
+ "-riscv-v-fixed-length-vector-lmul-max=8"
+ "-riscv-v-fixed-length-vector-elen-max=32"
+ PUBLIC
+)
+
if(${BUILD_INTERNAL_MODELS})
if(NOT ${BUILD_WITH_SPRINGBOK})
@@ -155,6 +170,18 @@
"LINKER:--defsym=__stack_size__=100k"
)
+iree_cc_binary(
+ NAME
+ mnist_embedded_sync
+ SRCS
+ "mnist.c"
+ DEPS
+ ::mnist_bytecode_module_dylib_c
+ samples::util::util
+ LINKOPTS
+ "LINKER:--defsym=__stack_size__=100k"
+)
+
if(NOT ${BUILD_INTERNAL_MODELS})
return()
endif()
@@ -227,6 +254,8 @@
DEPS
::semantic_lift_bytecode_module_dylib_c
samples::util::util
+ LINKOPTS
+ "LINKER:--defsym=__stack_size__=100k"
)
iree_cc_binary(
@@ -237,4 +266,6 @@
DEPS
::voice_commands_bytecode_module_dylib_c
samples::util::util
+ LINKOPTS
+ "LINKER:--defsym=__stack_size__=100k"
)
diff --git a/samples/float_model_embedding/barcode.c b/samples/float_model_embedding/barcode.c
index 3d047f5..362e339 100644
--- a/samples/float_model_embedding/barcode.c
+++ b/samples/float_model_embedding/barcode.c
@@ -22,6 +22,7 @@
2 * 2 * 72, 2 * 2 * 18, 72, 18},
.output_size_bytes = sizeof(float),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ .entry_func = "module.main",
.model_name = "barcode_float",
};
@@ -51,7 +52,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/float_model_embedding/daredevil.c b/samples/float_model_embedding/daredevil.c
index b74e425..488760e 100644
--- a/samples/float_model_embedding/daredevil.c
+++ b/samples/float_model_embedding/daredevil.c
@@ -20,6 +20,7 @@
.output_length = {527},
.output_size_bytes = sizeof(float),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ .entry_func = "module.main",
.model_name = "daredevil_float",
};
@@ -49,7 +50,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/float_model_embedding/fssd_25_8bit_v2.c b/samples/float_model_embedding/fssd_25_8bit_v2.c
index e7f660a..c02fa9a 100644
--- a/samples/float_model_embedding/fssd_25_8bit_v2.c
+++ b/samples/float_model_embedding/fssd_25_8bit_v2.c
@@ -20,6 +20,7 @@
.output_length = {1602, 1602 * 16},
.output_size_bytes = sizeof(float),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ .entry_func = "module.main",
.model_name = "fssd_25_8bit_v2_float",
};
@@ -49,7 +50,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/float_model_embedding/mnist.c b/samples/float_model_embedding/mnist.c
new file mode 100644
index 0000000..1a680b9
--- /dev/null
+++ b/samples/float_model_embedding/mnist.c
@@ -0,0 +1,56 @@
+// mnist float model
+// MlModel struct initialization to include model I/O info.
+// Bytecode loading, input/output processes.
+
+#include <springbok.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "samples/util/util.h"
+
+// Compiled module embedded here to avoid file IO:
+#include "samples/float_model_embedding/mnist_bytecode_module_dylib_c.h"
+
+const MlModel kModel = {
+ .num_input_dim = 4,
+ .input_shape = {1, 28, 28, 1},
+ .input_length = 28 * 28 * 1,
+ .input_size_bytes = sizeof(float),
+ .num_output = 1,
+ .output_length = {10},
+ .output_size_bytes = sizeof(float),
+ .hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ .entry_func = "module.predict",
+ .model_name = "mnist",
+};
+
+const iree_const_byte_span_t load_bytecode_module_data() {
+ const struct iree_file_toc_t *module_file_toc =
+ samples_float_model_embedding_mnist_bytecode_module_dylib_create();
+ return iree_make_const_byte_span(module_file_toc->data,
+ module_file_toc->size);
+}
+
+iree_status_t load_input_data(const MlModel *model, void **buffer) {
+ // Populate initial value
+ srand(74886738);
+ for (int i = 0; i < model->input_length; ++i) {
+ int x = rand();
+ ((float *)*buffer)[i] = (float)x / (float)RAND_MAX;
+ }
+ return iree_ok_status();
+}
+
+iree_status_t check_output_data(const MlModel *model,
+ iree_hal_buffer_mapping_t *mapped_memory,
+ int index_output) {
+ iree_status_t result = iree_ok_status();
+ if (index_output > model->num_output ||
+ mapped_memory->contents.data_length / model->output_size_bytes !=
+ model->output_length[index_output]) {
+ result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
+ }
+ LOG_INFO("Output #%d data length: %d\n", index_output,
+ mapped_memory->contents.data_length / model->output_size_bytes);
+ return result;
+}
diff --git a/samples/float_model_embedding/mobilenet_v1.c b/samples/float_model_embedding/mobilenet_v1.c
index fa15389..c8e83b1 100644
--- a/samples/float_model_embedding/mobilenet_v1.c
+++ b/samples/float_model_embedding/mobilenet_v1.c
@@ -20,6 +20,7 @@
.output_length = {1001},
.output_size_bytes = sizeof(float),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ .entry_func = "module.main",
.model_name = "mobilenet_v1_0.25_224_float",
};
@@ -49,7 +50,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/float_model_embedding/person_detection.c b/samples/float_model_embedding/person_detection.c
index ad945e1..c4c8430 100644
--- a/samples/float_model_embedding/person_detection.c
+++ b/samples/float_model_embedding/person_detection.c
@@ -20,6 +20,7 @@
.output_length = {2},
.output_size_bytes = sizeof(float),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ .entry_func = "module.main",
.model_name = "person_detection_float",
};
@@ -49,7 +50,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/float_model_embedding/scenenet_v2.c b/samples/float_model_embedding/scenenet_v2.c
index 1d6662d..e820152 100644
--- a/samples/float_model_embedding/scenenet_v2.c
+++ b/samples/float_model_embedding/scenenet_v2.c
@@ -20,6 +20,7 @@
.output_length = {170},
.output_size_bytes = sizeof(float),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ .entry_func = "module.main",
.model_name = "scenenet_v2_float",
};
@@ -49,7 +50,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/float_model_embedding/semantic_lift.c b/samples/float_model_embedding/semantic_lift.c
index c037f24..8bf1744 100644
--- a/samples/float_model_embedding/semantic_lift.c
+++ b/samples/float_model_embedding/semantic_lift.c
@@ -20,6 +20,7 @@
.output_length = {2, 2},
.output_size_bytes = sizeof(float),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ .entry_func = "module.main",
.model_name = "semantic_lift_float",
};
@@ -49,7 +50,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/float_model_embedding/voice_commands.c b/samples/float_model_embedding/voice_commands.c
index 29aa61e..d2b587b 100644
--- a/samples/float_model_embedding/voice_commands.c
+++ b/samples/float_model_embedding/voice_commands.c
@@ -20,6 +20,7 @@
.output_length = {63},
.output_size_bytes = sizeof(float),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ .entry_func = "module.main",
.model_name = "voice_commands_float",
};
@@ -49,7 +50,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/quant_model_embedding/barcode.c b/samples/quant_model_embedding/barcode.c
index 60b3e01..a8119cb 100644
--- a/samples/quant_model_embedding/barcode.c
+++ b/samples/quant_model_embedding/barcode.c
@@ -22,6 +22,7 @@
2 * 2 * 72, 2 * 2 * 18, 72, 18},
.output_size_bytes = sizeof(uint8_t),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8,
+ .entry_func = "module.main",
.model_name = "barcode_quant",
};
@@ -50,7 +51,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/quant_model_embedding/daredevil.c b/samples/quant_model_embedding/daredevil.c
index 63aea14..97cc434 100644
--- a/samples/quant_model_embedding/daredevil.c
+++ b/samples/quant_model_embedding/daredevil.c
@@ -20,6 +20,7 @@
.output_length = {527},
.output_size_bytes = sizeof(uint8_t),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8,
+ .entry_func = "module.main",
.model_name = "daredevil_quant",
};
@@ -48,7 +49,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/quant_model_embedding/fssd_25_8bit_v2.c b/samples/quant_model_embedding/fssd_25_8bit_v2.c
index 3631f90..3b38938 100644
--- a/samples/quant_model_embedding/fssd_25_8bit_v2.c
+++ b/samples/quant_model_embedding/fssd_25_8bit_v2.c
@@ -21,6 +21,7 @@
5 * 5 * 48, 5 * 5 * 3, 3 * 3 * 48, 3 * 3 * 3},
.output_size_bytes = sizeof(uint8_t),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8,
+ .entry_func = "module.main",
.model_name = "fssd_25_8bit_v2_quant",
};
@@ -49,7 +50,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/quant_model_embedding/mobilenet_v1.c b/samples/quant_model_embedding/mobilenet_v1.c
index 780d143..778781e 100644
--- a/samples/quant_model_embedding/mobilenet_v1.c
+++ b/samples/quant_model_embedding/mobilenet_v1.c
@@ -20,6 +20,7 @@
.output_length = {1001},
.output_size_bytes = sizeof(uint8_t),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8,
+ .entry_func = "module.main",
.model_name = "mobilenet_v1_0.25_224_quant",
};
@@ -48,7 +49,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/quant_model_embedding/mobilenet_v2.c b/samples/quant_model_embedding/mobilenet_v2.c
index ec733e8..0f0f63c 100644
--- a/samples/quant_model_embedding/mobilenet_v2.c
+++ b/samples/quant_model_embedding/mobilenet_v2.c
@@ -20,6 +20,7 @@
.output_length = {1001},
.output_size_bytes = sizeof(uint8_t),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8,
+ .entry_func = "module.main",
.model_name = "mobilenet_v2_1.0_224_quant",
};
@@ -48,7 +49,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/quant_model_embedding/person_detection.c b/samples/quant_model_embedding/person_detection.c
index a1ae000..1a7d4b2 100644
--- a/samples/quant_model_embedding/person_detection.c
+++ b/samples/quant_model_embedding/person_detection.c
@@ -20,6 +20,7 @@
.output_length = {2},
.output_size_bytes = sizeof(int8_t),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_SINT_8,
+ .entry_func = "module.main",
.model_name = "person_detection_quant",
};
@@ -48,7 +49,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/quant_model_embedding/scenenet_v2.c b/samples/quant_model_embedding/scenenet_v2.c
index 0e7edc7..ddc5a73 100644
--- a/samples/quant_model_embedding/scenenet_v2.c
+++ b/samples/quant_model_embedding/scenenet_v2.c
@@ -18,8 +18,9 @@
.input_size_bytes = sizeof(uint8_t),
.num_output = 1,
.output_length = {170},
- .output_size_bytes = sizeof(float),
+ .output_size_bytes = sizeof(uint8_t),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8,
+ .entry_func = "module.main",
.model_name = "scenenet_v2_quant",
};
@@ -48,7 +49,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/quant_model_embedding/semantic_lift.c b/samples/quant_model_embedding/semantic_lift.c
index eb29774..727fcf0 100644
--- a/samples/quant_model_embedding/semantic_lift.c
+++ b/samples/quant_model_embedding/semantic_lift.c
@@ -20,6 +20,7 @@
.output_length = {2, 2, 2},
.output_size_bytes = sizeof(uint8_t),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8,
+ .entry_func = "module.main",
.model_name = "semantic_lift_quant",
};
@@ -48,7 +49,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/quant_model_embedding/voice_commands.c b/samples/quant_model_embedding/voice_commands.c
index da4a9e8..dab6188 100644
--- a/samples/quant_model_embedding/voice_commands.c
+++ b/samples/quant_model_embedding/voice_commands.c
@@ -18,8 +18,9 @@
.input_size_bytes = sizeof(uint8_t),
.num_output = 1,
.output_length = {63},
- .output_size_bytes = sizeof(float),
+ .output_size_bytes = sizeof(uint8_t),
.hal_element_type = IREE_HAL_ELEMENT_TYPE_UINT_8,
+ .entry_func = "module.main",
.model_name = "voice_commands_quant",
};
@@ -48,7 +49,7 @@
model->output_length[index_output]) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
}
- LOG_INFO("Output #%d data length: %d \n", index_output,
+ LOG_INFO("Output #%d data length: %d\n", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
return result;
}
diff --git a/samples/util/util.c b/samples/util/util.c
index d618534..655289a 100644
--- a/samples/util/util.c
+++ b/samples/util/util.c
@@ -98,11 +98,10 @@
// Lookup the entry point function.
// Note that we use the synchronous variant which operates on pure type/shape
// erased buffers.
- const char kMainFunctionName[] = "module.main";
iree_vm_function_t main_function;
if (iree_status_is_ok(result)) {
result = (iree_vm_context_resolve_function(
- context, iree_make_cstring_view(kMainFunctionName), &main_function));
+ context, iree_make_cstring_view(model->entry_func), &main_function));
}
// Prepare the input buffers.
diff --git a/samples/util/util.h b/samples/util/util.h
index 6cd2e38..6ad7e66 100644
--- a/samples/util/util.h
+++ b/samples/util/util.h
@@ -9,6 +9,8 @@
#define MAX_MODEL_INPUT_DIM 4
#define MAX_MODEL_OUTPUTS 12
+#define MAX_ENTRY_FUNC_NAME 20
+
typedef struct {
int num_input_dim;
@@ -19,6 +21,7 @@
int output_length[MAX_MODEL_OUTPUTS];
int output_size_bytes;
enum iree_hal_element_types_t hal_element_type;
+ char entry_func[MAX_ENTRY_FUNC_NAME];
char model_name[];
} MlModel;