blob: b6cdbbd748dcc37f909a7d45926e2a085dbbaf51 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate ML model inputs from images."""
import argparse
import os
import struct
import numpy as np
from PIL import Image
from scipy.io import wavfile
parser = argparse.ArgumentParser(
description='Generate inputs for ML models.')
parser.add_argument('--i', dest='input_name',
help='Model input image name', required=True)
parser.add_argument('--o', dest='output_file',
help='Output binary name', required=True)
parser.add_argument('--s',
dest='input_shape',
help='Model input shape (example: "1, 224, 224, 3")',
required=True)
parser.add_argument('--q', dest='is_quant', action='store_true',
help='Indicate it is quant model (default: False)')
parser.add_argument('--r', dest='float_input_range', default="-1.0, 1.0",
help='Float model input range (default: "-1.0, 1.0")')
parser.add_argument('--sgn', dest='is_signed', action='store_true',
help='Indicate the input is signed (default: False)')
args = parser.parse_args()
def write_binary_file(file_path, data, is_quant, is_audio):
with open(file_path, "wb+") as file:
for d in data:
if is_audio:
file.write(struct.pack("<h", d))
elif is_quant:
file.write(struct.pack("<B", d))
else:
file.write(struct.pack("<f", d))
def gen_mlmodel_input(input_name, output_file, input_shape, is_quant,
is_signed, float_input_range):
if not os.path.exists(input_name):
raise RuntimeError(f"Input file {input_name} doesn't exist")
if len(input_shape) < 3:
raise ValueError("Input shape < 3 dimensions")
# convert input shape to a list
input_shape = [int(x) for x in args.input_shape.split(',')]
input_ext = os.path.splitext(input_name)[1]
is_audio = False
if (not input_ext) or (input_ext == '.bin') or (input_ext == '.raw'):
with open(input_name, mode='rb') as f:
data = np.fromfile(f, dtype=np.uint8 if is_quant else np.float32)
data = data[:np.prod(input_shape)].reshape(np.prod(input_shape))
elif input_ext == '.wav':
is_audio = True
_, data = wavfile.read(input_name)
data = data[:np.prod(input_shape)].reshape(np.prod(input_shape))
else:
resized_img = Image.open(input_name).resize(
(input_shape[2], input_shape[1]))
data = np.array(resized_img).reshape(np.prod(input_shape))
if not is_quant:
float_input_range = [
float(x) for x in args.float_input_range.split(',')
]
low = np.min(float_input_range)
high = np.max(float_input_range)
data = (high - low) * data / 255.0 + low
elif is_signed:
data -= 128
write_binary_file(output_file, data, is_quant, is_audio)
if __name__ == '__main__':
gen_mlmodel_input(args.input_name, args.output_file, args.input_shape,
args.is_quant, args.is_signed, args.float_input_range)