Added lowering for TFLite state to IREE::Utils operations (#7548)

Adding assign/read support for lowering from TFLite to IREE's Utils operators.
This adds elementary stateful TFLite behavior to IREE's support for MLIR. This
should move to tflite-to-tosa in the future.
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
index b6b91d3..87bcf3c 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
@@ -4,27 +4,49 @@
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
+
 package(
     default_visibility = ["//visibility:public"],
     features = ["layering_check"],
     licenses = ["notice"],  # Apache 2.0
 )
 
+gentbl_cc_library(
+    name = "PassesIncGen",
+    tbl_outs = [
+        (
+            ["-gen-pass-decls"],
+            "Passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "Passes.td",
+    deps = [
+        "@llvm-project//mlir:PassBaseTdFiles",
+    ],
+)
+
 cc_library(
     name = "TFL",
     srcs = [
         "ConvertMetadata.cpp",
+        "LowerGlobalTensors.cpp",
         "Passes.cpp",
+        "RetainCallOnceFuncs.cpp",
         "StripMetadata.cpp",
         "VerifyFullyConverted.cpp",
     ],
     hdrs = [
+        "PassDetail.h",
         "Passes.h",
+        "Passes.h.inc",
     ],
     defines = [
         "IREE_COMPILER_TENSORFLOW_ENABLED",
     ],
     deps = [
+        ":PassesIncGen",
         "@iree//iree/compiler/Dialect/Flow/IR",
         "@iree//iree/compiler/Dialect/Shape/IR",
         "@iree//iree/compiler/Dialect/Shape/Transforms",
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp
index 9c3bdb4..f14c213 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/ConvertMetadata.cpp
@@ -4,6 +4,7 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+#include "iree_tf_compiler/TFL/PassDetail.h"
 #include "iree_tf_compiler/TFL/Passes.h"
 #include "llvm/ADT/StringExtras.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -14,6 +15,7 @@
 namespace mlir {
 namespace iree_integrations {
 namespace TFL {
+namespace {
 
 // Extract the input and output names
 static void splitFunctionIONames(StringAttr namesAttr,
@@ -26,15 +28,8 @@
 }
 
 class ConvertModuleMetadataPass
-    : public PassWrapper<ConvertModuleMetadataPass, OperationPass<ModuleOp>> {
+    : public ConvertModuleMetadataBase<ConvertModuleMetadataPass> {
  public:
-  StringRef getArgument() const override {
-    return "iree-tflite-convert-module-metadata";
-  }
-
-  StringRef getDescription() const override {
-    return "Converts TFLite attributes to IREE attributes on modules";
-  }
 
   void runOnOperation() override {
     // None currently handled.
@@ -42,15 +37,8 @@
 };
 
 class ConvertFunctionMetadataPass
-    : public PassWrapper<ConvertFunctionMetadataPass, OperationPass<FuncOp>> {
+    : public ConvertFunctionMetadataBase<ConvertFunctionMetadataPass> {
  public:
-  StringRef getArgument() const override {
-    return "iree-tflite-convert-function-metadata";
-  }
-
-  StringRef getDescription() const override {
-    return "Converts TFLite attributes to IREE attributes on functions";
-  }
 
   void runOnOperation() override {
     auto funcOp = getOperation();
@@ -104,6 +92,7 @@
     }
   }
 };
+}  // anonymous namespace
 
 std::unique_ptr<OperationPass<ModuleOp>> createConvertModuleMetadataPass() {
   return std::make_unique<ConvertModuleMetadataPass>();
@@ -113,9 +102,6 @@
   return std::make_unique<ConvertFunctionMetadataPass>();
 }
 
-static PassRegistration<ConvertModuleMetadataPass> modulePass;
-static PassRegistration<ConvertFunctionMetadataPass> funcPass;
-
 }  // namespace TFL
 }  // namespace iree_integrations
 }  // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/LowerGlobalTensors.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/LowerGlobalTensors.cpp
new file mode 100644
index 0000000..67930ce
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/LowerGlobalTensors.cpp
@@ -0,0 +1,169 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Utils/ConversionUtils.h"
+#include "iree_tf_compiler/TFL/PassDetail.h"
+#include "iree_tf_compiler/TFL/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+namespace {
+
+class LowerGlobalTensorsPass
+    : public LowerGlobalTensorsBase<LowerGlobalTensorsPass> {
+ public:
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<mlir::TFL::TensorFlowLiteDialect,
+                    iree_compiler::IREE::Util::UtilDialect>();
+  }
+
+  // Converts TFLite state operations to the IREE equivalent.
+  void runOnOperation() override {
+    auto moduleOp = getOperation();
+    mlir::OpBuilder builder(moduleOp.body());
+
+    DenseMap<StringRef, FuncOp> symNameToFunction;
+    for (auto func : moduleOp.getOps<FuncOp>()) {
+      symNameToFunction[func.sym_name()] = func;
+    }
+
+    DenseMap<StringRef, DenseElementsAttr> sharedNameToConstant;
+    DenseMap<StringRef, LocationAttr> sharedNameToLoc;
+
+    llvm::SmallVector<mlir::TFL::VarHandleOp, 6> handleOps;
+    llvm::SmallVector<mlir::TFL::AssignVariableOp, 6> assignOps;
+    llvm::SmallVector<mlir::TFL::ReadVariableOp, 6> readOps;
+    for (auto it : symNameToFunction) {
+      auto func = std::get<1>(it);
+      // Look through the initialization functions and find the assigned values
+      // for each handle, save out the constant value.
+      for (auto init : func.getOps<mlir::TFL::CallOnceOp>()) {
+        FuncOp initFunc = symNameToFunction[init.session_init_function()];
+        for (auto assign : initFunc.getOps<mlir::TFL::AssignVariableOp>()) {
+          auto handle = dyn_cast<mlir::TFL::VarHandleOp>(
+              assign.resource_id().getDefiningOp());
+          if (!handle) continue;
+
+          DenseElementsAttr constant;
+          if (!matchPattern(assign.value(), m_Constant(&constant))) continue;
+          auto name = handle.shared_name();
+          sharedNameToConstant[name] = constant;
+          sharedNameToLoc[name] = handle.getLoc();
+        }
+      }
+
+      // We also want to grab the list of operations to replace.
+      for (auto& op : func.getOps()) {
+        if (auto handle = dyn_cast<mlir::TFL::VarHandleOp>(op))
+          handleOps.push_back(handle);
+        if (auto assign = dyn_cast<mlir::TFL::AssignVariableOp>(op))
+          assignOps.push_back(assign);
+        if (auto read = dyn_cast<mlir::TFL::ReadVariableOp>(op))
+          readOps.push_back(read);
+      }
+    }
+
+    // TF::CallOnceOps are no longer needed as we have already extracted their
+    // state.
+    SmallVector<mlir::TFL::CallOnceOp> callOnceOps;
+    for (auto func : moduleOp.getOps<FuncOp>()) {
+      for (auto init : func.getOps<mlir::TFL::CallOnceOp>()) {
+        callOnceOps.push_back(init);
+      }
+    }
+    for (auto op : callOnceOps) op.erase();
+
+    // Create the Util::GlobalOps to store our new global variables.
+    DenseMap<StringRef, std::string> sharedNameToFlowName;
+    for (auto it : sharedNameToConstant) {
+      auto name = std::get<0>(it);
+      auto attribute = std::get<1>(it);
+      auto locIt = sharedNameToLoc.find(name);
+      LocationAttr loc = mlir::UnknownLoc();
+      if (locIt != sharedNameToLoc.end()) {
+        loc = std::get<1>(*locIt);
+      }
+
+      std::string flowSymName = "__iree_flow_" + name.str();
+
+      // TODO(suderman): Determine the global type based on all store
+      // operations.
+      auto global = builder.create<iree_compiler::IREE::Util::GlobalOp>(
+          loc, flowSymName, /*is_mutable=*/true, attribute.getType(),
+          attribute);
+      global.setPrivate();
+      sharedNameToFlowName[name] = std::move(flowSymName);
+    }
+
+    // Replace handles with global addresses.
+    for (auto handle : handleOps) {
+      auto name = handle.shared_name();
+      auto flowName = sharedNameToFlowName[name];
+      auto constIt = sharedNameToConstant.find(name);
+      if (constIt == sharedNameToConstant.end()) continue;
+
+      auto attribute = std::get<1>(*constIt);
+
+      builder.setInsertionPoint(handle);
+      auto address = builder.create<iree_compiler::IREE::Util::GlobalAddressOp>(
+          handle.getLoc(),
+          iree_compiler::IREE::Util::PtrType::get(attribute.getType()),
+          SymbolRefAttr::get(builder.getContext(), flowName));
+      handle.getResult().replaceAllUsesWith(address.getResult());
+      handle.erase();
+    }
+
+    // Replace the assign ops with a global store operation.
+    for (auto assign : assignOps) {
+      auto address = dyn_cast<iree_compiler::IREE::Util::GlobalAddressOp>(
+          assign.resource_id().getDefiningOp());
+      if (!address) continue;
+
+      builder.setInsertionPoint(assign);
+      builder.create<iree_compiler::IREE::Util::GlobalStoreIndirectOp>(
+          assign.getLoc(), assign.value(), assign.resource_id());
+      assign.erase();
+    }
+
+    // Replace the read ops with a global load operation.
+    for (auto read : readOps) {
+      auto address = dyn_cast<iree_compiler::IREE::Util::GlobalAddressOp>(
+          read.resource_id().getDefiningOp());
+      if (!address) continue;
+
+      auto ptrType =
+          address.getType().dyn_cast<iree_compiler::IREE::Util::PtrType>();
+      if (!ptrType) continue;
+
+      auto type = ptrType.getTargetType();
+
+      builder.setInsertionPoint(read);
+      auto load =
+          builder.create<iree_compiler::IREE::Util::GlobalLoadIndirectOp>(
+              read.getLoc(), type, read.resource_id());
+      read.getResult().replaceAllUsesWith(load);
+      read.erase();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> createLowerGlobalTensorsPass() {
+  return std::make_unique<LowerGlobalTensorsPass>();
+}
+
+}  // namespace TFL
+}  // namespace iree_integrations
+}  // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/PassDetail.h b/integrations/tensorflow/iree_tf_compiler/TFL/PassDetail.h
new file mode 100644
index 0000000..0d943d7
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/PassDetail.h
@@ -0,0 +1,23 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_INTEGRATIONS_TENSORFLOW_IREE_TF_COMPILER_TFL_PASS_DETAIL_H_
+#define IREE_INTEGRATIONS_TENSORFLOW_IREE_TF_COMPILER_TFL_PASS_DETAIL_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+
+#define GEN_PASS_CLASSES
+#include "iree_tf_compiler/TFL/Passes.h.inc"  // IWYU pragma: keep
+
+}  // namespace TFL
+}  // namespace iree_integrations
+}  // namespace mlir
+
+#endif  // IREE_INTEGRATIONS_TENSORFLOW_IREE_TF_COMPILER_TFL_PASS_DETAIL_H_
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
index 1ab195d..2790642 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
@@ -18,10 +18,21 @@
 namespace iree_integrations {
 namespace TFL {
 
+namespace {
+#define GEN_PASS_REGISTRATION
+#include "iree_tf_compiler/TFL/Passes.h.inc"  // IWYU pragma: export
+}  // namespace
+
 // All IREE-specific passes that lower TFL representations before reaching the
 // IREE core should go here.
 void buildTFLImportPassPipeline(OpPassManager &pm) {
   //----------------------------------------------------------------------------
+  // Guarantee the call once functions are preserved.
+  //----------------------------------------------------------------------------
+
+  pm.addPass(createRetainCallOnceFuncsPass());
+
+  //----------------------------------------------------------------------------
   // Input IR cleanup
   //----------------------------------------------------------------------------
 
@@ -41,6 +52,7 @@
   //----------------------------------------------------------------------------
 
   mlir::tosa::TOSATFLLegalizationPipelineOptions tosaOptions;
+  pm.addPass(createLowerGlobalTensorsPass());
   mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, tosaOptions);
   pm.nest<FuncOp>().addPass(mlir::tosa::createStripQuantTypesPass());
   pm.addPass(createCanonicalizerPass());
@@ -73,6 +85,15 @@
       });
 }
 
+void registerAllPasses() {
+  registerTFLImportPassPipeline();
+
+  // Generated.
+  registerPasses();
+
+  createVerifyFullyConvertedPass();
+}
+
 }  // namespace TFL
 }  // namespace iree_integrations
 }  // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h
