blob: 5bd8918b123001c5cb85531a6fb0d45e9f97fc1e [file] [log] [blame]
#!/usr/bin/env python3
"""Generate ML model inputs from external images."""
import argparse
import os
import sys
import struct
import urllib.request
import numpy as np
from PIL import Image
parser = argparse.ArgumentParser(
description='Generate inputs for ML models.')
parser.add_argument('--i', dest='input_name',
help='Model input image name', required=True)
parser.add_argument('--o', dest='output_file',
help='Output binary name', required=True)
parser.add_argument('--s', dest='input_shape',
help='Model input shape (example: "1, 224, 224, 3")', required=True)
parser.add_argument('--q', dest='is_quant', action='store_true',
help='Indicate it is quant model (default: False)')
parser.add_argument('--u', dest='img_url', help='Input image URL')
args = parser.parse_args()
def write_binary_file(file_path, input, is_quant):
with open(file_path, "wb+") as file:
for d in input:
if is_quant:
file.write(struct.pack("<B", d))
else:
file.write(struct.pack("<f", d))
def gen_mlmodel_input(input_name, output_file, input_shape, is_quant, img_url):
if not os.path.exists(input_name):
urllib.request.urlretrieve(img_url, input_name)
if len(input_shape) < 3:
raise ValueError("Input shape < 3 dimensions")
resized_img = Image.open(input_name).resize(
(input_shape[1], input_shape[2]))
input = np.array(resized_img).reshape(np.prod(input_shape))
if not is_quant:
input = 2.0 / 255.0 * input - 1
write_binary_file(output_file, input, is_quant)
if __name__ == '__main__':
# convert input shape to a list
input_shape = [int(x) for x in args.input_shape.split(',')]
# remove whitespace in image URL if any
img_url = args.img_url.replace(' ', '')
gen_mlmodel_input(args.input_name, args.output_file,
input_shape, args.is_quant, img_url)