blob: 89e985516963b8bdec2742cd819f1cb6ae28cf45 [file] [log] [blame]
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A simple test to check whether the tflite_micro package works after it is
# installed.
# To test from the perspective of a package user, use import paths to locations
# in the Python installation environment rather than to locations in the tflm
# source tree.
from tflite_micro import runtime
import numpy as np
import pkg_resources
import sys
def passed():
# Create an interpreter with a sine model
model = pkg_resources.resource_filename(__name__, "sine_float.tflite")
interpreter = runtime.Interpreter.from_file(model)
OUTPUT_INDEX = 0
INPUT_INDEX = 0
input_shape = interpreter.get_input_details(INPUT_INDEX).get("shape")
# The interpreter infers sin(x)
def infer(x):
tensor = np.array(x, np.float32).reshape(input_shape)
interpreter.set_input(tensor, INPUT_INDEX)
interpreter.invoke()
return interpreter.get_output(OUTPUT_INDEX).squeeze()
# Check a few inferred values against a numerical computation
PI = 3.14
inputs = (0.0, PI / 2, PI, 3 * PI / 2, 2 * PI)
outputs = [infer(x) for x in inputs]
goldens = np.sin(inputs)
return np.allclose(outputs, goldens, atol=0.05)
if __name__ == "__main__":
sys.exit(0 if passed() else 1)