Wire saved_model import in properly.

Refactors compiler bindings to preserve diagnostics.
Adds ability to run pass pipelines.
Adds e2e unit test.
Updates colabs to demonstrate.

Closes #106

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/106 from stellaraccident:tfxla2 2106756d76af81ec57f94acff37f5d4891d227b3
PiperOrigin-RevId: 277208555
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD
index 6ff4b59..e6025d3 100644
--- a/bindings/python/pyiree/BUILD
+++ b/bindings/python/pyiree/BUILD
@@ -67,36 +67,48 @@
 
 cc_library(
     name = "base",
+    srcs = [
+        "compiler.cc",
+        "hal.cc",
+        "initialize.cc",
+        "rt.cc",
+        "status_utils.cc",
+        "vm.cc",
+    ],
     hdrs = [
         "binding.h",
+        "compiler.h",
+        "hal.h",
+        "initialize.h",
+        "rt.h",
         "status_utils.h",
+        "vm.h",
     ],
     copts = DEFAULT_COPTS,
     features = DEFAULT_FEATURES,
     deps = [
+        "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
-        "@iree_pybind11//:pybind11",
+        "@local_config_mlir//:IR",
         "//iree/base:api",
         "//iree/base:status",
-    ] + PYTHON_HEADERS_DEPS,
+        "//iree/base:init",
+        "//iree/hal:api",
+        "//iree/rt:api",
+        "//iree/schemas",
+        "//iree/vm:api",
+        "@llvm//:support",
+        "@local_config_mlir//:Parser",
+        "@local_config_mlir//:Pass",
+        "@iree_pybind11//:pybind11",
+    ] + COMPILER_DEPS + DRIVER_DEPS + PYTHON_HEADERS_DEPS,
 )
 
 iree_py_extension(
     name = "binding",
     srcs = [
-        "compiler.cc",
-        "compiler.h",
-        "hal.cc",
-        "hal.h",
-        "initialize.cc",
-        "initialize.h",
         "initialize_module.cc",
-        "rt.cc",
-        "rt.h",
-        "status_utils.cc",
-        "vm.cc",
-        "vm.h",
     ],
     copts = DEFAULT_COPTS,
     features = DEFAULT_FEATURES,
@@ -104,21 +116,8 @@
     win_def_file = "export.def",
     deps = [
         ":base",
-        "@com_google_absl//absl/container:inlined_vector",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:optional",
-        "//bindings/python/pyiree/tensorflow",
-        "//iree/base:api",
-        "//iree/base:init",
-        "//iree/hal:api",
-        "//iree/rt:api",
-        "//iree/schemas",
-        "//iree/vm:api",
-        "@llvm//:support",
-        "@local_config_mlir//:IR",
-        "@local_config_mlir//:Parser",
-        "@iree_pybind11//:pybind11",
-    ] + COMPILER_DEPS + DRIVER_DEPS + PYTHON_HEADERS_DEPS,
+        "//bindings/python/pyiree/tf_interop",
+    ],
 )
 
 py_test(
diff --git a/bindings/python/pyiree/compiler.cc b/bindings/python/pyiree/compiler.cc
index 238973b..7f207df 100644
--- a/bindings/python/pyiree/compiler.cc
+++ b/bindings/python/pyiree/compiler.cc
@@ -23,9 +23,9 @@
 #include "iree/schemas/module_def_generated.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
 #include "mlir/Parser.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
 
 namespace py = pybind11;
 
@@ -60,33 +60,122 @@
 
 }  // namespace
 