index 594e2af..bd836f6 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.h
@@ -25,10 +25,16 @@
 // IREE-specific passes for TFLite import
 //===----------------------------------------------------------------------===//
 
+// Retain functions used by tfl.call_once to avoid removal.
+std::unique_ptr<OperationPass<ModuleOp>> createRetainCallOnceFuncsPass();
+
 // Converts TFLite attributes that are useful to corresponding IREE attributes.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertModuleMetadataPass();
 std::unique_ptr<OperationPass<FuncOp>> createConvertFunctionMetadataPass();
 
+// Lowers TFLite's global tensor operations to the Util dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createLowerGlobalTensorsPass();
+
 // Strips all leftover TFLite-related attributes; none are needed by IREE.
 std::unique_ptr<OperationPass<ModuleOp>> createStripModuleMetadataPass();
 std::unique_ptr<OperationPass<FuncOp>> createStripFunctionMetadataPass();
@@ -42,15 +48,7 @@
 
 void registerTFLImportPassPipeline();
 
-inline void registerAllPasses() {
-  registerTFLImportPassPipeline();
-
-  createConvertModuleMetadataPass();
-  createConvertFunctionMetadataPass();
-  createStripModuleMetadataPass();
-  createStripFunctionMetadataPass();
-  createVerifyFullyConvertedPass();
-}
+void registerAllPasses();
 
 }  // namespace TFL
 }  // namespace iree_integrations
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.td b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.td
new file mode 100644
index 0000000..43ab16a
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.td
@@ -0,0 +1,54 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_INTEGRATIONS_TFL_PASSES
+#define IREE_INTEGRATIONS_TFL_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def ConvertFunctionMetadata :
+    Pass<"iree-tflite-convert-function-metadata", "mlir::FuncOp"> {
+  let summary = "Converts TFLite attributes to IREE attributes on functions.";
+  let constructor = "mlir::iree_integrations::TFL::createConvertFunctionMetadataPass()";
+}
+
+def ConvertModuleMetadata :
+    Pass<"iree-tflite-convert-module-metadata", "mlir::ModuleOp"> {
+  let summary = "Converts TFLite attributes to IREE attributes on modules.";
+  let constructor = "mlir::iree_integrations::TFL::createConvertModuleMetadataPass()";
+}
+
+def LowerGlobalTensors :
+    Pass<"iree-tflite-lower-global-tensors", "mlir::ModuleOp"> {
+  let summary = "Lowers tflite global tensors to IREE flow dialect variables.";
+  let constructor = "mlir::iree_integrations::TFL::createLowerGlobalTensorsPass()";
+}
+
+def RetainCallOnceFuncs :
+    Pass<"iree-tflite-retain-call-once-funcs", "mlir::ModuleOp"> {
+  let summary = "Guarantees that functions used by tfl.call_once are retained.";
+  let constructor = "mlir::iree_integrations::TFL::createRetainCallOnceFuncsPass()";
+}
+
+def StripFunctionMetadata :
+    Pass<"iree-tflite-strip-function-metadata", "mlir::FuncOp"> {
+  let summary = "Guarantees that functions used by tfl.call_once are retained.";
+  let constructor = "mlir::iree_integrations::TFL::createStripFunctionMetadataPass()";
+}
+
+def StripModuleMetadata :
+    Pass<"iree-tflite-strip-module-metadata", "mlir::ModuleOp"> {
+  let summary = "Guarantees that functions used by tfl.call_once are retained.";
+  let constructor = "mlir::iree_integrations::TFL::createStripModuleMetadataPass()";
+}
+
+def VerifyFullyConverted :
+    Pass<"iree-tflite-verify-fully-converted", "mlir::FuncOp"> {
+  let summary = "Verifies that all TFLite frontend ops were converted and none remain.";
+  let constructor = "mlir::iree_integrations::TFL::createVerifyFullyConvertedPass()";
+}
+
+#endif  // IREE_INTEGRATIONS_TFL_PASSES
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/RetainCallOnceFuncs.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/RetainCallOnceFuncs.cpp
new file mode 100644
index 0000000..8d8cabb
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/RetainCallOnceFuncs.cpp
@@ -0,0 +1,58 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Utils/ConversionUtils.h"
+#include "iree_tf_compiler/TFL/PassDetail.h"
+#include "iree_tf_compiler/TFL/Passes.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TFL {
+namespace {
+
+class RetainCallOnceFuncsPass
+    : public RetainCallOnceFuncsBase<RetainCallOnceFuncsPass> {
+ public:
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<mlir::TFL::TensorFlowLiteDialect>();
+  }
+
+  void runOnOperation() override {
+    auto moduleOp = getOperation();
+
+    llvm::DenseMap<StringRef, FuncOp> funcMap;
+    for (auto func : moduleOp.getOps<mlir::FuncOp>()) {
+      funcMap[func.sym_name()] = func;
+    }
+
+    for (auto func : moduleOp.getOps<mlir::FuncOp>()) {
+      for (auto callOnce : func.getOps<mlir::TFL::CallOnceOp>()) {
+        auto callFunc = funcMap[callOnce.session_init_function()];
+        callOnce->setAttr("session_init_function_symbol",
+                          SymbolRefAttr::get(callFunc));
+      }
+    }
+  }
+};
+
+}  // anonymous namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> createRetainCallOnceFuncsPass() {
+  return std::make_unique<RetainCallOnceFuncsPass>();
+}
+
+}  // namespace TFL
+}  // namespace iree_integrations
+}  // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp
index 0c6ef17..0ba99e6 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/StripMetadata.cpp
@@ -4,6 +4,7 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+#include "iree_tf_compiler/TFL/PassDetail.h"
 #include "iree_tf_compiler/TFL/Passes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
