Added lowering for TFLite state to IREE::Utils operations (#7548)
Adding assign/read support for lowering from TFLite to IREE's Utils operators.
This adds elementary stateful TFLite behavior to IREE's support for MLIR. This
should move to tflite-to-tosa in the future.
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
index b6b91d3..87bcf3c 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
@@ -4,27 +4,49 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
+
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)
+gentbl_cc_library(
+ name = "PassesIncGen",
+ tbl_outs = [
+ (
+ ["-gen-pass-decls"],
+ "Passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "Passes.td",
+ deps = [
+ "@llvm-project//mlir:PassBaseTdFiles",
+ ],
+)
+
cc_library(
name = "TFL",
srcs = [
"ConvertMetadata.cpp",
+ "LowerGlobalTensors.cpp",
"Passes.cpp",
+ "RetainCallOnceFuncs.cpp",
"StripMetadata.cpp",
"VerifyFullyConverted.cpp",
],
hdrs = [
+ "PassDetail.h",
"Passes.h",
+ "Passes.h.inc",
],
defines = [
"IREE_COMPILER_TENSORFLOW_ENABLED",
],
deps = [
+ ":PassesIncGen",
"@iree//iree/compiler/Dialect/Flow/IR",
"@iree//iree/compiler/Dialect/Shape/IR",
"@iree//iree/compiler/Dialect/Shape/Transforms",
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp
index 9c3bdb4..f14c213 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp
@@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree_tf_compiler/TFL/PassDetail.h"
#include "iree_tf_compiler/TFL/Passes.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -14,6 +15,7 @@
namespace mlir {
namespace iree_integrations {
namespace TFL {
+namespace {
// Extract the input and output names
static void splitFunctionIONames(StringAttr namesAttr,
@@ -26,15 +28,8 @@
}
class ConvertModuleMetadataPass
- : public PassWrapper<ConvertModuleMetadataPass, OperationPass<ModuleOp>> {
+ : public ConvertModuleMetadataBase<ConvertModuleMetadataPass> {
public:
- StringRef getArgument() const override {
- return "iree-tflite-convert-module-metadata";
- }
-
- StringRef getDescription() const override {
- return "Converts TFLite attributes to IREE attributes on modules";
- }
void runOnOperation() override {
// None currently handled.
@@ -42,15 +37,8 @@
};
class ConvertFunctionMetadataPass
- : public PassWrapper<ConvertFunctionMetadataPass, OperationPass<FuncOp>> {
+ : public ConvertFunctionMetadataBase<ConvertFunctionMetadataPass> {
public:
- StringRef getArgument() const override {
- return "iree-tflite-convert-function-metadata";
- }
-
- StringRef getDescription() const override {
- return "Converts TFLite attributes to IREE attributes on functions";
- }
void runOnOperation() override {
auto funcOp = getOperation();
@@ -104,6 +92,7 @@
}
}
};
+} // anonymous namespace
std::unique_ptr<OperationPass<ModuleOp>> createConvertModuleMetadataPass() {
return std::make_unique<ConvertModuleMetadataPass>();
@@ -113,9 +102,6 @@
return std::make_unique<ConvertFunctionMetadataPass>();
}
-static PassRegistration<ConvertModuleMetadataPass> modulePass;
-static PassRegistration<ConvertFunctionMetadataPass> funcPass;
-
} // namespace TFL
} // namespace iree_integrations
} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/LowerGlobalTensors.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/LowerGlobalTensors.cpp
new file mode 100644
index 0000000..67930ce
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/LowerGlobalTensors.cpp
@@ -0,0 +1,169 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Utils/ConversionUtils.h"
+#include "iree_tf_compiler/TFL/PassDetail.h"
+#include "iree_tf_compiler/TFL/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+namespace {
+
+class LowerGlobalTensorsPass
+ : public LowerGlobalTensorsBase<LowerGlobalTensorsPass> {
+ public:
+ void getDependentDialects(DialectRegistry& registry) const override {
+ registry.insert<mlir::TFL::TensorFlowLiteDialect,
+ iree_compiler::IREE::Util::UtilDialect>();
+ }
+
+ // Converts TFLite state operations to the IREE equivalent.
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ mlir::OpBuilder builder(moduleOp.body());
+
+ DenseMap<StringRef, FuncOp> symNameToFunction;
+ for (auto func : moduleOp.getOps<FuncOp>()) {
+ symNameToFunction[func.sym_name()] = func;
+ }
+
+ DenseMap<StringRef, DenseElementsAttr> sharedNameToConstant;
+ DenseMap<StringRef, LocationAttr> sharedNameToLoc;
+
+ llvm::SmallVector<mlir::TFL::VarHandleOp, 6> handleOps;
+ llvm::SmallVector<mlir::TFL::AssignVariableOp, 6> assignOps;
+ llvm::SmallVector<mlir::TFL::ReadVariableOp, 6> readOps;
+ for (auto it : symNameToFunction) {
+ auto func = std::get<1>(it);
+ // Look through the initialization functions and find the assigned values
+ // for each handle, save out the constant value.
+ for (auto init : func.getOps<mlir::TFL::CallOnceOp>()) {
+ FuncOp initFunc = symNameToFunction[init.session_init_function()];
+ for (auto assign : initFunc.getOps<mlir::TFL::AssignVariableOp>()) {
+ auto handle = dyn_cast<mlir::TFL::VarHandleOp>(
+ assign.resource_id().getDefiningOp());
+ if (!handle) continue;
+
+ DenseElementsAttr constant;
+ if (!matchPattern(assign.value(), m_Constant(&constant))) continue;
+ auto name = handle.shared_name();
+ sharedNameToConstant[name] = constant;
+ sharedNameToLoc[name] = handle.getLoc();
+ }
+ }
+
+ // We also want to grab the list of operations to replace.
+ for (auto& op : func.getOps()) {
+ if (auto handle = dyn_cast<mlir::TFL::VarHandleOp>(op))
+ handleOps.push_back(handle);
+ if (auto assign = dyn_cast<mlir::TFL::AssignVariableOp>(op))
+ assignOps.push_back(assign);
+ if (auto read = dyn_cast<mlir::TFL::ReadVariableOp>(op))
+ readOps.push_back(read);
+ }
+ }
+
+ // TF::CallOnceOps are no longer needed as we have already extracted their
+ // state.
+ SmallVector<mlir::TFL::CallOnceOp> callOnceOps;
+ for (auto func : moduleOp.getOps<FuncOp>()) {
+ for (auto init : func.getOps<mlir::TFL::CallOnceOp>()) {
+ callOnceOps.push_back(init);
+ }
+ }
+ for (auto op : callOnceOps) op.erase();
+
+ // Create the Util::GlobalOps to store our new global variables.
+ DenseMap<StringRef, std::string> sharedNameToFlowName;
+ for (auto it : sharedNameToConstant) {
+ auto name = std::get<0>(it);
+ auto attribute = std::get<1>(it);
+ auto locIt = sharedNameToLoc.find(name);
+ LocationAttr loc = mlir::UnknownLoc();
+ if (locIt != sharedNameToLoc.end()) {
+ loc = std::get<1>(*locIt);
+ }
+
+ std::string flowSymName = "__iree_flow_" + name.str();
+
+ // TODO(suderman): Determine the global type based on all store
+ // operations.
+ auto global = builder.create<iree_compiler::IREE::Util::GlobalOp>(
+ loc, flowSymName, /*is_mutable=*/true, attribute.getType(),
+ attribute);
+ global.setPrivate();
+ sharedNameToFlowName[name] = std::move(flowSymName);
+ }
+
+ // Replace handles with global addresses.
+ for (auto handle : handleOps) {
+ auto name = handle.shared_name();
+ auto flowName = sharedNameToFlowName[name];
+ auto constIt = sharedNameToConstant.find(name);
+ if (constIt == sharedNameToConstant.end()) continue;
+
+ auto attribute = std::get<1>(*constIt);
+
+ builder.setInsertionPoint(handle);
+ auto address = builder.create<iree_compiler::IREE::Util::GlobalAddressOp>(
+ handle.getLoc(),
+ iree_compiler::IREE::Util::PtrType::get(attribute.getType()),
+ SymbolRefAttr::get(builder.getContext(), flowName));
+ handle.getResult().replaceAllUsesWith(address.getResult());
+ handle.erase();
+ }
+
+ // Replace the assign ops with a global store operation.
+ for (auto assign : assignOps) {
+ auto address = dyn_cast<iree_compiler::IREE::Util::GlobalAddressOp>(
+ assign.resource_id().getDefiningOp());
+ if (!address) continue;
+
+ builder.setInsertionPoint(assign);
+ builder.create<iree_compiler::IREE::Util::GlobalStoreIndirectOp>(
+ assign.getLoc(), assign.value(), assign.resource_id());
+ assign.erase();
+ }
+
+ // Replace the read ops with a global load operation.
+ for (auto read : readOps) {
+ auto address = dyn_cast<iree_compiler::IREE::Util::GlobalAddressOp>(
+ read.resource_id().getDefiningOp());
+ if (!address) continue;
+
+ auto ptrType =
+ address.getType().dyn_cast<iree_compiler::IREE::Util::PtrType>();
+ if (!ptrType) continue;
+
+ auto type = ptrType.getTargetType();
+
+ builder.setInsertionPoint(read);
+ auto load =
+ builder.create<iree_compiler::IREE::Util::GlobalLoadIndirectOp>(
+ read.getLoc(), type, read.resource_id());
+ read.getResult().replaceAllUsesWith(load);
+ read.erase();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> createLowerGlobalTensorsPass() {
+ return std::make_unique<LowerGlobalTensorsPass>();
+}
+
+} // namespace TFL
+} // namespace iree_integrations
+} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/PassDetail.h b/integrations/tensorflow/iree_tf_compiler/TFL/PassDetail.h
new file mode 100644
index 0000000..0d943d7
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/PassDetail.h
@@ -0,0 +1,23 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_INTEGRATIONS_TENSORFLOW_IREE_TF_COMPILER_TFL_PASS_DETAIL_H_
+#define IREE_INTEGRATIONS_TENSORFLOW_IREE_TF_COMPILER_TFL_PASS_DETAIL_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+
+#define GEN_PASS_CLASSES
+#include "iree_tf_compiler/TFL/Passes.h.inc" // IWYU pragma: keep
+
+} // namespace TFL
+} // namespace iree_integrations
+} // namespace mlir
+
+#endif // IREE_INTEGRATIONS_TENSORFLOW_IREE_TF_COMPILER_TFL_PASS_DETAIL_H_
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
index 1ab195d..2790642 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
@@ -18,10 +18,21 @@
namespace iree_integrations {
namespace TFL {
+namespace {
+#define GEN_PASS_REGISTRATION
+#include "iree_tf_compiler/TFL/Passes.h.inc" // IWYU pragma: export
+} // namespace
+
// All IREE-specific passes that lower TFL representations before reaching the
// IREE core should go here.
void buildTFLImportPassPipeline(OpPassManager &pm) {
//----------------------------------------------------------------------------
+ // Guarantee the call once functions are preserved.
+ //----------------------------------------------------------------------------
+
+ pm.addPass(createRetainCallOnceFuncsPass());
+
+ //----------------------------------------------------------------------------
// Input IR cleanup
//----------------------------------------------------------------------------
@@ -41,6 +52,7 @@
//----------------------------------------------------------------------------
mlir::tosa::TOSATFLLegalizationPipelineOptions tosaOptions;
+ pm.addPass(createLowerGlobalTensorsPass());
mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, tosaOptions);
pm.nest<FuncOp>().addPass(mlir::tosa::createStripQuantTypesPass());
pm.addPass(createCanonicalizerPass());
@@ -73,6 +85,15 @@
});
}
+void registerAllPasses() {
+ registerTFLImportPassPipeline();
+
+ // Generated.
+ registerPasses();
+
+ createVerifyFullyConvertedPass();
+}
+
} // namespace TFL
} // namespace iree_integrations
} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h
index 594e2af..bd836f6 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h
@@ -25,10 +25,16 @@
// IREE-specific passes for TFLite import
//===----------------------------------------------------------------------===//
+// Retain functions used by tfl.call_once to avoid removal.
+std::unique_ptr<OperationPass<ModuleOp>> createRetainCallOnceFuncsPass();
+
// Converts TFLite attributes that are useful to corresponding IREE attributes.
std::unique_ptr<OperationPass<ModuleOp>> createConvertModuleMetadataPass();
std::unique_ptr<OperationPass<FuncOp>> createConvertFunctionMetadataPass();
+// Lowers TFLite's global tensor operations to the Util dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createLowerGlobalTensorsPass();
+
// Strips all leftover TFLite-related attributes; none are needed by IREE.
std::unique_ptr<OperationPass<ModuleOp>> createStripModuleMetadataPass();
std::unique_ptr<OperationPass<FuncOp>> createStripFunctionMetadataPass();
@@ -42,15 +48,7 @@
void registerTFLImportPassPipeline();
-inline void registerAllPasses() {
- registerTFLImportPassPipeline();
-
- createConvertModuleMetadataPass();
- createConvertFunctionMetadataPass();
- createStripModuleMetadataPass();
- createStripFunctionMetadataPass();
- createVerifyFullyConvertedPass();
-}
+void registerAllPasses();
} // namespace TFL
} // namespace iree_integrations
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.td b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.td
new file mode 100644
index 0000000..43ab16a
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.td
@@ -0,0 +1,54 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_INTEGRATIONS_TFL_PASSES
+#define IREE_INTEGRATIONS_TFL_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def ConvertFunctionMetadata :
+ Pass<"iree-tflite-convert-function-metadata", "mlir::FuncOp"> {
+ let summary = "Converts TFLite attributes to IREE attributes on functions.";
+ let constructor = "mlir::iree_integrations::TFL::createConvertFunctionMetadataPass()";
+}
+
+def ConvertModuleMetadata :
+ Pass<"iree-tflite-convert-module-metadata", "mlir::ModuleOp"> {
+ let summary = "Converts TFLite attributes to IREE attributes on modules.";
+ let constructor = "mlir::iree_integrations::TFL::createConvertModuleMetadataPass()";
+}
+
+def LowerGlobalTensors :
+ Pass<"iree-tflite-lower-global-tensors", "mlir::ModuleOp"> {
+ let summary = "Lowers tflite global tensors to IREE flow dialect variables.";
+ let constructor = "mlir::iree_integrations::TFL::createLowerGlobalTensorsPass()";
+}
+
+def RetainCallOnceFuncs :
+ Pass<"iree-tflite-retain-call-once-funcs", "mlir::ModuleOp"> {
+ let summary = "Guarantees that functions used by tfl.call_once are retained.";
+ let constructor = "mlir::iree_integrations::TFL::createRetainCallOnceFuncsPass()";
+}
+
+def StripFunctionMetadata :
+ Pass<"iree-tflite-strip-function-metadata", "mlir::FuncOp"> {
+ let summary = "Guarantees that functions used by tfl.call_once are retained.";
+ let constructor = "mlir::iree_integrations::TFL::createStripFunctionMetadataPass()";
+}
+
+def StripModuleMetadata :
+ Pass<"iree-tflite-strip-module-metadata", "mlir::ModuleOp"> {
+ let summary = "Guarantees that functions used by tfl.call_once are retained.";
+ let constructor = "mlir::iree_integrations::TFL::createStripModuleMetadataPass()";
+}
+
+def VerifyFullyConverted :
+ Pass<"iree-tflite-verify-fully-converted", "mlir::FuncOp"> {
+ let summary = "Verifies that all TFLite frontend ops were converted and none remain.";
+ let constructor = "mlir::iree_integrations::TFL::createVerifyFullyConvertedPass()";
+}
+
+#endif // IREE_INTEGRATIONS_TFL_PASSES
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/RetainCallOnceFuncs.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/RetainCallOnceFuncs.cpp
new file mode 100644
index 0000000..8d8cabb
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/RetainCallOnceFuncs.cpp
@@ -0,0 +1,58 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Utils/ConversionUtils.h"
+#include "iree_tf_compiler/TFL/PassDetail.h"
+#include "iree_tf_compiler/TFL/Passes.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+namespace {
+
+class RetainCallOnceFuncsPass
+ : public RetainCallOnceFuncsBase<RetainCallOnceFuncsPass> {
+ public:
+ void getDependentDialects(DialectRegistry& registry) const override {
+ registry.insert<mlir::TFL::TensorFlowLiteDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ llvm::DenseMap<StringRef, FuncOp> funcMap;
+ for (auto func : moduleOp.getOps<mlir::FuncOp>()) {
+ funcMap[func.sym_name()] = func;
+ }
+
+ for (auto func : moduleOp.getOps<mlir::FuncOp>()) {
+ for (auto callOnce : func.getOps<mlir::TFL::CallOnceOp>()) {
+ auto callFunc = funcMap[callOnce.session_init_function()];
+ callOnce->setAttr("session_init_function_symbol",
+ SymbolRefAttr::get(callFunc));
+ }
+ }
+ }
+};
+
+} // anonymous namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> createRetainCallOnceFuncsPass() {
+ return std::make_unique<RetainCallOnceFuncsPass>();
+}
+
+} // namespace TFL
+} // namespace iree_integrations
+} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp
index 0c6ef17..0ba99e6 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp
@@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree_tf_compiler/TFL/PassDetail.h"
#include "iree_tf_compiler/TFL/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
@@ -11,6 +12,7 @@
namespace mlir {
namespace iree_integrations {
namespace TFL {
+namespace {
static bool isTFLAttr(NamedAttribute &namedAttr) {
// NOTE: tflite mixes tf and tfl, for some reason.
@@ -24,15 +26,8 @@
}
class StripModuleMetadataPass
- : public PassWrapper<StripModuleMetadataPass, OperationPass<ModuleOp>> {
+ : public StripModuleMetadataBase<StripModuleMetadataPass> {
public:
- StringRef getArgument() const override {
- return "iree-tflite-strip-module-metadata";
- }
-
- StringRef getDescription() const override {
- return "Remove unneeded TFLite attributes from module ops";
- }
void runOnOperation() override {
auto moduleOp = getOperation();
@@ -46,15 +41,8 @@
};
class StripFunctionMetadataPass
- : public PassWrapper<StripFunctionMetadataPass, OperationPass<FuncOp>> {
+ : public StripFunctionMetadataBase<StripFunctionMetadataPass> {
public:
- StringRef getArgument() const override {
- return "iree-tflite-strip-function-metadata";
- }
-
- StringRef getDescription() const override {
- return "Remove unneeded TFLite attributes from func ops";
- }
void runOnOperation() override {
auto funcOp = getOperation();
@@ -85,6 +73,8 @@
}
};
+} // anonymous namespace
+
std::unique_ptr<OperationPass<ModuleOp>> createStripModuleMetadataPass() {
return std::make_unique<StripModuleMetadataPass>();
}
@@ -93,9 +83,6 @@
return std::make_unique<StripFunctionMetadataPass>();
}
-static PassRegistration<StripModuleMetadataPass> modulePass;
-static PassRegistration<StripFunctionMetadataPass> funcPass;
-
} // namespace TFL
} // namespace iree_integrations
} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp
index 961937e..3a6a782 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp
@@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Utils/ConversionUtils.h"
+#include "iree_tf_compiler/TFL/PassDetail.h"
#include "iree_tf_compiler/TFL/Passes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -16,21 +17,14 @@
namespace mlir {
namespace iree_integrations {
namespace TFL {
+namespace {
class VerifyFullyConvertedPass
- : public PassWrapper<VerifyFullyConvertedPass, FunctionPass> {
+ : public VerifyFullyConvertedBase<VerifyFullyConvertedPass> {
public:
- StringRef getArgument() const override {
- return "iree-tflite-verify-fully-converted";
- }
-
- StringRef getDescription() const override {
- return "Verifies that all TFLite frontend ops were converted and none "
- "remain";
- }
// Validates that no TFLite frontends ops are in the function.
- void runOnFunction() override {
+ void runOnOperation() override {
ConversionTarget target(getContext());
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
target.addIllegalDialect<mlir::TFL::TensorFlowLiteDialect>();
@@ -40,7 +34,7 @@
}
};
-static PassRegistration<VerifyFullyConvertedPass> pass;
+} // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>> createVerifyFullyConvertedPass() {
return std::make_unique<VerifyFullyConvertedPass>();
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD
index 5826703..7af87a5 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD
@@ -18,6 +18,8 @@
srcs = enforce_glob(
[
"convert_metadata.mlir",
+ "lower_global_tensors.mlir",
+ "retain_call_once_funcs.mlir",
"strip_metadata.mlir",
"verify_fully_converted.mlir",
],
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/lower_global_tensors.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/lower_global_tensors.mlir
new file mode 100644
index 0000000..e9d8668
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/lower_global_tensors.mlir
@@ -0,0 +1,122 @@
+// RUN: iree-opt-tflite -split-input-file -allow-unregistered-dialect -pass-pipeline='iree-tflite-lower-global-tensors' %s | IreeFileCheck %s
+
+// CHECK-LABEL: module {
+module {
+ // CHECK: util.global private mutable @__iree_flow_Variable = dense<1.000000e+00> : tensor<16x16xf32>
+ // CHECK: func @state
+ func @state(%arg0: tensor<16x16xf32>) -> () {
+ "tfl.call_once"() {session_init_function = "StateInit"} : () -> ()
+ return
+ }
+
+ func private @StateInit() {
+ %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+ %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
+ "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+ return
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module {
+module {
+ // CHECK: util.global private mutable @__iree_flow_Variable = dense<1.000000e+00> : tensor<16x16xf32>
+
+ // CHECK: func @assign
+ func @assign(%arg0: tensor<16x16xf32>) -> () {
+ "tfl.call_once"() {session_init_function = "AssignInit"} : () -> ()
+ // CHECK: %[[ADDR:.+]] = util.global.address @__iree_flow_Variable : !util.ptr<tensor<16x16xf32>>
+ %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+
+ // CHECK: util.global.store.indirect %arg0, %[[ADDR]]
+ "tfl.assign_variable"(%0, %arg0) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+ return
+ }
+
+ func private @AssignInit() {
+ %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+ %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
+ "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+ return
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module {
+module {
+ // CHECK: util.global private mutable @__iree_flow_Variable = dense<1.000000e+00> : tensor<16x16xf32>
+
+ // CHECK: func @read
+ func @read(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) {
+ "tfl.call_once"() {session_init_function = "ReadInit"} : () -> ()
+
+ // CHECK: %[[ADDR:.+]] = util.global.address @__iree_flow_Variable : !util.ptr<tensor<16x16xf32>>
+ %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+
+ // CHECK: %[[LOAD:.+]] = util.global.load.indirect %[[ADDR]] : !util.ptr<tensor<16x16xf32>> -> tensor<16x16xf32>
+ %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32>
+ return %1 : tensor<16x16xf32>
+ }
+
+ func private @ReadInit() {
+ %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+ %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
+ "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+ return
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module {
+module {
+ // CHECK: util.global private mutable @__iree_flow_Variable = dense<2.000000e+00> : tensor<16x16xf32>
+
+ // func @readAssign
+ func @readAssign(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) {
+ "tfl.call_once"() {session_init_function = "ReadAssignInit"} : () -> ()
+ // CHECK: %[[ADDR:.+]] = util.global.address @__iree_flow_Variable : !util.ptr<tensor<16x16xf32>>
+ %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+
+ // CHECK: %[[LOAD:.+]] = util.global.load.indirect %[[ADDR]] : !util.ptr<tensor<16x16xf32>> -> tensor<16x16xf32>
+ %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32>
+
+ // CHECK: %[[ADD:.+]] = tfl.add %[[LOAD]], %arg0
+ %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32>
+
+ // CHECK: util.global.store.indirect %[[ADD]], %[[ADDR]]
+ "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+ return %2 : tensor<16x16xf32>
+ }
+ func private @ReadAssignInit() {
+ %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+ %1 = "tfl.pseudo_const"() {value = dense<2.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
+ "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+ return
+ }
+}
+
+// -----
+
+module {
+ // CHECK-label: @nostate
+ func @nostate(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) {
+ "tfl.call_once"() {session_init_function = "NoStateInit"} : () -> ()
+ // CHECK: tfl.var_handle
+ %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+
+ // CHECK: tfl.read_variable
+ %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32>
+
+ %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32>
+
+ // CHECK: tfl.assign_variable
+ "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+ return %2 : tensor<16x16xf32>
+ }
+ func private @NoStateInit() {
+ return
+ }
+}
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/retain_call_once_funcs.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/retain_call_once_funcs.mlir
new file mode 100644
index 0000000..6fa63ad
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/retain_call_once_funcs.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-opt-tflite -allow-unregistered-dialect -split-input-file -pass-pipeline='iree-tflite-retain-call-once-funcs' %s | IreeFileCheck %s
+
+// CHECK-LABEL: module {
+module {
+ // CHECK-LABEL: @main
+ func @main(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) {
+ // CHECK: "tfl.call_once"() {session_init_function = "NoOp", session_init_function_symbol = @NoOp} : () -> ()
+ "tfl.call_once"() {session_init_function = "NoOp"} : () -> ()
+ %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+ %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32>
+ %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32>
+ "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+ return %2 : tensor<16x16xf32>
+ }
+ func private @NoOp() {
+ %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+ %1 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
+ "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+ return
+ }
+}
\ No newline at end of file