Import tf_saved_model.global_tensor into flow.variable
This adds support for variables to TfSavedModelAdoptExportsPass.
Also:
- Add iree/integrations/tensorflow/compiler/test/README.md to aid future
debugging.
- The tf_saved_model verifier now guarantees that there are no calls from
inside the module to an exported function. So we can remove the "Validate
that no one calls this" check.
PiperOrigin-RevId: 282418353
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
index 20660b4..7fa9e5b 100644
--- a/integrations/tensorflow/compiler/BUILD
+++ b/integrations/tensorflow/compiler/BUILD
@@ -38,10 +38,12 @@
}),
deps = select({
"//iree:enable_tensorflow": [
+ "@llvm//:support",
"@local_config_mlir//:Support",
"@local_config_mlir//:TransformUtils",
"@local_config_mlir//:IR",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
+ "//iree/compiler/Dialect/Flow/IR",
],
"//conditions:default": [
],
diff --git a/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp b/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
index 2e29b97..c4bdadf 100644
--- a/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
+++ b/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
@@ -13,31 +13,119 @@
// limitations under the License.
#include "integrations/tensorflow/compiler/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
namespace mlir {
namespace iree_compiler {
+namespace {
+
+LogicalResult ImportTfSavedModelGlobalTensorsToIREEFlow(ModuleOp module) {
+ OpBuilder global_builder(module.getBodyRegion());
+ SymbolTable symbol_table(module);
+
+ DenseMap<StringRef, std::string> sym_name_to_flow_sym_name;
+ for (auto global_tensor : module.getOps<tf_saved_model::GlobalTensorOp>()) {
+ auto exported_names = tf_saved_model::GetExportedNames(global_tensor);
+ std::string flow_sym_name;
+ if (exported_names.empty()) {
+ flow_sym_name = "__iree_flow_" + global_tensor.sym_name().str();
+ } else if (exported_names.size() == 1) {
+ flow_sym_name = exported_names[0].str();
+ } else {
+ return global_tensor.emitError()
+ << "Multiple exported names for global tensor not supported yet";
+ }
+ sym_name_to_flow_sym_name[global_tensor.sym_name()] = flow_sym_name;
+ global_builder.create<IREE::Flow::VariableOp>(
+ global_tensor.getLoc(), flow_sym_name, global_tensor.is_mutable(),
+ global_tensor.type(), global_tensor.value());
+ }
+
+ for (auto func : module.getOps<FuncOp>()) {
+ SmallVector<unsigned, 4> args_to_erase;
+ for (int i = 0, e = func.getNumArguments(); i < e; i++) {
+ tf_saved_model::GlobalTensorOp global_tensor =
+ tf_saved_model::LookupBoundInput(func, i, symbol_table);
+ if (!global_tensor) {
+ continue;
+ }
+ args_to_erase.push_back(i);
+ auto flow_sym_ref = global_builder.getSymbolRefAttr(
+ sym_name_to_flow_sym_name[global_tensor.sym_name()]);
+ Value *arg = func.getArgument(i);
+ if (global_tensor.is_mutable()) {
+ // The value is a tensor<*x!tf.resource> type, which flows into
+ // tf.ReadVariableOp/tf.AssignVariableOp.
+ // XLA resource functionalization should have canonicalized everything
+ // to uses of those two ops in the body of the tf_saved_model exported
+ // function.
+ for (OpOperand &operand : llvm::make_early_inc_range(arg->getUses())) {
+ if (auto read_variable =
+ dyn_cast<TF::ReadVariableOp>(operand.getOwner())) {
+ auto load = OpBuilder(read_variable)
+ .create<IREE::Flow::VariableLoadOp>(
+ read_variable.getLoc(),
+ read_variable.value()->getType(), flow_sym_ref);
+ read_variable.value()->replaceAllUsesWith(load.result());
+ read_variable.erase();
+ continue;
+ }
+ if (auto assign_variable =
+ dyn_cast<TF::AssignVariableOp>(operand.getOwner())) {
+ OpBuilder(assign_variable)
+ .create<IREE::Flow::VariableStoreOp>(assign_variable.getLoc(),
+ flow_sym_ref,
+ assign_variable.value());
+ assign_variable.erase();
+ continue;
+ }
+ return operand.getOwner()->emitError()
+ << "unknown op operating on resource for global tensor";
+ }
+ } else {
+ // The value is already a tensor value type. Just RAUW it with a
+ // `flow.variable.load`.
+ auto load =
+ OpBuilder(func.getBody())
+ .create<IREE::Flow::VariableLoadOp>(
+ global_tensor.getLoc(), arg->getType(), flow_sym_ref);
+ arg->replaceAllUsesWith(load.result());
+ }
+ }
+ func.eraseArguments(args_to_erase);
+ }
+
+ // Erase all the global tensors.
+ for (auto global_tensor : llvm::make_early_inc_range(
+ module.getOps<tf_saved_model::GlobalTensorOp>())) {
+ global_tensor.erase();
+ }
+ return success();
+}
+
+} // namespace
+
class TFSavedModelAdoptExportsPass
: public ModulePass<TFSavedModelAdoptExportsPass> {
public:
void runOnModule() override {
mlir::Builder builder(getModule());
- // TODO(laurenzo): Import tf_saved_model.global_tensor ops.
- for (auto global_tensor :
- getModule().getOps<mlir::tf_saved_model::GlobalTensorOp>()) {
- global_tensor.emitError()
- << "This pass doesn't support global tensors yet";
- signalPassFailure();
- return;
+ if (failed(ImportTfSavedModelGlobalTensorsToIREEFlow(getModule()))) {
+ return signalPassFailure();
}
// Handle saved model exported functions.
@@ -46,16 +134,6 @@
auto exported_names = mlir::tf_saved_model::GetExportedNames(func);
if (exported_names.empty()) continue;
- // TODO(laurenzo): Validate that no one calls this (they shouldn't)
- // before modifying in place.
- if (!mlir::SymbolTable::symbolKnownUseEmpty(func.getName(),
- getModule())) {
- func.emitError()
- << "Exported function is also called, which is not supported yet";
- signalPassFailure();
- return;
- }
-
// TODO(laurenzo): After sequencer rework, we should just keep the
// function name as-is and create explicit export ops for each exported
// function.
@@ -108,17 +186,6 @@
}
}
- // TODO(laurenzo): Handle bound inputs.
- for (int i = 0, e = func.getNumArguments(); i < e; i++) {
- if (func.getArgAttrOfType<mlir::SymbolRefAttr>(
- i, "tf_saved_model.bound_input")) {
- // emit error and signal pass failure
- func.emitError() << "This pass doesn't support bound inputs yet";
- signalPassFailure();
- return;
- }
- }
-
// Remove its designation as a saved model export.
func.removeAttr("tf_saved_model.exported_names");
}
diff --git a/integrations/tensorflow/compiler/test/README.md b/integrations/tensorflow/compiler/test/README.md
new file mode 100644
index 0000000..e008d61
--- /dev/null
+++ b/integrations/tensorflow/compiler/test/README.md
@@ -0,0 +1,21 @@
+# Running tests manually
+
+```shell
+$ bazel test :saved_model_adopt_exports
+```
+
+This will capture the output and pass it through FileCheck and report pass/fail,
+along with a hopefully informative description of what failed.
+
+# Debugging failures
+
+During development, it can be useful to just see the raw output directly.
+
+To see the raw output of the MLIR import and conversion process:
+
+```shell
+$ bazel run :saved_model_adopt_exports -- --disable_filecheck
+```
+
+Look for the `RUN_TEST: <test_name>` and `FINISH_TEST: <test_name>` lines to
+narrow in on the test that interests you.
diff --git a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
index afbb341..f017a43 100644
--- a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
+++ b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
@@ -15,6 +15,7 @@
# pylint: disable=invalid-name
# pylint: disable=missing-docstring
+# pylint: disable=line-too-long
import pyiree
import tensorflow.compat.v2 as tf
@@ -51,28 +52,129 @@
passes=SAVED_MODEL_IMPORT_PASSES,
print_input_module=True)
+# T0002: Tests that bound global vars import properly.
-# Tests that a bound global var imports properly.
-# NOTE: This is currently an error and needs to be implemented
-# CHECK-LABEL: RUN_TEST: T0002_FlatArgsResultsBoundGlobalVar
-# CHECK: [ERROR]: This pass doesn't support global tensors yet
-# CHECK: FINISH_TEST_WITH_EXCEPTION
-class T0002_FlatArgsResultsBoundGlobalVar(tf.Module):
+
+# CHECK-LABEL: RUN_TEST: T0002a_SimpleVarRead
+# CHECK: flow.variable @v mutable dense<0.000000e+00> : tensor<f32>
+# CHECK: func @f() -> (tensor<f32> {tf_saved_model.index_path = []})
+# CHECK: attributes{{.*}}iree.module.export
+# CHECK: flow.variable.load @v : tensor<f32>
+# CHECK: FINISH_TEST
+class T0002a_SimpleVarRead(tf.Module):
def __init__(self):
- self.v = tf.Variable([1., 2., 3., 4.])
+ self.v = tf.Variable(0.)
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.float32),
- tf.TensorSpec([4], tf.float32)
- ])
- def simple_mul(self, a, b):
- return a * b + self.v
+ @tf.function(input_signature=[])
+ def f(self):
+ return self.v
+
+
+# CHECK-LABEL: RUN_TEST: T0002b_SimpleVarWrite
+# CHECK: flow.variable @v mutable dense<0.000000e+00> : tensor<f32>
+# CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.index_path = [0]})
+# CHECK: attributes{{.*}}iree.module.export
+# CHECK: flow.variable.store @v, %arg0 : tensor<f32>
+# CHECK: FINISH_TEST
+class T0002b_SimpleVarWrite(tf.Module):
+
+ def __init__(self):
+ self.v = tf.Variable(0.)
+
+ @tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
+ def f(self, a):
+ self.v.assign(a)
+
+
+# CHECK-LABEL: RUN_TEST: T0002c_SimpleConst
+# flow.variable [[CONST:@.+]] dense<0.000000e+00> : tensor<f32>
+# func @f() -> (tensor<f32> {tf_saved_model.index_path = []})
+# attributes{{.*}}iree.module.export
+# flow.variable.load [[CONST]] : tensor<f32>
+# CHECK: FINISH_TEST
+class T0002c_SimpleConst(tf.Module):
+
+ def __init__(self):
+ self.c = tf.constant(0.)
+
+ @tf.function(input_signature=[])
+ def f(self):
+ return self.c
+
+
+# CHECK-LABEL: RUN_TEST: T0002d_VarCompatibleShapeChange
+# CHECK: flow.variable @v mutable dense<0.000000e+00> : tensor<1xf32>
+# CHECK: func @f()
+# CHECK: attributes{{.*}}iree.module.export
+# CHECK: [[CONST_2xf32:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
+# CHECK: [[CONST_3xf32:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
+# CHECK: flow.variable.store @v, [[CONST_2xf32]] : tensor<2xf32>
+# CHECK: flow.variable.store @v, [[CONST_3xf32]] : tensor<3xf32>
+# CHECK: FINISH_TEST
+class T0002d_VarCompatibleShapeChange(tf.Module):
+
+ def __init__(self):
+ self.v = tf.Variable([0.], shape=[None])
+
+ @tf.function(input_signature=[])
+ def f(self):
+ self.v.assign(tf.constant([0., 1.]))
+ self.v.assign(tf.constant([0., 1., 2.]))
+
+
+# CHECK-LABEL: RUN_TEST: T0002e_Error_VarMultipleExportedNames
+# CHECK: [ERROR]: Multiple exported names for global tensor not supported yet
+# CHECK: FINISH_TEST
+class T0002e_Error_VarMultipleExportedNames(tf.Module):
+
+ def __init__(self):
+ self.v = tf.Variable(0.)
+ self.v2 = self.v
+
+
+# CHECK-LABEL: RUN_TEST: T0002f_Error_UnsupportedResourceOp
+# CHECK: [ERROR]: unknown op operating on resource for global tensor
+# CHECK: FINISH_TEST
+class T0002f_Error_UnsupportedResourceOp(tf.Module):
+
+ def __init__(self):
+ self.v = tf.Variable([0.], shape=[None])
+
+ @tf.function(input_signature=[])
+ def f(self):
+ self.v.assign_add(tf.constant([0., 1.]))
pyiree.tf_test_driver.add_test(
- test_name="T0002_FlatArgsResultsBoundGlobalVar",
- tf_module_builder=T0002_FlatArgsResultsBoundGlobalVar,
+ test_name="T0002a_SimpleVarRead",
+ tf_module_builder=T0002a_SimpleVarRead,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True)
+pyiree.tf_test_driver.add_test(
+ test_name="T0002b_SimpleVarWrite",
+ tf_module_builder=T0002b_SimpleVarWrite,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True)
+pyiree.tf_test_driver.add_test(
+ test_name="T0002c_SimpleConst",
+ tf_module_builder=T0002c_SimpleConst,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True)
+pyiree.tf_test_driver.add_test(
+ test_name="T0002d_VarCompatibleShapeChange",
+ tf_module_builder=T0002d_VarCompatibleShapeChange,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True)
+pyiree.tf_test_driver.add_test(
+ test_name="T0002e_Error_VarMultipleExportedNames",
+ tf_module_builder=T0002e_Error_VarMultipleExportedNames,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True,
+ expect_pass_failure=True)
+pyiree.tf_test_driver.add_test(
+ test_name="T0002f_Error_UnsupportedResourceOp",
+ tf_module_builder=T0002f_Error_UnsupportedResourceOp,
passes=SAVED_MODEL_IMPORT_PASSES,
print_input_module=True,
expect_pass_failure=True)