Fork and adapt old iree.compiler -> iree.compiler.tools. (#7164)

* Fork and adapt old iree.compiler -> iree.compiler.tools.

This is the first step to removing the original iree.compiler Python package in favor of the more full featured version developed under the dedicated iree-compiler-api project.

* Forks iree.compiler.[core|tf|tflite|xla] and support code to iree.compiler.tools
* Sets up temporary redirects from the root iree.compiler namespace so we have time to transition folks to new names.
* Adds a CAPI for invoking iree-translate as a function.
* Adds a standalone `ireec` binary and links it against the main CAPI shared library (ends up being 6KB on my machine).
* Sets up some testing for the project.
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1d68351..652a204 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -441,7 +441,12 @@
 endif()
 
 if(IREE_BUILD_PYTHON_BINDINGS)
-  add_subdirectory(third_party/pybind11 EXCLUDE_FROM_ALL)
+  if(NOT TARGET pybind11::module)
+    message(STATUS "Using bundled pybind11")
+    add_subdirectory(third_party/pybind11 EXCLUDE_FROM_ALL)
+  else()
+    message(STATUS "Not including bundled pybind11 (already configured)")
+  endif()
 endif()
 
 if(IREE_TARGET_BACKEND_METAL-SPIRV)
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index c633808..d6fe747 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -352,8 +352,9 @@
 )
 
 cc_library(
-    name = "iree_translate_main",
-    srcs = ["iree-translate-main.cc"],
+    name = "iree_translate_lib",
+    srcs = ["iree_translate_lib.cc"],
+    hdrs = ["iree_translate_lib.h"],
     deps = [
         ":init_compiler_modules",
         ":init_iree_passes_and_dialects",
@@ -379,9 +380,10 @@
 
 cc_binary(
     name = "iree-translate",
+    srcs = ["iree-translate-main.cc"],
     tags = ["hostonly"],
     deps = [
-        ":iree_translate_main",
+        ":iree_translate_lib",
     ],
 )
 
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 0bee09c..6502a13 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -336,9 +336,11 @@
 
   iree_cc_library(
     NAME
-      iree_translate_main
+      iree_translate_lib
+    HDRS
+      "iree_translate_lib.h"
     SRCS
-      "iree-translate-main.cc"
+      "iree_translate_lib.cc"
     DEPS
       ::init_compiler_modules
       ::init_iree_passes_and_dialects
@@ -365,8 +367,10 @@
   iree_cc_binary(
     NAME
       iree-translate
+    SRCS
+      "iree-translate-main.cc"
     DEPS
-      ::iree_translate_main
+      ::iree_translate_lib
     DATA
       lld
     HOSTONLY
diff --git a/iree/tools/iree-translate-main.cc b/iree/tools/iree-translate-main.cc
index f506926..448f1a7 100644
--- a/iree/tools/iree-translate-main.cc
+++ b/iree/tools/iree-translate-main.cc
@@ -4,137 +4,8 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-// IREE translation main entry function.
-//
-// Note that this differs from mlir-translate and similar because we use the
-// PassManger and do transformations on the IR before translating to other
-// formats. Thus we use our own main entry function because we register
-// Dialects and PassManager CLI options.
-
-#include <functional>
-#include <memory>
-#include <string>
-#include <type_traits>
-
-#include "iree/compiler/Dialect/VM/Target/init_targets.h"
-#include "iree/tools/init_compiler_modules.h"
-#include "iree/tools/init_iree_dialects.h"
-#include "iree/tools/init_mlir_dialects.h"
-#include "iree/tools/init_passes.h"
-#include "iree/tools/init_targets.h"
-#include "iree/tools/init_translations.h"
-#include "iree/tools/init_xla_dialects.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/InitLLVM.h"
-#include "llvm/Support/MemoryBuffer.h"
-#include "llvm/Support/SMLoc.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/ToolOutputFile.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/AsmState.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/FileUtilities.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/Timing.h"
-#include "mlir/Support/ToolUtilities.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
-#include "mlir/Translation.h"
-
-static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
-                                                llvm::cl::desc("<input file>"),
-                                                llvm::cl::init("-"));
-
-static llvm::cl::opt<std::string> outputFilename(
-    "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
-    llvm::cl::init("-"));
-
-static llvm::cl::opt<bool> splitInputFile(
-    "split-input-file",
-    llvm::cl::desc("Split the input file into pieces and "
-                   "process each chunk independently"),
-    llvm::cl::init(false));
-
-static llvm::cl::opt<bool> printMainAddress(
-    "print-main-address",
-    llvm::cl::desc("Print the address of main to stderr to aid in symbolizing "
-                   "stack traces after the fact"),
-    llvm::cl::init(false));
+#include "iree/tools/iree_translate_lib.h"
 
 int main(int argc, char **argv) {
-  llvm::InitLLVM y(argc, argv);
-  mlir::DialectRegistry registry;
-  mlir::registerMlirDialects(registry);
-  mlir::registerLLVMDialectTranslation(registry);
-  mlir::registerXLADialects(registry);
-  mlir::iree_compiler::registerAllPasses();
-  mlir::iree_compiler::registerIreeDialects(registry);
-  mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
-  mlir::iree_compiler::registerHALTargetBackends();
-  mlir::iree_compiler::registerVMTargets();
-  mlir::registerMlirTranslations();
-  mlir::iree_compiler::registerIreeTranslations();
-  // Make sure command line options are registered.
-  (void)mlir::iree_compiler::IREE::HAL::getTargetOptionsFromFlags();
-
-  // Register MLIRContext command-line options like
-  // -mlir-print-op-on-diagnostic.
-  mlir::registerMLIRContextCLOptions();
-  // Register assembly printer command-line options like
-  // -mlir-print-op-generic.
-  mlir::registerAsmPrinterCLOptions();
-  // Register pass manager command-line options like -print-ir-*.
-  mlir::registerPassManagerCLOptions();
-  mlir::registerDefaultTimingManagerCLOptions();
-
-  // Add flags for all the registered translations.
-  llvm::cl::opt<const mlir::TranslateFunction *, false, mlir::TranslationParser>
-      translationRequested("", llvm::cl::desc("Translation to perform"),
-                           llvm::cl::Required);
-
-  llvm::cl::ParseCommandLineOptions(argc, argv, "IREE translation driver\n");
-
-  if (printMainAddress) {
-    llvm::errs() << "iree-translate main is at "
-                 << reinterpret_cast<void *>(&main) << "\n";
-  }
-
-  std::string errorMessage;
-  auto input = mlir::openInputFile(inputFilename, &errorMessage);
-  if (!input) {
-    llvm::errs() << errorMessage << "\n";
-    return 1;
-  }
-
-  auto output = mlir::openOutputFile(outputFilename, &errorMessage);
-  if (!output) {
-    llvm::errs() << errorMessage << "\n";
-    return 1;
-  }
-
-  /// Processes the memory buffer with a new MLIRContext.
-  auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
-                           llvm::raw_ostream &os) {
-    mlir::MLIRContext context;
-    context.allowUnregisteredDialects();
-    context.appendDialectRegistry(registry);
-    llvm::SourceMgr sourceMgr;
-    sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
-    mlir::SourceMgrDiagnosticHandler diagHandler(sourceMgr, &context);
-    return (*translationRequested)(sourceMgr, os, &context);
-  };
-
-  if (splitInputFile) {
-    if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer,
-                                           output->os())))
-      return 1;
-  } else {
-    if (failed(processBuffer(std::move(input), output->os()))) return 1;
-  }
-
-  output->keep();
-  return 0;
+  return mlir::iree_compiler::runIreeTranslateMain(argc, argv);
 }
