blob: 83870fc0e24600b08add19908926bc4fc344642f [file] [log] [blame]
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Tensor class """
from typing import Dict, Optional
import string
import textwrap
from tflite_micro.codegen import utils
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
_TENSOR_TYPES: Dict[int, str] = {
schema_fb.TensorType.FLOAT16: "kTfLiteFloat16",
schema_fb.TensorType.FLOAT32: "kTfLiteFloat32",
schema_fb.TensorType.FLOAT64: "kTfLiteFloat64",
schema_fb.TensorType.INT16: "kTfLiteInt16",
schema_fb.TensorType.UINT16: "kTfLiteUInt16",
schema_fb.TensorType.INT32: "kTfLiteInt32",
schema_fb.TensorType.UINT32: "kTfLiteUInt32",
schema_fb.TensorType.UINT8: "kTfLiteUInt8",
schema_fb.TensorType.INT8: "kTfLiteInt8",
schema_fb.TensorType.INT64: "kTfLiteInt64",
schema_fb.TensorType.UINT64: "kTfLiteUInt64",
schema_fb.TensorType.STRING: "kTfLiteString",
schema_fb.TensorType.BOOL: "kTfLiteBool",
schema_fb.TensorType.COMPLEX64: "kTfLiteComplex64",
schema_fb.TensorType.COMPLEX128: "kTfLiteComplex128",
schema_fb.TensorType.RESOURCE: "kTfLiteResource",
schema_fb.TensorType.VARIANT: "kTfLiteVariant",
schema_fb.TensorType.INT4: "kTfLiteInt4",
}
class Buffer(object):
""" This buffer could be either a static array or a pointer into the arena """
def __init__(self, buffer_name: str, buffer: schema_fb.BufferT):
# TODO(rjascani): Get arena allocation offsets from preprocessor
self._buffer_name = buffer_name
self._buffer = buffer
@property
def address(self) -> str:
if self._buffer is None or self._buffer.data is None:
# TODO(rjascani): This needs to point into the arena
return f"nullptr /* {self._buffer_name} */"
return f"&{self._buffer_name}"
def generate_c_buffer_array(self, indent: str) -> str:
if self._buffer is None or self._buffer.data is None:
return f"// {self._buffer_name} is located in the arena\n"
buffer_template = string.Template(
"alignas(16) uint8_t ${buffer_name}[${size}] = {\n"
"${body}\n"
"};\n")
byte_strs = ['0x{:02X}'.format(b) for b in self._buffer.data]
lines = []
for byte_strs_for_line in utils.split_into_chunks(byte_strs, 12):
bytes_segment = ', '.join(byte_strs_for_line)
lines.append(f' {bytes_segment},')
return textwrap.indent(
buffer_template.substitute(buffer_name=self._buffer_name,
size=len(self._buffer.data),
body='\n'.join(lines)), indent)
class Tensor(object):
def __init__(self, buffer: Buffer, tensor: schema_fb.TensorT):
self._buffer = buffer
self._tensor: schema_fb.TensorT = tensor
@property
def buffer_index(self) -> bool:
return self._tensor.buffer
@property
def buffer(self) -> Buffer:
return self._buffer
@property
def has_shape(self) -> bool:
return self._tensor.shape is not None
@property
def needs_zero_length_int_array(self) -> bool:
return not self.has_shape
def generate_c_tensor_dims(self, type_name: str, tensor_name: str) -> str:
if not self.has_shape:
return f"// No data dims necessary for {tensor_name}"
return utils.IntArray(self._tensor.shape).generate_c_struct(
type_name + "Dims", tensor_name + "_dims")
def generate_c_tensor_init(self, tflite_tensor_name: str,
tensor_name: str) -> str:
init_template = string.Template(
"${tflite_tensor_name} = TfLiteEvalTensor{\n"
" .data = {.data = static_cast<void*>(${data})},\n"
" .dims = ${dims},\n"
" .type = ${tflite_type}};")
dims = "reinterpret_cast<TfLiteIntArray*>(&{})".format(
f"{tensor_name}_dims" if self._tensor.
shape is not None else "zero_length_int_array")
return init_template.substitute(
tflite_tensor_name=tflite_tensor_name,
tensor_name=tensor_name,
data=self._buffer.address,
dims=dims,
tflite_type=_TENSOR_TYPES[self._tensor.type])