blob: de583dae78cd7ca06efd13c2e79c7c6936b01fec [file] [log] [blame]
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This tool generates a header with Micro Mutable Op Resolver code for a given
model. See README.md for more info.
"""
import os
import re
from absl import app
from absl import flags
from mako import template
from tflite_micro.tensorflow.lite.tools import visualize as visualize
TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), 'templates')
TEMPLATE_DIR = os.path.abspath(TEMPLATE_DIR)
FLAGS = flags.FLAGS
flags.DEFINE_string(
'common_tflite_path', None,
'Common path to tflite files. This need to be an absolute path.'
'This would typically be the path to the directory where the models reside.'
)
flags.DEFINE_list(
'input_tflite_files', None,
'Relative path name list of the input TFLite files.'
'This would be relative to the common path.'
'This would typically be the name(s) of the tflite file(s).')
flags.DEFINE_string('output_dir', None, 'Directory to output generated files.')
flags.DEFINE_string(
'verify_op_list_against_header', None,
'Take micro_mutable_op_resolver.h as input and verifies that all generated operator calls are there.'
)
flags.mark_flag_as_required('common_tflite_path')
flags.mark_flag_as_required('input_tflite_files')
flags.mark_flag_as_required('output_dir')
def ParseString(word):
"""Converts a flatbuffer operator string to a format suitable for Micro
Mutable Op Resolver. Example: CONV_2D --> AddConv2D."""
# Edge case for AddDetectionPostprocess().
# The custom code is TFLite_Detection_PostProcess.
word = word.replace('TFLite', '')
word_split = re.split('_|-', word)
formated_op_string = ''
for part in word_split:
if len(part) > 1:
if part[0].isalpha():
formated_op_string += part[0].upper() + part[1:].lower()
else:
formated_op_string += part.upper()
else:
formated_op_string += part.upper()
# Edge case for AddUnidirectionalSequenceLSTM().
formated_op_string = formated_op_string.replace('Lstm', 'LSTM')
return 'Add' + formated_op_string
def GenerateMicroMutableOpsResolverHeaderFile(operators, name_of_model,
output_dir):
"""Generates Micro Mutable Op Resolver code based on a template."""
number_of_ops = len(operators)
outfile = 'micro_mutable_op_resolver.h'
template_file_path = os.path.join(TEMPLATE_DIR, outfile + '.mako')
build_template = template.Template(filename=template_file_path)
with open(output_dir + '/gen_' + outfile, 'w') as file_obj:
key_values_in_template = {
'model': name_of_model,
'number_of_ops': number_of_ops,
'operators': operators
}
file_obj.write(build_template.render(**key_values_in_template))
def GetModelOperatorsAndActivation(model_path):
"""Extracts a set of operators from a tflite model."""
custom_op_found = False
operators_and_activations = set()
with open(model_path, 'rb') as f:
data_bytes = bytearray(f.read())
data = visualize.CreateDictFromFlatbuffer(data_bytes)
for op_code in data["operator_codes"]:
if op_code['custom_code'] is None:
op_code["builtin_code"] = max(op_code["builtin_code"],
op_code["deprecated_builtin_code"])
else:
custom_op_found = True
operators_and_activations.add(
visualize.NameListToString(op_code['custom_code']))
for op_code in data["operator_codes"]:
# Custom operator already added.
if custom_op_found and visualize.BuiltinCodeToName(
op_code['builtin_code']) == "CUSTOM":
continue
operators_and_activations.add(
visualize.BuiltinCodeToName(op_code['builtin_code']))
return operators_and_activations
def VerifyOpList(op_list, header):
"""Make sure operators in list are not missing in header file ."""
supported_op_list = []
with open(header, 'r') as f:
for l in f.readlines():
if "TfLiteStatus Add" in l:
op = l.strip().split(' ')[1].split('(')[0]
supported_op_list.append(op)
for op in op_list:
if op not in supported_op_list:
print(f'{op} not supported by TFLM')
return True
return False
def main(_):
model_names = []
final_operator_list = []
merged_operator_list = []
common_model_path = FLAGS.common_tflite_path
relative_model_paths = FLAGS.input_tflite_files
for relative_model_path in relative_model_paths:
full_model_path = f"{common_model_path}/{relative_model_path}"
operators = GetModelOperatorsAndActivation(full_model_path)
model_name = full_model_path.split('/')[-1]
model_names.append(model_name)
parsed_operator_list = []
for op in sorted(list(operators)):
parsed_operator_list.append(ParseString(op))
merged_operator_list = merged_operator_list + parsed_operator_list
number_models = len(model_names)
if number_models > 1:
model_name = ", ".join(model_names)
[
final_operator_list.append(operator) for operator in merged_operator_list
if operator not in final_operator_list
]
if FLAGS.verify_op_list_against_header and VerifyOpList(
final_operator_list, FLAGS.verify_op_list_against_header):
return True
os.makedirs(FLAGS.output_dir, exist_ok=True)
GenerateMicroMutableOpsResolverHeaderFile(final_operator_list, model_name,
FLAGS.output_dir)
return False
if __name__ == '__main__':
app.run(main)