diff --git a/iree/tools/iree_translate_lib.cc b/iree/tools/iree_translate_lib.cc
new file mode 100644
index 0000000..ded9121
--- /dev/null
+++ b/iree/tools/iree_translate_lib.cc
@@ -0,0 +1,137 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/tools/iree_translate_lib.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <type_traits>
+
+#include "iree/compiler/Dialect/VM/Target/init_targets.h"
+#include "iree/tools/init_compiler_modules.h"
+#include "iree/tools/init_iree_dialects.h"
+#include "iree/tools/init_mlir_dialects.h"
+#include "iree/tools/init_passes.h"
+#include "iree/tools/init_targets.h"
+#include "iree/tools/init_translations.h"
+#include "iree/tools/init_xla_dialects.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SMLoc.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/Timing.h"
+#include "mlir/Support/ToolUtilities.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Translation.h"
+
+int mlir::iree_compiler::runIreeTranslateMain(int argc, char **argv) {
+  llvm::InitLLVM y(argc, argv);
+  mlir::DialectRegistry registry;
+  mlir::registerMlirDialects(registry);
+  mlir::registerLLVMDialectTranslation(registry);
+  mlir::registerXLADialects(registry);
+  mlir::iree_compiler::registerAllPasses();
+  mlir::iree_compiler::registerIreeDialects(registry);
+  mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
+  mlir::iree_compiler::registerHALTargetBackends();
+  mlir::iree_compiler::registerVMTargets();
+  mlir::registerMlirTranslations();
+  mlir::iree_compiler::registerIreeTranslations();
+  // Make sure command line options are registered.
+  (void)mlir::iree_compiler::IREE::HAL::getTargetOptionsFromFlags();
+
+  // Register MLIRContext command-line options like
+  // -mlir-print-op-on-diagnostic.
+  mlir::registerMLIRContextCLOptions();
+  // Register assembly printer command-line options like
+  // -mlir-print-op-generic.
+  mlir::registerAsmPrinterCLOptions();
+  // Register pass manager command-line options like -print-ir-*.
+  mlir::registerPassManagerCLOptions();
+  mlir::registerDefaultTimingManagerCLOptions();
+
+  // General command line flags.
+  llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+                                           llvm::cl::desc("<input file>"),
+                                           llvm::cl::init("-"));
+
+  llvm::cl::opt<std::string> outputFilename(
+      "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+      llvm::cl::init("-"));
+
+  llvm::cl::opt<bool> splitInputFile(
+      "split-input-file",
+      llvm::cl::desc("Split the input file into pieces and "
+                     "process each chunk independently"),
+      llvm::cl::init(false));
+
+  llvm::cl::opt<bool> printMainAddress(
+      "print-main-address",
+      llvm::cl::desc(
+          "Print the address of main to stderr to aid in symbolizing "
+          "stack traces after the fact"),
+      llvm::cl::init(false));
+
+  // Add flags for all the registered translations.
+  llvm::cl::opt<const mlir::TranslateFunction *, false, mlir::TranslationParser>
+      translationRequested("", llvm::cl::desc("Translation to perform"),
+                           llvm::cl::Optional);
+
+  llvm::cl::ParseCommandLineOptions(argc, argv, "IREE translation driver\n");
+
+  if (printMainAddress) {
+    llvm::errs() << "iree-translate main is at "
+                 << reinterpret_cast<void *>(&runIreeTranslateMain) << "\n";
+  }
+
+  std::string errorMessage;
+  auto input = mlir::openInputFile(inputFilename, &errorMessage);
+  if (!input) {
+    llvm::errs() << errorMessage << "\n";
+    return 1;
+  }
+
+  auto output = mlir::openOutputFile(outputFilename, &errorMessage);
+  if (!output) {
+    llvm::errs() << errorMessage << "\n";
+    return 1;
+  }
+
+  /// Processes the memory buffer with a new MLIRContext.
+  auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
+                           llvm::raw_ostream &os) {
+    mlir::MLIRContext context;
+    context.allowUnregisteredDialects();
+    context.appendDialectRegistry(registry);
+    llvm::SourceMgr sourceMgr;
+    sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
+    mlir::SourceMgrDiagnosticHandler diagHandler(sourceMgr, &context);
+    return (*translationRequested)(sourceMgr, os, &context);
+  };
+
+  if (splitInputFile) {
+    if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer,
+                                           output->os())))
+      return 1;
+  } else {
+    if (failed(processBuffer(std::move(input), output->os()))) return 1;
+  }
+
+  output->keep();
+  return 0;
+}
diff --git a/iree/tools/iree_translate_lib.h b/iree/tools/iree_translate_lib.h
new file mode 100644
index 0000000..572e03a
--- /dev/null
+++ b/iree/tools/iree_translate_lib.h
@@ -0,0 +1,18 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_TOOLS_IREE_TRANSLATE_LIB_H
+#define IREE_TOOLS_IREE_TRANSLATE_LIB_H
+
+namespace mlir {
+namespace iree_compiler {
+
+int runIreeTranslateMain(int argc, char **argv);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_TOOLS_IREE_TRANSLATE_LIB_H
diff --git a/llvm-external-projects/iree-compiler-api/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/CMakeLists.txt
index 76781b9..9e9f53f 100644
--- a/llvm-external-projects/iree-compiler-api/CMakeLists.txt
+++ b/llvm-external-projects/iree-compiler-api/CMakeLists.txt
@@ -35,11 +35,12 @@
 
 set(IREE_COMPILER_API_INTREE ON)
 if(IREE_COMPILER_API_INTREE)
+  set(IREE_BUILD_TESTS OFF)  # Conflicts with our tests if we are top-level.
   set(LLVM_EXTERNAL_MLIR_IREE_DIALECTS_SOURCE_DIR "${IREE_COMPILER_API_SOURCE_DIR}/../iree-dialects")
   set(IREE_MAIN_SOURCE_DIR "${IREE_COMPILER_API_SOURCE_DIR}/../..")
   set(LLVM_MAIN_SRC_DIR "${IREE_COMPILER_API_SOURCE_DIR}/../../third_party/llvm-project/llvm")
   set(LLVM_EXTERNAL_MLIR_HLO_SOURCE_DIR "${IREE_COMPILER_API_SOURCE_DIR}/../../third_party/mlir-hlo")
-
+  enable_testing()
 else()
   message(FATAL_ERROR "Non intree (source package) not yet supported")
 endif()
@@ -140,3 +141,4 @@
 include(HandleLLVMOptions)
 add_subdirectory(lib)
 add_subdirectory(python)
+add_subdirectory(unittests)
diff --git a/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Tools.h b/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Tools.h
new file mode 100644
index 0000000..ce2cf79
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Tools.h
@@ -0,0 +1,24 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_COMPILER_API_TOOLS_H
+#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_COMPILER_API_TOOLS_H
+
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/// Runs the IREE compiler main function. This is used to build ireec-like
+/// binaries that link against a common shared library.
+MLIR_CAPI_EXPORTED int ireeCompilerRunMain(int argc, char **argv);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif  // IREE_LLVM_EXTERNAL_PROJECTS_IREE_COMPILER_API_TOOLS_H
diff --git a/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt
index e5fcc0d..0d50421 100644
--- a/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt
+++ b/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_public_c_api_library(IREECompilerAPICompilerCAPI
   Compiler.cpp
+  Tools.cpp
   # TODO: If installing, complains about IREEVM not being in any export set.
   DISABLE_INSTALL
   LINK_COMPONENTS
@@ -12,6 +13,9 @@
 
     # All HAL Targets.
     iree::tools::init_targets
+
+    # Tools.
+    iree::tools::iree_translate_lib
 )
 
 # TODO: Fix upstream so there is a way to know what the actual compile target
diff --git a/llvm-external-projects/iree-compiler-api/lib/CAPI/Tools.cpp b/llvm-external-projects/iree-compiler-api/lib/CAPI/Tools.cpp
new file mode 100644
index 0000000..e3b8416
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/lib/CAPI/Tools.cpp
@@ -0,0 +1,13 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-compiler-c/Tools.h"
+
+#include "iree/tools/iree_translate_lib.h"
+
+int ireeCompilerRunMain(int argc, char **argv) {
+  return mlir::iree_compiler::runIreeTranslateMain(argc, argv);
+}
diff --git a/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt
index d471da7..896dd14 100644
--- a/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt
+++ b/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt
@@ -16,6 +16,17 @@
 )
 declare_mlir_python_sources(IREECompilerAPIPythonExtensions)
 
+declare_mlir_python_sources(IREECompilerAPIPythonTools
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler"
+  SOURCES
+    __init__.py
+    tf.py
+    tflite.py
+    xla.py
+  SOURCES_GLOB
+    tools/*.py
+)
+
 ################################################################################
 # Extensions
 ################################################################################
@@ -39,6 +50,7 @@
   # Local sources.
   IREECompilerAPIPythonSources
   IREECompilerAPIPythonExtensions
+  IREECompilerAPIPythonTools
 
   # TODO: Core is now implicitly building/registering all dialects, increasing
   # build burden by ~5x. Make it stop.
@@ -82,3 +94,18 @@
   COMMON_CAPI_LINK_LIBS
     IREECompilerAggregateCAPI
   )
+
+
+# Build the ireec tool into _mlir_libs.
+add_executable(
+  IREECompilerIREECTool
+    IREECTool.c
+)
+target_link_libraries(IREECompilerIREECTool IREECompilerAggregateCAPI)
+set_target_properties(IREECompilerIREECTool
+  PROPERTIES
+    OUTPUT_NAME "ireec"
+    RUNTIME_OUTPUT_DIRECTORY "${IREE_COMPILER_API_BINARY_DIR}/python_package/iree/compiler/_mlir_libs"
+    BUILD_RPATH_USE_ORIGIN ON
+)
+add_dependencies(IREECompilerPythonModules IREECompilerIREECTool)
diff --git a/llvm-external-projects/iree-compiler-api/python/IREECTool.c b/llvm-external-projects/iree-compiler-api/python/IREECTool.c
new file mode 100644
index 0000000..5428add
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/IREECTool.c
@@ -0,0 +1,9 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-compiler-c/Tools.h"
+
+int main(int argc, char **argv) { return ireeCompilerRunMain(argc, argv); }
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/__init__.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/__init__.py
new file mode 100644
index 0000000..b49ab00
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/__init__.py
@@ -0,0 +1,9 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Re-export some legacy APIs from the tools package to this top-level.
+# TODO: Deprecate and remove these names once clients are migrated.
+from .tools import *
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tf.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tf.py
new file mode 100644
index 0000000..7eba07c
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tf.py
@@ -0,0 +1,10 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import sys
+from .tools import tf
+
+sys.modules[__name__] = tf
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tflite.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tflite.py
new file mode 100644
index 0000000..9a9a9cb
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tflite.py
@@ -0,0 +1,10 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import sys
+from .tools import tflite
+
+sys.modules[__name__] = tflite
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/__init__.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/__init__.py
new file mode 100644
index 0000000..2adf9e5
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/__init__.py
@@ -0,0 +1,10 @@
+# Lint-as: python3
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .core import *
+from .debugging import TempFileSaver
+from .binaries import CompilerToolError
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/binaries.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/binaries.py
new file mode 100644
index 0000000..d45da87
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/binaries.py
@@ -0,0 +1,283 @@
+# Lint-as: python3
+"""Utilities for locating and invoking compiler tool binaries."""
+
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import importlib
+import io
+import os
+import platform
+import subprocess
+import sys
+import textwrap
+import threading
+
+from typing import List, Optional, Union
+
+__all__ = [
+    "find_tool",
+    "invoke_immediate",
+    "invoke_pipeline",
+    "get_tool_path",
+    "CompilerToolError",
+]
+
+_BUILTIN_TOOLS = [
+    "ireec",
+]
+
+# In normal distribution circumstances, each named tool is associated with
+# a python module that provides a `get_tool` function for getting its absolute
+# path. This dictionary maps the tool name to the module.
+_TOOL_MODULE_MAP = {
+    "iree-import-tflite": "iree.tools.tflite",
+    "iree-import-xla": "iree.tools.xla",
+    "iree-import-tf": "iree.tools.tf",
+}
+
+# Map of tool module to package name as distributed to archives (used for
+# error messages).
+_TOOL_MODULE_PACKAGES = {
+    "iree.tools.tf": "iree-tools-tf",
+    "iree.tools.tflite": "iree-tools-tflite",
+    "iree.tools.xla": "iree-tools-xla",
+}
+
+# Environment variable holding directories to be searched for named tools.
+# Delimitted by os.pathsep.
+_TOOL_PATH_ENVVAR = "IREE_TOOL_PATH"
+
+
+class CompilerToolError(Exception):
+  """Compiler exception that preserves the command line and error output."""
+
+  def __init__(self, process: subprocess.CompletedProcess):
+    try:
+      errs = process.stderr.decode("utf-8")
+    except:
+      errs = str(process.stderr)  # Decode error or other: best we can do.
+
+    tool_name = os.path.basename(process.args[0])
+    super().__init__(f"Error invoking IREE compiler tool {tool_name}\n"
+                     f"Diagnostics:\n{errs}\n\n"
+                     f"Invoked with:\n {tool_name} {' '.join(process.args)}")
+
+
+def get_tool_path() -> List[str]:
+  """Returns list of paths to search for tools."""
+  list_str = os.environ.get(_TOOL_PATH_ENVVAR)
+  if not list_str:
+    return []
+  return list_str.split(os.pathsep)
+
+
+def find_tool(exe_name: str) -> str:
+  """Finds a tool by its (extension-less) executable name.
+
+  Args:
+    exe_name: The name of the executable (extension-less).
+  Returns:
+    An absolute path to the tool.
+  Raises:
+    ValueError: If the tool is not known or not found.
+  """
+  is_builtin = exe_name in _BUILTIN_TOOLS
+  if not is_builtin and exe_name not in _TOOL_MODULE_MAP:
+    raise ValueError(f"IREE compiler tool '{exe_name}' is not a known tool")
+  # First search an explicit tool path.
+  tool_path = get_tool_path()
+  for path_entry in tool_path:
+    if not path_entry:
+      continue
+    candidate_exe = os.path.join(path_entry, exe_name)
+    if os.path.isfile(candidate_exe) and os.access(candidate_exe, os.X_OK):
+      return candidate_exe
+
+  if is_builtin:
+    # Get builtin tool.
+    candidate_exe = _get_builtin_tool(exe_name)
+  else:
+    # Attempt to load the tool module.
+    tool_module_name = _TOOL_MODULE_MAP[exe_name]
+    tool_module_package = _TOOL_MODULE_PACKAGES[tool_module_name]
+    try:
+      tool_module = importlib.import_module(tool_module_name)
+    except ModuleNotFoundError:
+      raise ValueError(
+          f"IREE compiler tool '{exe_name}' is not installed (it should have been "
+          f"found in the python module '{tool_module_name}', typically installed "
+          f"via the package {tool_module_package}).\n\n"
+          f"Either install the package or set the {_TOOL_PATH_ENVVAR} environment "
+          f"variable to contain the path of the tool executable "
+          f"(current {_TOOL_PATH_ENVVAR} = {repr(tool_path)})") from None
+
+    # Ask the module for its tool.
+    candidate_exe = tool_module.get_tool(exe_name)
+
+  if (not candidate_exe or not os.path.isfile(candidate_exe) or
+      not os.access(candidate_exe, os.X_OK)):
+    raise ValueError(
+        f"IREE compiler tool '{exe_name}' was located in module "
+        f"'{tool_module_name}' but the file was not found or not executable: "
+        f"{candidate_exe}")
+  return candidate_exe
+
+
+def _get_builtin_tool(exe_name: str) -> Optional[str]:
+  if platform.system() == "Windows":
+    exe_name = exe_name + ".exe"
+  this_path = os.path.dirname(__file__)
+  tool_path = os.path.join(this_path, "..", "_mlir_libs", exe_name)
+  return tool_path
+
+
+def invoke_immediate(command_line: List[str],
+                     *,
+                     input_file: Optional[bytes] = None,
+                     immediate_input=None):
+  """Invokes an immediate command.
+
+  This is separate from invoke_pipeline as it is simpler and supports more
+  complex input redirection, using recommended facilities for sub-processes
+  (less magic).
+
+  Note that this differs from the usual way of using subprocess.run or
+  subprocess.Popen().communicate() because we need to pump all of the error
+  streams individually and only pump pipes not connected to a different stage.
+  Uses threads to pump everything that is required.
+  """
+  run_args = {}
+  input_file_handle = None
+  stderr_handle = sys.stderr
+  try:
+    # Redirect input.
+    if input_file is not None:
+      input_file_handle = open(input_file, "rb")
+      run_args["stdin"] = input_file_handle
+    elif immediate_input is not None:
+      run_args["input"] = immediate_input
+
+    # Capture output.
+    # TODO(#4131) python>=3.7: Use capture_output=True.
+    run_args["stdout"] = subprocess.PIPE
+    run_args["stderr"] = subprocess.PIPE
+    process = subprocess.run(command_line, **run_args)
+    if process.returncode != 0:
+      raise CompilerToolError(process)
+    # Emit stderr contents.
+    _write_binary_stderr(stderr_handle, process.stderr)
+    return process.stdout
+  finally:
+    if input_file_handle:
+      input_file_handle.close()
+
+
+def invoke_pipeline(command_lines: List[List[str]], immediate_input=None):
+  """Invoke a pipeline of commands.
+
+  The first stage of the pipeline will have its stdin set to DEVNULL and each
+  subsequent stdin will derive from the prior stdout. The final stdout will
+  be accumulated and returned. All stderr contents are accumulated and printed
+  to stderr on completion or the first failing stage of the pipeline will have
+  an exception raised with its stderr output.
+  """
+  stages = []
+  pipeline_input = (subprocess.DEVNULL
+                    if immediate_input is None else subprocess.PIPE)
+  prev_out = pipeline_input
+  stderr_handle = sys.stderr
+
+  # Create all stages.
+  for i in range(len(command_lines)):
+    command_line = command_lines[i]
+    popen_args = {
+        "stdin": prev_out,
+        "stdout": subprocess.PIPE,
+        "stderr": subprocess.PIPE,
+    }
+    process = subprocess.Popen(command_line, **popen_args)
+    prev_out = process.stdout
+    capture_output = (i == (len(command_lines) - 1))
+    stages.append(_PipelineStage(process, capture_output))
+
+  # Start stages.
+  for stage in stages:
+    stage.start()
+
+  # Pump input.
+  pipe_success = True
+  if immediate_input is not None:
+    try:
+      pipe_success = False
+      stages[0].process.stdin.write(immediate_input)
+      pipe_success = True
+    finally:
+      stages[0].process.stdin.close()
+
+  # Join.
+  for stage in stages:
+    stage.join()
+
+  # Check for errors.
+  for stage in stages:
+    assert stage.completed
+    if stage.completed.returncode != 0:
+      raise CompilerToolError(stage.completed)
+
+  # Broken pipe.
+  if not pipe_success:
+    raise CompilerToolError(stages[0].completed)
+
+  # Print any stderr output.
+  for stage in stages:
+    _write_binary_stderr(stderr_handle, stage.errs)
+  return stages[-1].outs
+
+
+class _PipelineStage(threading.Thread):
+  """Wraps a process and pumps its handles, waiting for completion."""
+
+  def __init__(self, process, capture_output):
+    super().__init__()
+    self.process = process
+    self.capture_output = capture_output
+    self.completed: Optional[subprocess.CompletedProcess] = None
+    self.outs = None
+    self.errs = None
+
+  def pump_stderr(self):
+    self.errs = self.process.stderr.read()
+
+  def pump_stdout(self):
+    self.outs = self.process.stdout.read()
+
+  def run(self):
+    stderr_thread = threading.Thread(target=self.pump_stderr)
+    stderr_thread.start()
+    if self.capture_output:
+      stdout_thread = threading.Thread(target=self.pump_stdout)
+      stdout_thread.start()
+    self.process.wait()
+    stderr_thread.join()
+    if self.capture_output:
+      stdout_thread.join()
+    self.completed = subprocess.CompletedProcess(self.process.args,
+                                                 self.process.returncode,
+                                                 self.outs, self.errs)
+    self.process.stderr.close()
+    self.process.stdout.close()
+
+
+def _write_binary_stderr(out_handle, contents):
+  # Fast-paths buffered text-io (which stderr is by default) while allowing
+  # full decode for non buffered and binary io.
+  if hasattr(out_handle, "buffer"):
+    out_handle.buffer.write(contents)
+  elif isinstance(out_handle, io.TextIOBase):
+    out_handle.write(contents.decode("utf-8"))
+  else:
+    out_handle.write(contents)
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/core.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/core.py
new file mode 100644
index 0000000..20962b8
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/core.py
@@ -0,0 +1,252 @@
+# Lint-as: python3
+"""Core compiler interface."""
+
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# TODO(#4131) python>=3.7: Use postponed type annotations.
+
+from enum import Enum
+import subprocess
+from typing import Any, Dict, List, Optional, Sequence, Union
+
+from .debugging import TempFileSaver
+from .binaries import *
+
+__all__ = [
+    "DEFAULT_TESTING_BACKENDS",
+    "compile_file",
+    "compile_str",
+    "CompilerOptions",
+    "OutputFormat",
+]
+
+# Default testing backend for invoking the compiler.
+# TODO: Remove these. In the absence of default profiles, though, it is better
+# to centralize.
+DEFAULT_TESTING_BACKENDS = ["dylib-llvm-aot"]
+DEFAULT_TESTING_DRIVER = "dylib"
+
+
+class InputType(Enum):
+  """The type of input pipeline to run prior to the core compiler."""
+  NONE = "none"
+  MHLO = "mhlo"
+  TOSA = "tosa"
+
+  @staticmethod
+  def parse(spec: Union[str, "InputType"]) -> "InputType":
+    """Parses or returns an InputType.
+
+    Args:
+      spec: An InputType instance or the case-insensitive name of one of the
+        enum values.
+    Returns:
+      An InputType instance.
+    """
+    if isinstance(spec, InputType):
+      return spec
+    spec = spec.upper().replace("-", "_")
+    if spec not in InputType.__members__:
+      raise ValueError(f"For input_type= argument, expected one of: "
+                       f"{', '.join(InputType.__members__.keys())}")
+    return InputType[spec]
+
+
+class OutputFormat(Enum):
+  """The output format of the compiler."""
+  FLATBUFFER_BINARY = "flatbuffer-binary"
+  FLATBUFFER_TEXT = "flatbuffer-text"
+  MLIR_TEXT = "mlir-text"
+
+  @staticmethod
+  def parse(spec: Union[str, "OutputFormat"]) -> "OutputFormat":
+    """Parses or returns an OutputFormat.
+
+    Args:
+      spec: An OutputFormat instance or the case-insensitive name of one of
+        the enum values.
+    Returns:
+      An OutputFormat instance.
+    """
+    if isinstance(spec, OutputFormat):
+      return spec
+    spec = spec.upper().replace("-", "_")
+    if spec not in OutputFormat.__members__:
+      raise ValueError(f"For output_format= argument, expected one of: "
+                       f"{', '.join(OutputFormat.__members__.keys())}")
+    return OutputFormat[spec]
+
+
+# TODO(#4131) python>=3.7: Consider using a dataclass.
+class CompilerOptions:
+  """Options to the compiler backend.
+
+  Keyword options:
+    output_file: Optionally save the compiled binary to a file instead of
+      returning it.
+    target_backends: List of str names of target backends to compile into
+      the binary. The resulting binary will run on targets that match one
+      or more of the compiled backends.
+    input_type: The type of input legalization to perform prior to full
+      compilation. Defaults to none.
+    output_format: Override the output format. See the OutputFormat enum.
+      Values can either be an enum value or a case-insensitive name of
+      the option. Typically used for debugging
+    extra_args: Optional list of additional arguments to pass to the compiler.
+      Example: ["--print-ir-after-all"]
+    optimize: Whether to apply some default high level optimizations (default
+      True).
+    output_mlir_debuginfo: Include debuginfo (including paths) in any saved or
+      returned MLIR.
+    output_generic_mlir: Use the generic (and more portable) MLIR formatting for
+      any saved or returned MLIR instead of the per-dialect custom assembly.
+    extended_diagnostics: Outputs extended information on diagnostics,
+      potentially outputting very verbosely (defaults to False).
+    strip_debug_ops: Whether to strip high level operations used to aid
+      debugging.
+    strip_source_map: Whether to strip source map information (used to generate
+      better errors).
+    crash_reproducer_path: File name to output an MLIR crash dump to if there
+      is a compiler failure.
+    enable_tflite_bindings: Support the IREE TFLite runtime bindings API shim.
+    enable_benchmark: Whether to generate instrumented binaries suitable
+      for benchmarking.
+  """
+
+  def __init__(self,
+               *,
+               output_file: Optional[str] = None,
+               target_backends: Sequence[str] = (),
+               input_type: Union[InputType, str] = InputType.NONE,
+               output_format: Union[OutputFormat,
+                                    str] = OutputFormat.FLATBUFFER_BINARY,
+               extra_args: Sequence[str] = (),
+               optimize: bool = True,
+               output_mlir_debuginfo: bool = True,
+               output_generic_mlir: bool = False,
+               extended_diagnostics: bool = False,
+               strip_debug_ops: bool = False,
+               strip_source_map: bool = False,
+               crash_reproducer_path: Optional[str] = None,
+               enable_tflite_bindings: bool = False,
+               enable_benchmark: bool = False):
+    self.output_file = output_file
+    self.target_backends = target_backends
+    self.input_type = InputType.parse(input_type)
+    self.output_format = OutputFormat.parse(output_format)
+    self.extra_args = extra_args
+    self.optimize = optimize
+    self.output_mlir_debuginfo = output_mlir_debuginfo
+    self.output_generic_mlir = output_generic_mlir
+    self.extended_diagnostics = extended_diagnostics
+    self.strip_debug_ops = strip_debug_ops
+    self.strip_source_map = strip_source_map
+    self.crash_reproducer_path = crash_reproducer_path
+    self.enable_tflite_bindings = enable_tflite_bindings
+    self.enable_benchmark = enable_benchmark
+
+
+def build_compile_command_line(input_file: str, tfs: TempFileSaver,
+                               options: CompilerOptions) -> List[str]:
+  """Builds a command line for invoking the compiler.
+
+  Args:
+    input_file: The input file name.
+    tfs: TempFileSaver.
+    options: Compiler options.
+  Returns:
+    List of strings of command line.
+  """
+  iree_translate = find_tool("ireec")
+  if not options.target_backends:
+    raise ValueError("Expected a non-empty list for 'target_backends'")
+
+  cl = [
+      iree_translate,
+      input_file,
+      f"--iree-input-type={options.input_type.value}",
+      f"--iree-vm-bytecode-module-output-format={options.output_format.value}",
+  ]
+  for target_backend in options.target_backends:
+    cl.append(f"--iree-hal-target-backends={target_backend}")
+
+  # Output file.
+  output_file = tfs.alloc_optional("core-output.bin",
+                                   export_as=options.output_file)
+  if output_file:
+    cl.append(f"-o={output_file}")
+
+  # Translation to perform.
+  cl.append("--iree-mlir-to-vm-bytecode-module")
+
+  # MLIR flags.
+  if options.output_mlir_debuginfo:
+    cl.append("--mlir-print-debuginfo")
+  if options.output_generic_mlir:
+    cl.append("--mlir-print-op-generic")
+  if options.extended_diagnostics:
+    # Note that different tools have different defaults, so be explicit.
+    cl.append("--mlir-print-op-on-diagnostic=true")
+  else:
+    cl.append("--mlir-print-op-on-diagnostic=false")
+
+  # Other options to set if specified.
+  if options.strip_debug_ops:
+    cl.append("--iree-vm-bytecode-module-strip-debug-ops")
+  if options.strip_source_map:
+    cl.append("--iree-vm-bytecode-module-strip-source-map")
+  crash_reproducer_path = tfs.alloc_optional(
+      "core-reproducer.mlir", export_as=options.crash_reproducer_path)
+  if crash_reproducer_path:
+    cl.append(f"--pass-pipeline-crash-reproducer={crash_reproducer_path}")
+  if options.enable_tflite_bindings:
+    cl.append("--iree-tflite-bindings-support")
+  if options.enable_benchmark:
+    cl.append("--iree-flow-export-benchmark-funcs")
+
+  cl.extend(options.extra_args)
+  return cl
+
+
+def compile_file(input_file: str, **kwargs):
+  """Invokes the IREE compiler on an input file.
+
+  Args:
+    input_file: File containing MLIR assembly to compile.
+    **kwargs: Keyword arguments corresponding to CompilerOptions.
+  Returns:
+    Either a byte buffer of the compiled content or None if output_file
+    was specified in the options.
+  """
+  with TempFileSaver.implicit() as tfs:
+    options = CompilerOptions(**kwargs)
+    cl = build_compile_command_line(input_file, tfs, options)
+    result = invoke_immediate(cl)
+    if options.output_file:
+      return None
+    return result
+
+
+def compile_str(input_str: Union[str, bytes], **kwargs):
+  """Invokes the IREE compiler with an input string.
+
+  Args:
+    input_str: MLIR assembly to parse/compile (str or bytes).
+    **kwargs: Keyword arguments corresponding to CompilerOptions.
+  Returns:
+    Either a byte buffer of the compiled content or None if output_file
+    was specified in the options.
+  """
+  with TempFileSaver.implicit() as tfs:
+    options = CompilerOptions(**kwargs)
+    cl = build_compile_command_line("-", tfs, options)
+    input_bytes = input_str.encode("utf-8") if isinstance(input_str,
+                                                          str) else input_str
+    result = invoke_immediate(cl, immediate_input=input_bytes)
+    if options.output_file:
+      return None
+    return result
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/debugging.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/debugging.py
new file mode 100644
index 0000000..fc897dd
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/debugging.py
@@ -0,0 +1,176 @@
+"""Debugging support."""
+
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional
+
+import logging
+import os
+import shutil
+import sys
+import threading
+
+_thread_locals = threading.local()
+_invocation_id = 0
+
+
+def _get_temp_file_saver_stack():
+  try:
+    return _thread_locals.temp_file_saver_stack
+  except AttributeError:
+    stack = []
+    _thread_locals.temp_file_saver_stack = stack
+    return stack
+
+
+def _interpolate_path_pattern(path_pattern: str, *, invocation_id: str):
+  # We do not use str.format() because we do not know the providence of
+  # path_pattern. Instead, handle a fixed set of replacements.
+  path_pattern = path_pattern.replace("{id}", str(invocation_id))
+  path_pattern = path_pattern.replace("{pid}", str(os.getpid()))
+  path_pattern = path_pattern.replace("{main}", os.path.basename(sys.argv[0]))
+  return path_pattern
+
+
+class TempFileSaver:
+  """Manages the saving of temp files resulting from tool invocations.
+
+  The TempFileSaver is a thread-local context bound object. An attempt to
+  create a new one will return the most recent instance created and entered
+  as a context manager. This allows up-stack callers to establish the
+  policy for saving temporaries and deep implementations will inherit it.
+
+  Proper usage from users wishing to establish a saver context:
+    with TempFileSaver():
+      # Do things with temp files.
+
+  Proper usage for implementors wishing to use an established saver context
+  or set up a new one:
+    with TempFileSaver.implicit() as tfs:
+      # Do things with temp files.
+
+  The outer-most creator can customize it with explicit arguments to __init__
+  but these will be ignored if an instance is already thread bound.
+  """
+  TEMP_PATH_ENV_KEY = "IREE_SAVE_TEMPS"
+
+  @staticmethod
+  def implicit():
+    stack = _get_temp_file_saver_stack()
+    if stack:
+      return stack[-1]
+    return TempFileSaver()
+
+  def __init__(self,
+               temp_path_pattern: Optional[str] = None,
+               *,
+               invocation_id: Optional[str] = None):
+    self.retained = False
+    self._refcount = 0
+    if temp_path_pattern is None:
+      temp_path_pattern = os.environ.get(TempFileSaver.TEMP_PATH_ENV_KEY)
+    if temp_path_pattern is None:
+      return
+
+    global _invocation_id
+    if invocation_id is not None:
+      self.invocation_id = invocation_id
+    else:
+      self.invocation_id = _invocation_id
+      _invocation_id += 1
+
+    self.retained_path = _interpolate_path_pattern(
+        temp_path_pattern, invocation_id=self.invocation_id)
+    self.retained = True
+    self._retained_file_names = set()
+    self._copy_on_finalize = list()  # Of (source_path, target_path)
+
+  def __enter__(self):
+    _get_temp_file_saver_stack().append(self)
+    self._refcount += 1
+    return self
+
+  def __exit__(self, exc_type, exc_value, traceback):
+    del _get_temp_file_saver_stack()[-1]
+    self._refcount -= 1
+    if self._refcount == 0:
+      self._finalize()
+
+  @staticmethod
+  def current():
+    try:
+      return _get_temp_file_saver_stack()[-1]
+    except KeyError:
+      raise RuntimeError("No current TempFileSaver")
+
+  def alloc_optional(self,
+                     file_name: str,
+                     *,
+                     export_as: Optional[str] = None) -> Optional[str]:
+    """Allocates an optional temporary file.
+
+
+    When in non-retained mode, the return value is 'export_as', meaning that the
+    file is just some user specified output file.
+
+    When in retained mode, the output file will be an index-mangled variant
+    of 'file_name' under the temp_path. In addition, a mapping will be added
+    so that upon finalization, the file is also exported to 'export_as' if
+    specified.
+
+    Returns None if neither a user-specified 'export_as' is specified nor in
+    retained mode.
+
+    The distinction between retained temporaries and exports is to help in
+    cases for when the caller has requested that an artifact be written to
+    a specific place (i.e. an output file) but for debuggability, we also
+    want to save it as a temporary. In this case, we save it to the temporary
+    location and then conclude by moving artifacts to their final location
+    once the saver goes out of scope.
+    """
+    if not self.retained:
+      return export_as
+    alloced_path = self._alloc_retained_path(file_name)
+    if export_as:
+      self._copy_on_finalize.append((alloced_path, export_as))
+    return alloced_path
+
+  def _alloc_retained_path(self, file_name: str) -> str:
+    assert self.retained
+    index = 0
+    original_file_name = file_name
+    while True:
+      if file_name not in self._retained_file_names:
+        # First use of this name.
+        self._retained_file_names.add(file_name)
+        os.makedirs(self.retained_path, exist_ok=True)
+        return os.path.join(self.retained_path, file_name)
+      index += 1
+      stem, ext = os.path.splitext(original_file_name)
+      file_name = f"{stem}_{index}{ext}"
+
+  def _finalize(self):
+    if not self.retained:
+      return
+    # See which files were materialized.
+    was_materialized = []
+    for file_name in self._retained_file_names:
+      file_path = os.path.join(self.retained_path, file_name)
+      if os.path.exists(file_path):
+        was_materialized.append((file_name, file_path))
+    if was_materialized:
+      logging.info(
+          "**** IREE Compiler retained temporary files (%s)***:\n%s",
+          self.invocation_id, "\n".join([
+              f"  * {file_name} : {file_path}"
+              for file_name, file_path in was_materialized
+          ]))
+    for source_path, target_path in self._copy_on_finalize:
+      if os.path.exists(source_path):
+        logging.info("Copy retained file to output: %s -> %s", source_path,
+                     target_path)
+        shutil.copyfile(source_path, target_path)
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/__init__.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/__init__.py
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/ireec/__main__.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/ireec/__main__.py
new file mode 100644
index 0000000..2ded4b8
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/ireec/__main__.py
@@ -0,0 +1,21 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import subprocess
+import sys
+
+from iree.compiler.tools import binaries
+
+
+def main(args=None):
+  if args is None:
+    args = sys.argv[1:]
+  exe = binaries.find_tool("ireec")
+  return subprocess.call(args=[exe] + args)
+
+
+if __name__ == "__main__":
+  sys.exit(main())
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tf.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tf.py
new file mode 100644
index 0000000..7056f33
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tf.py
@@ -0,0 +1,238 @@
+# Lint-as: python3
+"""TensorFlow compiler interface."""
+
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# TODO(#4131) python>=3.7: Use postponed type annotations.
+
+from enum import Enum
+import logging
+import tempfile
+from typing import List, Optional, Sequence, Set, Union
+
+from .core import CompilerOptions, DEFAULT_TESTING_BACKENDS, build_compile_command_line
+from .debugging import TempFileSaver
+from .binaries import find_tool, invoke_immediate, invoke_pipeline
+
+__all__ = [
+    "compile_saved_model",
+    "compile_module",
+    "is_available",
+    "DEFAULT_TESTING_BACKENDS",
+    "ImportOptions",
+    "ImportType",
+]
+
+_TF_IMPORT_TOOL = "iree-import-tf"
+
+
+def is_available():
+  """Determine if TensorFlow and the compiler are available."""
+  try:
+    import tensorflow as tf
+  except ModuleNotFoundError:
+    logging.warn("Unable to import tensorflow")
+    return False
+  try:
+    find_tool(_TF_IMPORT_TOOL)
+  except ValueError:
+    logging.warning("Unable to find IREE tool %s", _TF_IMPORT_TOOL)
+    return False
+  return True
+
+
+class ImportType(Enum):
+  """Import type of the model."""
+  OBJECT_GRAPH = "savedmodel_v2"
+  V2 = "savedmodel_v2"
+  SIGNATURE_DEF = "savedmodel_v1"
+  V1 = "savedmodel_v1"
+
+  @staticmethod
+  def parse(spec: Union[str, "ImportType"]) -> "ImportType":
+    """Parses or returns an ImportType.
+
+    Args:
+      spec: An ImportType instance or the case-insensitive name of one of
+        the enum values.
+    Returns:
+      An ImportType instance.
+    """
+    if isinstance(spec, ImportType):
+      return spec
+    spec = spec.upper()
+    if spec not in ImportType.__members__:
+      raise ValueError(f"For import_type= argument, expected one of: "
+                       f"{', '.join(ImportType.__members__.keys())}")
+    return ImportType[spec]
+
+
+# TODO(#4131) python>=3.7: Consider using a dataclass.
+class ImportOptions(CompilerOptions):
+  """Import options layer on top of the backend compiler options."""
+
+  def __init__(self,
+               exported_names: Sequence[str] = (),
+               import_only: bool = False,
+               import_type: Union[ImportType, str] = ImportType.OBJECT_GRAPH,
+               saved_model_tags: Set[str] = set(),
+               import_extra_args: Sequence[str] = (),
+               save_temp_tf_input: Optional[str] = None,
+               save_temp_mid_level_input: Optional[str] = None,
+               save_temp_iree_input: Optional[str] = None,
+               use_tosa: bool = False,
+               **kwargs):
+    """Initialize options from keywords.
+
+    Args:
+      exported_names: Optional sequence representing the exported names to
+        keep (object graph/v2 models only).
+      import_only: Only import the module. If True, the result will be textual
+        MLIR that can be further fed to the IREE compiler. If False (default),
+        the result will be the fully compiled IREE binary. In both cases,
+        bytes-like output is returned. Note that if the output_file= is
+        specified and import_only=True, then the MLIR form will be written to
+        the output file.
+      import_type: Type of import to perform. See ImportType enum.
+      saved_model_tags: Set of tags to export (signature def/v1 saved models
+        only).
+      import_extra_args: Extra arguments to pass to the iree-import-tf tool.
+      save_temp_tf_input: Optionally save the IR that is input to the
+        TensorFlow pipeline.
+      save_temp_mid_level_input: Optionally save the IR that is input to the
+        mid level IR.
+      save_temp_iree_input: Optionally save the IR that is the result of the
+        import (ready to be passed to IREE).
+    """
+    super().__init__(**kwargs)
+    self.exported_names = exported_names
+    self.import_only = import_only
+    self.import_type = ImportType.parse(import_type)
+    self.saved_model_tags = saved_model_tags
+    self.import_extra_args = import_extra_args
+    self.save_temp_tf_input = save_temp_tf_input
+    self.save_temp_mid_level_input = save_temp_mid_level_input
+    self.save_temp_iree_input = save_temp_iree_input
+    self.use_tosa = use_tosa
+
+
+def build_import_command_line(input_path: str, tfs: TempFileSaver,
+                              options: ImportOptions) -> List[str]:
+  """Builds a command line for invoking the import stage.
+
+  Args:
+    input_path: The input path.
+    tfs: TempFileSaver.
+    options: Import options.
+  Returns:
+    List of strings of command line.
+  """
+  tf_import = find_tool(_TF_IMPORT_TOOL)
+  cl = [
+      tf_import,
+      input_path,
+      f"--tf-import-type={options.import_type.value}",
+      f"--tf-savedmodel-exported-names={','.join(options.exported_names)}",
+      f"--tf-savedmodel-tags={','.join(options.saved_model_tags)}",
+  ]
+
+  if options.import_only and options.output_file:
+    # Import stage directly outputs.
+    output_file = tfs.alloc_optional("tf-output.mlir",
+                                     export_as=options.output_file)
+    cl.append(f"-o={output_file}")
+
+  # MLIR flags.
+  if options.output_mlir_debuginfo:
+    cl.append("--mlir-print-debuginfo")
+  if options.output_generic_mlir:
+    cl.append("--mlir-print-op-generic")
+
+  # Save temps flags.
+  save_tf_input = tfs.alloc_optional("tf-input.mlir",
+                                     export_as=options.save_temp_tf_input)
+  if save_tf_input:
+    cl.append(f"--save-temp-tf-input={save_tf_input}")
+  save_mid_level_input = tfs.alloc_optional(
+      "tf-mid-level-input.mlir", export_as=options.save_temp_mid_level_input)
+  if save_mid_level_input:
+    cl.append(f"--save-temp-mid-level-input={save_mid_level_input}")
+  save_iree_input = tfs.alloc_optional("tf-iree-input.mlir",
+                                       export_as=options.save_temp_iree_input)
+  if save_iree_input:
+    cl.append(f"--save-temp-iree-input={save_iree_input}")
+
+  if options.use_tosa:
+    cl.append(f"--use-tosa")
+
+  # Crash reproducer (locally qualified).
+  requested_crash_reproducer_path = options.crash_reproducer_path
+  if requested_crash_reproducer_path:
+    requested_crash_reproducer_path = (requested_crash_reproducer_path +
+                                       ".import-tf")
+  crash_reproducer_path = tfs.alloc_optional(
+      "tf-reproducer.mlir", export_as=requested_crash_reproducer_path)
+  if crash_reproducer_path:
+    cl.append(f"--pass-pipeline-crash-reproducer={crash_reproducer_path}")
+
+  # Extra args.
+  cl.extend(options.import_extra_args)
+  return cl
+
+
+def compile_saved_model(saved_model_dir: str, **kwargs):
+  """Compiles an on-disk saved model to an IREE binary.
+
+  Args:
+    saved_model_dir: Path to directory where the model was saved.
+    **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+  Returns:
+    A bytes-like object with the compiled output or None if output_file=
+    was specified.
+  """
+  with TempFileSaver.implicit() as tfs:
+    options = ImportOptions(**kwargs)
+    import_cl = build_import_command_line(saved_model_dir, tfs, options)
+    if options.import_only:
+      # One stage tool pipeline.
+      result = invoke_immediate(import_cl)
+      if options.output_file:
+        return None
+      return result
+
+    # Full compilation pipeline.
+    compile_cl = build_compile_command_line("-", tfs, options)
+    result = invoke_pipeline([import_cl, compile_cl])
+    if options.output_file:
+      return None
+    return result
+
+
+def compile_module(module, saved_model_dir: Optional[str] = None, **kwargs):
+  """Compiles a tf.Module to an IREE binary (by saving to disk).
+
+  Args:
+    module: The tf.Module instance to convert to MLIR
+    saved_model_dir: Optional path to save the tf.Module to. The module will not
+      be persisted on disk outside of this call if this is not provided.
+    **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+  Returns:
+    Same as compile_saved_model().
+  """
+  with TempFileSaver.implicit() as tfs:
+
+    def do_it(saved_model_dir):
+      import tensorflow as tf
+      options = tf.saved_model.SaveOptions(save_debug_info=True)
+      tf.saved_model.save(module, saved_model_dir, options=options)
+      return compile_saved_model(saved_model_dir, **kwargs)
+
+    if saved_model_dir:
+      return do_it(saved_model_dir)
+    else:
+      with tempfile.TemporaryDirectory(suffix=".sm") as td:
+        return do_it(td)
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tflite.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tflite.py
new file mode 100644
index 0000000..327becc
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tflite.py
@@ -0,0 +1,198 @@
+# Lint-as: python3
+"""TFLite compiler interface."""
+
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# TODO(#4131) python>=3.7: Use postponed type annotations.
+
+from enum import Enum
+import logging
+import tempfile
+from typing import List, Optional, Sequence, Set, Union
+
+from .debugging import TempFileSaver
+from .binaries import find_tool, invoke_immediate, invoke_pipeline
+from .core import CompilerOptions, DEFAULT_TESTING_BACKENDS, build_compile_command_line
+
+__all__ = [
+    "compile_file",
+    "compile_str",
+    "is_available",
+    "DEFAULT_TESTING_BACKENDS",
+    "ImportOptions",
+]
+
+_IMPORT_TOOL = "iree-import-tflite"
+
+
+def is_available():
+  """Determine if the XLA frontend is available."""
+  try:
+    find_tool(_IMPORT_TOOL)
+  except ValueError:
+    logging.warning("Unable to find IREE tool %s", _IMPORT_TOOL)
+    return False
+  return True
+
+
+# TODO(#4131) python>=3.7: Consider using a dataclass.
+class ImportOptions(CompilerOptions):
+  """Import options layer on top of the backend compiler options."""
+
+  def __init__(self,
+               input_arrays: Sequence[str] = (),
+               output_arrays: Sequence[str] = (),
+               import_only: bool = False,
+               import_extra_args: Sequence[str] = (),
+               save_temp_tfl_input: Optional[str] = None,
+               save_temp_iree_input: Optional[str] = None,
+               input_type: Optional[str] = "tosa",
+               **kwargs):
+    """Initialize options from keywords.
+
+    Args:
+      input_arrays: Sequence of input array node names (if different from
+        default).
+      output_arrays: Sequence of output array node names (if different from
+        default).
+      import_only: Only import the module. If True, the result will be textual
+        MLIR that can be further fed to the IREE compiler. If False (default),
+        the result will be the fully compiled IREE binary. In both cases,
+        bytes-like output is returned. Note that if the output_file= is
+        specified and import_only=True, then the MLIR form will be written to
+        the output file.
+      import_extra_args: Extra arguments to pass to the iree-import-tf tool.
+      save_temp_tfl_input: Optionally save the IR that results from importing
+        the flatbuffer (prior to any further transformations).
+      save_temp_iree_input: Optionally save the IR that is the result of the
+        import (ready to be passed to IREE).
+    """
+    super().__init__(input_type=input_type, **kwargs)
+    self.input_arrays = input_arrays
+    self.output_arrays = output_arrays
+    self.import_only = import_only
+    self.import_extra_args = import_extra_args
+    self.save_temp_tfl_input = save_temp_tfl_input
+    self.save_temp_iree_input = save_temp_iree_input
+
+
+def build_import_command_line(input_path: str, tfs: TempFileSaver,
+                              options: ImportOptions) -> List[str]:
+  """Builds a command line for invoking the import stage.
+
+  Args:
+    input_path: The input path.
+    tfs: TempFileSaver.
+    options: Import options.
+  Returns:
+    List of strings of command line.
+  """
+  import_tool = find_tool(_IMPORT_TOOL)
+  cl = [
+      import_tool,
+      input_path,
+  ]
+
+  if options.import_only and options.output_file:
+    # Import stage directly outputs.
+    output_file = tfs.alloc_optional("tflite-output.mlir",
+                                     export_as=options.output_file)
+    cl.append(f"-o={options.output_file}")
+
+  # Input arrays.
+  if options.input_arrays:
+    for input_array in options.input_arrays:
+      cl.append(f"--input-array={input_array}")
+    for output_array in options.output_arrays:
+      cl.append(f"--output-array={output_array}")
+
+  # MLIR flags.
+  if options.output_mlir_debuginfo:
+    cl.append("--mlir-print-debuginfo")
+  if options.output_generic_mlir:
+    cl.append("--mlir-print-op-generic")
+
+  # Save temps flags.
+  tfl_input = tfs.alloc_optional("tflite-input.mlir",
+                                 export_as=options.save_temp_tfl_input)
+  if tfl_input:
+    cl.append(f"--save-temp-tfl-input={tfl_input}")
+  iree_input = tfs.alloc_optional("tflite-iree-input.mlir",
+                                  export_as=options.save_temp_iree_input)
+  if iree_input:
+    cl.append(f"--save-temp-iree-input={iree_input}")
+
+  # Crash reproducer (locally qualified).
+  requested_crash_reproducer_path = options.crash_reproducer_path
+  if requested_crash_reproducer_path:
+    requested_crash_reproducer_path = (requested_crash_reproducer_path +
+                                       ".import-tflite")
+  crash_reproducer_path = tfs.alloc_optional(
+      "tflite-reproducer.mlir", export_as=requested_crash_reproducer_path)
+  if crash_reproducer_path:
+    cl.append(f"--pass-pipeline-crash-reproducer={crash_reproducer_path}")
+
+  # Extra args.
+  cl.extend(options.import_extra_args)
+  return cl
+
+
+def compile_file(fb_path: str, **kwargs):
+  """Compiles a TFLite flatbuffer file to an IREE binary.
+
+  Args:
+    fb_path: Path to the flatbuffer.
+    **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+  Returns:
+    A bytes-like object with the compiled output or None if output_file=
+    was specified.
+  """
+  with TempFileSaver.implicit() as tfs:
+    options = ImportOptions(**kwargs)
+    import_cl = build_import_command_line(fb_path, tfs, options)
+    if options.import_only:
+      # One stage tool pipeline.
+      result = invoke_immediate(import_cl)
+      if options.output_file:
+        return None
+      return result
+
+    # Full compilation pipeline.
+    compile_cl = build_compile_command_line("-", tfs, options)
+    result = invoke_pipeline([import_cl, compile_cl])
+    if options.output_file:
+      return None
+    return result
+
+
+def compile_str(fb_content: bytes, **kwargs):
+  """Compiles in-memory TFLite flatbuffer to an IREE binary.
+
+  Args:
+    xla_content: Flatbuffer content as bytes.
+    **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+  Returns:
+    A bytes-like object with the compiled output or None if output_file=
+    was specified.
+  """
+  with TempFileSaver.implicit() as tfs:
+    options = ImportOptions(**kwargs)
+    import_cl = build_import_command_line("-", tfs, options)
+    if options.import_only:
+      # One stage tool pipeline.
+      result = invoke_immediate(import_cl, immediate_input=fb_content)
+      if options.output_file:
+        return None
+      return result
+
+    # Full compilation pipeline.
+    compile_cl = build_compile_command_line("-", tfs, options)
+    result = invoke_pipeline([import_cl, compile_cl],
+                             immediate_input=fb_content)
+    if options.output_file:
+      return None
+    return result
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/xla.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/xla.py
new file mode 100644
index 0000000..9b2b1c2
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/xla.py
@@ -0,0 +1,212 @@
+# Lint-as: python3
+"""XLA compiler interface."""
+
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# TODO(#4131) python>=3.7: Use postponed type annotations.
+
+from enum import Enum
+import logging
+import tempfile
+from typing import List, Optional, Sequence, Set, Union
+
+from .core import CompilerOptions, DEFAULT_TESTING_BACKENDS, build_compile_command_line
+from .debugging import TempFileSaver
+from .binaries import find_tool, invoke_immediate, invoke_pipeline
+
+__all__ = [
+    "compile_file",
+    "compile_str",
+    "is_available",
+    "DEFAULT_TESTING_BACKENDS",
+    "ImportOptions",
+    "ImportFormat",
+]
+
+_IMPORT_TOOL = "iree-import-xla"
+
+
+def is_available():
+  """Determine if the XLA frontend is available."""
+  try:
+    find_tool(_IMPORT_TOOL)
+  except ValueError:
+    logging.warning("Unable to find IREE tool %s", _IMPORT_TOOL)
+    return False
+  return True
+
+
+class ImportFormat(Enum):
+  """Import type of the model."""
+  BINARY_PROTO = "binary_proto"
+  TEXT_PROTO = "text_proto"
+  HLO_TEXT = "hlo_text"
+
+  @staticmethod
+  def parse(spec: Union[str, "ImportFormat"]) -> "ImportFormat":
+    """Parses or returns an ImportFormat.
+
+    Args:
+      spec: An ImportFormat instance or the case-insensitive name of one of
+        the enum values.
+    Returns:
+      An ImportFormat instance.
+    """
+    if isinstance(spec, ImportFormat):
+      return spec
+    spec = spec.upper()
+    if spec not in ImportFormat.__members__:
+      raise ValueError(f"For import_format= argument, expected one of: "
+                       f"{', '.join(ImportFormat.__members__.keys())}")
+    return ImportFormat[spec]
+
+
+# TODO(#4131) python>=3.7: Consider using a dataclass.
+class ImportOptions(CompilerOptions):
+  """Import options layer on top of the backend compiler options."""
+
+  def __init__(self,
+               import_only: bool = False,
+               import_format: Union[ImportFormat,
+                                    str] = ImportFormat.BINARY_PROTO,
+               import_extra_args: Sequence[str] = (),
+               save_temp_mhlo_input: Optional[str] = None,
+               save_temp_iree_input: Optional[str] = None,
+               **kwargs):
+    """Initialize options from keywords.
+
+    Args:
+      import_format: Format of the proto (text or binary).
+      save_temp_iree_input: Optionally save the IR that is the result of the
+        import (ready to be passed to IREE).
+    """
+    super().__init__(**kwargs)
+    self.import_only = import_only
+    self.import_format = ImportFormat.parse(import_format)
+    self.import_extra_args = import_extra_args
+    self.save_temp_mhlo_input = save_temp_mhlo_input
+    self.save_temp_iree_input = save_temp_iree_input
+
+
+def build_import_command_line(input_path: str, tfs: TempFileSaver,
+                              options: ImportOptions) -> List[str]:
+  """Builds a command line for invoking the import stage.
+
+  Args:
+    input_path: The input path.
+    tfs: TempFileSaver.
+    options: Import options.
+  Returns:
+    List of strings of command line.
+  """
+  import_tool = find_tool(_IMPORT_TOOL)
+  cl = [
+      import_tool,
+      input_path,
+      f"--xla-format={options.import_format.value}",
+  ]
+
+  if options.import_only and options.output_file:
+    # Import stage directly outputs.
+    output_file = tfs.alloc_optional("xla-output.mlir",
+                                     export_as=options.output_file)
+    cl.append(f"-o={output_file}")
+
+  # MLIR flags.
+  if options.output_mlir_debuginfo:
+    cl.append("--mlir-print-debuginfo")
+  if options.output_generic_mlir:
+    cl.append("--mlir-print-op-generic")
+
+  # Save temps flags.
+  save_mhlo_input = tfs.alloc_optional("tf-mhlo.mlir",
+                                       export_as=options.save_temp_mhlo_input)
+  if save_mhlo_input:
+    cl.append(f"--save-temp-mhlo-input={save_mhlo_input}")
+  iree_input = tfs.alloc_optional("xla-iree-input.mlir",
+                                  export_as=options.save_temp_iree_input)
+  if iree_input:
+    cl.append(f"--save-temp-iree-input={iree_input}")
+
+  # Crash reproducer (locally qualified).
+  requested_crash_reproducer_path = options.crash_reproducer_path
+  if requested_crash_reproducer_path:
+    requested_crash_reproducer_path = (requested_crash_reproducer_path +
+                                       ".import-xla")
+  crash_reproducer_path = tfs.alloc_optional(
+      "xla-reproducer.mlir", export_as=requested_crash_reproducer_path)
+  if crash_reproducer_path:
+    cl.append(f"--pass-pipeline-crash-reproducer={crash_reproducer_path}")
+
+  # Extra args.
+  cl.extend(options.import_extra_args)
+  return cl
+
+
+def compile_file(xla_file_path: str, **kwargs):
+  """Compiles an on-disk XLA protocol buffer to an IREE binary.
+
+  Args:
+    xla_file_path: Path to the XLA proto file.
+    **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+  Returns:
+    A bytes-like object with the compiled output or None if output_file=
+    was specified.
+  """
+  with TempFileSaver.implicit() as tfs:
+    options = ImportOptions(**kwargs)
+    import_cl = build_import_command_line(xla_file_path, tfs, options)
+    if options.import_only:
+      # One stage tool pipeline.
+      result = invoke_immediate(import_cl)
+      if options.output_file:
+        return None
+      return result
+
+    # Full compilation pipeline.
+    compile_cl = build_compile_command_line("-", tfs, options)
+    result = invoke_pipeline([import_cl, compile_cl])
+    if options.output_file:
+      return None
+    return result
+
+
+def compile_str(xla_content: Union[bytes, str], **kwargs):
+  """Compiles in-memory XLA content to an IREE binary.
+
+  Args:
+    xla_content: Either bytes or str content (str is only valid for text
+      formats).
+    **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+  Returns:
+    A bytes-like object with the compiled output or None if output_file=
+    was specified.
+  """
+  with TempFileSaver.implicit() as tfs:
+    options = ImportOptions(**kwargs)
+    if isinstance(xla_content, str):
+      if options.import_format not in [
+          ImportFormat.TEXT_PROTO, ImportFormat.HLO_TEXT
+      ]:
+        raise ValueError("If passing a string, ImportFormat must be TEXT_PROTO")
+      xla_content = xla_content.encode("utf-8")
+
+    import_cl = build_import_command_line("-", tfs, options)
+    if options.import_only:
+      # One stage tool pipeline.
+      result = invoke_immediate(import_cl, immediate_input=xla_content)
+      if options.output_file:
+        return None
+      return result
+
+    # Full compilation pipeline.
+    compile_cl = build_compile_command_line("-", tfs, options)
+    result = invoke_pipeline([import_cl, compile_cl],
+                             immediate_input=xla_content)
+    if options.output_file:
+      return None
+    return result
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/xla.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/xla.py
new file mode 100644
index 0000000..e66d5fe
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/xla.py
@@ -0,0 +1,10 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import sys
+from .tools import xla
+
+sys.modules[__name__] = xla
diff --git a/llvm-external-projects/iree-compiler-api/unittests/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/unittests/CMakeLists.txt
new file mode 100644
index 0000000..f7c6b56
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(tools)
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/unittests/tools/CMakeLists.txt
new file mode 100644
index 0000000..9233c3f
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/CMakeLists.txt
@@ -0,0 +1,52 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+function(iree_compiler_api_py_test)
+  cmake_parse_arguments(
+    ARG
+    ""
+    "NAME;MAIN"
+    ""
+    ${ARGN}
+  )
+  set(TEST_NAME "iree-compiler-api-${ARG_NAME}")
+  add_test(
+    NAME
+      ${TEST_NAME}
+    WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
+    COMMAND "${Python3_EXECUTABLE}" "${CMAKE_CURRENT_SOURCE_DIR}/${ARG_MAIN}"
+  )
+  set_tests_properties(${TEST_NAME} PROPERTIES
+    ENVIRONMENT PYTHONPATH=${IREE_COMPILER_API_BINARY_DIR}/python_package)
+endfunction()
+
+iree_compiler_api_py_test(
+  NAME
+    compiler_core_test
+  MAIN
+    "compiler_core_test.py"
+)
+
+iree_compiler_api_py_test(
+  NAME
+    compiler_tf_test
+  MAIN
+    "compiler_tf_test.py"
+)
+
+iree_compiler_api_py_test(
+  NAME
+    compiler_tflite_test
+  MAIN
+    "compiler_tflite_test.py"
+)
+
+iree_compiler_api_py_test(
+  NAME
+    compiler_xla_test
+  MAIN
+    "compiler_xla_test.py"
+)
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/README.md b/llvm-external-projects/iree-compiler-api/unittests/tools/README.md
new file mode 100644
index 0000000..e6e6c29
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/README.md
@@ -0,0 +1,8 @@
+# Python API Tests
+
+These tests are run in an environment where all available Python bindings
+are setup on the `PYTHONPATH`. Each will internally skip itself if optional
+components are not available.
+
+Note that IREE compiler tool locations can be overridden by specifying the
+`IREE_TOOL_PATH` environment variable.
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_core_test.py b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_core_test.py
new file mode 100644
index 0000000..996fc29
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_core_test.py
@@ -0,0 +1,216 @@
+# Lint as: python3
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import contextlib
+import logging
+import os
+import io
+import tempfile
+import unittest
+
+import iree.compiler.tools
+
+SIMPLE_MUL_ASM = """
+func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+    %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+    return %0 : tensor<4xf32>
+}
+"""
+
+
+class CompilerTest(unittest.TestCase):
+
+  def setUp(self):
+    if "IREE_SAVE_TEMPS" in os.environ:
+      del os.environ["IREE_SAVE_TEMPS"]
+
+  def testNoTargetBackends(self):
+    with self.assertRaisesRegex(
+        ValueError, "Expected a non-empty list for 'target_backends'"):
+      binary = iree.compiler.tools.compile_str(SIMPLE_MUL_ASM)
+
+  def testCompileStr(self):
+    binary = iree.compiler.tools.compile_str(
+        SIMPLE_MUL_ASM,
+        input_type="mhlo",
+        target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+    logging.info("Flatbuffer size = %d", len(binary))
+    self.assertTrue(binary)
+
+  # Compiling the string form means that the compiler does not have a valid
+  # source file name, which can cause issues on the AOT side. Verify
+  # specifically. See: https://github.com/google/iree/issues/4439
+  def testCompileStrLLVMAOT(self):
+    binary = iree.compiler.tools.compile_str(SIMPLE_MUL_ASM,
+                                             input_type="mhlo",
+                                             target_backends=["dylib-llvm-aot"])
+    logging.info("Flatbuffer size = %d", len(binary))
+    self.assertTrue(binary)
+
+  # Verifies that multiple target_backends are accepted. Which two are not
+  # load bearing.
+  # See: https://github.com/google/iree/issues/4436
+  def testCompileMultipleBackends(self):
+    binary = iree.compiler.tools.compile_str(
+        SIMPLE_MUL_ASM,
+        input_type="mhlo",
+        target_backends=["dylib-llvm-aot", "vulkan-spirv"])
+    logging.info("Flatbuffer size = %d", len(binary))
+    self.assertTrue(binary)
+
+  def testCompileInputFile(self):
+    with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+      try:
+        f.write(SIMPLE_MUL_ASM)
+        f.close()
+        binary = iree.compiler.tools.compile_file(
+            f.name,
+            input_type="mhlo",
+            target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+      finally:
+        os.remove(f.name)
+    logging.info("Flatbuffer size = %d", len(binary))
+    self.assertIn(b"simple_mul", binary)
+
+  def testCompileOutputFile(self):
+    with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+      try:
+        f.close()
+        output = iree.compiler.tools.compile_str(
+            SIMPLE_MUL_ASM,
+            input_type="mhlo",
+            output_file=f.name,
+            target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+        self.assertIsNone(output)
+
+        with open(f.name, "rb") as f_read:
+          binary = f_read.read()
+      finally:
+        os.remove(f.name)
+    logging.info("Flatbuffer size = %d", len(binary))
+    self.assertIn(b"simple_mul", binary)
+
+  def testOutputFbText(self):
+    text = iree.compiler.tools.compile_str(
+        SIMPLE_MUL_ASM,
+        input_type="mhlo",
+        output_format=iree.compiler.tools.OutputFormat.FLATBUFFER_TEXT,
+        target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode(
+            "utf-8")
+    # Just check for an arbitrary JSON-tag.
+    self.assertIn('"exported_functions"', text)
+
+  def testBadInputType(self):
+    with self.assertRaisesRegex(
+        ValueError, "For input_type= argument, expected one of: "
+        "NONE, MHLO, TOSA"):
+      _ = iree.compiler.tools.compile_str(
+          SIMPLE_MUL_ASM,
+          input_type="not-existing",
+          output_format="foobar",
+          target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+
+  def testBadOutputFormat(self):
+    with self.assertRaisesRegex(
+        ValueError, "For output_format= argument, expected one of: "
+        "FLATBUFFER_BINARY, FLATBUFFER_TEXT, MLIR_TEXT"):
+      _ = iree.compiler.tools.compile_str(
+          SIMPLE_MUL_ASM,
+          input_type="mhlo",
+          output_format="foobar",
+          target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+
+  def testOutputFbTextParsed(self):
+    text = iree.compiler.tools.compile_str(
+        SIMPLE_MUL_ASM,
+        input_type='mhlo',
+        output_format='flatbuffer_text',
+        target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode(
+            "utf-8")
+    # Just check for an arbitrary JSON-tag.
+    self.assertIn('"exported_functions"', text)
+
+  def testOutputMlirText(self):
+    text = iree.compiler.tools.compile_str(
+        SIMPLE_MUL_ASM,
+        input_type="mhlo",
+        output_format=iree.compiler.tools.OutputFormat.MLIR_TEXT,
+        target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode(
+            "utf-8")
+    # Just check for a textual op name.
+    self.assertIn("vm.module", text)
+
+  def testExtraArgsStderr(self):
+    # mlir-timing is not special: it just does something and emits to stderr.
+    with io.StringIO() as buf, contextlib.redirect_stderr(buf):
+      iree.compiler.tools.compile_str(
+          SIMPLE_MUL_ASM,
+          input_type="mhlo",
+          extra_args=["--mlir-timing"],
+          target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+      stderr = buf.getvalue()
+    self.assertIn("Execution time report", stderr)
+
+  def testAllOptions(self):
+    binary = iree.compiler.tools.compile_str(
+        SIMPLE_MUL_ASM,
+        input_type="mhlo",
+        optimize=False,
+        strip_debug_ops=True,
+        strip_source_map=True,
+        crash_reproducer_path="foobar.txt",
+        # Re-enable when benchmarking pass is fixed: #6196
+        # enable_benchmark=True,
+        target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+
+  def testException(self):
+    with self.assertRaisesRegex(iree.compiler.tools.CompilerToolError,
+                                "Invoked with"):
+      _ = iree.compiler.tools.compile_str(
+          "I'm a little teapot but not a valid program",
+          target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+
+  def testExplicitTempFileSaver(self):
+    temp_dir = tempfile.TemporaryDirectory()
+    output_file = tempfile.NamedTemporaryFile("wt")
+    output_file.close()
+    with iree.compiler.tools.TempFileSaver(temp_dir.name):
+      output = iree.compiler.tools.compile_str(
+          SIMPLE_MUL_ASM,
+          input_type="mhlo",
+          output_file=output_file.name,
+          target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+      self.assertIsNone(output)
+
+    # There should be an output file and a core-output.bin in the temp dir.
+    self.assertTrue(os.path.exists(output_file.name))
+    expected_temp_file = os.path.join(temp_dir.name, "core-output.bin")
+    self.assertTrue(os.path.exists(expected_temp_file))
+
+    # And they should have the same contents.
+    with open(output_file.name, "rb") as f:
+      output_contents = f.read()
+    with open(expected_temp_file, "rb") as f:
+      temp_contents = f.read()
+    self.assertEqual(temp_contents, output_contents)
+    temp_dir.cleanup()
+
+  def testEnvTempFileSaver(self):
+    temp_dir = tempfile.TemporaryDirectory()
+    os.environ["IREE_SAVE_TEMPS"] = temp_dir.name
+    with iree.compiler.tools.TempFileSaver() as tfs:
+      self.assertTrue(tfs.retained)
+      self.assertEqual(tfs.retained_path, temp_dir.name)
+
+  def testTempFileSaverDisabled(self):
+    with iree.compiler.tools.TempFileSaver() as tfs:
+      self.assertFalse(tfs.retained)
+
+
+if __name__ == "__main__":
+  logging.basicConfig(level=logging.DEBUG)
+  unittest.main()
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tf_test.py b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tf_test.py
new file mode 100644
index 0000000..fc1ee71
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tf_test.py
@@ -0,0 +1,83 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+import os
+import sys
+import tempfile
+import unittest
+
+# TODO: No idea why pytype cannot find names from this module.
+# pytype: disable=name-error
+import iree.compiler.tools.tf
+
+if not iree.compiler.tools.tf.is_available():
+  print(f"Skipping test {__file__} because the IREE TensorFlow compiler "
+        f"is not installed")
+  sys.exit(0)
+
+import tensorflow as tf
+
+
+class SimpleArithmeticModule(tf.Module):
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([4], tf.float32),
+      tf.TensorSpec([4], tf.float32)
+  ])
+  def simple_mul(self, a, b):
+    return a * b
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([128, 3072], tf.float32),
+      tf.TensorSpec([3072, 256], tf.float32),
+  ])
+  def simple_matmul(self, a, b):
+    return tf.matmul(a, b)
+
+
+# TODO(laurenzo): More test cases needed (may need additional files).
+# Specifically, figure out how to test v1 models.
+class TfCompilerTest(tf.test.TestCase):
+
+  def testImportSavedModel(self):
+    import_mlir = iree.compiler.tools.tf.compile_saved_model(
+        self.smdir, import_only=True, output_generic_mlir=True).decode("utf-8")
+    self.assertIn("sym_name = \"simple_matmul\"", import_mlir)
+
+  def testCompileSavedModel(self):
+    binary = iree.compiler.tools.tf.compile_saved_model(
+        self.smdir,
+        target_backends=iree.compiler.tools.tf.DEFAULT_TESTING_BACKENDS)
+    logging.info("Compiled len: %d", len(binary))
+    self.assertIn(b"simple_matmul", binary)
+    self.assertIn(b"simple_mul", binary)
+
+  def testCompileModule(self):
+    binary = iree.compiler.tools.tf.compile_module(
+        self.m, target_backends=iree.compiler.tools.tf.DEFAULT_TESTING_BACKENDS)
+    logging.info("Compiled len: %d", len(binary))
+    self.assertIn(b"simple_matmul", binary)
+    self.assertIn(b"simple_mul", binary)
+
+  @classmethod
+  def setUpClass(cls):
+    cls.m = SimpleArithmeticModule()
+    cls.tempdir = tempfile.TemporaryDirectory()
+    cls.smdir = os.path.join(cls.tempdir.name, "arith.sm")
+    tf.saved_model.save(
+        cls.m,
+        cls.smdir,
+        options=tf.saved_model.SaveOptions(save_debug_info=True))
+
+  @classmethod
+  def tearDownClass(cls):
+    cls.tempdir.cleanup()
+
+
+if __name__ == "__main__":
+  logging.basicConfig(level=logging.DEBUG)
+  tf.test.main()
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tflite_test.py b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tflite_test.py
new file mode 100644
index 0000000..671f515
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tflite_test.py
@@ -0,0 +1,102 @@
+# Lint as: python3
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+import os
+import sys
+import tempfile
+import unittest
+
+# TODO: No idea why pytype cannot find names from this module.
+# pytype: disable=name-error
+import iree.compiler.tools.tflite
+
+if not iree.compiler.tools.tflite.is_available():
+  print(f"Skipping test {__file__} because the IREE TFLite compiler "
+        f"is not installed")
+  sys.exit(0)
+
+
+class CompilerTest(unittest.TestCase):
+
+  def testImportBinaryPbFile(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata",
+                        "tflite_sample.fb")
+    text = iree.compiler.tools.tflite.compile_file(
+        path, import_only=True).decode("utf-8")
+    logging.info("%s", text)
+    self.assertIn("tosa.mul", text)
+
+  def testCompileBinaryPbFile(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata",
+                        "tflite_sample.fb")
+    binary = iree.compiler.tools.tflite.compile_file(
+        path,
+        target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS)
+    logging.info("Binary length = %d", len(binary))
+    self.assertIn(b"main", binary)
+
+  def testImportBinaryPbFileOutputFile(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata",
+                        "tflite_sample.fb")
+    with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+      try:
+        f.close()
+        output = iree.compiler.tools.tflite.compile_file(path,
+                                                         import_only=True,
+                                                         output_file=f.name)
+        self.assertIsNone(output)
+        with open(f.name, "rt") as f_read:
+          text = f_read.read()
+      finally:
+        os.remove(f.name)
+    logging.info("%s", text)
+    self.assertIn("tosa.mul", text)
+
+  def testCompileBinaryPbFileOutputFile(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata",
+                        "tflite_sample.fb")
+    with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+      try:
+        f.close()
+        output = iree.compiler.tools.tflite.compile_file(
+            path,
+            output_file=f.name,
+            target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS)
+        self.assertIsNone(output)
+        with open(f.name, "rb") as f_read:
+          binary = f_read.read()
+      finally:
+        os.remove(f.name)
+    logging.info("Binary length = %d", len(binary))
+    self.assertIn(b"main", binary)
+
+  def testImportBinaryPbBytes(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata",
+                        "tflite_sample.fb")
+    with open(path, "rb") as f:
+      content = f.read()
+    text = iree.compiler.tools.tflite.compile_str(
+        content, import_only=True).decode("utf-8")
+    logging.info("%s", text)
+    self.assertIn("tosa.mul", text)
+
+  def testCompileBinaryPbBytes(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata",
+                        "tflite_sample.fb")
+    with open(path, "rb") as f:
+      content = f.read()
+    binary = iree.compiler.tools.tflite.compile_str(
+        content,
+        target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS)
+    logging.info("Binary length = %d", len(binary))
+    self.assertIn(b"main", binary)
+
+
+if __name__ == "__main__":
+  logging.basicConfig(level=logging.DEBUG)
+  unittest.main()
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_xla_test.py b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_xla_test.py
new file mode 100644
index 0000000..ad73db4
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_xla_test.py
@@ -0,0 +1,118 @@
+# Lint as: python3
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+import os
+import sys
+import tempfile
+import unittest
+
+import iree.compiler.tools.xla
+
+if not iree.compiler.tools.xla.is_available():
+  print(f"Skipping test {__file__} because the IREE XLA compiler "
+        f"is not installed")
+  sys.exit(0)
+
+
+class CompilerTest(unittest.TestCase):
+
+  def testImportBinaryPbFile(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+    text = iree.compiler.tools.xla.compile_file(
+        path, import_only=True).decode("utf-8")
+    logging.info("%s", text)
+    self.assertIn("linalg.generic", text)
+
+  def testCompileBinaryPbFile(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+    binary = iree.compiler.tools.xla.compile_file(
+        path, target_backends=iree.compiler.tools.xla.DEFAULT_TESTING_BACKENDS)
+    logging.info("Binary length = %d", len(binary))
+    self.assertIn(b"main", binary)
+
+  def testImportBinaryPbFileOutputFile(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+    with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+      try:
+        f.close()
+        output = iree.compiler.tools.xla.compile_file(path,
+                                                      import_only=True,
+                                                      output_file=f.name)
+        self.assertIsNone(output)
+        with open(f.name, "rt") as f_read:
+          text = f_read.read()
+      finally:
+        os.remove(f.name)
+    logging.info("%s", text)
+    self.assertIn("linalg.generic", text)
+
+  def testCompileBinaryPbFileOutputFile(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+    with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+      try:
+        f.close()
+        output = iree.compiler.tools.xla.compile_file(
+            path,
+            output_file=f.name,
+            target_backends=iree.compiler.DEFAULT_TESTING_BACKENDS)
+        self.assertIsNone(output)
+        with open(f.name, "rb") as f_read:
+          binary = f_read.read()
+      finally:
+        os.remove(f.name)
+    logging.info("Binary length = %d", len(binary))
+    self.assertIn(b"main", binary)
+
+  def testImportBinaryPbBytes(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+    with open(path, "rb") as f:
+      content = f.read()
+    text = iree.compiler.tools.xla.compile_str(content,
+                                               import_only=True).decode("utf-8")
+    logging.info("%s", text)
+    self.assertIn("linalg.generic", text)
+
+  def testCompileBinaryPbBytes(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+    with open(path, "rb") as f:
+      content = f.read()
+    binary = iree.compiler.tools.xla.compile_str(
+        content,
+        target_backends=iree.compiler.tools.xla.DEFAULT_TESTING_BACKENDS)
+    logging.info("Binary length = %d", len(binary))
+    self.assertIn(b"main", binary)
+
+  def testImportHloTextFile(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.hlo")
+    text = iree.compiler.tools.xla.compile_file(
+        path, import_only=True, import_format="hlo_text").decode("utf-8")
+    logging.info("%s", text)
+    self.assertIn("linalg.generic", text)
+
+  def testImportHloTextStr(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.hlo")
+    with open(path, "rt") as f:
+      content = f.read()
+    text = iree.compiler.tools.xla.compile_str(
+        content, import_only=True, import_format="hlo_text").decode("utf-8")
+    logging.info("%s", text)
+    self.assertIn("linalg.generic", text)
+
+  def testImportHloTextBytes(self):
+    path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.hlo")
+    with open(path, "rb") as f:
+      content = f.read()
+    text = iree.compiler.tools.xla.compile_str(
+        content, import_only=True, import_format="hlo_text").decode("utf-8")
+    logging.info("%s", text)
+    self.assertIn("linalg.generic", text)
+
+
+if __name__ == "__main__":
+  logging.basicConfig(level=logging.DEBUG)
+  unittest.main()
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_tflite.py b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_tflite.py
new file mode 100644
index 0000000..a325089
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_tflite.py
@@ -0,0 +1,29 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+
+import tensorflow as tf
+
+
+class Squared(tf.Module):
+
+  @tf.function
+  def __call__(self, x):
+    return tf.square(x)
+
+
+model = Squared()
+concrete_func = model.__call__.get_concrete_function(
+    tf.TensorSpec(shape=[4, 3], dtype=tf.float32))
+
+converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func],
+                                                            model)
+tflite_model = converter.convert()
+
+this_dir = os.path.dirname(__file__)
+with open(os.path.join(this_dir, "tflite_sample.fb"), "wb") as f:
+  f.write(tflite_model)
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_xla.py b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_xla.py
new file mode 100644
index 0000000..c583291
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_xla.py
@@ -0,0 +1,28 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+
+import numpy as np
+
+# Jax is the most accessible way to get at an xla_client.
+# python -m pip install --upgrade pip
+# python -m pip install --upgrade jax jaxlib
+from jaxlib import xla_client
+
+ops = xla_client.ops
+
+builder = xla_client.XlaBuilder("testbuilder")
+in_shape = np.array([4], dtype=np.float32)
+in_feed = ops.Parameter(builder, 0, xla_client.shape_from_pyval(in_shape))
+result = ops.Add(in_feed, ops.Constant(builder, np.float32(1.0)))
+xla_computation = builder.Build(result)
+
+this_dir = os.path.dirname(__file__)
+with open(os.path.join(this_dir, "xla_sample.pb"), "wb") as f:
+  f.write(xla_computation.as_serialized_hlo_module_proto())
+with open(os.path.join(this_dir, "xla_sample.hlo"), "wt") as f:
+  f.write(xla_computation.as_hlo_text())
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/tflite_sample.fb b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/tflite_sample.fb
new file mode 100644
index 0000000..52cb9e4
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/tflite_sample.fb
Binary files differ
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.hlo b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.hlo
new file mode 100644
index 0000000..2617141
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.hlo
@@ -0,0 +1,9 @@
+HloModule testbuilder.5
+
+ENTRY testbuilder.5 {
+  parameter.1 = f32[1] parameter(0)
+  constant.2 = f32[] constant(1)
+  broadcast.3 = f32[1]{0} broadcast(constant.2), dimensions={}
+  ROOT add.4 = f32[1]{0} add(parameter.1, broadcast.3)
+}
+
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.pb b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.pb
new file mode 100644
index 0000000..1f69a64
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.pb
Binary files differ