blob: 759db56a9027d53152797adb270d5bc7aae2b228 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for converting .tflite, .bmp and .wav files to cc arrays."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import struct
import wave
import numpy as np
from PIL import Image
def generate_file(out_fname, array_name, array_type, array_contents, size):
"""Write an array of values to a CC or header file."""
os.makedirs(os.path.dirname(out_fname), exist_ok=True)
if out_fname.endswith('.cc'):
out_cc_file = open(out_fname, 'w')
out_cc_file.write('#include <cstdint>\n\n')
out_cc_file.write('#include "{}"\n\n'.format(
out_fname.split('genfiles/')[-1].replace('.cc', '.h')))
out_cc_file.write('const unsigned int {}_size = {};\n'.format(
array_name, str(size)))
out_cc_file.write('alignas(16) const {} {}[] = {{'.format(
array_type, array_name))
out_cc_file.write(array_contents)
out_cc_file.write('};\n')
out_cc_file.close()
elif out_fname.endswith('.h'):
out_hdr_file = open(out_fname, 'w')
out_hdr_file.write('#include <cstdint>\n\n')
out_hdr_file.write(
'extern const unsigned int {}_size;\n'.format(array_name))
out_hdr_file.write('extern const {} {}[];\n'.format(
array_type, array_name))
out_hdr_file.close()
else:
raise ValueError('generated file must be end with .cc or .h')
def bytes_to_hexstring(buffer):
"""Convert a byte array to a hex string."""
hex_values = [hex(buffer[i]) for i in range(len(buffer))]
out_string = ','.join(hex_values)
return out_string
def generate_array(input_fname):
"""Return array size and array of data from the input file."""
if input_fname.endswith('.tflite'):
with open(input_fname, 'rb') as input_file:
buffer = input_file.read()
size = len(buffer)
out_string = bytes_to_hexstring(buffer)
return [size, out_string]
elif input_fname.endswith('.bmp'):
img = Image.open(input_fname, mode='r')
image_bytes = img.tobytes()
size = len(image_bytes)
out_string = bytes_to_hexstring(image_bytes)
return [size, out_string]
elif input_fname.endswith('.wav'):
wav_file = wave.open(input_fname, mode='r')
num_channels = wav_file.getnchannels()
n_frames = wav_file.getnframes()
frames = wav_file.readframes(n_frames)
samples = struct.unpack('<%dh' % (num_channels * n_frames), frames)
out_string = ','.join(map(str, samples))
wav_file.close()
return [wav_file.getnframes(), out_string]
elif input_fname.endswith('.csv'):
with open(input_fname, 'r') as input_file:
# Assume one array per csv file.
elements = input_file.readline()
return [len(elements.split(',')), elements]
elif input_fname.endswith('.npy'):
data = np.float32(np.load(input_fname, allow_pickle=False))
data_1d = data.flatten()
out_string = ','.join([str(x) for x in data_1d])
return [len(data_1d), out_string]
else:
with open(input_fname, 'rb') as input_file:
buffer = input_file.read()
size = len(buffer)
out_string = bytes_to_hexstring(buffer)
return [size, out_string]
def get_array_name(input_fname):
# Normalize potential relative path to remove additional dot.
abs_fname = os.path.abspath(input_fname)
base_array_name = 'g_' + abs_fname.split('.')[-2].split('/')[-1]
if input_fname.endswith('.tflite'):
return [base_array_name + '_model_data', 'unsigned char']
elif input_fname.endswith('.bmp'):
return [base_array_name + '_image_data', 'unsigned char']
elif input_fname.endswith('.wav'):
return [base_array_name + '_audio_data', 'int16_t']
elif input_fname.endswith('_int32.csv'):
return [base_array_name + '_test_data', 'int32_t']
elif input_fname.endswith('_int16.csv'):
return [base_array_name + '_test_data', 'int16_t']
elif input_fname.endswith('_int8.csv'):
return [base_array_name + '_test_data', 'int8_t']
elif input_fname.endswith('_float.csv'):
return [base_array_name + '_test_data', 'float']
elif input_fname.endswith('npy'):
return [base_array_name + '_test_data', 'float']
else:
return [base_array_name, 'unsigned char']
def main():
"""Create cc sources with c arrays with data from each .tflite or .bmp."""
parser = argparse.ArgumentParser()
parser.add_argument(
'output',
help='base directory for all outputs or a cc or header to generate.')
parser.add_argument(
'inputs',
nargs='+',
help='input wav, bmp or tflite files to convert. '
'If output is a cc or header only one input may be specified.')
args = parser.parse_args()
if args.output.endswith('.cc') or args.output.endswith('.h'):
assert len(args.inputs) == 1
size, cc_array = generate_array(args.inputs[0])
generated_array_name, array_type = get_array_name(args.inputs[0])
generate_file(args.output, generated_array_name, array_type, cc_array,
size)
else:
# Deduplicate inputs to prevent duplicate generated files (ODR issue).
for input_file in list(dict.fromkeys(args.inputs)):
output_base_fname = os.path.join(args.output,
os.path.splitext(input_file)[0])
if input_file.endswith('.tflite'):
output_base_fname = output_base_fname + '_model_data'
elif input_file.endswith('.bmp'):
output_base_fname = output_base_fname + '_image_data'
elif input_file.endswith('.wav'):
output_base_fname = output_base_fname + '_audio_data'
elif input_file.endswith('.csv'):
output_base_fname = output_base_fname + '_test_data'
elif input_file.endswith('.npy'):
output_base_fname = output_base_fname + '_test_data'
else:
raise ValueError(
'input file must be .tflite, .bmp, .wav , .npy or .csv')
output_cc_fname = output_base_fname + '.cc'
# Print output cc filename for Make to include it in the build.
print(output_cc_fname)
output_hdr_fname = output_base_fname + '.h'
size, cc_array = generate_array(input_file)
generated_array_name, array_type = get_array_name(input_file)
generate_file(output_cc_fname, generated_array_name, array_type,
cc_array, size)
generate_file(output_hdr_fname, generated_array_name, array_type,
cc_array, size)
if __name__ == '__main__':
main()