Add a basic pass to "adopt" tf_saved_model exports as IREE exports.
This allows some simple e2e cases to work but needs some significant work to cover all cases.
PiperOrigin-RevId: 277587410
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
new file mode 100644
index 0000000..a1310df
--- /dev/null
+++ b/integrations/tensorflow/compiler/BUILD
@@ -0,0 +1,52 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "tensorflow",
+ srcs = select({
+ "//iree:enable_tensorflow": [
+ "TFSavedModelAdoptExports.cpp",
+ ],
+ "//conditions:default": [
+ ],
+ }),
+ hdrs = [
+ "Passes.h",
+ ],
+ defines = select({
+ "//iree:enable_tensorflow": [
+ "IREE_COMPILER_TENSORFLOW_ENABLED",
+ ],
+ "//conditions:default": [
+ ],
+ }),
+ deps = select({
+ "//iree:enable_tensorflow": [
+ "@local_config_mlir//:Support",
+ "@local_config_mlir//:TransformUtils",
+ "@local_config_mlir//:IR",
+ ],
+ "//conditions:default": [
+ ],
+ }) + [
+ "@local_config_mlir//:Pass",
+ "@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
+ ],
+ alwayslink = 1,
+)
diff --git a/integrations/tensorflow/compiler/Passes.h b/integrations/tensorflow/compiler/Passes.h
new file mode 100644
index 0000000..2452978
--- /dev/null
+++ b/integrations/tensorflow/compiler/Passes.h
@@ -0,0 +1,34 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_INTEGRATIONS_TENSORFLOW_COMPILER_PASSES_H_
+#define IREE_INTEGRATIONS_TENSORFLOW_COMPILER_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+//===----------------------------------------------------------------------===//
+// TensorFlow Import
+//===----------------------------------------------------------------------===//
+
+// In a module tagged with `tf_saved_model.semantics`, adopts any exported
+// SavedModel functions to be used as IREE exported functions.
+std::unique_ptr<OpPassBase<ModuleOp>> createTFSavedModelAdoptExportsPass();
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_INTEGRATIONS_TENSORFLOW_COMPILER_PASSES_H_
diff --git a/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp b/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
new file mode 100644
index 0000000..a88bc32
--- /dev/null
+++ b/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
@@ -0,0 +1,139 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/iree/integrations/tensorflow/compiler/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+class TFSavedModelAdoptExportsPass
+ : public ModulePass<TFSavedModelAdoptExportsPass> {
+ public:
+ void runOnModule() override {
+ mlir::Builder builder(getModule());
+
+ // TODO(laurenzo): Import tf_saved_model.global_tensor ops.
+ for (auto global_tensor :
+ getModule().getOps<mlir::tf_saved_model::GlobalTensorOp>()) {
+ global_tensor.emitError()
+ << "This pass doesn't support global tensors yet";
+ signalPassFailure();
+ return;
+ }
+
+ // Handle saved model exported functions.
+ for (auto func : getModule().getOps<FuncOp>()) {
+ // Transfer exported names to IREE.
+ auto exported_names = mlir::tf_saved_model::GetExportedNames(func);
+ if (exported_names.empty()) continue;
+
+ // TODO(laurenzo): Validate that no one calls this (they shouldn't)
+ // before modifying in place.
+ if (!mlir::SymbolTable::symbolKnownUseEmpty(func.getName(),
+ getModule())) {
+ func.emitError()
+ << "Exported function is also called, which is not supported yet";
+ signalPassFailure();
+ return;
+ }
+
+ // TODO(laurenzo): After sequencer rework, we should just keep the
+ // function name as-is and create explicit export ops for each exported
+ // function.
+ if (exported_names.size() > 1) {
+ func.emitError() << "Multiple exported names not supported yet";
+ signalPassFailure();
+ return;
+ }
+ func.setName(exported_names.front());
+
+ // Tag it as an IREE exported function.
+ func.setAttr("iree.module.export", builder.getUnitAttr());
+
+ // TODO(laurenzo): Validate and map structured arguments signaled via
+ // non-monotonic tf_saved_model.index_path attributes. For now, just fail
+ // if we encounter such arguments.
+ for (int i = 0, e = func.getNumArguments(); i < e; i++) {
+ auto array = func.getArgAttrOfType<mlir::ArrayAttr>(
+ i, "tf_saved_model.index_path");
+ if (!array) continue;
+ auto attrs = array.getValue();
+ if (attrs.size() == 1) {
+ if (auto integer = attrs.front().dyn_cast<IntegerAttr>()) {
+ if (integer.getValue() == i) {
+ continue;
+ }
+ }
+ }
+ func.emitError()
+ << "This pass doesn't support structured arguments yet";
+ signalPassFailure();
+ return;
+ }
+
+ // TODO(laurenzo): Also accept structured results. For now, just fail
+ // if any are found.
+ if (func.getNumResults() > 1) {
+ func.emitError() << "This pass doesn't support multiple results yet";
+ signalPassFailure();
+ return;
+ }
+ for (int i = 0, e = func.getNumResults(); i < e; i++) {
+ auto array = func.getResultAttrOfType<mlir::ArrayAttr>(
+ i, "tf_saved_model.index_path");
+ if (array && array.size() != 0) {
+ func.emitError()
+ << "This pass doesn't support structured results yet";
+ signalPassFailure();
+ return;
+ }
+ }
+
+ // TODO(laurenzo): Handle bound inputs.
+ for (int i = 0, e = func.getNumArguments(); i < e; i++) {
+ if (func.getArgAttrOfType<mlir::SymbolRefAttr>(
+ i, "tf_saved_model.bound_input")) {
+ // emit error and signal pass failure
+ func.emitError() << "This pass doesn't support bound inputs yet";
+ signalPassFailure();
+ return;
+ }
+ }
+
+ // Remove its designation as a saved model export.
+ func.removeAttr("tf_saved_model.exported_names");
+ }
+
+ // We should have now removed anything requiring saved model semantics.
+ getModule().removeAttr("tf_saved_model.semantics");
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createTFSavedModelAdoptExportsPass() {
+ return std::make_unique<TFSavedModelAdoptExportsPass>();
+}
+
+static PassRegistration<TFSavedModelAdoptExportsPass> pass(
+ "iree-tf-saved-model-adopt-exports", "Adopts TF saved model exports");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/integrations/tensorflow/compiler/test/BUILD b/integrations/tensorflow/compiler/test/BUILD
new file mode 100644
index 0000000..1a43886
--- /dev/null
+++ b/integrations/tensorflow/compiler/test/BUILD
@@ -0,0 +1,51 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load(
+ "//iree:build_defs.bzl",
+ "INTREE_TENSORFLOW_PY_DEPS",
+ "iree_glob_lit_tests",
+ "iree_setup_lit_package",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_setup_lit_package(
+ data = [
+ "//iree/tools:iree-opt",
+ ],
+)
+
+iree_glob_lit_tests()
+
+py_test(
+ name = "saved_model_adopt_exports",
+ srcs = [
+ "iree_tf_test_driver.py",
+ "saved_model_adopt_exports.py",
+ ],
+ args = [
+ "--filecheck_binary=$(location //third_party/llvm/llvm:FileCheck)",
+ ],
+ data = [
+ "@llvm//:FileCheck",
+ ],
+ python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + [
+ "//bindings/python/pyiree",
+ ],
+)
diff --git a/integrations/tensorflow/compiler/test/iree_tf_test_driver.py b/integrations/tensorflow/compiler/test/iree_tf_test_driver.py
new file mode 100644
index 0000000..a4f25d7
--- /dev/null
+++ b/integrations/tensorflow/compiler/test/iree_tf_test_driver.py
@@ -0,0 +1,113 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Utilities for running tests from TensorFlow models."""
+
+import contextlib
+import io
+import subprocess
+import sys
+import tempfile
+import traceback
+
+from absl import app
+from absl import flags
+import tensorflow.compat.v2 as tf
+import pyiree
+
+flags.DEFINE_string("filecheck_binary", "filecheck",
+ "Location of the filecheck binary.")
+flags.DEFINE_bool("disable_filecheck", False,
+ "Disables filecheck redirection (for debugging).")
+FLAGS = flags.FLAGS
+
+ALL_TEST_DICTS = []
+
+
+def add_test(**kwargs):
+ assert "test_name" in kwargs, "'test_name' is a required argument"
+ ALL_TEST_DICTS.append(kwargs)
+
+
+def _run_test(test_dict):
+ """Runs an individual test dict."""
+ tf_module_builder_lambda = test_dict["tf_module_builder"]
+ tf_module = tf_module_builder_lambda()
+ ctx = pyiree.binding.compiler.CompilerContext()
+ with tempfile.TemporaryDirectory() as sm_path:
+ options = tf.saved_model.SaveOptions(save_debug_info=True)
+ tf.saved_model.save(tf_module, sm_path, options=options)
+ input_module = pyiree.binding.tf_interop.load_saved_model(ctx, sm_path)
+
+ passes = test_dict.get("passes")
+ expect_pass_failure = test_dict.get("expect_pass_failure")
+ if passes:
+ try:
+ input_module.run_pass_pipeline(passes)
+ except: # pylint: disable=bare-except
+ if not expect_pass_failure:
+ print(
+ "UNEXPECTED PASS FAILURE (INTERMEDIATE ASM FOLLOWS ON STDERR):",
+ file=sys.stderr)
+ print(input_module.to_asm(), file=sys.stderr)
+ raise
+
+ # Print the input module ASM.
+ if test_dict.get("print_input_module"):
+ print(input_module.to_asm())
+
+
+def _internal_run_tests():
+ """Main function that runs all tests."""
+ test_count = 0
+ for test_dict in ALL_TEST_DICTS:
+ test_count += 1
+ test_name = test_dict["test_name"]
+ print("RUN_TEST:", test_name)
+ try:
+ _run_test(test_dict)
+ print("FINISH_TEST:", test_name)
+ except: # pylint: disable=bare-except
+ # Error goes to stdout for FileCheck.
+ traceback.print_exc(file=sys.stdout)
+ print("FINISH_TEST_WITH_EXCEPTION:", test_name)
+
+ print("FINISHED: RAN", test_count, "TESTS", file=sys.stderr)
+
+
+def run_tests(main_file, with_filecheck=True):
+ """Main entry point."""
+
+ def internal_main(unused_argv):
+ """App main."""
+ # In case if running with a version prior to v2 defaulting.
+ tf.enable_v2_behavior()
+ if with_filecheck and not FLAGS.disable_filecheck:
+ # Capture and run through filecheck.
+ filecheck_capture_io = io.StringIO()
+ with contextlib.redirect_stdout(filecheck_capture_io):
+ _internal_run_tests()
+ filecheck_capture_io.flush()
+ filecheck_input = filecheck_capture_io.getvalue()
+ p = subprocess.Popen(
+ [FLAGS.filecheck_binary, main_file, "--dump-input=fail"],
+ stdin=subprocess.PIPE)
+ p.communicate(filecheck_input.encode("UTF-8"))
+ sys.exit(p.returncode)
+ else:
+ # Just run directly.
+ _internal_run_tests()
+
+ app.run(internal_main)
diff --git a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
new file mode 100644
index 0000000..5a212a3
--- /dev/null
+++ b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
@@ -0,0 +1,185 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests supported features of saved models."""
+
+# pylint: disable=invalid-name
+# pylint: disable=missing-docstring
+# pylint: disable=g-import-not-at-top
+
+# Always load the test driver as a relative import.
+import os
+import sys
+sys.path.insert(0, os.path.dirname(__file__))
+import iree_tf_test_driver
+import tensorflow.compat.v2 as tf
+
+SAVED_MODEL_IMPORT_PASSES = [
+ "tf-executor-graph-pruning",
+ "tf-standard-pipeline",
+ "iree-tf-saved-model-adopt-exports",
+ "canonicalize",
+]
+
+
+# Tests that a simple example with flat args and a single result and no
+# captures imports properly.
+# CHECK-LABEL: RUN_TEST: T0001_FlatArgsResultsNoBoundGlobals
+# CHECK: module
+# CHECK-NOT: tf_saved_model.semantics
+# CHECK: @simple_mul_no_capture
+# CHECK-NEXT: iree.module.export
+# CHECK: FINISH_TEST
+class T0001_FlatArgsResultsNoBoundGlobals(tf.Module):
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([4], tf.float32),
+ tf.TensorSpec([4], tf.float32)
+ ])
+ def simple_mul_no_capture(self, a, b):
+ return a * b
+
+
+iree_tf_test_driver.add_test(
+ test_name="T0001_FlatArgsResultsNoBoundGlobals",
+ tf_module_builder=T0001_FlatArgsResultsNoBoundGlobals,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True)
+
+
+# Tests that a bound global var imports properly.
+# NOTE: This is currently an error and needs to be implemented
+# CHECK-LABEL: RUN_TEST: T0002_FlatArgsResultsBoundGlobalVar
+# CHECK: [ERROR]: This pass doesn't support global tensors yet
+# CHECK: FINISH_TEST_WITH_EXCEPTION
+class T0002_FlatArgsResultsBoundGlobalVar(tf.Module):
+
+ def __init__(self):
+ self.v = tf.Variable([1., 2., 3., 4.])
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([4], tf.float32),
+ tf.TensorSpec([4], tf.float32)
+ ])
+ def simple_mul(self, a, b):
+ return a * b + self.v
+
+
+iree_tf_test_driver.add_test(
+ test_name="T0002_FlatArgsResultsBoundGlobalVar",
+ tf_module_builder=T0002_FlatArgsResultsBoundGlobalVar,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True,
+ expect_pass_failure=True)
+
+
+# Tests that a structured argument is handled properly.
+# NOTE: This is currently an error and needs to be implemented
+# CHECK-LABEL: RUN_TEST: T0003_StructuredArgs
+# CHECK: [ERROR]: This pass doesn't support structured arguments yet
+# CHECK: FINISH_TEST_WITH_EXCEPTION
+class T0003_StructuredArgs(tf.Module):
+
+ @tf.function(input_signature=[{
+ "x": tf.TensorSpec([4], tf.float32),
+ "y": tf.TensorSpec([4], tf.float32)
+ }])
+ def simple_mul(self, d):
+ return d["x"] * d["y"]
+
+
+iree_tf_test_driver.add_test(
+ test_name="T0003_StructuredArgs",
+ tf_module_builder=T0003_StructuredArgs,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True,
+ expect_pass_failure=True)
+
+
+# Tests that a structured argument is handled properly.
+# NOTE: This is currently an error and needs to be implemented
+# CHECK-LABEL: RUN_TEST: T0003_StructuredMultipleResult
+# CHECK: [ERROR]: This pass doesn't support multiple results yet
+# CHECK: FINISH_TEST_WITH_EXCEPTION
+class T0003_StructuredMultipleResult(tf.Module):
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([4], tf.float32),
+ tf.TensorSpec([4], tf.float32)
+ ])
+ def simple_mul(self, a, b):
+ product = a * b
+ return {"x": product, "x_squared": product * product}
+
+
+iree_tf_test_driver.add_test(
+ test_name="T0003_StructuredMultipleResult",
+ tf_module_builder=T0003_StructuredMultipleResult,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True,
+ expect_pass_failure=True)
+
+
+# Tests that a structured argument is handled properly.
+# NOTE: This is currently an error and needs to be implemented
+# CHECK-LABEL: RUN_TEST: T0004_StructuredSingleResult
+# CHECK: [ERROR]: This pass doesn't support structured results yet
+# CHECK: FINISH_TEST_WITH_EXCEPTION
+class T0004_StructuredSingleResult(tf.Module):
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([4], tf.float32),
+ tf.TensorSpec([4], tf.float32)
+ ])
+ def simple_mul(self, a, b):
+ product = a * b
+ return {"x": product}
+
+
+iree_tf_test_driver.add_test(
+ test_name="T0004_StructuredSingleResult",
+ tf_module_builder=T0004_StructuredSingleResult,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True,
+ expect_pass_failure=True)
+
+
+# Tests that a structured argument is handled properly.
+# NOTE: This is currently an error and needs to be implemented
+# CHECK-LABEL: RUN_TEST: T0005_MultipleExportedFuncNames
+# CHECK: [ERROR]: Multiple exported names not supported yet
+# CHECK: FINISH_TEST_WITH_EXCEPTION
+class T0005_MultipleExportedFuncNames(tf.Module):
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([4], tf.float32),
+ tf.TensorSpec([4], tf.float32)
+ ])
+ def simple_mul(self, a, b):
+ product = a * b
+ return {"x": product}
+
+
+# Force a function alias.
+T0005_MultipleExportedFuncNames.another_copy = (
+ T0005_MultipleExportedFuncNames.simple_mul)
+
+iree_tf_test_driver.add_test(
+ test_name="T0005_MultipleExportedFuncNames",
+ tf_module_builder=T0005_MultipleExportedFuncNames,
+ passes=SAVED_MODEL_IMPORT_PASSES,
+ print_input_module=True,
+ expect_pass_failure=True)
+
+if __name__ == "__main__":
+ iree_tf_test_driver.run_tests(__file__, with_filecheck=True)
diff --git a/iree/compiler/Translation/Sequencer/BUILD b/iree/compiler/Translation/Sequencer/BUILD
index 4245b19..c3e7776 100644
--- a/iree/compiler/Translation/Sequencer/BUILD
+++ b/iree/compiler/Translation/Sequencer/BUILD
@@ -17,6 +17,7 @@
"//iree/compiler/Utils",
"//iree/hal:executable_format",
"//iree/schemas",
+ "//third_party/iree/integrations/tensorflow/compiler:tensorflow",
"@com_github_google_flatbuffers//:flatbuffers",
"@llvm//:support",
"@local_config_mlir//:IR",
diff --git a/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp b/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp
index b88725d..5e55ea7 100644
--- a/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp
+++ b/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp
@@ -22,6 +22,7 @@
#include "flatbuffers/flatbuffers.h"
#include "flatbuffers/minireflect.h"
+#include "third_party/iree/integrations/tensorflow/compiler/Passes.h"
#include "iree/base/status.h"
#include "iree/compiler/IR/ConfigOps.h"
#include "iree/compiler/IR/Sequencer/OpWriters.h"
@@ -50,6 +51,7 @@
#include "mlir/IR/Module.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Translation.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
@@ -210,10 +212,45 @@
VMModuleBuilder *moduleBuilder);
LogicalResult defineFunction(FuncOp function, VMModuleBuilder *moduleBuilder);
+ // Optional pass pipelines.
+ LogicalResult runTensorFlowImportPasses(ModuleOp module);
+
ModuleTranslationOptions options_;
};
+#if defined(IREE_COMPILER_TENSORFLOW_ENABLED)
+// Builds a pass pipeline that imports a module imported from TensorFlow
+// (that has already been legalized to XLA HLO ops).
+// NOTE: We will likely pull in more of the XLA legalization over time as the
+// dependency story is worked out.
+void buildTensorFlowImportPassPipeline(PassManager *passManager) {
+ passManager->addPass(createTFSavedModelAdoptExportsPass());
+}
+
+LogicalResult SequencerTranslator::runTensorFlowImportPasses(ModuleOp module) {
+ if (!module.getAttr("tf_saved_model.semantics")) {
+ // Not a TensorFlow module. Do nothing.
+ return success();
+ }
+
+ // Run passes to import from TensorFlow.
+ auto tensorflowPasses = createPassManager(module.getContext(), options());
+ buildTensorFlowImportPassPipeline(tensorflowPasses.get());
+ if (failed(runPassPipeline(options(), tensorflowPasses.get(), module))) {
+ module.emitError() << "Failed to run TensorFlow import passes";
+ return failure();
+ }
+ return success();
+}
+#else
+LogicalResult SequencerTranslator::runTensorFlowImportPasses(ModuleOp module) {
+ // NO-OP
+}
+#endif
+
std::vector<uint8_t> SequencerTranslator::translateModule(ModuleOp module) {
+ runTensorFlowImportPasses(module);
+
// Run one large set of passes to get to a partitioned module.
auto partitioningPasses = createPassManager(module.getContext(), options());
buildLegalizeInputPassPipeline(partitioningPasses.get());
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index cca9c86..d6a02b4 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -20,6 +20,7 @@
"//iree/compiler/Transforms/Interpreter",
"//iree/compiler/Transforms/Sequencer",
"//iree/compiler/Translation/SPIRV",
+ "//third_party/iree/integrations/tensorflow/compiler:tensorflow",
"@llvm//:support",
"@local_config_mlir//:AffineDialectRegistration",
"@local_config_mlir//:MlirOptLib",