@@ -11,6 +12,7 @@
 namespace mlir {
 namespace iree_integrations {
 namespace TFL {
+namespace {
 
 static bool isTFLAttr(NamedAttribute &namedAttr) {
   // NOTE: tflite mixes tf and tfl, for some reason.
@@ -24,15 +26,8 @@
 }
 
 class StripModuleMetadataPass
-    : public PassWrapper<StripModuleMetadataPass, OperationPass<ModuleOp>> {
+    : public StripModuleMetadataBase<StripModuleMetadataPass> {
  public:
-  StringRef getArgument() const override {
-    return "iree-tflite-strip-module-metadata";
-  }
-
-  StringRef getDescription() const override {
-    return "Remove unneeded TFLite attributes from module ops";
-  }
 
   void runOnOperation() override {
     auto moduleOp = getOperation();
@@ -46,15 +41,8 @@
 };
 
 class StripFunctionMetadataPass
-    : public PassWrapper<StripFunctionMetadataPass, OperationPass<FuncOp>> {
+    : public StripFunctionMetadataBase<StripFunctionMetadataPass> {
  public:
-  StringRef getArgument() const override {
-    return "iree-tflite-strip-function-metadata";
-  }
-
-  StringRef getDescription() const override {
-    return "Remove unneeded TFLite attributes from func ops";
-  }
 
   void runOnOperation() override {
     auto funcOp = getOperation();
@@ -85,6 +73,8 @@
   }
 };
 
+}  // anonymous namespace
+
 std::unique_ptr<OperationPass<ModuleOp>> createStripModuleMetadataPass() {
   return std::make_unique<StripModuleMetadataPass>();
 }
@@ -93,9 +83,6 @@
   return std::make_unique<StripFunctionMetadataPass>();
 }
 
-static PassRegistration<StripModuleMetadataPass> modulePass;
-static PassRegistration<StripFunctionMetadataPass> funcPass;
-
 }  // namespace TFL
 }  // namespace iree_integrations
 }  // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp
index 961937e..3a6a782 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/VerifyFullyConverted.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree/compiler/Utils/ConversionUtils.h"
+#include "iree_tf_compiler/TFL/PassDetail.h"
 #include "iree_tf_compiler/TFL/Passes.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -16,21 +17,14 @@
 namespace mlir {
 namespace iree_integrations {
 namespace TFL {
+namespace {
 
 class VerifyFullyConvertedPass
-    : public PassWrapper<VerifyFullyConvertedPass, FunctionPass> {
+    : public VerifyFullyConvertedBase<VerifyFullyConvertedPass> {
  public:
-  StringRef getArgument() const override {
-    return "iree-tflite-verify-fully-converted";
-  }
-
-  StringRef getDescription() const override {
-    return "Verifies that all TFLite frontend ops were converted and none "
-           "remain";
-  }
 
   // Validates that no TFLite frontends ops are in the function.
-  void runOnFunction() override {
+  void runOnOperation() override {
     ConversionTarget target(getContext());
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
     target.addIllegalDialect<mlir::TFL::TensorFlowLiteDialect>();
@@ -40,7 +34,7 @@
   }
 };
 
-static PassRegistration<VerifyFullyConvertedPass> pass;
+}  // anonymous namespace
 
 std::unique_ptr<OperationPass<FuncOp>> createVerifyFullyConvertedPass() {
   return std::make_unique<VerifyFullyConvertedPass>();
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD
index 5826703..7af87a5 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/BUILD
@@ -18,6 +18,8 @@
     srcs = enforce_glob(
         [
             "convert_metadata.mlir",
+            "lower_global_tensors.mlir",
+            "retain_call_once_funcs.mlir",
             "strip_metadata.mlir",
             "verify_fully_converted.mlir",
         ],
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/lower_global_tensors.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/lower_global_tensors.mlir
new file mode 100644
index 0000000..e9d8668
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/lower_global_tensors.mlir
@@ -0,0 +1,122 @@
+// RUN: iree-opt-tflite -split-input-file -allow-unregistered-dialect -pass-pipeline='iree-tflite-lower-global-tensors' %s | IreeFileCheck %s
+
+// CHECK-LABEL: module {
+module {
+  // CHECK: util.global private mutable @__iree_flow_Variable = dense<1.000000e+00> : tensor<16x16xf32>
+  // CHECK: func @state
+  func @state(%arg0: tensor<16x16xf32>) -> () {
+    "tfl.call_once"() {session_init_function = "StateInit"} : () -> ()
+    return
+  }
+
+  func private @StateInit() {
+    %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+    %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
+    "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+    return
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module {
+module {
+  // CHECK: util.global private mutable @__iree_flow_Variable = dense<1.000000e+00> : tensor<16x16xf32>
+
+  // CHECK: func @assign
+  func @assign(%arg0: tensor<16x16xf32>) -> () {
+    "tfl.call_once"() {session_init_function = "AssignInit"} : () -> ()
+    // CHECK: %[[ADDR:.+]] = util.global.address @__iree_flow_Variable : !util.ptr<tensor<16x16xf32>>
+    %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+
+    // CHECK: util.global.store.indirect %arg0, %[[ADDR]]
+    "tfl.assign_variable"(%0, %arg0) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+    return
+  }
+
+  func private @AssignInit() {
+    %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+    %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
+    "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+    return
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module {
+module {
+  // CHECK: util.global private mutable @__iree_flow_Variable = dense<1.000000e+00> : tensor<16x16xf32>
+
+  // CHECK: func @read
+  func @read(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) {
+    "tfl.call_once"() {session_init_function = "ReadInit"} : () -> ()
+
+    // CHECK: %[[ADDR:.+]] = util.global.address @__iree_flow_Variable : !util.ptr<tensor<16x16xf32>>
+    %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+
+    // CHECK: %[[LOAD:.+]] = util.global.load.indirect %[[ADDR]] : !util.ptr<tensor<16x16xf32>> -> tensor<16x16xf32>
+    %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32>
+    return %1 : tensor<16x16xf32>
+  }
+
+  func private @ReadInit() {
+    %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+    %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
+    "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+    return
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module {
+module {
+  // CHECK: util.global private mutable @__iree_flow_Variable = dense<2.000000e+00> : tensor<16x16xf32>
+
+  // func @readAssign
+  func @readAssign(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) {
+    "tfl.call_once"() {session_init_function = "ReadAssignInit"} : () -> ()
+    // CHECK: %[[ADDR:.+]] = util.global.address @__iree_flow_Variable : !util.ptr<tensor<16x16xf32>>
+    %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+
+    // CHECK: %[[LOAD:.+]] = util.global.load.indirect %[[ADDR]] : !util.ptr<tensor<16x16xf32>> -> tensor<16x16xf32>
+    %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32>
+
+    // CHECK: %[[ADD:.+]] = tfl.add %[[LOAD]], %arg0
+    %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32>
+
+    // CHECK: util.global.store.indirect %[[ADD]], %[[ADDR]]
+    "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+    return %2 : tensor<16x16xf32>
+  }
+  func private @ReadAssignInit() {
+    %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+    %1 = "tfl.pseudo_const"() {value = dense<2.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
+    "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+    return
+  }
+}
+
+// -----
+
+module {
+  // CHECK-label: @nostate
+  func @nostate(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) {
+    "tfl.call_once"() {session_init_function = "NoStateInit"} : () -> ()
+    // CHECK: tfl.var_handle
+    %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+
+    // CHECK: tfl.read_variable
+    %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32>
+
+    %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32>
+
+    // CHECK: tfl.assign_variable
+    "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+    return %2 : tensor<16x16xf32>
+  }
+  func private @NoStateInit() {
+    return
+  }
+}
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/retain_call_once_funcs.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/retain_call_once_funcs.mlir
new file mode 100644
index 0000000..6fa63ad
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/retain_call_once_funcs.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-opt-tflite -allow-unregistered-dialect -split-input-file -pass-pipeline='iree-tflite-retain-call-once-funcs' %s | IreeFileCheck %s
+
+// CHECK-LABEL: module {
+module {
+  // CHECK-LABEL: @main
+  func @main(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) {
+    // CHECK: "tfl.call_once"() {session_init_function = "NoOp", session_init_function_symbol = @NoOp} : () -> ()
+    "tfl.call_once"() {session_init_function = "NoOp"} : () -> ()
+    %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+    %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32>
+    %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32>
+    "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+    return %2 : tensor<16x16xf32>
+  }
+  func private @NoOp() {
+    %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource>
+    %1 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32>
+    "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> ()
+    return
+  }
+}
\ No newline at end of file