| # Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Python package for TFLM Python Interpreter""" |
| |
| import os |
| |
| from tflite_micro.python.tflite_micro import _runtime |
| from tflite_micro.tensorflow.lite.tools import flatbuffer_utils |
| |
| |
| class Interpreter(object): |
| |
| def __init__(self, model_data, custom_op_registerers, arena_size): |
| if model_data is None: |
| raise ValueError("Model must not be None") |
| |
| if not isinstance(custom_op_registerers, list) or not all( |
| isinstance(s, str) for s in custom_op_registerers): |
| raise ValueError("Custom ops registerers must be a list of strings") |
| |
| # This is a heuristic to ensure that the arena is sufficiently sized. |
| if arena_size is None: |
| arena_size = len(model_data) * 10 |
| |
| # Some models make use of resource variables ops, get the count here |
| num_resource_variables = flatbuffer_utils.count_resource_variables( |
| model_data) |
| print("Number of resource variables the model uses = ", |
| num_resource_variables) |
| |
| self._interpreter = _runtime.InterpreterWrapper(model_data, |
| custom_op_registerers, |
| arena_size, |
| num_resource_variables) |
| |
| @classmethod |
| def from_file(self, model_path, custom_op_registerers=[], arena_size=None): |
| """Instantiates a TFLM interpreter from a model .tflite filepath. |
| |
| Args: |
| model_path: Filepath to the .tflite model |
| custom_op_registerers: List of strings, each of which is the name of a |
| custom OP registerer |
| arena_size: Tensor arena size in bytes. If unused, tensor arena size will |
| default to 10 times the model size. |
| |
| Returns: |
| An Interpreter instance |
| """ |
| if model_path is None or not os.path.isfile(model_path): |
| raise ValueError("Invalid model file path") |
| |
| with open(model_path, "rb") as f: |
| model_data = f.read() |
| |
| return Interpreter(model_data, custom_op_registerers, arena_size) |
| |
| @classmethod |
| def from_bytes(self, model_data, custom_op_registerers=[], arena_size=None): |
| """Instantiates a TFLM interpreter from a model in byte array. |
| |
| Args: |
| model_data: Model in byte array format |
| custom_op_registerers: List of strings, each of which is the name of a |
| custom OP registerer |
| arena_size: Tensor arena size in bytes. If unused, tensor arena size will |
| default to 10 times the model size. |
| |
| Returns: |
| An Interpreter instance |
| """ |
| |
| return Interpreter(model_data, custom_op_registerers, arena_size) |
| |
| def print_allocations(self): |
| """Invoke the RecordingMicroAllocator to print the arena usage. |
| |
| This should be called after `invoke()`. |
| |
| Returns: |
| This method does not return anything, but It dumps the arena |
| usage to stderr. |
| """ |
| self._interpreter.PrintAllocations() |
| |
| def invoke(self): |
| """Invoke the TFLM interpreter to run an inference. |
| |
| This should be called after `set_input()`. |
| |
| Returns: |
| Status code of the C++ invoke function. A RuntimeError will be raised as |
| well upon any error. |
| """ |
| return self._interpreter.Invoke() |
| |
| def reset(self): |
| """Reset the model state to be what you would expect when the interpreter is first |
| |
| created. i.e. after Init and Prepare is called for the very first time. |
| |
| This should be called after invoke stateful model like LSTM. |
| |
| Returns: |
| Status code of the C++ invoke function. A RuntimeError will be raised as |
| well upon any error. |
| """ |
| return self._interpreter.Reset() |
| |
| def set_input(self, input_data, index): |
| """Set input data into input tensor. |
| |
| This should be called before `invoke()`. |
| |
| Args: |
| input_data: Input data in numpy array format. The numpy array format is |
| chosen to be consistent with TFLite interpreter. |
| index: An integer between 0 and the number of input tensors (exclusive) |
| consistent with the order defined in the list of inputs in the .tflite |
| model |
| """ |
| if input_data is None: |
| raise ValueError("Input data must not be None") |
| if index is None or index < 0: |
| raise ValueError("Index must be a non-negative integer") |
| |
| self._interpreter.SetInputTensor(input_data, index) |
| |
| def get_output(self, index): |
| """Get data from output tensor. |
| |
| The output data correspond to the most recent `invoke()`. |
| |
| Args: |
| index: An integer between 0 and the number of output tensors (exclusive) |
| consistent with the order defined in the list of outputs in the .tflite |
| model |
| |
| Returns: |
| Output data in numpy array format. The numpy array format is chosen to |
| be consistent with TFLite interpreter. |
| """ |
| if index is None or index < 0: |
| raise ValueError("Index must be a non-negative integer") |
| |
| return self._interpreter.GetOutputTensor(index) |
| |
| def get_input_details(self, index): |
| """Get input tensor information |
| |
| Args: |
| index (int): An integer between 0 and the number of output tensors |
| (exclusive) consistent with the order defined in the list of outputs |
| in the .tflite model |
| |
| Returns: |
| A dictionary from input index to tensor details where each item is a |
| dictionary with details about an input tensor. Each dictionary contains |
| the following fields that describe the tensor: |
| + `shape`: The shape of the tensor. |
| + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`). |
| + `quantization_parameters`: A dictionary of parameters used to quantize |
| the tensor: |
| ~ `scales`: List of scales (one if per-tensor quantization). |
| ~ `zero_points`: List of zero_points (one if per-tensor quantization). |
| ~ `quantized_dimension`: Specifies the dimension of per-axis |
| quantization, in the case of multiple scales/zero_points. |
| |
| """ |
| if index is None or index < 0: |
| raise ValueError("Index must be a non-negative integer") |
| |
| return self._interpreter.GetInputTensorDetails(index) |
| |
| def get_output_details(self, index): |
| """Get output tensor information |
| |
| Args: |
| index (int): An integer between 0 and the number of output tensors |
| (exclusive) consistent with the order defined in the list of outputs |
| in the .tflite model |
| |
| Returns: |
| A dictionary from input index to tensor details where each item is a |
| dictionary with details about an input tensor. Each dictionary contains |
| the following fields that describe the tensor: |
| + `shape`: The shape of the tensor. |
| + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`). |
| + `quantization_parameters`: A dictionary of parameters used to quantize |
| the tensor: |
| ~ `scales`: List of scales (one if per-tensor quantization). |
| ~ `zero_points`: List of zero_points (one if per-tensor quantization). |
| ~ `quantized_dimension`: Specifies the dimension of per-axis |
| quantization, in the case of multiple scales/zero_points. |
| |
| """ |
| if index is None or index < 0: |
| raise ValueError("Index must be a non-negative integer") |
| |
| return self._interpreter.GetOutputTensorDetails(index) |