blob: 5e97c63ae7f8be5b5922cd120d823927f66f02ee [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import shutil
from absl import app
from absl import flags
from mako import template
from tflite_micro.tensorflow.lite.micro.tools import generate_test_for_model
TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), 'templates')
TEMPLATE_DIR = os.path.abspath(TEMPLATE_DIR)
FLAGS = flags.FLAGS
flags.DEFINE_string('input_tflite_file', None,
'Full path name to the input TFLite file.')
flags.DEFINE_string(
'output_dir', None, 'Directory to output generated files. \
Note that final output will be in FLAGS.output_dir/<base name of model>. \
Where <base name of model> will come from FLAGS.input_tflite_file.')
flags.DEFINE_integer('arena_size', 1024 * 136, 'Size of arena')
flags.DEFINE_boolean('verify_output', False,
'Verify output or just run model.')
flags.mark_flag_as_required('input_tflite_file')
flags.mark_flag_as_required('output_dir')
class MicroMutableOpTestGenerator(generate_test_for_model.TestDataGenerator):
def __init__(self, output_dir, model_path, verify_output, arena_size):
super().__init__(output_dir, [model_path], [0]) # Third argument not used.
self.verify_output = verify_output
self.arena_size = arena_size
self.target = model_path.split('/')[-1].split('.')[0]
self.target_with_path = model_path.split('tflite_micro/')[-1]. \
split('tflite-micro/')[-1].split('.')[0]
# Only int8 models supported
self.input_type = 'int8'
self.output_type = 'int8'
self.input_types = [self.input_type]
def generate_golden(self):
if not self.verify_output:
return
super().generate_golden_single_in_single_out()
def generate_test(self, template_dir, template_file, out_file):
template_file_path = os.path.join(template_dir, template_file)
build_template = template.Template(filename=template_file_path)
path_to_target = self.target_with_path.split('/' + self.target)[0] + \
'/' + self.target
with open(self.output_dir + '/' + out_file, 'w') as file_obj:
key_values_in_template = {
'arena_size': self.arena_size,
'verify_output': int(self.verify_output),
'path_to_target': path_to_target,
'target': self.target,
'target_with_path': self.target_with_path,
'input_dtype': self.input_type,
'output_dtype': self.output_type
}
file_obj.write(build_template.render(**key_values_in_template))
def generate_build_file(self, template_dir):
template_file_path = os.path.join(template_dir, 'BUILD.mako')
build_template = template.Template(filename=template_file_path)
with open(self.output_dir + '/BUILD', 'w') as file_obj:
key_values_in_template = {
'verify_output': self.verify_output,
'target': self.target,
'input_dtype': self.input_type,
'output_dtype': self.output_type
}
file_obj.write(build_template.render(**key_values_in_template))
def main(_):
model_path = FLAGS.input_tflite_file
model_name = model_path.split('/')[-1]
base_model_name = model_name.split('.')[0]
name_of_make_target = 'generated_micro_mutable_op_resolver_' + base_model_name
out_dir = FLAGS.output_dir + '/' + base_model_name
os.makedirs(out_dir, exist_ok=True)
# Copy model to out dir to get the Mako generation right
new_model_path = out_dir + '/' + model_name
shutil.copyfile(model_path, new_model_path)
data_generator = MicroMutableOpTestGenerator(out_dir, new_model_path,
FLAGS.verify_output,
FLAGS.arena_size)
data_generator.generate_golden()
data_generator.generate_build_file(TEMPLATE_DIR)
data_generator.generate_makefile(
test_file='micro_mutable_op_resolver_test.cc',
src_prefix=name_of_make_target)
data_generator.generate_test(
TEMPLATE_DIR,
template_file='micro_mutable_op_resolver_test.cc.mako',
out_file='micro_mutable_op_resolver_test.cc')
if __name__ == '__main__':
app.run(main)