Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 1 | #!/usr/bin/env python3 |
Adam Jesionowski | 6e273a7 | 2022-04-14 12:20:20 -0700 | [diff] [blame] | 2 | # Copyright 2022 Google LLC |
| 3 | # |
| 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | # you may not use this file except in compliance with the License. |
| 6 | # You may obtain a copy of the License at |
| 7 | # |
| 8 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | # |
| 10 | # Unless required by applicable law or agreed to in writing, software |
| 11 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | # See the License for the specific language governing permissions and |
| 14 | # limitations under the License. |
| 15 | |
Cindy Liu | 60ab777 | 2021-12-05 17:30:31 -0800 | [diff] [blame] | 16 | """Generate ML model inputs from images.""" |
Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 17 | import argparse |
| 18 | import os |
Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 19 | import struct |
| 20 | import urllib.request |
| 21 | |
| 22 | import numpy as np |
| 23 | from PIL import Image |
Lun Dong | a43730e | 2022-06-22 21:52:14 -0700 | [diff] [blame] | 24 | from scipy.io import wavfile |
Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 25 | |
| 26 | |
| 27 | parser = argparse.ArgumentParser( |
| 28 | description='Generate inputs for ML models.') |
| 29 | parser.add_argument('--i', dest='input_name', |
| 30 | help='Model input image name', required=True) |
| 31 | parser.add_argument('--o', dest='output_file', |
| 32 | help='Output binary name', required=True) |
| 33 | parser.add_argument('--s', dest='input_shape', |
| 34 | help='Model input shape (example: "1, 224, 224, 3")', required=True) |
| 35 | parser.add_argument('--q', dest='is_quant', action='store_true', |
| 36 | help='Indicate it is quant model (default: False)') |
Lun Dong | c12a70b | 2021-12-08 05:12:27 +0000 | [diff] [blame] | 37 | parser.add_argument('--r', dest='float_input_range', default="-1.0, 1.0", |
| 38 | help='Float model input range (default: "-1.0, 1.0")') |
Lun Dong | 5e4bc36 | 2022-08-03 16:24:22 -0700 | [diff] [blame] | 39 | parser.add_argument('--sgn', dest='is_signed', action='store_true', |
| 40 | help='Indicate the input is signed (default: False)') |
Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 41 | args = parser.parse_args() |
| 42 | |
| 43 | |
Lun Dong | a43730e | 2022-06-22 21:52:14 -0700 | [diff] [blame] | 44 | def write_binary_file(file_path, input, is_quant, is_audio): |
Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 45 | with open(file_path, "wb+") as file: |
| 46 | for d in input: |
Lun Dong | a43730e | 2022-06-22 21:52:14 -0700 | [diff] [blame] | 47 | if is_audio: |
| 48 | file.write(struct.pack("<h", d)) |
| 49 | elif is_quant: |
Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 50 | file.write(struct.pack("<B", d)) |
| 51 | else: |
| 52 | file.write(struct.pack("<f", d)) |
| 53 | |
| 54 | |
Lun Dong | 5e4bc36 | 2022-08-03 16:24:22 -0700 | [diff] [blame] | 55 | def gen_mlmodel_input(input_name, output_file, input_shape, is_quant, is_signed): |
Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 56 | if not os.path.exists(input_name): |
Cindy Liu | 60ab777 | 2021-12-05 17:30:31 -0800 | [diff] [blame] | 57 | raise RuntimeError("Input file %s doesn't exist" % {input_name}) |
Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 58 | if len(input_shape) < 3: |
| 59 | raise ValueError("Input shape < 3 dimensions") |
Lun Dong | 99b4f1a | 2022-02-06 20:46:11 +0000 | [diff] [blame] | 60 | input_ext = os.path.splitext(input_name)[1] |
Lun Dong | a43730e | 2022-06-22 21:52:14 -0700 | [diff] [blame] | 61 | is_audio = False |
Lun Dong | 99b4f1a | 2022-02-06 20:46:11 +0000 | [diff] [blame] | 62 | if (not input_ext) or (input_ext == '.bin'): |
| 63 | with open(input_name, mode='rb') as f: |
Lun Dong | 4b5bede | 2022-06-19 18:51:48 -0700 | [diff] [blame] | 64 | input = np.fromfile(f, dtype=np.uint8 if is_quant else np.float32) |
| 65 | input = input[:np.prod(input_shape)].reshape(np.prod(input_shape)) |
Lun Dong | a43730e | 2022-06-22 21:52:14 -0700 | [diff] [blame] | 66 | elif (input_ext == '.wav'): |
| 67 | is_audio = True |
| 68 | _, input = wavfile.read(input_name) |
| 69 | input = input[:np.prod(input_shape)].reshape(np.prod(input_shape)) |
Lun Dong | 99b4f1a | 2022-02-06 20:46:11 +0000 | [diff] [blame] | 70 | else: |
| 71 | resized_img = Image.open(input_name).resize( |
Lun Dong | cedef3d | 2022-07-29 21:40:17 -0700 | [diff] [blame] | 72 | (input_shape[2], input_shape[1])) |
Lun Dong | 99b4f1a | 2022-02-06 20:46:11 +0000 | [diff] [blame] | 73 | input = np.array(resized_img).reshape(np.prod(input_shape)) |
| 74 | if not is_quant: |
| 75 | low = np.min(float_input_range) |
| 76 | high = np.max(float_input_range) |
| 77 | input = (high - low) * input / 255.0 + low |
Lun Dong | 5e4bc36 | 2022-08-03 16:24:22 -0700 | [diff] [blame] | 78 | elif is_signed: |
| 79 | input -= 128 |
Lun Dong | a43730e | 2022-06-22 21:52:14 -0700 | [diff] [blame] | 80 | write_binary_file(output_file, input, is_quant, is_audio) |
Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 81 | |
| 82 | |
| 83 | if __name__ == '__main__': |
| 84 | # convert input shape to a list |
| 85 | input_shape = [int(x) for x in args.input_shape.split(',')] |
Lun Dong | c12a70b | 2021-12-08 05:12:27 +0000 | [diff] [blame] | 86 | float_input_range = [float(x) for x in args.float_input_range.split(',')] |
Lun Dong | db85fd7 | 2021-11-02 02:39:48 -0700 | [diff] [blame] | 87 | gen_mlmodel_input(args.input_name, args.output_file, |
Lun Dong | 5e4bc36 | 2022-08-03 16:24:22 -0700 | [diff] [blame] | 88 | input_shape, args.is_quant, args.is_signed) |