blob: 867b4ae88b37d4c3ce538a1be230ce5a8380da18 [file] [log] [blame]
# Copyright 2020 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import logging
import os
import sys
import tempfile
import unittest
from iree.compiler.api import (
Session,
Source,
Output,
)
# TODO: No idea why pytype cannot find names from this module.
# pytype: disable=name-error
import iree.compiler.tools.tflite
if not iree.compiler.tools.tflite.is_available():
print(
f"Skipping test {__file__} because the IREE TFLite compiler "
f"is not installed"
)
sys.exit(0)
def mlir_bytecode_to_text(bytecode):
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(session, bytecode)
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir(out)
text_ir = str(bytes(out.map_memory()))
out.close()
return text_ir
class CompilerTest(unittest.TestCase):
def testImportBinaryPbFile(self):
path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb")
bytecode = iree.compiler.tools.tflite.compile_file(path, import_only=True)
text = mlir_bytecode_to_text(bytecode)
logging.info("%s", text)
self.assertIn("tosa.mul", text)
def testCompileBinaryPbFile(self):
path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb")
binary = iree.compiler.tools.tflite.compile_file(
path, target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS
)
logging.info("Binary length = %d", len(binary))
self.assertIn(b"main", binary)
def testImportBinaryPbFileOutputFile(self):
path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb")
with tempfile.NamedTemporaryFile("wt", delete=False) as f:
try:
f.close()
output = iree.compiler.tools.tflite.compile_file(
path, import_only=True, output_file=f.name
)
self.assertIsNone(output)
with open(f.name, "rb") as f_read:
bytecode = f_read.read()
finally:
os.remove(f.name)
text = mlir_bytecode_to_text(bytecode)
logging.info("%s", text)
self.assertIn("tosa.mul", text)
def testCompileBinaryPbFileOutputFile(self):
path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb")
with tempfile.NamedTemporaryFile("wt", delete=False) as f:
try:
f.close()
output = iree.compiler.tools.tflite.compile_file(
path,
output_file=f.name,
target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS,
)
self.assertIsNone(output)
with open(f.name, "rb") as f_read:
binary = f_read.read()
finally:
os.remove(f.name)
logging.info("Binary length = %d", len(binary))
self.assertIn(b"main", binary)
def testImportBinaryPbBytes(self):
path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb")
with open(path, "rb") as f:
content = f.read()
bytecode = iree.compiler.tools.tflite.compile_str(content, import_only=True)
text = mlir_bytecode_to_text(bytecode)
logging.info("%s", text)
self.assertIn("tosa.mul", text)
def testCompileBinaryPbBytes(self):
path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb")
with open(path, "rb") as f:
content = f.read()
binary = iree.compiler.tools.tflite.compile_str(
content, target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS
)
logging.info("Binary length = %d", len(binary))
self.assertIn(b"main", binary)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()