| #!/usr/bin/env python3 |
| # Copyright 2023 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """ |
| Load images and generate ML model inputs in C include file style. |
| """ |
| |
| import argparse |
| import logging as log |
| import glob, os |
| from PIL import Image |
| |
| def main(): |
| """Query latest change lists by topic name from the Gerrit server. |
| |
| Args: |
| input: The input image file name. |
| output: The output C file name. |
| """ |
| parser = argparse.ArgumentParser(description='Generate inputs for ML models.') |
| parser.add_argument("--input", |
| "-i", |
| action='append', |
| required=True, |
| help="Input image file") |
| parser.add_argument("--output", |
| "-o", |
| required=True, |
| help="Output c file") |
| parser.add_argument("--input_shape", |
| "-s", |
| nargs='+', |
| type=int, |
| help='Input image shape (example: "320 240")') |
| parser.add_argument("--input_mode", |
| "-m", |
| default="L", |
| help='Input image mode (example: "RGB")') |
| parser.add_argument("--signed", |
| default=False, |
| dest='is_signed', |
| action='store_true', |
| help='Indicate the input is signed (default: False)') |
| args = parser.parse_args() |
| |
| # Open output c file |
| outfile = open(args.output, 'w') |
| # Write macro definition |
| include_guard = f"""MATCHA_{args.output.upper().replace('/', '_') |
| .replace('.', '_').replace('-', '_')}_""" |
| outfile.write(str(f"#ifndef {include_guard}\n" + f"#define {include_guard}\n\n")) |
| |
| # Write each image data into C array |
| for infile_path in args.input: |
| im = Image.open(infile_path) |
| # Check image mode |
| assert(im.mode == args.input_mode) |
| bytes_per_pixel = get_bytes_per_pixel(im.mode) |
| # Check image size |
| assert(im.size == tuple(args.input_shape)) |
| |
| im_data = list(im.getdata()) |
| if args.is_signed: |
| im_data = [pixel - 128 for pixel in im_data] |
| |
| # Set array vriable |
| data_var = infile_path.split('/')[-1].split('.')[0] |
| # TODO(ykwang): Enable other modes by using word type |
| outfile.write("const unsigned char " + data_var + "[] = {\n") |
| outfile.write(" ") |
| |
| # Compute checksum and check pixel range |
| checksum = 0 |
| byte_idx = 0 |
| for pixel in im_data: |
| if args.is_signed: |
| assert(-(0x1 << (bytes_per_pixel * 8 - 1)) <= \ |
| pixel < \ |
| (0x1 << (bytes_per_pixel * 8 - 1))) |
| else: |
| assert(0 <= pixel < (0x1 << (bytes_per_pixel * 8))) |
| # Compute checksum |
| # Cast the negative signed number into unsign number |
| if pixel < 0: |
| pixel = (pixel * (-1) - 1) ^ 0xff |
| checksum = (checksum + (pixel << (byte_idx * 8))) & 0xffffffff |
| byte_idx = (byte_idx + 1) % 4 |
| print("Checksum is ", hex(checksum)) |
| |
| # Write pixel into C file |
| max_num_per_line = 12 |
| outfile.write( |
| ",\n ".join([ |
| ", ".join([hex(pixel) for pixel in im_data[i:i + max_num_per_line]]) |
| for i in range(0, len(im_data), max_num_per_line)])) |
| outfile.write("\n};\n\n") |
| outfile.write("const unsigned int " + data_var + "_len = " +\ |
| str(len(im_data)) + ";\n\n") |
| |
| # Write macro definition |
| outfile.write(f"#endif // {include_guard}") |
| |
| def get_bytes_per_pixel(mode: str) -> int: |
| """Return the bytes per pixel for each modes""" |
| if mode == "L": |
| return 1 |
| elif mode == "RGB": |
| return 3 |
| elif "YCbCr": |
| return 3 |
| else: |
| raise ValueError(f"mode: {mode} is not supported yet") |
| |
| if __name__ == '__main__': |
| main() |