Enable loading realistic model inputs from external images/files
Currently all model inputs are arbitrarily random.
We here add support for allowing loading realistic model input from
external images/files. Only person_detection and mobilenet are supported
for now.
Change-Id: Ib189ff9735dd656a63d01d2c5b90123942ae7c60
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 201fcab..9c41392 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -74,5 +74,6 @@
include($ENV{ROOTDIR}/toolchain/iree/build_tools/cmake/iree_copts.cmake)
include(springbok_bytecode_module)
+include(iree_model_input)
# Add the included directory here.
add_subdirectory(samples)
diff --git a/build_tools/gen_mlmodel_input.py b/build_tools/gen_mlmodel_input.py
new file mode 100755
index 0000000..5bd8918
--- /dev/null
+++ b/build_tools/gen_mlmodel_input.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+"""Generate ML model inputs from external images."""
+import argparse
+import os
+import sys
+import struct
+import urllib.request
+
+import numpy as np
+from PIL import Image
+
+
+parser = argparse.ArgumentParser(
+ description='Generate inputs for ML models.')
+parser.add_argument('--i', dest='input_name',
+ help='Model input image name', required=True)
+parser.add_argument('--o', dest='output_file',
+ help='Output binary name', required=True)
+parser.add_argument('--s', dest='input_shape',
+ help='Model input shape (example: "1, 224, 224, 3")', required=True)
+parser.add_argument('--q', dest='is_quant', action='store_true',
+ help='Indicate it is quant model (default: False)')
+parser.add_argument('--u', dest='img_url', help='Input image URL')
+args = parser.parse_args()
+
+
+def write_binary_file(file_path, input, is_quant):
+ with open(file_path, "wb+") as file:
+ for d in input:
+ if is_quant:
+ file.write(struct.pack("<B", d))
+ else:
+ file.write(struct.pack("<f", d))
+
+
+def gen_mlmodel_input(input_name, output_file, input_shape, is_quant, img_url):
+ if not os.path.exists(input_name):
+ urllib.request.urlretrieve(img_url, input_name)
+ if len(input_shape) < 3:
+ raise ValueError("Input shape < 3 dimensions")
+ resized_img = Image.open(input_name).resize(
+ (input_shape[1], input_shape[2]))
+ input = np.array(resized_img).reshape(np.prod(input_shape))
+ if not is_quant:
+ input = 2.0 / 255.0 * input - 1
+ write_binary_file(output_file, input, is_quant)
+
+
+if __name__ == '__main__':
+ # convert input shape to a list
+ input_shape = [int(x) for x in args.input_shape.split(',')]
+ # remove whitespace in image URL if any
+ img_url = args.img_url.replace(' ', '')
+ gen_mlmodel_input(args.input_name, args.output_file,
+ input_shape, args.is_quant, img_url)
diff --git a/cmake/iree_model_input.cmake b/cmake/iree_model_input.cmake
new file mode 100644
index 0000000..9d66348
--- /dev/null
+++ b/cmake/iree_model_input.cmake
@@ -0,0 +1,66 @@
+include(CMakeParseArguments)
+
+# iree_model_input()
+#
+# CMake function to load an external model input (an image)
+# and convert to the iree_c_embed_data.
+#
+# Parameters:
+# NAME: Name of model input image.
+# SHAPE: Input shape.
+# SRC: Input image URL.
+# QUANT: When added, indicate it's a quant model.
+#
+# Examples:
+# iree_model_input(
+# NAME
+# person_detection_quant_input
+# SHAPE
+# "1, 96, 96, 1"
+# SRC
+# "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/ \
+# tensorflow/lite/micro/examples/person_detection/testdata/person.bmp"
+# QUANT
+# )
+#
+function(iree_model_input)
+ cmake_parse_arguments(
+ _RULE
+ "QUANT"
+ "NAME;SHAPE;SRC"
+ ""
+ ${ARGN}
+ )
+
+ get_filename_component(EXT_STR "${_RULE_SRC}" LAST_EXT)
+ set(_GEN_INPUT_SCRIPT "${CMAKE_SOURCE_DIR}/build_tools/gen_mlmodel_input.py")
+ set(_OUTPUT_BINARY ${_RULE_NAME}.bin)
+ set(_ARGS)
+ list(APPEND _ARGS "--i=${_RULE_NAME}${EXT_STR}")
+ list(APPEND _ARGS "--o=${_OUTPUT_BINARY}")
+ list(APPEND _ARGS "--s=${_RULE_SHAPE}")
+ list(APPEND _ARGS "--u=${_RULE_SRC}")
+ if(_RULE_QUANT)
+ list(APPEND _ARGS "--q")
+ endif()
+
+ add_custom_command(
+ OUTPUT
+ ${_OUTPUT_BINARY}
+ COMMAND
+ ${_GEN_INPUT_SCRIPT} ${_ARGS}
+ )
+
+ iree_c_embed_data(
+ NAME
+ "${_RULE_NAME}_c"
+ GENERATED_SRCS
+ "${_OUTPUT_BINARY}"
+ C_FILE_OUTPUT
+ "${_RULE_NAME}_c.c"
+ H_FILE_OUTPUT
+ "${_RULE_NAME}_c.h"
+ FLATTEN
+ PUBLIC
+ )
+endfunction()
diff --git a/samples/float_model_embedding/CMakeLists.txt b/samples/float_model_embedding/CMakeLists.txt
index 8d3ef0a..fbd36b3 100644
--- a/samples/float_model_embedding/CMakeLists.txt
+++ b/samples/float_model_embedding/CMakeLists.txt
@@ -145,6 +145,34 @@
endif(${BUILD_INTERNAL_MODELS})
#-------------------------------------------------------------------------------
+# Binaries to execute the IREE model input
+#-------------------------------------------------------------------------------
+
+iree_model_input(
+ NAME
+ mobilenet_input
+ SHAPE
+ "1, 224, 224, 3"
+ SRC
+ "https://storage.googleapis.com/download.tensorflow.org/ \
+ example_images/YellowLabradorLooking_new.jpg"
+)
+
+if(${BUILD_INTERNAL_MODELS})
+
+iree_model_input(
+ NAME
+ person_detection_input
+ SHAPE
+ "1, 96, 96, 1"
+ SRC
+ "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/ \
+ tensorflow/lite/micro/examples/person_detection/testdata/person.bmp"
+)
+
+endif(${BUILD_INTERNAL_MODELS})
+
+#-------------------------------------------------------------------------------
# Binaries to execute the MLIR bytecode modules
#-------------------------------------------------------------------------------
@@ -165,6 +193,7 @@
"mobilenet_v1.c"
DEPS
::mobilenet_v1_bytecode_module_dylib_c
+ ::mobilenet_input_c
samples::util::util
LINKOPTS
"LINKER:--defsym=__stack_size__=100k"
@@ -241,6 +270,7 @@
"person_detection.c"
DEPS
::person_detection_bytecode_module_dylib_c
+ ::person_detection_input_c
samples::util::util
LINKOPTS
"LINKER:--defsym=__stack_size__=100k"
diff --git a/samples/float_model_embedding/mobilenet_v1.c b/samples/float_model_embedding/mobilenet_v1.c
index 0472dde..7535395 100644
--- a/samples/float_model_embedding/mobilenet_v1.c
+++ b/samples/float_model_embedding/mobilenet_v1.c
@@ -9,6 +9,7 @@
#include "samples/util/util.h"
// Compiled module embedded here to avoid file IO:
+#include "samples/float_model_embedding/mobilenet_input_c.h"
#include "samples/float_model_embedding/mobilenet_v1_bytecode_module_dylib_c.h"
const MlModel kModel = {
@@ -34,14 +35,8 @@
iree_status_t load_input_data(const MlModel *model, void **buffer) {
iree_status_t result = alloc_input_buffer(model, buffer);
- // Populate initial value
- srand(33333333);
- if (iree_status_is_ok(result)) {
- for (int i = 0; i < model->input_length[0]; ++i) {
- int x = rand();
- ((float *)*buffer)[i] = (float)x / (float)RAND_MAX;
- }
- }
+ const struct iree_file_toc_t *input_file_toc = mobilenet_input_c_create();
+ memcpy(*buffer, input_file_toc->data, input_file_toc->size);
return result;
}
@@ -56,5 +51,16 @@
}
LOG_INFO("Output #%d data length: %d", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
+ // find the label index with best prediction
+ float best_out = 0.0;
+ int best_idx = -1;
+ for (int i = 0; i < model->output_length[index_output]; ++i) {
+ float out = ((float *)mapped_memory->contents.data)[i];
+ if (out > best_out) {
+ best_out = out;
+ best_idx = i;
+ }
+ }
+ LOG_INFO("Image prediction result is: id: %d", best_idx + 1);
return result;
}
diff --git a/samples/float_model_embedding/person_detection.c b/samples/float_model_embedding/person_detection.c
index 19cfb4f..a892093 100644
--- a/samples/float_model_embedding/person_detection.c
+++ b/samples/float_model_embedding/person_detection.c
@@ -10,6 +10,7 @@
// Compiled module embedded here to avoid file IO:
#include "samples/float_model_embedding/person_detection_bytecode_module_dylib_c.h"
+#include "samples/float_model_embedding/person_detection_input_c.h"
const MlModel kModel = {
.num_input = 1,
@@ -34,14 +35,9 @@
iree_status_t load_input_data(const MlModel *model, void **buffer) {
iree_status_t result = alloc_input_buffer(model, buffer);
- // Populate initial value
- srand(44444444);
- if (iree_status_is_ok(result)) {
- for (int i = 0; i < model->input_length[0]; ++i) {
- int x = rand();
- ((float *)*buffer)[i] = (float)x / (float)RAND_MAX;
- }
- }
+ const struct iree_file_toc_t *input_file_toc =
+ person_detection_input_c_create();
+ memcpy(*buffer, input_file_toc->data, input_file_toc->size);
return result;
}
@@ -56,5 +52,15 @@
}
LOG_INFO("Output #%d data length: %d", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
+
+ float *data = (float *)mapped_memory->contents.data;
+ char buffer[20];
+ int chars_needed = float_to_str(0, NULL, data[0]);
+ float_to_str(chars_needed, buffer, data[0]);
+ LOG_INFO("Output: Non-person Score: %s", buffer);
+ chars_needed = float_to_str(0, NULL, data[1]);
+ float_to_str(chars_needed, buffer, data[1]);
+ LOG_INFO("Output: Person Score: %s", buffer);
+
return result;
}
diff --git a/samples/quant_model_embedding/CMakeLists.txt b/samples/quant_model_embedding/CMakeLists.txt
index cca899d..f728e9e 100644
--- a/samples/quant_model_embedding/CMakeLists.txt
+++ b/samples/quant_model_embedding/CMakeLists.txt
@@ -147,6 +147,32 @@
endif(${BUILD_INTERNAL_MODELS})
#-------------------------------------------------------------------------------
+# Binaries to execute the IREE model input
+#-------------------------------------------------------------------------------
+
+iree_model_input(
+ NAME
+ mobilenet_quant_input
+ SHAPE
+ "1, 224, 224, 3"
+ SRC
+ "https://storage.googleapis.com/download.tensorflow.org/ \
+ example_images/YellowLabradorLooking_new.jpg"
+ QUANT
+)
+
+iree_model_input(
+ NAME
+ person_detection_quant_input
+ SHAPE
+ "1, 96, 96, 1"
+ SRC
+ "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/ \
+ tensorflow/lite/micro/examples/person_detection/testdata/person.bmp"
+ QUANT
+)
+
+#-------------------------------------------------------------------------------
# Binaries to execute the MLIR bytecode modules
#-------------------------------------------------------------------------------
@@ -179,6 +205,7 @@
"mobilenet_v1.c"
DEPS
::mobilenet_v1_bytecode_module_dylib_c
+ ::mobilenet_quant_input_c
samples::util::util
LINKOPTS
"LINKER:--defsym=__stack_size__=100k"
@@ -191,6 +218,7 @@
"person_detection.c"
DEPS
::person_detection_bytecode_module_dylib_c
+ ::person_detection_quant_input_c
samples::util::util
LINKOPTS
"LINKER:--defsym=__stack_size__=128k"
diff --git a/samples/quant_model_embedding/mobilenet_v1.c b/samples/quant_model_embedding/mobilenet_v1.c
index 3570a4b..2fe8509 100644
--- a/samples/quant_model_embedding/mobilenet_v1.c
+++ b/samples/quant_model_embedding/mobilenet_v1.c
@@ -9,6 +9,7 @@
#include "samples/util/util.h"
// Compiled module embedded here to avoid file IO:
+#include "samples/quant_model_embedding/mobilenet_quant_input_c.h"
#include "samples/quant_model_embedding/mobilenet_v1_bytecode_module_dylib_c.h"
const MlModel kModel = {
@@ -34,13 +35,9 @@
iree_status_t load_input_data(const MlModel *model, void **buffer) {
iree_status_t result = alloc_input_buffer(model, buffer);
- // Populate initial value
- srand(33333333);
- if (iree_status_is_ok(result)) {
- for (int i = 0; i < model->input_length[0]; ++i) {
- ((uint8_t *)*buffer)[i] = (uint8_t)rand();
- }
- }
+ const struct iree_file_toc_t *input_file_toc =
+ mobilenet_quant_input_c_create();
+ memcpy(*buffer, input_file_toc->data, input_file_toc->size);
return result;
}
@@ -55,5 +52,16 @@
}
LOG_INFO("Output #%d data length: %d", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
+ // find the label index with best prediction
+ int best_out = 0;
+ int best_idx = -1;
+ for (int i = 0; i < model->output_length[index_output]; ++i) {
+ uint8_t out = ((uint8_t *)mapped_memory->contents.data)[i];
+ if (out > best_out) {
+ best_out = out;
+ best_idx = i;
+ }
+ }
+ LOG_INFO("Image prediction result is: id: %d", best_idx + 1);
return result;
}
diff --git a/samples/quant_model_embedding/person_detection.c b/samples/quant_model_embedding/person_detection.c
index 0cd1a09..819fc1e 100644
--- a/samples/quant_model_embedding/person_detection.c
+++ b/samples/quant_model_embedding/person_detection.c
@@ -10,6 +10,7 @@
// Compiled module embedded here to avoid file IO:
#include "samples/quant_model_embedding/person_detection_bytecode_module_dylib_c.h"
+#include "samples/quant_model_embedding/person_detection_quant_input_c.h"
const MlModel kModel = {
.num_input = 1,
@@ -34,13 +35,9 @@
iree_status_t load_input_data(const MlModel *model, void **buffer) {
iree_status_t result = alloc_input_buffer(model, buffer);
- // Populate initial value
- srand(44444444);
- if (iree_status_is_ok(result)) {
- for (int i = 0; i < model->input_length[0]; ++i) {
- ((int8_t *)*buffer)[i] = (int8_t)rand();
- }
- }
+ const struct iree_file_toc_t *input_file_toc =
+ person_detection_quant_input_c_create();
+ memcpy(*buffer, input_file_toc->data, input_file_toc->size);
return result;
}
@@ -55,5 +52,7 @@
}
LOG_INFO("Output #%d data length: %d", index_output,
mapped_memory->contents.data_length / model->output_size_bytes);
+ int8_t *data = (int8_t *)mapped_memory->contents.data;
+ LOG_INFO("Output: Non-person Score: %d; Person Score: %d", data[0], data[1]);
return result;
}