blob: bbcd8d06880ea5febdf9b6365394edae1994b44e [file] [log] [blame]
Lun Dongdb85fd72021-11-02 02:39:48 -07001#!/usr/bin/env python3
Adam Jesionowski6e273a72022-04-14 12:20:20 -07002# Copyright 2022 Google LLC
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
Cindy Liu60ab7772021-12-05 17:30:31 -080016"""Generate ML model inputs from images."""
Lun Dongdb85fd72021-11-02 02:39:48 -070017import argparse
18import os
Lun Dongdb85fd72021-11-02 02:39:48 -070019import struct
20import urllib.request
21
22import numpy as np
23from PIL import Image
Lun Donga43730e2022-06-22 21:52:14 -070024from scipy.io import wavfile
Lun Dongdb85fd72021-11-02 02:39:48 -070025
26
27parser = argparse.ArgumentParser(
28 description='Generate inputs for ML models.')
29parser.add_argument('--i', dest='input_name',
30 help='Model input image name', required=True)
31parser.add_argument('--o', dest='output_file',
32 help='Output binary name', required=True)
33parser.add_argument('--s', dest='input_shape',
34 help='Model input shape (example: "1, 224, 224, 3")', required=True)
35parser.add_argument('--q', dest='is_quant', action='store_true',
36 help='Indicate it is quant model (default: False)')
Lun Dongc12a70b2021-12-08 05:12:27 +000037parser.add_argument('--r', dest='float_input_range', default="-1.0, 1.0",
38 help='Float model input range (default: "-1.0, 1.0")')
Lun Dong5e4bc362022-08-03 16:24:22 -070039parser.add_argument('--sgn', dest='is_signed', action='store_true',
40 help='Indicate the input is signed (default: False)')
Lun Dongdb85fd72021-11-02 02:39:48 -070041args = parser.parse_args()
42
43
Lun Donga43730e2022-06-22 21:52:14 -070044def write_binary_file(file_path, input, is_quant, is_audio):
Lun Dongdb85fd72021-11-02 02:39:48 -070045 with open(file_path, "wb+") as file:
46 for d in input:
Lun Donga43730e2022-06-22 21:52:14 -070047 if is_audio:
48 file.write(struct.pack("<h", d))
49 elif is_quant:
Lun Dongdb85fd72021-11-02 02:39:48 -070050 file.write(struct.pack("<B", d))
51 else:
52 file.write(struct.pack("<f", d))
53
54
Lun Dong5e4bc362022-08-03 16:24:22 -070055def gen_mlmodel_input(input_name, output_file, input_shape, is_quant, is_signed):
Lun Dongdb85fd72021-11-02 02:39:48 -070056 if not os.path.exists(input_name):
Cindy Liu60ab7772021-12-05 17:30:31 -080057 raise RuntimeError("Input file %s doesn't exist" % {input_name})
Lun Dongdb85fd72021-11-02 02:39:48 -070058 if len(input_shape) < 3:
59 raise ValueError("Input shape < 3 dimensions")
Lun Dong99b4f1a2022-02-06 20:46:11 +000060 input_ext = os.path.splitext(input_name)[1]
Lun Donga43730e2022-06-22 21:52:14 -070061 is_audio = False
Lun Dong99b4f1a2022-02-06 20:46:11 +000062 if (not input_ext) or (input_ext == '.bin'):
63 with open(input_name, mode='rb') as f:
Lun Dong4b5bede2022-06-19 18:51:48 -070064 input = np.fromfile(f, dtype=np.uint8 if is_quant else np.float32)
65 input = input[:np.prod(input_shape)].reshape(np.prod(input_shape))
Lun Donga43730e2022-06-22 21:52:14 -070066 elif (input_ext == '.wav'):
67 is_audio = True
68 _, input = wavfile.read(input_name)
69 input = input[:np.prod(input_shape)].reshape(np.prod(input_shape))
Lun Dong99b4f1a2022-02-06 20:46:11 +000070 else:
71 resized_img = Image.open(input_name).resize(
Lun Dongcedef3d2022-07-29 21:40:17 -070072 (input_shape[2], input_shape[1]))
Lun Dong99b4f1a2022-02-06 20:46:11 +000073 input = np.array(resized_img).reshape(np.prod(input_shape))
74 if not is_quant:
75 low = np.min(float_input_range)
76 high = np.max(float_input_range)
77 input = (high - low) * input / 255.0 + low
Lun Dong5e4bc362022-08-03 16:24:22 -070078 elif is_signed:
79 input -= 128
Lun Donga43730e2022-06-22 21:52:14 -070080 write_binary_file(output_file, input, is_quant, is_audio)
Lun Dongdb85fd72021-11-02 02:39:48 -070081
82
83if __name__ == '__main__':
84 # convert input shape to a list
85 input_shape = [int(x) for x in args.input_shape.split(',')]
Lun Dongc12a70b2021-12-08 05:12:27 +000086 float_input_range = [float(x) for x in args.float_input_range.split(',')]
Lun Dongdb85fd72021-11-02 02:39:48 -070087 gen_mlmodel_input(args.input_name, args.output_file,
Lun Dong5e4bc362022-08-03 16:24:22 -070088 input_shape, args.is_quant, args.is_signed)