Add initial version of IREE TFLite importer compiler frontend.
* This uses the TOSA conversion path upstream.
* Since IREE does not yet support ingesting TOSA, tests that do full compilation are skipped (but present).
* This uncovered some layering and ergonomic issues upstream that I will likely pre-patch before landing this.
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0aaf142..2890087 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -46,6 +46,7 @@
option(IREE_BUILD_JAVA_BINDINGS "Builds the IREE java bindings." OFF)
option(IREE_BUILD_EXPERIMENTAL "Builds experimental projects." OFF)
option(IREE_BUILD_TENSORFLOW_COMPILER "Builds TensorFlow compiler frontend." OFF)
+option(IREE_BUILD_TFLITE_COMPILER "Builds the TFLite compiler frontend." OFF)
option(IREE_BUILD_XLA_COMPILER "Builds TensorFlow XLA compiler frontend." OFF)
set(IREE_HAL_DRIVERS_TO_BUILD "all"
@@ -57,7 +58,9 @@
# Note that this is a normal CMake variable used to gate build features (not
# a cache variable that is user-settable).
set(IREE_ENABLE_TENSORFLOW OFF)
-if(${IREE_BUILD_TENSORFLOW_COMPILER} OR ${IREE_BUILD_XLA_COMPILER})
+if(${IREE_BUILD_TENSORFLOW_COMPILER} OR
+ ${IREE_BUILD_TFLITE_COMPILER} OR
+ ${IREE_BUILD_XLA_COMPILER})
set(IREE_ENABLE_TENSORFLOW ON)
endif()
diff --git a/bindings/python/pyiree/compiler2/CMakeLists.txt b/bindings/python/pyiree/compiler2/CMakeLists.txt
index f5bf412..c203bdd 100644
--- a/bindings/python/pyiree/compiler2/CMakeLists.txt
+++ b/bindings/python/pyiree/compiler2/CMakeLists.txt
@@ -22,6 +22,7 @@
"__init__.py"
"core.py"
"tf.py"
+ "tflite.py"
"tools.py"
"xla.py"
)
diff --git a/bindings/python/pyiree/compiler2/tflite.py b/bindings/python/pyiree/compiler2/tflite.py
new file mode 100644
index 0000000..64f9bf2
--- /dev/null
+++ b/bindings/python/pyiree/compiler2/tflite.py
@@ -0,0 +1,175 @@
+# Lint-as: python3
+"""TFLite compiler interface."""
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# TODO(#4131) python>=3.7: Use postponed type annotations.
+
+from enum import Enum
+import logging
+import tempfile
+from typing import List, Optional, Sequence, Set, Union
+
+from .tools import find_tool, invoke_immediate, invoke_pipeline
+from .core import CompilerOptions, DEFAULT_TESTING_BACKENDS, build_compile_command_line
+
+__all__ = [
+ "compile_file",
+ "compile_str",
+ "is_available",
+ "DEFAULT_TESTING_BACKENDS",
+ "ImportOptions",
+]
+
+_IMPORT_TOOL = "iree-import-tflite"
+
+
+def is_available():
+ """Determine if the XLA frontend is available."""
+ try:
+ find_tool(_IMPORT_TOOL)
+ except ValueError:
+ logging.warning("Unable to find IREE tool %s", _IMPORT_TOOL)
+ return False
+ return True
+
+
+# TODO(#4131) python>=3.7: Consider using a dataclass.
+class ImportOptions(CompilerOptions):
+ """Import options layer on top of the backend compiler options."""
+
+ def __init__(self,
+ input_arrays: Sequence[str] = (),
+ output_arrays: Sequence[str] = (),
+ import_only: bool = False,
+ import_extra_args: Sequence[str] = (),
+ save_temp_tfl_input: Optional[str] = None,
+ save_temp_iree_input: Optional[str] = None,
+ **kwargs):
+ """Initialize options from keywords.
+
+ Args:
+ input_arrays: Sequence of input array node names (if different from
+ default).
+ output_arrays: Sequence of output array node names (if different from
+ default).
+ import_only: Only import the module. If True, the result will be textual
+ MLIR that can be further fed to the IREE compiler. If False (default),
+ the result will be the fully compiled IREE binary. In both cases,
+ bytes-like output is returned. Note that if the output_file= is
+ specified and import_only=True, then the MLIR form will be written to
+ the output file.
+ import_extra_args: Extra arguments to pass to the iree-tf-import tool.
+ save_temp_tfl_input: Optionally save the IR that results from importing
+ the flatbuffer (prior to any further transformations).
+ save_temp_iree_input: Optionally save the IR that is the result of the
+ import (ready to be passed to IREE).
+ """
+ super().__init__(**kwargs)
+ self.input_arrays = input_arrays
+ self.output_arrays = output_arrays
+ self.import_only = import_only
+ self.import_extra_args = import_extra_args
+ self.save_temp_tfl_input = save_temp_tfl_input
+ self.save_temp_iree_input = save_temp_iree_input
+
+
+def build_import_command_line(input_path: str,
+ options: ImportOptions) -> List[str]:
+ """Builds a command line for invoking the import stage.
+
+ Args:
+ input_path: The input path.
+ options: Import options.
+ Returns:
+ List of strings of command line.
+ """
+ import_tool = find_tool(_IMPORT_TOOL)
+ cl = [
+ import_tool,
+ input_path,
+ ]
+ if options.import_only and options.output_file:
+ # Import stage directly outputs.
+ if options.output_file:
+ cl.append(f"-o={options.output_file}")
+ # Input arrays.
+ if options.input_arrays:
+ for input_array in options.input_arrays:
+ cl.append(f"--input-array={input_array}")
+ for output_array in options.output_arrays:
+ cl.append(f"--output-array={output_array}")
+ # Save temps flags.
+ if options.save_temp_tfl_input:
+ cl.append(f"--save-temp-tfl-input={options.save_temp_tfl_input}")
+ if options.save_temp_iree_input:
+ cl.append(f"--save-temp-iree-input={options.save_temp_iree_input}")
+ # Extra args.
+ cl.extend(options.import_extra_args)
+ return cl
+
+
+def compile_file(fb_path: str, **kwargs):
+ """Compiles a TFLite flatbuffer file to an IREE binary.
+
+ Args:
+ fb_path: Path to the flatbuffer.
+ **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+ Returns:
+ A bytes-like object with the compiled output or None if output_file=
+ was specified.
+ """
+ options = ImportOptions(**kwargs)
+ import_cl = build_import_command_line(fb_path, options)
+ if options.import_only:
+ # One stage tool pipeline.
+ result = invoke_immediate(import_cl)
+ if options.output_file:
+ return None
+ return result
+
+ # Full compilation pipeline.
+ compile_cl = build_compile_command_line("-", options)
+ result = invoke_pipeline([import_cl, compile_cl])
+ if options.output_file:
+ return None
+ return result
+
+
+def compile_str(fb_content: bytes, **kwargs):
+ """Compiles in-memory TFLite flatbuffer to an IREE binary.
+
+ Args:
+ xla_content: Flatbuffer content as bytes.
+ **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+ Returns:
+ A bytes-like object with the compiled output or None if output_file=
+ was specified.
+ """
+ options = ImportOptions(**kwargs)
+ import_cl = build_import_command_line("-", options)
+ if options.import_only:
+ # One stage tool pipeline.
+ result = invoke_immediate(import_cl, immediate_input=fb_content)
+ if options.output_file:
+ return None
+ return result
+
+ # Full compilation pipeline.
+ compile_cl = build_compile_command_line("-", options)
+ result = invoke_pipeline([import_cl, compile_cl], immediate_input=fb_content)
+ if options.output_file:
+ return None
+ return result
diff --git a/bindings/python/pyiree/compiler2/tools.py b/bindings/python/pyiree/compiler2/tools.py
index 76df5dc..69f79ab 100644
--- a/bindings/python/pyiree/compiler2/tools.py
+++ b/bindings/python/pyiree/compiler2/tools.py
@@ -37,6 +37,7 @@
# a python module that provides a `get_tool` function for getting its absolute
# path. This dictionary maps the tool name to the module.
_TOOL_MODULE_MAP = {
+ "iree-import-tflite": "pyiree.tools.tflite",
"iree-import-xla": "pyiree.tools.xla",
"iree-tf-import": "pyiree.tools.tf",
"iree-translate": "pyiree.tools.core",
@@ -47,6 +48,7 @@
_TOOL_MODULE_PACKAGES = {
"pyiree.tools.core": "google-iree-tools-core",
"pyiree.tools.tf": "google-iree-tools-tf",
+ "pyiree.tools.tflite": "google-iree-tools-tflite",
"pyiree.tools.xla": "google-iree-tools-xla",
}
diff --git a/bindings/python/tests/CMakeLists.txt b/bindings/python/tests/CMakeLists.txt
index 2e1f330..d332051 100644
--- a/bindings/python/tests/CMakeLists.txt
+++ b/bindings/python/tests/CMakeLists.txt
@@ -28,6 +28,13 @@
iree_py_test(
NAME
+ compiler_tflite_test
+ SRCS
+ "compiler_tflite_test.py"
+)
+
+iree_py_test(
+ NAME
compiler_xla_test
SRCS
"compiler_xla_test.py"
diff --git a/bindings/python/tests/compiler_tflite_test.py b/bindings/python/tests/compiler_tflite_test.py
new file mode 100644
index 0000000..b49e729
--- /dev/null
+++ b/bindings/python/tests/compiler_tflite_test.py
@@ -0,0 +1,104 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+import unittest
+
+# TODO: No idea why pytype cannot find names from this module.
+# pytype: disable=name-error
+from pyiree.compiler2.tflite import *
+
+if not is_available():
+ print(f"Skipping test {__file__} because the IREE TFLite compiler "
+ f"is not installed")
+ sys.exit(0)
+
+
+class CompilerTest(unittest.TestCase):
+
+ def testImportBinaryPbFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ text = compile_file(path, import_only=True).decode("utf-8")
+ logging.info("%s", text)
+ self.assertIn("tosa.mul", text)
+
+ @unittest.skip("IREE tosa compilation not implemented yet")
+ def testCompileBinaryPbFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ binary = compile_file(path, target_backends=DEFAULT_TESTING_BACKENDS)
+ logging.info("Binary length = %d", len(binary))
+ self.assertIn(b"main", binary)
+
+ def testImportBinaryPbFileOutputFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+ try:
+ f.close()
+ output = compile_file(path, import_only=True, output_file=f.name)
+ self.assertIsNone(output)
+ with open(f.name, "rt") as f_read:
+ text = f_read.read()
+ finally:
+ os.remove(f.name)
+ logging.info("%s", text)
+ self.assertIn("tosa.mul", text)
+
+ @unittest.skip("IREE tosa compilation not implemented yet")
+ def testCompileBinaryPbFileOutputFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+ try:
+ f.close()
+ output = compile_file(path,
+ output_file=f.name,
+ target_backends=DEFAULT_TESTING_BACKENDS)
+ self.assertIsNone(output)
+ with open(f.name, "rb") as f_read:
+ binary = f_read.read()
+ finally:
+ os.remove(f.name)
+ logging.info("Binary length = %d", len(binary))
+ self.assertIn(b"main", binary)
+
+ def testImportBinaryPbBytes(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ with open(path, "rb") as f:
+ content = f.read()
+ text = compile_str(content, import_only=True).decode("utf-8")
+ logging.info("%s", text)
+ self.assertIn("tosa.mul", text)
+
+ @unittest.skip("IREE tosa compilation not implemented yet")
+ def testCompileBinaryPbBytes(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ with open(path, "rb") as f:
+ content = f.read()
+ binary = compile_str(content, target_backends=DEFAULT_TESTING_BACKENDS)
+ logging.info("Binary length = %d", len(binary))
+ self.assertIn(b"main", binary)
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/bindings/python/tests/testdata/generate_tflite.py b/bindings/python/tests/testdata/generate_tflite.py
new file mode 100644
index 0000000..f7764b0
--- /dev/null
+++ b/bindings/python/tests/testdata/generate_tflite.py
@@ -0,0 +1,36 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import tensorflow as tf
+
+
+class Squared(tf.Module):
+
+ @tf.function
+ def __call__(self, x):
+ return tf.square(x)
+
+
+model = Squared()
+concrete_func = model.__call__.get_concrete_function(
+ tf.TensorSpec(shape=[4, 3], dtype=tf.float32))
+
+converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+tflite_model = converter.convert()
+
+this_dir = os.path.dirname(__file__)
+with open(os.path.join(this_dir, "tflite_sample.fb"), "wb") as f:
+ f.write(tflite_model)
diff --git a/bindings/python/tests/testdata/tflite_sample.fb b/bindings/python/tests/testdata/tflite_sample.fb
new file mode 100644
index 0000000..52cb9e4
--- /dev/null
+++ b/bindings/python/tests/testdata/tflite_sample.fb
Binary files differ
diff --git a/integrations/tensorflow/CMakeLists.txt b/integrations/tensorflow/CMakeLists.txt
index f54fd26..7b7f01a 100644
--- a/integrations/tensorflow/CMakeLists.txt
+++ b/integrations/tensorflow/CMakeLists.txt
@@ -30,6 +30,11 @@
list(APPEND _executable_paths integrations/tensorflow/compiler/iree-tf-import)
endif()
+if(${IREE_BUILD_TFLITE_COMPILER})
+ list(APPEND _bazel_targets //integrations/tensorflow/compiler:iree-import-tflite)
+ list(APPEND _executable_paths integrations/tensorflow/compiler/iree-import-tflite)
+endif()
+
if(${IREE_BUILD_XLA_COMPILER})
list(APPEND _bazel_targets //integrations/tensorflow/compiler:iree-import-xla)
list(APPEND _executable_paths integrations/tensorflow/compiler/iree-import-xla)
diff --git a/integrations/tensorflow/bindings/python/CMakeLists.txt b/integrations/tensorflow/bindings/python/CMakeLists.txt
index 44fce9b..f88027a 100644
--- a/integrations/tensorflow/bindings/python/CMakeLists.txt
+++ b/integrations/tensorflow/bindings/python/CMakeLists.txt
@@ -23,6 +23,10 @@
_add_overlay_subdirectory(pyiree/tools/tf)
endif()
+if(${IREE_BUILD_TFLITE_COMPILER})
+ _add_overlay_subdirectory(pyiree/tools/tflite)
+endif()
+
if(${IREE_BUILD_XLA_COMPILER})
_add_overlay_subdirectory(pyiree/tools/xla)
endif()
diff --git a/integrations/tensorflow/bindings/python/pyiree/tools/tflite/CMakeLists.txt b/integrations/tensorflow/bindings/python/pyiree/tools/tflite/CMakeLists.txt
new file mode 100644
index 0000000..faee272
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tools/tflite/CMakeLists.txt
@@ -0,0 +1,28 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_py_library(
+ NAME
+ tflite
+ SRCS
+ "__init__.py"
+ DEPS
+ integrations_iree_tensorflow_importers
+)
+
+iree_symlink_tool(
+ TARGET tflite
+ FROM_TOOL_TARGET integrations_tensorflow_compiler_iree-import-tflite
+ TO_EXE_NAME iree-import-tflite
+)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tools/tflite/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tools/tflite/__init__.py
new file mode 100644
index 0000000..104fc49
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tools/tflite/__init__.py
@@ -0,0 +1,29 @@
+# Lint-as: python3
+"""TFLite tools."""
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import os
+import platform
+
+
+def get_tool(exe_name: str) -> Optional[str]:
+ if platform.system() == "Windows":
+ exe_name = exe_name + ".exe"
+ this_path = os.path.dirname(__file__)
+ tool_path = os.path.join(this_path, exe_name)
+ return tool_path
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
index 6ac0eca..4e874de 100644
--- a/integrations/tensorflow/compiler/BUILD
+++ b/integrations/tensorflow/compiler/BUILD
@@ -120,6 +120,30 @@
)
cc_binary(
+ name = "iree-import-tflite",
+ srcs = ["iree-import-tflite-main.cpp"],
+ deps = [
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:QuantOps",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:TosaDialect",
+ "@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_import",
+ "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
+ "@org_tensorflow//tensorflow/compiler/mlir/tosa:tosa_legalize_tfl",
+ ] + [
+ # TODO: See about removing this dep (TFLite shouldn't need tf ops).
+ "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow",
+ # TODO: There are layering violations upstream. Fix and remove these.
+ "@org_tensorflow//tensorflow/compiler/mlir/tosa:tosa_legalize_common",
+ "@org_tensorflow//tensorflow/compiler/mlir/tosa:tosa_pipelines",
+ "@org_tensorflow//tensorflow/compiler/mlir/tosa:tfl_tosa_passes",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+cc_binary(
name = "iree-import-xla",
srcs = ["iree-import-xla-main.cpp"],
deps = [
diff --git a/integrations/tensorflow/compiler/iree-import-tflite-main.cpp b/integrations/tensorflow/compiler/iree-import-tflite-main.cpp
new file mode 100644
index 0000000..163ee18
--- /dev/null
+++ b/integrations/tensorflow/compiler/iree-import-tflite-main.cpp
@@ -0,0 +1,145 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Support/FileUtilities.h"
+#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tosa/tosa_passpipes.h"
+
+using namespace llvm;
+using namespace mlir;
+
+int main(int argc, char **argv) {
+ llvm::InitLLVM y(argc, argv);
+
+ static cl::opt<std::string> inputPath(
+ cl::Positional, cl::desc("<TFLite FlatBuffer>"), cl::Required);
+ static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
+ cl::value_desc("filename"),
+ cl::init("-"));
+ static llvm::cl::opt<std::string> saveTempTflInput(
+ "save-temp-tfl-input",
+ llvm::cl::desc("Save the TFL pipeline input to this file"),
+ llvm::cl::init(""));
+ static llvm::cl::opt<std::string> saveTempIreeImport(
+ "save-temp-iree-input",
+ llvm::cl::desc("Save the resultant IR to this file (useful for saving an "
+ "intermediate in a pipeline)"),
+ llvm::cl::init(""));
+
+ static cl::list<std::string> inputArrayFlag(
+ "input-array",
+ llvm::cl::desc("Input tensor, if different from the default inputs"),
+ llvm::cl::ZeroOrMore);
+ static cl::list<std::string> outputArrayFlag(
+ "output-array",
+ llvm::cl::desc("Output tensor, if different from the default outputs"),
+ llvm::cl::ZeroOrMore);
+
+ // Register any command line options.
+ registerAsmPrinterCLOptions();
+ registerMLIRContextCLOptions();
+ cl::ParseCommandLineOptions(argc, argv);
+
+ // Initialize dialects.
+ DialectRegistry registry;
+ registry.insert<mlir::TFL::TensorFlowLiteDialect>();
+ registry.insert<mlir::tosa::TosaDialect>();
+ registry.insert<quant::QuantizationDialect>();
+ registry.insert<TF::TensorFlowDialect>();
+ registry.insert<StandardOpsDialect>();
+
+ // Convert the Module proto into MLIR.
+ MLIRContext context;
+ registry.loadAll(&context);
+
+ // Load input buffer.
+ std::string errorMessage;
+ auto inputFile = openInputFile(inputPath, &errorMessage);
+ if (!inputFile) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+
+ // Convert.
+ std::vector<std::string> inputArrays(inputArrayFlag.begin(),
+ inputArrayFlag.end());
+ std::vector<std::string> outputArrays(outputArrayFlag.begin(),
+ outputArrayFlag.end());
+ auto loc = mlir::FileLineColLoc::get(inputFile->getBufferIdentifier(), 0, 0,
+ &context);
+ OwningModuleRef module = tflite::FlatBufferToMlir(
+ absl::string_view(inputFile->getBufferStart(),
+ inputFile->getBufferSize()),
+ &context, loc,
+ /*use_external_constant=*/false, inputArrays, outputArrays);
+ if (!module) {
+ // Error should have emitted.
+ llvm::errs() << "Unable to import TFLite flatbuffer to MLIR Module\n";
+ return 2;
+ }
+
+ // Save.
+ auto saveToFile = [&](llvm::StringRef savePath) -> LogicalResult {
+ auto outputFile = openOutputFile(savePath);
+ if (!outputFile) {
+ llvm::errs() << "Could not open output file: " << savePath << "\n";
+ return failure();
+ }
+ OpPrintingFlags printFlags;
+ printFlags.enableDebugInfo();
+ module->print(outputFile->os(), printFlags);
+ outputFile->os() << "\n";
+ outputFile->keep();
+ return success();
+ };
+
+ // Save temp input.
+ if (!saveTempTflInput.empty()) {
+ if (failed(saveToFile(saveTempTflInput))) return 10;
+ }
+
+ // Run transformations.
+ mlir::tosa::TOSALegalizationPipelineOptions tosaOptions;
+ PassManager pm(&context, PassManager::Nesting::Implicit);
+ applyPassManagerCLOptions(pm);
+ mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, tosaOptions);
+ if (failed(pm.run(*module))) {
+ llvm::errs() << "Running iree-import-tflite pass pipeline failed (see "
+ "diagnostics)\n";
+ return 3;
+ }
+
+ // Save temp output.
+ if (!saveTempIreeImport.empty()) {
+ if (failed(saveToFile(saveTempIreeImport))) return 10;
+ }
+
+ // Save output.
+ if (failed(saveToFile(outputFilename))) return 3;
+ return 0;
+}
diff --git a/scripts/check_tabs.sh b/scripts/check_tabs.sh
index 494ec10..3b2467e 100755
--- a/scripts/check_tabs.sh
+++ b/scripts/check_tabs.sh
@@ -26,6 +26,7 @@
"/third_party/"
"^third_party/"
"\.pb$"
+ "\.fb$"
)
# Join on |