Add support for input image with int8_t format
Add support for input imsage with int8_t format (e.g., for HPS model)
Change-Id: I7084af19fcb15f9ca66fcb5f881a127dc349a953
diff --git a/build_tools/gen_mlmodel_input.py b/build_tools/gen_mlmodel_input.py
index fb87b3a..bbcd8d0 100755
--- a/build_tools/gen_mlmodel_input.py
+++ b/build_tools/gen_mlmodel_input.py
@@ -36,6 +36,8 @@
help='Indicate it is quant model (default: False)')
parser.add_argument('--r', dest='float_input_range', default="-1.0, 1.0",
help='Float model input range (default: "-1.0, 1.0")')
+parser.add_argument('--sgn', dest='is_signed', action='store_true',
+ help='Indicate the input is signed (default: False)')
args = parser.parse_args()
@@ -50,7 +52,7 @@
file.write(struct.pack("<f", d))
-def gen_mlmodel_input(input_name, output_file, input_shape, is_quant):
+def gen_mlmodel_input(input_name, output_file, input_shape, is_quant, is_signed):
if not os.path.exists(input_name):
raise RuntimeError("Input file %s doesn't exist" % {input_name})
if len(input_shape) < 3:
@@ -73,6 +75,8 @@
low = np.min(float_input_range)
high = np.max(float_input_range)
input = (high - low) * input / 255.0 + low
+ elif is_signed:
+ input -= 128
write_binary_file(output_file, input, is_quant, is_audio)
@@ -81,4 +85,4 @@
input_shape = [int(x) for x in args.input_shape.split(',')]
float_input_range = [float(x) for x in args.float_input_range.split(',')]
gen_mlmodel_input(args.input_name, args.output_file,
- input_shape, args.is_quant)
+ input_shape, args.is_quant, args.is_signed)
diff --git a/cmake/iree_model_input.cmake b/cmake/iree_model_input.cmake
index 0871799..578e53a 100644
--- a/cmake/iree_model_input.cmake
+++ b/cmake/iree_model_input.cmake
@@ -26,7 +26,7 @@
function(iree_model_input)
cmake_parse_arguments(
_RULE
- "QUANT"
+ "QUANT;SIGNED"
"NAME;SHAPE;SRC;RANGE"
""
${ARGN}
@@ -62,6 +62,9 @@
if(_RULE_QUANT)
list(APPEND _ARGS "--q")
endif()
+ if(_RULE_SIGNED)
+ list(APPEND _ARGS "--sgn")
+ endif()
# Replace dependencies passed by ::name with iree::package::name
iree_package_ns(_PACKAGE_NS)