blob: 3ea48e34b38927a3f0ad4efa412cfa5ea0fe4fb1 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Load images and generate ML model inputs in C include file style.
"""
import argparse
import logging as log
import glob, os
from PIL import Image
def main():
"""Query latest change lists by topic name from the Gerrit server.
Args:
input: The input image file name.
output: The output C file name.
"""
parser = argparse.ArgumentParser(description='Generate inputs for ML models.')
parser.add_argument("--input",
"-i",
action='append',
required=True,
help="Input image file")
parser.add_argument("--output",
"-o",
required=True,
help="Output c file")
parser.add_argument("--input_shape",
"-s",
nargs='+',
type=int,
help='Input image shape (example: "320 240")')
parser.add_argument("--input_mode",
"-m",
default="L",
help='Input image mode (example: "RGB")')
parser.add_argument("--signed",
default=False,
dest='is_signed',
action='store_true',
help='Indicate the input is signed (default: False)')
args = parser.parse_args()
# Open output c file
outfile = open(args.output, 'w')
# Write macro definition
include_guard = f"""MATCHA_{args.output.upper().replace('/', '_')
.replace('.', '_').replace('-', '_')}_"""
outfile.write(str(f"#ifndef {include_guard}\n" + f"#define {include_guard}\n\n"))
# Write each image data into C array
for infile_path in args.input:
im = Image.open(infile_path)
# Check image mode
assert(im.mode == args.input_mode)
bytes_per_pixel = get_bytes_per_pixel(im.mode)
# Check image size
assert(im.size == tuple(args.input_shape))
im_data = list(im.getdata())
if args.is_signed:
im_data = [pixel - 128 for pixel in im_data]
# Set array vriable
data_var = infile_path.split('/')[-1].split('.')[0]
# TODO(ykwang): Enable other modes by using word type
outfile.write("const unsigned char " + data_var + "[] = {\n")
outfile.write(" ")
# Compute checksum and check pixel range
checksum = 0
byte_idx = 0
for pixel in im_data:
if args.is_signed:
assert(-(0x1 << (bytes_per_pixel * 8 - 1)) <= \
pixel < \
(0x1 << (bytes_per_pixel * 8 - 1)))
else:
assert(0 <= pixel < (0x1 << (bytes_per_pixel * 8)))
# Compute checksum
# Cast the negative signed number into unsign number
if pixel < 0:
pixel = (pixel * (-1) - 1) ^ 0xff
checksum = (checksum + (pixel << (byte_idx * 8))) & 0xffffffff
byte_idx = (byte_idx + 1) % 4
print("Checksum is ", hex(checksum))
# Write pixel into C file
max_num_per_line = 12
outfile.write(
",\n ".join([
", ".join([hex(pixel) for pixel in im_data[i:i + max_num_per_line]])
for i in range(0, len(im_data), max_num_per_line)]))
outfile.write("\n};\n\n")
outfile.write("const unsigned int " + data_var + "_len = " +\
str(len(im_data)) + ";\n\n")
# Write macro definition
outfile.write(f"#endif // {include_guard}")
def get_bytes_per_pixel(mode: str) -> int:
"""Return the bytes per pixel for each modes"""
if mode == "L":
return 1
elif mode == "RGB":
return 3
elif "YCbCr":
return 3
else:
raise ValueError(f"mode: {mode} is not supported yet")
if __name__ == '__main__':
main()