blob: 7be9e35e0875397a88b8e9a805299d3473f4f1cb [file] [log] [blame]
include(CMakeParseArguments)
# iree_model_input()
#
# CMake function to load an external model input (an image)
# and convert to the iree_c_embed_data.
#
# Parameters:
# NAME: Name of model input image.
# SHAPE: Input shape.
# SRC: Input image URL.
# QUANT: When added, indicate it's a quant model.
#
# Examples:
# iree_model_input(
# NAME
# person_detection_quant_input
# SHAPE
# "1, 96, 96, 1"
# SRC
# "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/ \
# tensorflow/lite/micro/examples/person_detection/testdata/person.bmp"
# QUANT
# )
#
function(iree_model_input)
cmake_parse_arguments(
_RULE
"QUANT"
"NAME;SHAPE;SRC;RANGE"
""
${ARGN}
)
string(REGEX REPLACE "[ \t\r\n]" "" _RULE_SRC_TRIM ${_RULE_SRC})
string(REGEX MATCH "^https:" _RULE_SRC_URL ${_RULE_SRC_TRIM})
if (_RULE_SRC_URL)
get_filename_component(_INPUT_FILENAME "${_RULE_SRC}" NAME)
find_program(_WGET wget HINT "$ENV{PATH}" REQUIRED)
add_custom_command(
OUTPUT
${_INPUT_FILENAME}
COMMAND
${_WGET} -q -P "${CMAKE_CURRENT_BINARY_DIR}" -O "${_INPUT_FILENAME}"
"${_RULE_SRC_TRIM}"
COMMENT
"Download ${_INPUT_FILENAME} from ${_RULE_SRC_TRIM}"
)
else()
set(_INPUT_FILENAME ${_RULE_SRC_TRIM})
endif()
set(_GEN_INPUT_SCRIPT "${CMAKE_SOURCE_DIR}/build_tools/gen_mlmodel_input.py")
set(_OUTPUT_BINARY ${_RULE_NAME}.bin)
set(_ARGS)
list(APPEND _ARGS "--i=${_INPUT_FILENAME}")
list(APPEND _ARGS "--o=${_OUTPUT_BINARY}")
list(APPEND _ARGS "--s=${_RULE_SHAPE}")
if(_RULE_RANGE)
list(APPEND _ARGS "--r=${_RULE_RANGE}")
endif()
if(_RULE_QUANT)
list(APPEND _ARGS "--q")
endif()
add_custom_command(
OUTPUT
${_OUTPUT_BINARY}
COMMAND
${_GEN_INPUT_SCRIPT} ${_ARGS}
DEPENDS
${_GEN_INPUT_SCRIPT}
${_INPUT_FILENAME}
)
iree_c_embed_data(
NAME
"${_RULE_NAME}_c"
GENERATED_SRCS
"${_OUTPUT_BINARY}"
C_FILE_OUTPUT
"${_RULE_NAME}_c.c"
H_FILE_OUTPUT
"${_RULE_NAME}_c.h"
FLATTEN
PUBLIC
)
endfunction()