blob: 2879d37166413a04b0dd6b95797a32ac869cdd9f [file] [log] [blame]
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Provides object representation for the model that is conducive to code
generation using templates. """
import abc
from typing import Optional
import string
import textwrap
from tflite_micro.codegen import utils
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
class Operator(abc.ABC):
def __init__(self, operator: schema_fb.OperatorT):
self._operator: schema_fb.OperatorT = operator
self._inputs: utils.IntArray = utils.IntArray(self._operator.inputs)
self._outputs: utils.IntArray = utils.IntArray(self._operator.outputs)
self._intermediates: Optional[utils.IntArray] = utils.IntArray(
self._operator.intermediates) if self._operator.intermediates else None
def generate_c_node_data(self, type_name: str, node_name: str) -> str:
struct_template = string.Template("struct ${type_name} {\n"
"${body}"
"} ${node_name};")
body_template = string.Template("${inputs}\n"
"${outputs}\n"
"${intermediates}\n"
"${builtin_data}\n")
if self._intermediates:
intermediates = self._intermediates.generate_c_struct(
"Intermediates", "intermediates")
else:
intermediates = "// No intermediates"
body = body_template.substitute(
inputs=self._inputs.generate_c_struct("Inputs", "inputs"),
outputs=self._outputs.generate_c_struct("Outputs", "outputs"),
intermediates=intermediates,
builtin_data=self.generate_c_builtin_data())
return struct_template.substitute(type_name=type_name,
node_name=node_name,
body=textwrap.indent(body, " "))
def generate_c_node_init(self, tflite_node_name: str,
node_data_name: str) -> str:
init_template = string.Template(
"${tflite_node_name} = TfLiteNode{\n"
" .inputs ="
" reinterpret_cast<TfLiteIntArray*>(&${node_data_name}.inputs),\n"
" .outputs ="
" reinterpret_cast<TfLiteIntArray*>(&${node_data_name}.outputs),\n"
" .intermediates = ${intermediates},\n"
" .user_data = nullptr,\n"
" .builtin_data ="
" static_cast<void*>(&${node_data_name}.builtin_data),\n"
" .custom_initial_data = nullptr,\n"
" .custom_initial_data_size = 0};")
if self._intermediates:
intermediates = (
"reinterpret_cast<TfLiteIntArray*>(&{}.intermediates)".format(
self._intermediates))
else:
intermediates = "nullptr"
return init_template.substitute(tflite_node_name=tflite_node_name,
node_data_name=node_data_name,
intermediates=intermediates)
@property
def op_code_index(self) -> int:
return self._operator.opcodeIndex
@abc.abstractmethod
def generate_c_builtin_data(self) -> str:
raise NotImplementedError(f"Generating builtin data in {self.__name__}")