Adding tflite import passes and metadata conversion. (#4698)
Progress on #3974. Unblocks #3972/#3975/#3978.
diff --git a/integrations/tensorflow/CMakeLists.txt b/integrations/tensorflow/CMakeLists.txt
index 34b8896..4242f90 100644
--- a/integrations/tensorflow/CMakeLists.txt
+++ b/integrations/tensorflow/CMakeLists.txt
@@ -55,10 +55,11 @@
integrations_iree_tensorflow_test_tools
BAZEL_TARGETS
//iree_tf_compiler:iree-tf-opt
+ //iree_tf_compiler:iree-opt-tflite
EXECUTABLE_PATHS
iree_tf_compiler/iree-tf-opt
+ iree_tf_compiler/iree-opt-tflite
)
- add_subdirectory(iree_tf_compiler/test)
endif()
if(${IREE_BUILD_PYTHON_BINDINGS})
diff --git a/integrations/tensorflow/iree_tf_compiler/BUILD b/integrations/tensorflow/iree_tf_compiler/BUILD
index b87785d..6456cae 100644
--- a/integrations/tensorflow/iree_tf_compiler/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/BUILD
@@ -37,6 +37,23 @@
)
cc_binary(
+ name = "iree-opt-tflite",
+ srcs = ["iree-opt-tflite-main.cpp"],
+ deps = [
+ "//iree_tf_compiler/TFL",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MlirOptLib",
+ "@llvm-project//mlir:QuantOps",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TosaDialect",
+ "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
+ "@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
+ ],
+)
+
+cc_binary(
name = "iree-tf-import",
srcs = ["iree-tf-import-main.cpp"],
deps = [
@@ -63,17 +80,12 @@
name = "iree-import-tflite",
srcs = ["iree-import-tflite-main.cpp"],
deps = [
+ "//iree_tf_compiler/TFL",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
- "@llvm-project//mlir:QuantOps",
- "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
- "@llvm-project//mlir:TosaDialect",
"@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_import",
- "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
- "@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
- "@org_tensorflow//tensorflow/compiler/mlir/tosa:tfl_passes",
],
)
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
new file mode 100644
index 0000000..3b9ef6d
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
@@ -0,0 +1,55 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "TFL",
+ srcs = [
+ "ConvertMetadata.cpp",
+ "Passes.cpp",
+ "StripMetadata.cpp",
+ "VerifyFullyConverted.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ ],
+ defines = [
+ "IREE_COMPILER_TENSORFLOW_ENABLED",
+ ],
+ deps = [
+ "@iree//iree/compiler/Dialect/Flow/IR",
+ "@iree//iree/compiler/Dialect/IREE/IR",
+ "@iree//iree/compiler/Dialect/Shape/Conversion",
+ "@iree//iree/compiler/Dialect/Shape/IR",
+ "@iree//iree/compiler/Dialect/Shape/Transforms",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:QuantOps",
+ "@llvm-project//mlir:Shape",
+ "@llvm-project//mlir:ShapeTransforms",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TosaDialect",
+ "@llvm-project//mlir:TransformUtils",
+ "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
+ "@org_tensorflow//tensorflow/compiler/mlir/tosa:tfl_passes",
+ ],
+)
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp
new file mode 100644
index 0000000..692b7e7
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp
@@ -0,0 +1,117 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree_tf_compiler/TFL/Passes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+
+// Extract the input and output names
+static void splitFunctionIONames(StringAttr namesAttr,
+ llvm::SmallVectorImpl<std::string> &names) {
+ SmallVector<StringRef, 4> namesRef;
+ llvm::SplitString(namesAttr.getValue(), namesRef, ",");
+ for (auto nameRef : namesRef) {
+ names.push_back(nameRef.str());
+ }
+}
+
+class ConvertModuleMetadataPass
+ : public PassWrapper<ConvertModuleMetadataPass, OperationPass<ModuleOp>> {
+ public:
+ void runOnOperation() override {
+ // None currently handled.
+ }
+};
+
+class ConvertFunctionMetadataPass
+ : public PassWrapper<ConvertFunctionMetadataPass, OperationPass<FuncOp>> {
+ public:
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+
+ // Setup TF entry functions as an IREE entry point and preserve the
+ // associated metadata. Note that TFLite uses `tf.entry_function`.
+ auto entryFunctionAttr =
+ funcOp->getAttrOfType<DictionaryAttr>("tf.entry_function");
+ if (entryFunctionAttr) {
+ setupEntryPointAttrs(funcOp, entryFunctionAttr);
+ }
+ }
+
+ private:
+ // TF/TFL pack their I/O names on an annoying dictionary. We want our shape
+ // names to match up with those for readability so we extract them here.
+ // Is this ugly? Yeah - but such is what we have to deal with here.
+ void setupEntryPointAttrs(FuncOp funcOp, DictionaryAttr entryFunctionAttr) {
+ auto inputsAttr =
+ entryFunctionAttr.get("inputs").template dyn_cast_or_null<StringAttr>();
+ auto outputsAttr = entryFunctionAttr.get("outputs")
+ .template dyn_cast_or_null<StringAttr>();
+ if (!inputsAttr || !outputsAttr) {
+ funcOp.emitError() << "functions with tf.entry_function must have "
+ "input and output names to be handled by IREE";
+ signalPassFailure();
+ return;
+ }
+
+ funcOp->setAttr("iree.module.export", UnitAttr::get(&getContext()));
+
+ SmallVector<std::string, 4> inputNames;
+ SmallVector<std::string, 4> outputNames;
+ splitFunctionIONames(inputsAttr, inputNames);
+ splitFunctionIONames(outputsAttr, outputNames);
+ if (inputNames.size() != funcOp.getNumArguments() ||
+ outputNames.size() != funcOp.getNumResults()) {
+ funcOp.emitError()
+ << "tf.entry_function attribute malformed: inputs/outputs don't "
+ "match the function signature";
+ signalPassFailure();
+ return;
+ }
+ for (unsigned i = 0; i < inputNames.size(); ++i) {
+ funcOp.setArgAttr(i, "iree.identifier",
+ StringAttr::get(inputNames[i], &getContext()));
+ }
+ for (unsigned i = 0; i < outputNames.size(); ++i) {
+ funcOp.setResultAttr(i, "iree.identifier",
+ StringAttr::get(outputNames[i], &getContext()));
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createConvertModuleMetadataPass() {
+ return std::make_unique<ConvertModuleMetadataPass>();
+}
+
+std::unique_ptr<OperationPass<FuncOp>> createConvertFunctionMetadataPass() {
+ return std::make_unique<ConvertFunctionMetadataPass>();
+}
+
+static PassRegistration<ConvertModuleMetadataPass> modulePass(
+ "iree-tflite-convert-module-metadata",
+ "Converts TFLite attributes to IREE attributes on modules");
+
+static PassRegistration<ConvertFunctionMetadataPass> funcPass(
+ "iree-tflite-convert-function-metadata",
+ "Converts TFLite attributes to IREE attributes on functions");
+
+} // namespace TFL
+} // namespace iree_integrations
+} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
new file mode 100644
index 0000000..cb5a5dd
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
@@ -0,0 +1,85 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree_tf_compiler/TFL/Passes.h"
+
+#include "iree/compiler/Dialect/Shape/Conversion/Passes.h"
+#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+#include "tensorflow/compiler/mlir/tosa/tfl_passes.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+
+// All IREE-specific passes that lower TFL representations before reaching the
+// IREE core should go here.
+void buildTFLImportPassPipeline(OpPassManager &pm) {
+ //----------------------------------------------------------------------------
+ // Input IR cleanup
+ //----------------------------------------------------------------------------
+
+ pm.addPass(createInlinerPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createSymbolDCEPass());
+
+ //----------------------------------------------------------------------------
+ // Convert useful metadata into forms IREE's main compiler understands
+ //----------------------------------------------------------------------------
+
+ pm.addPass(createConvertModuleMetadataPass());
+ pm.nest<ModuleOp>().addPass(createConvertFunctionMetadataPass());
+
+ //----------------------------------------------------------------------------
+ // Convert all TFL ops to TOSA ops
+ //----------------------------------------------------------------------------
+
+ mlir::tosa::TOSATFLLegalizationPipelineOptions tosaOptions;
+ mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, tosaOptions);
+ pm.addPass(createCanonicalizerPass());
+
+ //----------------------------------------------------------------------------
+ // Lowering shape-related constructs
+ //----------------------------------------------------------------------------
+
+ // TODO(#3975): support dynamic shapes in tflite inputs.
+ // pm.addPass(iree_compiler::Shape::createConvertHLOToShapePass());
+ // pm.addPass(createCanonicalizerPass());
+ // pm.addPass(iree_compiler::Shape::createConvertShapeToShapexPass());
+ // pm.addPass(createCanonicalizerPass());
+
+ //----------------------------------------------------------------------------
+ // Remove the rest of the TFL goo and verify that all ops converted
+ //----------------------------------------------------------------------------
+
+ pm.addPass(createStripModuleMetadataPass());
+ pm.nest<ModuleOp>().addPass(createStripFunctionMetadataPass());
+ pm.addPass(createVerifyFullyConvertedPass());
+}
+
+void registerTFLImportPassPipeline() {
+ mlir::PassPipelineRegistration<> pipeline(
+ "iree-tflite-import-pipeline",
+ "Run IREE-specific passes for importing TFLite code into IREE",
+ [](OpPassManager &passManager) {
+ buildTFLImportPassPipeline(passManager);
+ });
+}
+
+} // namespace TFL
+} // namespace iree_integrations
+} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h
new file mode 100644
index 0000000..ef69025
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h
@@ -0,0 +1,67 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_INTEGRATIONS_TENSORFLOW_IREE_TF_COMPILER_TFL_PASSES_H_
+#define IREE_INTEGRATIONS_TENSORFLOW_IREE_TF_COMPILER_TFL_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+
+//===----------------------------------------------------------------------===//
+// Pipelines
+//===----------------------------------------------------------------------===//
+
+// Create a single pipeline that will run all the needed IREE-specific TFL
+// import passes in the right order.
+void buildTFLImportPassPipeline(OpPassManager &pm);
+
+//===----------------------------------------------------------------------===//
+// IREE-specific passes for TFLite import
+//===----------------------------------------------------------------------===//
+
+// Converts TFLite attributes that are useful to corresponding IREE attributes.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertModuleMetadataPass();
+std::unique_ptr<OperationPass<FuncOp>> createConvertFunctionMetadataPass();
+
+// Strips all leftover TFLite-related attributes; none are needed by IREE.
+std::unique_ptr<OperationPass<ModuleOp>> createStripModuleMetadataPass();
+std::unique_ptr<OperationPass<FuncOp>> createStripFunctionMetadataPass();
+
+// Validates whether any TFLite operations remain.
+std::unique_ptr<OperationPass<FuncOp>> createVerifyFullyConvertedPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+void registerTFLImportPassPipeline();
+
+inline void registerAllPasses() {
+ registerTFLImportPassPipeline();
+
+ createConvertModuleMetadataPass();
+ createConvertFunctionMetadataPass();
+ createStripModuleMetadataPass();
+ createStripFunctionMetadataPass();
+ createVerifyFullyConvertedPass();
+}
+
+} // namespace TFL
+} // namespace iree_integrations
+} // namespace mlir
+
+#endif // IREE_INTEGRATIONS_TENSORFLOW_IREE_TF_COMPILER_TFL_PASSES_H_
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp
new file mode 100644
index 0000000..6ba141a
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp
@@ -0,0 +1,98 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree_tf_compiler/TFL/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+
+static bool isTFLAttr(NamedAttribute &namedAttr) {
+ // NOTE: tflite mixes tf and tfl, for some reason.
+ auto name = namedAttr.first.strref();
+ if (name.startswith("tf.") || name.startswith("tf_") ||
+ name.startswith("tfl.") || name.startswith("tfl_")) {
+ return true;
+ }
+ StringRef attrNamespace = namedAttr.second.getDialect().getNamespace();
+ return attrNamespace == "tf" || attrNamespace == "tfl";
+}
+
+class StripModuleMetadataPass
+ : public PassWrapper<StripModuleMetadataPass, OperationPass<ModuleOp>> {
+ public:
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range(
+ moduleOp.getAttrs(),
+ [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); }));
+ for (auto namedAttr : stripAttrs) {
+ moduleOp.removeAttr(namedAttr.first);
+ }
+ }
+};
+
+class StripFunctionMetadataPass
+ : public PassWrapper<StripFunctionMetadataPass, OperationPass<FuncOp>> {
+ public:
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range(
+ funcOp.getAttrs(),
+ [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); }));
+ for (auto namedAttr : stripAttrs) {
+ funcOp.removeAttr(namedAttr.first);
+ }
+
+ for (int i = 0; i < funcOp.getNumArguments(); ++i) {
+ auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range(
+ funcOp.getArgAttrs(i),
+ [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); }));
+ for (auto namedAttr : stripAttrs) {
+ funcOp.removeArgAttr(i, namedAttr.first);
+ }
+ }
+
+ for (int i = 0; i < funcOp.getNumResults(); ++i) {
+ auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range(
+ funcOp.getResultAttrs(i),
+ [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); }));
+ for (auto namedAttr : stripAttrs) {
+ funcOp.removeResultAttr(i, namedAttr.first);
+ }
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createStripModuleMetadataPass() {
+ return std::make_unique<StripModuleMetadataPass>();
+}
+
+std::unique_ptr<OperationPass<FuncOp>> createStripFunctionMetadataPass() {
+ return std::make_unique<StripFunctionMetadataPass>();
+}
+
+static PassRegistration<StripModuleMetadataPass> modulePass(
+ "iree-tflite-strip-module-metadata",
+ "Remove unneeded TFLite attributes from module ops");
+
+static PassRegistration<StripFunctionMetadataPass> funcPass(
+ "iree-tflite-strip-function-metadata",
+ "Remove unneeded TFLite attributes from func ops");
+
+} // namespace TFL
+} // namespace iree_integrations
+} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp
new file mode 100644
index 0000000..7c7a43a
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp
@@ -0,0 +1,83 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree_tf_compiler/TFL/Passes.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+
+static bool isTFLOp(Operation *op) {
+ if (!op || !op->getDialect()) return false;
+ StringRef opNamespace = op->getDialect()->getNamespace();
+ return opNamespace == mlir::TFL::TensorFlowLiteDialect::getDialectNamespace();
+}
+
+class VerifyFullyConvertedPass
+ : public PassWrapper<VerifyFullyConvertedPass, FunctionPass> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mlir::TFL::TensorFlowLiteDialect>();
+ }
+
+ // Validates that no TFLite frontends ops are in the function.
+ void runOnFunction() override {
+ DenseSet<Operation *> illegalOps;
+ getFunction().walk([&](Operation *op) {
+ if (isTFLOp(op)) illegalOps.insert(op);
+ });
+ if (!illegalOps.empty()) {
+ emitLegalizationErrors(getFunction().getLoc(), illegalOps);
+ return signalPassFailure();
+ }
+ }
+
+ // Emits debug information which includes the number of ops of each type which
+ // failed to legalize.
+ void emitLegalizationErrors(Location loc,
+ const DenseSet<Operation *> &nonlegalizedOps) {
+ // Print op errors for each of the TFLite ops that still remain.
+ std::map<StringRef, int> opNameCounts;
+ for (Operation *nonlegalizedOp : nonlegalizedOps) {
+ StringRef opName = nonlegalizedOp->getName().getStringRef();
+ opNameCounts[opName]++;
+ nonlegalizedOp->emitOpError() << ": unlegalized TFLite op still exists";
+ }
+
+ std::vector<std::string> errorMessages;
+ errorMessages.reserve(opNameCounts.size());
+ for (const auto &opInfo : opNameCounts) {
+ errorMessages.push_back(
+ llvm::formatv("\t{0} (count: {1})", opInfo.first, opInfo.second));
+ }
+ emitError(loc) << "The following TFLite operations still remain: \n"
+ << llvm::join(errorMessages, "\n") << "\n";
+ }
+};
+
+static PassRegistration<VerifyFullyConvertedPass> pass(
+ "iree-tflite-verify-fully-converted",
+ "Verifies that all TFLite frontend ops were converted and none remain");
+
+std::unique_ptr<OperationPass<FuncOp>> createVerifyFullyConvertedPass() {
+ return std::make_unique<VerifyFullyConvertedPass>();
+}
+
+} // namespace TFL
+} // namespace iree_integrations
+} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/test/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD
similarity index 92%
rename from integrations/tensorflow/iree_tf_compiler/test/BUILD
rename to integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD
index 1ed07b9..e9acbed 100644
--- a/integrations/tensorflow/iree_tf_compiler/test/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD
@@ -1,4 +1,4 @@
-# Copyright 2019 Google LLC
+# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@
name = "lit",
srcs = glob(["*.mlir"]),
data = [
- "//iree_tf_compiler:iree-tf-opt",
+ "//iree_tf_compiler:iree-opt-tflite",
"@iree//iree/tools:IreeFileCheck",
],
driver = "@iree//iree/tools:run_lit.sh",
diff --git a/integrations/tensorflow/iree_tf_compiler/test/CMakeLists.txt b/integrations/tensorflow/iree_tf_compiler/TFL/test/CMakeLists.txt
similarity index 92%
rename from integrations/tensorflow/iree_tf_compiler/test/CMakeLists.txt
rename to integrations/tensorflow/iree_tf_compiler/TFL/test/CMakeLists.txt
index 287b33f..e329fd6 100644
--- a/integrations/tensorflow/iree_tf_compiler/test/CMakeLists.txt
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright 2020 Google LLC
+# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,5 +22,5 @@
"${_GLOB_X_MLIR}"
DATA
iree::tools::IreeFileCheck
- iree_tf_compiler_iree-tf-opt
+ iree_tf_compiler_iree-opt-tflite
)
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/convert_metadata.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/convert_metadata.mlir
new file mode 100644
index 0000000..2f77a9d
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/convert_metadata.mlir
@@ -0,0 +1,17 @@
+// RUN: iree-opt-tflite -split-input-file -pass-pipeline='iree-tflite-convert-module-metadata,func(iree-tflite-convert-function-metadata)' %s | IreeFileCheck %s
+
+module attributes {tfl.schema_version = 3 : i32} {
+ // CHECK: func @main(
+ // CHECK-SAME: %arg0: tensor<?xf32> {iree.identifier = "input0"},
+ // CHECK-SAME: %arg1: tensor<?xf32> {iree.identifier = "input1"}
+ // CHECK-SAME: ) -> (
+ // CHECK-SAME: tensor<?xf32> {iree.identifier = "output0"},
+ // CHECK-SAME: tensor<?xf32> {iree.identifier = "output1"})
+ // CHECK-SAME: attributes
+ // CHECK-SAME: iree.module.export
+ func @main(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) attributes {
+ tf.entry_function = {inputs = "input0,input1", outputs = "output0,output1"}
+ } {
+ return %arg0, %arg1 : tensor<?xf32>, tensor<?xf32>
+ }
+}
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/strip_metadata.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/strip_metadata.mlir
new file mode 100644
index 0000000..1c9504b
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/strip_metadata.mlir
@@ -0,0 +1,14 @@
+// RUN: iree-opt-tflite -split-input-file -verify-diagnostics -pass-pipeline='iree-tflite-strip-module-metadata,func(iree-tflite-strip-function-metadata)' %s | IreeFileCheck %s
+
+// CHECK-LABEL: module {
+// CHECK-NOT: tf.schema_version
+module attributes {tfl.schema_version = 3 : i32} {
+ // CHECK: func @main
+ // CHECK-NOT: tf.entry_function
+ func @main(%arg0: tensor<1x8x8x3xf32>) -> tensor<1x8x8x3xf32> attributes {tf.entry_function = {inputs = "input", outputs = "output"}} {
+ // CHECK-NEXT: tfl.add
+ %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32>
+ %1 = tfl.add %0, %arg0 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32>
+ return %1 : tensor<1x8x8x3xf32>
+ }
+}
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/verify_fully_converted.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/verify_fully_converted.mlir
new file mode 100644
index 0000000..8b0c5cd
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/verify_fully_converted.mlir
@@ -0,0 +1,19 @@
+// RUN: iree-opt-tflite %s -iree-tflite-verify-fully-converted -split-input-file -verify-diagnostics
+
+// CHECK-LABEL: func @main
+func @main(%arg0: tensor<2xf32>) -> (tensor<2xf32>) {
+ // CHECK: "tosa.add"
+ %0 = "tosa.add"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+// expected-error@+4 {{'tfl.add' op : unlegalized TFLite op still exists}}
+// expected-error@+4 {{'tfl.sub' op : unlegalized TFLite op still exists}}
+// expected-error@below {{The following TFLite operations still remain}}
+func @main(%arg0: tensor<1x8x8x3xf32>) -> tensor<1x8x8x3xf32> attributes {tf.entry_function = {inputs = "input", outputs = "output"}} {
+ %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32>
+ %1 = tfl.sub %0, %arg0 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32>
+ return %1 : tensor<1x8x8x3xf32>
+}
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
index 44307d9..19a7e12 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
@@ -12,13 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "iree_tf_compiler/TFL/Passes.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/ToolOutputFile.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
@@ -27,9 +25,6 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
-#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
-#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
-#include "tensorflow/compiler/mlir/tosa/tfl_passes.h"
using namespace llvm;
using namespace mlir;
@@ -68,11 +63,6 @@
// Initialize dialects.
DialectRegistry registry;
- registry.insert<mlir::TFL::TensorFlowLiteDialect>();
- registry.insert<mlir::tosa::TosaDialect>();
- registry.insert<quant::QuantizationDialect>();
- registry.insert<TF::TensorFlowDialect>();
- registry.insert<StandardOpsDialect>();
// Convert the Module proto into MLIR.
MLIRContext context;
@@ -124,10 +114,9 @@
}
// Run transformations.
- mlir::tosa::TOSATFLLegalizationPipelineOptions tosaOptions;
PassManager pm(&context, PassManager::Nesting::Implicit);
applyPassManagerCLOptions(pm);
- mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, tosaOptions);
+ mlir::iree_integrations::TFL::buildTFLImportPassPipeline(pm);
if (failed(pm.run(*module))) {
llvm::errs() << "Running iree-import-tflite pass pipeline failed (see "
"diagnostics)\n";
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-opt-tflite-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-opt-tflite-main.cpp
new file mode 100644
index 0000000..d11c867
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/iree-opt-tflite-main.cpp
@@ -0,0 +1,49 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Main entry function for iree-tflite-opt and derived binaries.
+//
+// This is a bare-bones, minimal *-opt just for testing the handful of local
+// passes here. If you need something, add it, but add only what you need as
+// each addition will likely end up on the build critical path.
+
+#include "iree_tf_compiler/TFL/Passes.h"
+#include "llvm/Support/InitLLVM.h"
+#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Support/MlirOptMain.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+
+int main(int argc, char **argv) {
+ llvm::InitLLVM y(argc, argv);
+
+ mlir::DialectRegistry registry;
+ registry.insert<mlir::quant::QuantizationDialect>();
+ registry.insert<mlir::TF::TensorFlowDialect>();
+ registry.insert<mlir::TFL::TensorFlowLiteDialect>();
+ registry.insert<mlir::StandardOpsDialect>();
+ registry.insert<mlir::tosa::TosaDialect>();
+
+ mlir::iree_integrations::TFL::registerAllPasses();
+
+ if (failed(MlirOptMain(argc, argv, "IREE-TFL modular optimizer driver\n",
+ registry,
+ /*preloadDialectsInContext=*/false))) {
+ return 1;
+ }
+ return 0;
+}