-std::shared_ptr<OpaqueBlob> CompileModuleFromAsm(const std::string& moduleAsm) {
+CompilerContextBundle::CompilerContextBundle() {
   InitializeExtension({});
+  // Setup a diagnostic handler.
+  mlir_context()->getDiagEngine().registerHandler(
+      [this](mlir::Diagnostic& d) { diagnostics_.push_back(std::move(d)); });
+}
+CompilerContextBundle::~CompilerContextBundle() = default;
 
-  MLIRContext context;
+std::string CompilerContextBundle::ConsumeDiagnosticsAsString() {
+  std::string s;
+  llvm::raw_string_ostream sout(s);
+  bool first = true;
+  for (auto& d : diagnostics_) {
+    if (!first) {
+      sout << "\n\n";
+    } else {
+      first = false;
+    }
 
+    switch (d.getSeverity()) {
+      case DiagnosticSeverity::Note:
+        sout << "[NOTE]";
+        break;
+      case DiagnosticSeverity::Warning:
+        sout << "[WARNING]";
+        break;
+      case DiagnosticSeverity::Error:
+        sout << "[ERROR]";
+        break;
+      case DiagnosticSeverity::Remark:
+        sout << "[REMARK]";
+        break;
+      default:
+        sout << "[UNKNOWN]";
+    }
+    // Message.
+    sout << ": " << d << "\n\t";
+
+    // Location.
+    d.getLocation().print(sout);
+  }
+
+  diagnostics_.clear();
+  return sout.str();
+}
+
+void CompilerContextBundle::ClearDiagnostics() { diagnostics_.clear(); }
+
+CompilerModuleBundle CompilerContextBundle::ParseAsm(
+    const std::string& asm_text) {
   // Arrange to get a view that includes a terminating null to avoid additional
   // copy.
-  const char* moduleAsmChars = moduleAsm.c_str();
-  StringRef moduleAsmSr(moduleAsmChars, moduleAsm.size() + 1);
+  const char* asm_chars = asm_text.c_str();
+  StringRef asm_sr(asm_chars, asm_text.size() + 1);
 
-  // TODO(laurenzo): This error handling is super hoaky. Hook into the MLIR
-  // error reporter and plumb through properly.
-  OwningModuleRef mlirModule = parseMLIRModuleFromString(moduleAsmSr, &context);
-  if (!mlirModule) {
-    throw std::runtime_error("Failed to parse MLIR asm");
+  auto module_ref = parseMLIRModuleFromString(asm_sr, mlir_context());
+  if (!module_ref) {
+    throw RaiseValueError("Failed to parse MLIR asm");
   }
+  return CompilerModuleBundle(shared_from_this(), module_ref.release());
+}
 
-  auto moduleBlob =
-      mlir::iree_compiler::translateMlirToIreeSequencerModule(mlirModule.get());
-  if (moduleBlob.empty()) {
+std::string CompilerModuleBundle::ToAsm() {
+  // Print to asm.
+  std::string asm_output;
+  llvm::raw_string_ostream sout(asm_output);
+  OpPrintingFlags print_flags;
+  module_op().print(sout, print_flags);
+  return sout.str();
+}
+
+std::shared_ptr<OpaqueBlob> CompilerModuleBundle::CompileToSequencerBlob() {
+  auto module_blob =
+      mlir::iree_compiler::translateMlirToIreeSequencerModule(module_op());
+  if (module_blob.empty()) {
     throw std::runtime_error("Failed to translate MLIR module");
   }
-  return std::make_shared<OpaqueByteVectorBlob>(std::move(moduleBlob));
+  return std::make_shared<OpaqueByteVectorBlob>(std::move(module_blob));
+}
+
+void CompilerModuleBundle::RunPassPipeline(
+    const std::vector<std::string>& pipelines) {
+  mlir::PassManager pm(context_->mlir_context());
+
+  // Parse the pass pipelines.
+  std::string error;
+  llvm::raw_string_ostream error_stream(error);
+  for (const auto& pipeline : pipelines) {
+    if (failed(mlir::parsePassPipeline(pipeline, pm, error_stream))) {
+      throw RaiseValueError(error_stream.str().c_str());
+    }
+  }
+
+  // Run them.
+  if (failed(pm.run(module_op_))) {
+    throw RaisePyError(PyExc_RuntimeError,
+                       "Error running pass pipelines (see diagnostics)");
+  }
 }
 
 void SetupCompilerBindings(pybind11::module m) {
-  m.def("compile_module_from_asm", CompileModuleFromAsm);
+  py::class_<CompilerContextBundle, std::shared_ptr<CompilerContextBundle>>(
+      m, "CompilerContext")
+      .def(py::init<>([]() {
+        // Need explicit make_shared to avoid UB with enable_shared_from_this.
+        return std::make_shared<CompilerContextBundle>();
+      }))
+      .def("parse_asm", &CompilerContextBundle::ParseAsm)
+      .def("get_diagnostics",
+           &CompilerContextBundle::ConsumeDiagnosticsAsString)
+      .def("clear_diagnostics", &CompilerContextBundle::ClearDiagnostics);
+  py::class_<CompilerModuleBundle>(m, "CompilerModule")
+      .def("to_asm", &CompilerModuleBundle::ToAsm)
+      .def("compile_to_sequencer_blob",
+           &CompilerModuleBundle::CompileToSequencerBlob)
+      .def("run_pass_pipeline", &CompilerModuleBundle::RunPassPipeline);
 }
 
 }  // namespace python
diff --git a/bindings/python/pyiree/compiler.h b/bindings/python/pyiree/compiler.h
index 0bd6624..d40c057 100644
--- a/bindings/python/pyiree/compiler.h
+++ b/bindings/python/pyiree/compiler.h
@@ -18,10 +18,58 @@
 #include <string>
 
 #include "bindings/python/pyiree/binding.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
 
 namespace iree {
 namespace python {
 
+class CompilerContextBundle;
+class CompilerModuleBundle;
+
+// Wraps an MLIR module and its producing context.
+class CompilerModuleBundle {
+ public:
+  CompilerModuleBundle(std::shared_ptr<CompilerContextBundle> context,
+                       mlir::ModuleOp module_op)
+      : context_(std::move(context)), module_op_(std::move(module_op)) {}
+
+  mlir::ModuleOp& module_op() { return module_op_; }
+  std::string ToAsm();
+
+  // Runs one or more pass pipelines (as is mlir::parsePassPipeline).
+  void RunPassPipeline(const std::vector<std::string>& pipelines);
+
+  // Compiles the MLIR module to an IREE sequencer module.
+  std::shared_ptr<OpaqueBlob> CompileToSequencerBlob();
+
+ private:
+  std::shared_ptr<CompilerContextBundle> context_;
+  mlir::ModuleOp module_op_;
+};
+
+// Bundle of MLIRContext related things that facilitates interop with
+// Python.
+class CompilerContextBundle
+    : public std::enable_shared_from_this<CompilerContextBundle> {
+ public:
+  CompilerContextBundle();
+  ~CompilerContextBundle();
+
+  mlir::MLIRContext* mlir_context() { return &mlir_context_; }
+
+  CompilerModuleBundle ParseAsm(const std::string& asm_text);
+
+  // Consumes/clears diagnostics.
+  std::string ConsumeDiagnosticsAsString();
+  void ClearDiagnostics();
+
+ private:
+  mlir::MLIRContext mlir_context_;
+  std::vector<mlir::Diagnostic> diagnostics_;
+};
+
 void SetupCompilerBindings(pybind11::module m);
 
 }  // namespace python
diff --git a/bindings/python/pyiree/compiler_test.py b/bindings/python/pyiree/compiler_test.py
index 5cd3de0..b90b2c4 100644
--- a/bindings/python/pyiree/compiler_test.py
+++ b/bindings/python/pyiree/compiler_test.py
@@ -23,16 +23,24 @@
 
 class CompilerTest(absltest.TestCase):
 
-  def testModuleCompileAndIntrospectFromAsm(self):
+  def testParseError(self):
+    ctx = binding.compiler.CompilerContext()
+    with self.assertRaises(ValueError):
+      ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""")
+    diag_str = ctx.get_diagnostics()
+    self.assertRegex(diag_str, "custom op 'FOOBAR' is unknown")
 
-    m = binding.compiler.compile_module_from_asm("""
-    func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
-          attributes { iree.module.export } {
-        %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
-        return %0 : tensor<4xf32>
-    }
-    """)
-    self.assertTrue(m)
+  def testParseAndCompileToSequencer(self):
+    ctx = binding.compiler.CompilerContext()
+    input_module = ctx.parse_asm("""
+      func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
+            attributes { iree.module.export } {
+          %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+          return %0 : tensor<4xf32>
+      }
+      """)
+    binary = input_module.compile_to_sequencer_blob()
+    self.assertTrue(binary)
 
 
 if __name__ == '__main__':
diff --git a/bindings/python/pyiree/initialize_module.cc b/bindings/python/pyiree/initialize_module.cc
index 021f6a5..5a2ebdb 100644
--- a/bindings/python/pyiree/initialize_module.cc
+++ b/bindings/python/pyiree/initialize_module.cc
@@ -18,7 +18,7 @@
 #include "bindings/python/pyiree/initialize.h"
 #include "bindings/python/pyiree/rt.h"
 #include "bindings/python/pyiree/status_utils.h"
-#include "bindings/python/pyiree/tensorflow/register_tensorflow.h"
+#include "bindings/python/pyiree/tf_interop/register_tensorflow.h"
 #include "bindings/python/pyiree/vm.h"
 
 namespace iree {
diff --git a/bindings/python/pyiree/runtime_test.py b/bindings/python/pyiree/runtime_test.py
index 3d7d480..a6d3db6 100644
--- a/bindings/python/pyiree/runtime_test.py
+++ b/bindings/python/pyiree/runtime_test.py
@@ -22,14 +22,16 @@
 
 
 def create_simple_mul_module():
-  blob = binding.compiler.compile_module_from_asm("""
+  ctx = binding.compiler.CompilerContext()
+  input_module = ctx.parse_asm("""
     func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
           attributes { iree.module.export } {
         %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
         return %0 : tensor<4xf32>
     }
     """)
-  m = binding.vm.create_module_from_blob(blob)
+  binary = input_module.compile_to_sequencer_blob()
+  m = binding.vm.create_module_from_blob(binary)
   return m
 
 
diff --git a/bindings/python/pyiree/status_utils.cc b/bindings/python/pyiree/status_utils.cc
index 63f2131..9e3e91d 100644
--- a/bindings/python/pyiree/status_utils.cc
+++ b/bindings/python/pyiree/status_utils.cc
@@ -63,8 +63,9 @@
   return pybind11::error_already_set();
 }
 
-pybind11::error_already_set RaiseValueError(const char* message) {
-  PyErr_SetString(PyExc_ValueError, message);
+pybind11::error_already_set RaisePyError(PyObject* exc_class,
+                                         const char* message) {
+  PyErr_SetString(exc_class, message);
   return pybind11::error_already_set();
 }
 
diff --git a/bindings/python/pyiree/status_utils.h b/bindings/python/pyiree/status_utils.h
index feda6cb..34d77a6 100644
--- a/bindings/python/pyiree/status_utils.h
+++ b/bindings/python/pyiree/status_utils.h
@@ -32,8 +32,16 @@
 
 // Raises a value error with the given message.
 // Correct usage:
+//   throw RaiseValueError(PyExc_ValueError, "Foobar'd");
+pybind11::error_already_set RaisePyError(PyObject* exc_class,
+                                         const char* message);
+
+// Raises a value error with the given message.
+// Correct usage:
 //   throw RaiseValueError("Foobar'd");
-pybind11::error_already_set RaiseValueError(const char* message);
+inline pybind11::error_already_set RaiseValueError(const char* message) {
+  return RaisePyError(PyExc_ValueError, message);
+}
 
 // Consumes a StatusOr<T>, returning an rvalue reference to the T if the
 // status is ok(). Otherwise, throws an exception.
diff --git a/bindings/python/pyiree/tensorflow/register_tensorflow.cc b/bindings/python/pyiree/tensorflow/register_tensorflow.cc
deleted file mode 100644
index b2e09b0..0000000
--- a/bindings/python/pyiree/tensorflow/register_tensorflow.cc
+++ /dev/null
@@ -1,62 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "bindings/python/pyiree/tensorflow/register_tensorflow.h"
-
-#include <string>
-#include <vector>
-
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
-#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
-
-using namespace mlir;  // NOLINT
-
-namespace iree {
-namespace python {
-
-namespace {
-
-std::string ImportSavedModelToMlirAsm(const std::string& saved_model_dir,
-                                      std::vector<std::string> exported_names,
-                                      std::vector<std::string> tags) {
-  std::unordered_set<std::string> tags_set;
-  for (const auto& tag : tags) {
-    tags_set.insert(tag);
-  }
-
-  MLIRContext context;
-  auto module = tensorflow::SavedModelToMlirImport(
-      saved_model_dir, tags_set, absl::MakeSpan(exported_names), &context);
-
-  // Print to asm.
-  std::string asm_output;
-  llvm::raw_string_ostream sout(asm_output);
-  OpPrintingFlags print_flags;
-  module->print(sout, print_flags);
-  return sout.str();
-}
-
-}  // namespace
-
-void SetupTensorFlowBindings(pybind11::module m) {
-  m.def("import_saved_model_to_mlir_asm", &ImportSavedModelToMlirAsm,
-        py::arg("saved_model_dir"),
-        py::arg("exported_names") = std::vector<std::string>(),
-        py::arg("tags") = std::vector<std::string>({std::string("serve")}));
-}
-
-}  // namespace python
-}  // namespace iree
diff --git a/bindings/python/pyiree/tensorflow/BUILD b/bindings/python/pyiree/tf_interop/BUILD
similarity index 73%
rename from bindings/python/pyiree/tensorflow/BUILD
rename to bindings/python/pyiree/tf_interop/BUILD
index 2838997..33a8f1f 100644
--- a/bindings/python/pyiree/tensorflow/BUILD
+++ b/bindings/python/pyiree/tf_interop/BUILD
@@ -12,6 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+load(
+    "//iree:build_defs.bzl",
+    "INTREE_TENSORFLOW_PY_DEPS",
+)
+
 package(
     default_visibility = ["//visibility:public"],
     licenses = ["notice"],  # Apache 2.0
@@ -25,28 +30,21 @@
     "-use_header_modules",  # Incompatible with exceptions builds.
 ]
 
-config_setting(
-    name = "enable",
-    define_values = {
-        "iree_tensorflow": "true",
-    },
-)
-
 cc_library(
-    name = "tensorflow",
+    name = "tf_interop",
     hdrs = [
         "register_tensorflow.h",
     ],
     copts = DEFAULT_COPTS,
     defines = select({
-        ":enable": [
+        "//iree:enable_tensorflow": [
             "IREE_TENSORFLOW_ENABLED",
         ],
         "//conditions:default": [],
     }),
     features = DEFAULT_FEATURES,
     deps = select({
-        ":enable": [
+        "//iree:enable_tensorflow": [
             ":tensorflow_impl",
         ],
         "//conditions:default": [
@@ -80,6 +78,10 @@
     "@org_tensorflow//tensorflow/core/kernels:state",
 ]
 
+TF_XLA_PASS_DEPS = [
+    "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_tf",
+]
+
 cc_library(
     name = "tensorflow_impl",
     srcs = [
@@ -92,12 +94,14 @@
     features = DEFAULT_FEATURES,
     visibility = ["//visibility:private"],
     deps = [
-        "@local_config_mlir//:IR",
-        "@llvm//:support",
-        "@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
-        "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:translate_lib",
-        "//bindings/python/pyiree:base",
-    ] + SAVED_MODEL_TF_RUNTIME_DEPS + SAVED_MODEL_REQUIRED_KERNEL_DEPS,
+               "@local_config_mlir//:IR",
+               "@llvm//:support",
+               "@org_tensorflow//tensorflow/cc/saved_model:loader",
+               "@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
+               "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+               "//bindings/python/pyiree:base",
+           ] + SAVED_MODEL_TF_RUNTIME_DEPS + SAVED_MODEL_REQUIRED_KERNEL_DEPS +
+           TF_XLA_PASS_DEPS,
 )
 
 cc_library(
@@ -114,3 +118,16 @@
         "//bindings/python/pyiree:base",
     ],
 )
+
+py_test(
+    name = "saved_model_test",
+    srcs = ["saved_model_test.py"],
+    python_version = "PY3",
+    tags = [
+        "oss_only",
+    ],
+    deps = [
+        "//bindings/python/pyiree:binding",
+        "//bindings/python:pathsetup",  # build_cleaner: keep
+    ] + INTREE_TENSORFLOW_PY_DEPS,
+)
diff --git a/bindings/python/pyiree/tf_interop/register_tensorflow.cc b/bindings/python/pyiree/tf_interop/register_tensorflow.cc
new file mode 100644
index 0000000..d0e85a7
--- /dev/null
+++ b/bindings/python/pyiree/tf_interop/register_tensorflow.cc
@@ -0,0 +1,84 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/tf_interop/register_tensorflow.h"
+
+#include <string>
+#include <vector>
+
+#include "bindings/python/pyiree/compiler.h"
+#include "bindings/python/pyiree/status_utils.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "tensorflow/cc/saved_model/loader.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
+
+using namespace mlir;  // NOLINT
+
+using tensorflow::ConvertSavedModelToMlir;
+using tensorflow::RunOptions;
+using tensorflow::SavedModelBundle;
+using tensorflow::SessionOptions;
+
+namespace iree {
+namespace python {
+
+namespace {
+
+CompilerModuleBundle LoadSavedModel(
+    std::shared_ptr<CompilerContextBundle> context_bundle,
+    const std::string& saved_model_dir,
+    const std::vector<std::string>& exported_names) {
+  SessionOptions session_options;
+  RunOptions run_options;
+  SavedModelBundle bundle;
+  std::unordered_set<std::string> tags{"serve"};
+  auto load_status = LoadSavedModel(
+      session_options, run_options,
+      std::string(saved_model_dir.data(), saved_model_dir.length()), tags,
+      &bundle);
+  if (!load_status.ok()) {
+    std::stringstream msg;
+    msg << "Failed to load saved model '" << saved_model_dir
+        << "': " << load_status;
+    throw RaisePyError(PyExc_RuntimeError, msg.str().c_str());
+  }
+
+  // TODO(laurenzo): Fix the upstream ConvertSavedModelToMlir() to take a const
+  // span of external names.
+  std::vector<std::string> mutable_exported_names = exported_names;
+  auto module_or =
+      ConvertSavedModelToMlir(bundle, context_bundle->mlir_context(),
+                              absl::MakeSpan(mutable_exported_names));
+  if (!module_or.status().ok()) {
+    std::stringstream msg;
+    msg << "Failed to load saved model '" << saved_model_dir
+        << "': " << load_status;
+    throw RaisePyError(PyExc_RuntimeError, msg.str().c_str());
+  }
+  return CompilerModuleBundle(context_bundle,
+                              module_or.ConsumeValueOrDie().release());
+}
+
+}  // namespace
+
+void SetupTensorFlowBindings(pybind11::module m) {
+  m.def("load_saved_model", &LoadSavedModel, py::arg("compiler_context"),
+        py::arg("saved_model_dir"),
+        py::arg("exported_names") = std::vector<std::string>());
+}
+
+}  // namespace python
+}  // namespace iree
diff --git a/bindings/python/pyiree/tensorflow/register_tensorflow.h b/bindings/python/pyiree/tf_interop/register_tensorflow.h
similarity index 80%
rename from bindings/python/pyiree/tensorflow/register_tensorflow.h
rename to bindings/python/pyiree/tf_interop/register_tensorflow.h
index 95dbf3e..7a54404 100644
--- a/bindings/python/pyiree/tensorflow/register_tensorflow.h
+++ b/bindings/python/pyiree/tf_interop/register_tensorflow.h
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_TENSORFLOW_REGISTER_TENSORFLOW_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_TENSORFLOW_REGISTER_TENSORFLOW_H_
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_TF_INTEROP_REGISTER_TENSORFLOW_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_TF_INTEROP_REGISTER_TENSORFLOW_H_
 
 #include <string>
 
@@ -27,4 +27,4 @@
 }  // namespace python
 }  // namespace iree
 
-#endif  // IREE_BINDINGS_PYTHON_PYIREE_TENSORFLOW_REGISTER_TENSORFLOW_H_
+#endif  // IREE_BINDINGS_PYTHON_PYIREE_TF_INTEROP_REGISTER_TENSORFLOW_H_
diff --git a/bindings/python/pyiree/tensorflow/register_tensorflow_noop.cc b/bindings/python/pyiree/tf_interop/register_tensorflow_noop.cc
similarity index 91%
rename from bindings/python/pyiree/tensorflow/register_tensorflow_noop.cc
rename to bindings/python/pyiree/tf_interop/register_tensorflow_noop.cc
index 41e0eed..e5f83aa 100644
--- a/bindings/python/pyiree/tensorflow/register_tensorflow_noop.cc
+++ b/bindings/python/pyiree/tf_interop/register_tensorflow_noop.cc
@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "bindings/python/pyiree/tensorflow/register_tensorflow.h"
+#include "bindings/python/pyiree/tf_interop/register_tensorflow.h"
 
 namespace iree {
 namespace python {
diff --git a/bindings/python/pyiree/tf_interop/saved_model_test.py b/bindings/python/pyiree/tf_interop/saved_model_test.py
new file mode 100644
index 0000000..4d0ee59
--- /dev/null
+++ b/bindings/python/pyiree/tf_interop/saved_model_test.py
@@ -0,0 +1,103 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import importlib
+import os
+import sys
+import tempfile
+
+from pyiree import binding as binding
+
+# Determine if compiled with tf_interop support.
+if not hasattr(binding, "tf_interop"):
+  print("Not running tests because tf_interop support not compiled")
+  sys.exit(0)
+
+# Dynamically import tensorflow.
+try:
+  # Use a dynamic import so as to avoid hermetic dependency analysis
+  # (i.e. we only want the tensorflow from the environment).
+  tf = importlib.import_module("tensorflow")
+  # Just in case if linked against a pre-V2 defaulted version.
+  tf.enable_v2_behavior()
+  tf = tf.compat.v2
+except ImportError:
+  print("Not running tests because tensorflow is not available")
+  sys.exit(0)
+
+
+class StatefulModule(tf.Module):
+
+  def __init__(self):
+    self.v = tf.Variable([4], dtype=tf.float32)
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([4], tf.float32),
+      tf.TensorSpec([4], tf.float32)
+  ])
+  def add(self, a, b):
+    return tf.tanh(self.v * a + b)
+
+
+class RuntimeTest(tf.test.TestCase):
+
+  def testLoadSavedModelToXlaPipeline(self):
+    """Tests that a basic saved model to XLA workflow grossly functions.
+
+    This is largely here to verify that everything is linked in that needs to be
+    and that there are not no-ops, etc.
+    """
+    with tempfile.TemporaryDirectory() as temp_dir:
+      sm_dir = os.path.join(temp_dir, "simple.sm")
+      print("Saving to:", sm_dir)
+      my_module = StatefulModule()
+      options = tf.saved_model.SaveOptions(save_debug_info=True)
+      tf.saved_model.save(my_module, sm_dir, options=options)
+
+      # Load it up.
+      ctx = binding.compiler.CompilerContext()
+      input_module = binding.tf_interop.load_saved_model(ctx, sm_dir)
+      input_asm = input_module.to_asm()
+      print("LOADED ASM:\n", input_asm)
+      # Should have out exported name and have executor islands.
+      self.assertRegex(input_asm,
+                       r"""tf_saved_model.exported_names = \["add"\]""")
+      self.assertRegex(input_asm, r"""tf_executor\.island""")
+
+      # Run the necessary lowering passes. Makes sure that these are linked in.
+      input_module.run_pass_pipeline([
+          "tf-executor-graph-pruning",
+          "tf-standard-pipeline",
+          "canonicalize",
+      ])
+      lowered_asm = input_module.to_asm()
+      print("LOWERED ASM:\n", lowered_asm)
+      # Should have collapsed all executor islands.
+      self.assertNotRegex(lowered_asm, r"""tf_executor\.island""")
+
+      # And legalize to XLA.
+      input_module.run_pass_pipeline([
+          "xla-legalize-tf",
+      ])
+      xla_asm = input_module.to_asm()
+      print("XLA ASM:", xla_asm)
+      self.assertRegex(xla_asm, "xla_hlo.tanh")
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/colab/low_level_invoke_function.ipynb b/colab/low_level_invoke_function.ipynb
index b0757bc..bcc7bf5 100644
--- a/colab/low_level_invoke_function.ipynb
+++ b/colab/low_level_invoke_function.ipynb
@@ -65,25 +65,27 @@
       "metadata": {
         "id": "rxaiDxiq96SD",
         "colab_type": "code",
+        "outputId": "a1304fa3-b15a-4fab-eaaf-1467dc867191",
         "colab": {
           "base_uri": "https://localhost:8080/",
           "height": 51
-        },
-        "outputId": "c7c254b5-2dd4-478e-bdc4-8386a1d29953"
+        }
       },
       "source": [
         "# Compile a module.\n",
-        "blob = binding.compiler.compile_module_from_asm(\"\"\"\n",
+        "ctx = binding.compiler.CompilerContext()\n",
+        "input_module = ctx.parse_asm(\"\"\"\n",
         "  func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>\n",
         "        attributes { iree.module.export } {\n",
         "      %0 = \"xla_hlo.mul\"(%arg0, %arg1) {name = \"mul.1\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n",
         "      return %0 : tensor<4xf32>\n",
         "  }\n",
         "  \"\"\")\n",
+        "blob = input_module.compile_to_sequencer_blob()\n",
         "m = binding.vm.create_module_from_blob(blob)\n",
         "dump_module(m)"
       ],
-      "execution_count": 4,
+      "execution_count": 7,
       "outputs": [
         {
           "output_type": "stream",
@@ -100,11 +102,11 @@
       "metadata": {
         "id": "aH6VdaoXD4hV",
         "colab_type": "code",
+        "outputId": "d109d4e7-83bf-4038-c0d1-c643dbd10c8e",
         "colab": {
           "base_uri": "https://localhost:8080/",
           "height": 102
-        },
-        "outputId": "c4e8be2c-f144-41c8-952c-9aca66300899"
+        }
       },
       "source": [
         "# Initialize the runtime and register the module.\n",
@@ -137,15 +139,15 @@
         "result_ary = np.array(result, copy=False)\n",
         "print(\"NP result:\", result_ary)\n"
       ],
-      "execution_count": 5,
+      "execution_count": 8,
       "outputs": [
         {
           "output_type": "stream",
           "text": [
             "INVOKE F: simple_mul\n",
             "Status: True\n",
-            "Results: [<pyiree.binding.hal.BufferView object at 0x7fc7697105e0>]\n",
-            "Mapped result: <pyiree.binding.hal.MappedMemory object at 0x7fc769710618>\n",
+            "Results: [<pyiree.binding.hal.BufferView object at 0x00000179E51410D8>]\n",
+            "Mapped result: <pyiree.binding.hal.MappedMemory object at 0x00000179E51412D0>\n",
             "NP result: [ 4. 10. 18. 28.]\n"
           ],
           "name": "stdout"
diff --git a/colab/simple_tensorflow_module_import.ipynb b/colab/simple_tensorflow_module_import.ipynb
index ed4766e..11befd7 100644
--- a/colab/simple_tensorflow_module_import.ipynb
+++ b/colab/simple_tensorflow_module_import.ipynb
@@ -58,11 +58,11 @@
       "metadata": {
         "id": "6YGqN2uqP_7P",
         "colab_type": "code",
+        "outputId": "4cc03b22-bee5-4ffd-e4cf-6fa3055ea886",
         "colab": {
           "base_uri": "https://localhost:8080/",
-          "height": 411
-        },
-        "outputId": "4e8ba182-c7ee-402b-b6e9-15590e8617c5"
+          "height": 853
+        }
       },
       "source": [
         "class MyModule(tf.Module):\n",
@@ -80,23 +80,35 @@
         "options = tf.saved_model.SaveOptions(save_debug_info=True)\n",
         "tf.saved_model.save(my_mod, os.path.join(SAVE_PATH, \"simple.sm\"), options=options)\n",
         "\n",
-        "mlir_asm = binding.tf_interop.import_saved_model_to_mlir_asm(os.path.join(SAVE_PATH, \"simple.sm\"))\n",
-        "print(mlir_asm)"
+        "ctx = binding.compiler.CompilerContext()\n",
+        "input_module = binding.tf_interop.load_saved_model(ctx, os.path.join(SAVE_PATH, \"simple.sm\"))\n",
+        "print('LOADED ASM:', input_module.to_asm())\n",
+        "\n",
+        "# Canonicalize the TF import.\n",
+        "input_module.run_pass_pipeline([\n",
+        "  \"tf-executor-graph-pruning\",\n",
+        "  \"tf-standard-pipeline\",\n",
+        "  \"canonicalize\",\n",
+        "])\n",
+        "print(\"LOWERED TF ASM:\", input_module.to_asm())\n",
+        "\n",
+        "# Legalize to XLA (high-level).\n",
+        "input_module.run_pass_pipeline([\n",
+        "  \"xla-legalize-tf\",\n",
+        "])\n",
+        "print(\"XLA ASM:\", input_module.to_asm())"
       ],
-      "execution_count": 2,
+      "execution_count": 5,
       "outputs": [
         {
           "output_type": "stream",
           "text": [
-            "WARNING:tensorflow:From c:\\users\\laurenzo\\scoop\\apps\\python36\\current\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1785: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
-            "Instructions for updating:\n",
-            "If using Keras pass *_constraint arguments to layers.\n",
             "INFO:tensorflow:Assets written to: C:\\Users\\laurenzo\\saved_models\\simple.sm\\assets\n",
-            "\n",
+            "LOADED ASM: \n",
             "\n",
             "module attributes {tf_saved_model.semantics} {\n",
             "  \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n",
-            "  func @__inference_add_160(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n",
+            "  func @__inference_add_2620(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n",
             "  attributes  {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n",
             "    %0 = tf_executor.graph {\n",
             "      %1:2 = tf_executor.island wraps \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n",
@@ -109,6 +121,35 @@
             "    return %0 : tensor<4xf32>\n",
             "  }\n",
             "}\n",
+            "\n",
+            "LOWERED TF ASM: \n",
+            "\n",
+            "module attributes {tf_saved_model.semantics} {\n",
+            "  \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n",
+            "  func @__inference_add_2620(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n",
+            "  attributes  {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n",
+            "    %0 = \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n",
+            "    %1 = \"tf.Mul\"(%0, %arg0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"mul\"} : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n",
+            "    %2 = \"tf.AddV2\"(%1, %arg1) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"add\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n",
+            "    %3 = \"tf.Tanh\"(%2) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Tanh\"} : (tensor<4xf32>) -> tensor<4xf32>\n",
+            "    %4 = \"tf.Identity\"(%3) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Identity\"} : (tensor<4xf32>) -> tensor<4xf32>\n",
+            "    return %4 : tensor<4xf32>\n",
+            "  }\n",
+            "}\n",
+            "\n",
+            "XLA ASM: \n",
+            "\n",
+            "module attributes {tf_saved_model.semantics} {\n",
+            "  \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n",
+            "  func @__inference_add_2620(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n",
+            "  attributes  {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n",
+            "    %0 = \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n",
+            "    %1 = \"xla_hlo.mul\"(%0, %arg0) : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n",
+            "    %2 = xla_hlo.add %1, %arg1 : tensor<4xf32>\n",
+            "    %3 = \"xla_hlo.tanh\"(%2) : (tensor<4xf32>) -> tensor<4xf32>\n",
+            "    return %3 : tensor<4xf32>\n",
+            "  }\n",
+            "}\n",
             "\n"
           ],
           "name": "stdout"
diff --git a/iree/BUILD.bazel b/iree/BUILD.bazel
index d21f74c..eed7ab7 100644
--- a/iree/BUILD.bazel
+++ b/iree/BUILD.bazel
@@ -29,3 +29,10 @@
     name = "target_config",
     defines = ["IREE_UNSPECIFIED_TARGET=1"],
 )
+
+config_setting(
+    name = "enable_tensorflow",
+    define_values = {
+        "iree_tensorflow": "true",
+    },
+)
diff --git a/iree/build_defs.bzl b/iree/build_defs.bzl
index 31fe1fb..d1bce9f 100644
--- a/iree/build_defs.bzl
+++ b/iree/build_defs.bzl
@@ -8,6 +8,10 @@
 NUMPY_DEPS = []
 PYTHON_HEADERS_DEPS = ["@iree_native_python//:python_headers"]
 
+# Optional deps to enable an intree TensorFlow python. This build configuration
+# defaults to getting TensorFlow from the python environment (empty).
+INTREE_TENSORFLOW_PY_DEPS = []
+
 def platform_trampoline_deps(basename, path = "base"):
     """Produce a list of deps for the given `basename` platform target.