Fork and adapt old iree.compiler -> iree.compiler.tools. (#7164)
* Fork and adapt old iree.compiler -> iree.compiler.tools.
This is the first step to removing the original iree.compiler Python package in favor of the more full featured version developed under the dedicated iree-compiler-api project.
* Forks iree.compiler.[core|tf|tflite|xla] and support code to iree.compiler.tools
* Sets up temporary redirects from the root iree.compiler namespace so we have time to transition folks to new names.
* Adds a CAPI for invoking iree-translate as a function.
* Adds a standalone `ireec` binary and links it against the main CAPI shared library (ends up being 6KB on my machine).
* Sets up some testing for the project.
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1d68351..652a204 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -441,7 +441,12 @@
endif()
if(IREE_BUILD_PYTHON_BINDINGS)
- add_subdirectory(third_party/pybind11 EXCLUDE_FROM_ALL)
+ if(NOT TARGET pybind11::module)
+ message(STATUS "Using bundled pybind11")
+ add_subdirectory(third_party/pybind11 EXCLUDE_FROM_ALL)
+ else()
+ message(STATUS "Not including bundled pybind11 (already configured)")
+ endif()
endif()
if(IREE_TARGET_BACKEND_METAL-SPIRV)
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index c633808..d6fe747 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -352,8 +352,9 @@
)
cc_library(
- name = "iree_translate_main",
- srcs = ["iree-translate-main.cc"],
+ name = "iree_translate_lib",
+ srcs = ["iree_translate_lib.cc"],
+ hdrs = ["iree_translate_lib.h"],
deps = [
":init_compiler_modules",
":init_iree_passes_and_dialects",
@@ -379,9 +380,10 @@
cc_binary(
name = "iree-translate",
+ srcs = ["iree-translate-main.cc"],
tags = ["hostonly"],
deps = [
- ":iree_translate_main",
+ ":iree_translate_lib",
],
)
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 0bee09c..6502a13 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -336,9 +336,11 @@
iree_cc_library(
NAME
- iree_translate_main
+ iree_translate_lib
+ HDRS
+ "iree_translate_lib.h"
SRCS
- "iree-translate-main.cc"
+ "iree_translate_lib.cc"
DEPS
::init_compiler_modules
::init_iree_passes_and_dialects
@@ -365,8 +367,10 @@
iree_cc_binary(
NAME
iree-translate
+ SRCS
+ "iree-translate-main.cc"
DEPS
- ::iree_translate_main
+ ::iree_translate_lib
DATA
lld
HOSTONLY
diff --git a/iree/tools/iree-translate-main.cc b/iree/tools/iree-translate-main.cc
index f506926..448f1a7 100644
--- a/iree/tools/iree-translate-main.cc
+++ b/iree/tools/iree-translate-main.cc
@@ -4,137 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-// IREE translation main entry function.
-//
-// Note that this differs from mlir-translate and similar because we use the
-// PassManger and do transformations on the IR before translating to other
-// formats. Thus we use our own main entry function because we register
-// Dialects and PassManager CLI options.
-
-#include <functional>
-#include <memory>
-#include <string>
-#include <type_traits>
-
-#include "iree/compiler/Dialect/VM/Target/init_targets.h"
-#include "iree/tools/init_compiler_modules.h"
-#include "iree/tools/init_iree_dialects.h"
-#include "iree/tools/init_mlir_dialects.h"
-#include "iree/tools/init_passes.h"
-#include "iree/tools/init_targets.h"
-#include "iree/tools/init_translations.h"
-#include "iree/tools/init_xla_dialects.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/InitLLVM.h"
-#include "llvm/Support/MemoryBuffer.h"
-#include "llvm/Support/SMLoc.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/ToolOutputFile.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/AsmState.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/FileUtilities.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/Timing.h"
-#include "mlir/Support/ToolUtilities.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
-#include "mlir/Translation.h"
-
-static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
- llvm::cl::desc("<input file>"),
- llvm::cl::init("-"));
-
-static llvm::cl::opt<std::string> outputFilename(
- "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
- llvm::cl::init("-"));
-
-static llvm::cl::opt<bool> splitInputFile(
- "split-input-file",
- llvm::cl::desc("Split the input file into pieces and "
- "process each chunk independently"),
- llvm::cl::init(false));
-
-static llvm::cl::opt<bool> printMainAddress(
- "print-main-address",
- llvm::cl::desc("Print the address of main to stderr to aid in symbolizing "
- "stack traces after the fact"),
- llvm::cl::init(false));
+#include "iree/tools/iree_translate_lib.h"
int main(int argc, char **argv) {
- llvm::InitLLVM y(argc, argv);
- mlir::DialectRegistry registry;
- mlir::registerMlirDialects(registry);
- mlir::registerLLVMDialectTranslation(registry);
- mlir::registerXLADialects(registry);
- mlir::iree_compiler::registerAllPasses();
- mlir::iree_compiler::registerIreeDialects(registry);
- mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
- mlir::iree_compiler::registerHALTargetBackends();
- mlir::iree_compiler::registerVMTargets();
- mlir::registerMlirTranslations();
- mlir::iree_compiler::registerIreeTranslations();
- // Make sure command line options are registered.
- (void)mlir::iree_compiler::IREE::HAL::getTargetOptionsFromFlags();
-
- // Register MLIRContext command-line options like
- // -mlir-print-op-on-diagnostic.
- mlir::registerMLIRContextCLOptions();
- // Register assembly printer command-line options like
- // -mlir-print-op-generic.
- mlir::registerAsmPrinterCLOptions();
- // Register pass manager command-line options like -print-ir-*.
- mlir::registerPassManagerCLOptions();
- mlir::registerDefaultTimingManagerCLOptions();
-
- // Add flags for all the registered translations.
- llvm::cl::opt<const mlir::TranslateFunction *, false, mlir::TranslationParser>
- translationRequested("", llvm::cl::desc("Translation to perform"),
- llvm::cl::Required);
-
- llvm::cl::ParseCommandLineOptions(argc, argv, "IREE translation driver\n");
-
- if (printMainAddress) {
- llvm::errs() << "iree-translate main is at "
- << reinterpret_cast<void *>(&main) << "\n";
- }
-
- std::string errorMessage;
- auto input = mlir::openInputFile(inputFilename, &errorMessage);
- if (!input) {
- llvm::errs() << errorMessage << "\n";
- return 1;
- }
-
- auto output = mlir::openOutputFile(outputFilename, &errorMessage);
- if (!output) {
- llvm::errs() << errorMessage << "\n";
- return 1;
- }
-
- /// Processes the memory buffer with a new MLIRContext.
- auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
- llvm::raw_ostream &os) {
- mlir::MLIRContext context;
- context.allowUnregisteredDialects();
- context.appendDialectRegistry(registry);
- llvm::SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
- mlir::SourceMgrDiagnosticHandler diagHandler(sourceMgr, &context);
- return (*translationRequested)(sourceMgr, os, &context);
- };
-
- if (splitInputFile) {
- if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer,
- output->os())))
- return 1;
- } else {
- if (failed(processBuffer(std::move(input), output->os()))) return 1;
- }
-
- output->keep();
- return 0;
+ return mlir::iree_compiler::runIreeTranslateMain(argc, argv);
}
diff --git a/iree/tools/iree_translate_lib.cc b/iree/tools/iree_translate_lib.cc
new file mode 100644
index 0000000..ded9121
--- /dev/null
+++ b/iree/tools/iree_translate_lib.cc
@@ -0,0 +1,137 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/tools/iree_translate_lib.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <type_traits>
+
+#include "iree/compiler/Dialect/VM/Target/init_targets.h"
+#include "iree/tools/init_compiler_modules.h"
+#include "iree/tools/init_iree_dialects.h"
+#include "iree/tools/init_mlir_dialects.h"
+#include "iree/tools/init_passes.h"
+#include "iree/tools/init_targets.h"
+#include "iree/tools/init_translations.h"
+#include "iree/tools/init_xla_dialects.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SMLoc.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/Timing.h"
+#include "mlir/Support/ToolUtilities.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Translation.h"
+
+int mlir::iree_compiler::runIreeTranslateMain(int argc, char **argv) {
+ llvm::InitLLVM y(argc, argv);
+ mlir::DialectRegistry registry;
+ mlir::registerMlirDialects(registry);
+ mlir::registerLLVMDialectTranslation(registry);
+ mlir::registerXLADialects(registry);
+ mlir::iree_compiler::registerAllPasses();
+ mlir::iree_compiler::registerIreeDialects(registry);
+ mlir::iree_compiler::registerIreeCompilerModuleDialects(registry);
+ mlir::iree_compiler::registerHALTargetBackends();
+ mlir::iree_compiler::registerVMTargets();
+ mlir::registerMlirTranslations();
+ mlir::iree_compiler::registerIreeTranslations();
+ // Make sure command line options are registered.
+ (void)mlir::iree_compiler::IREE::HAL::getTargetOptionsFromFlags();
+
+ // Register MLIRContext command-line options like
+ // -mlir-print-op-on-diagnostic.
+ mlir::registerMLIRContextCLOptions();
+ // Register assembly printer command-line options like
+ // -mlir-print-op-generic.
+ mlir::registerAsmPrinterCLOptions();
+ // Register pass manager command-line options like -print-ir-*.
+ mlir::registerPassManagerCLOptions();
+ mlir::registerDefaultTimingManagerCLOptions();
+
+ // General command line flags.
+ llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+ llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"));
+
+ llvm::cl::opt<std::string> outputFilename(
+ "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+ llvm::cl::init("-"));
+
+ llvm::cl::opt<bool> splitInputFile(
+ "split-input-file",
+ llvm::cl::desc("Split the input file into pieces and "
+ "process each chunk independently"),
+ llvm::cl::init(false));
+
+ llvm::cl::opt<bool> printMainAddress(
+ "print-main-address",
+ llvm::cl::desc(
+ "Print the address of main to stderr to aid in symbolizing "
+ "stack traces after the fact"),
+ llvm::cl::init(false));
+
+ // Add flags for all the registered translations.
+ llvm::cl::opt<const mlir::TranslateFunction *, false, mlir::TranslationParser>
+ translationRequested("", llvm::cl::desc("Translation to perform"),
+ llvm::cl::Optional);
+
+ llvm::cl::ParseCommandLineOptions(argc, argv, "IREE translation driver\n");
+
+ if (printMainAddress) {
+ llvm::errs() << "iree-translate main is at "
+ << reinterpret_cast<void *>(&runIreeTranslateMain) << "\n";
+ }
+
+ std::string errorMessage;
+ auto input = mlir::openInputFile(inputFilename, &errorMessage);
+ if (!input) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+
+ auto output = mlir::openOutputFile(outputFilename, &errorMessage);
+ if (!output) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+
+ /// Processes the memory buffer with a new MLIRContext.
+ auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
+ llvm::raw_ostream &os) {
+ mlir::MLIRContext context;
+ context.allowUnregisteredDialects();
+ context.appendDialectRegistry(registry);
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
+ mlir::SourceMgrDiagnosticHandler diagHandler(sourceMgr, &context);
+ return (*translationRequested)(sourceMgr, os, &context);
+ };
+
+ if (splitInputFile) {
+ if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer,
+ output->os())))
+ return 1;
+ } else {
+ if (failed(processBuffer(std::move(input), output->os()))) return 1;
+ }
+
+ output->keep();
+ return 0;
+}
diff --git a/iree/tools/iree_translate_lib.h b/iree/tools/iree_translate_lib.h
new file mode 100644
index 0000000..572e03a
--- /dev/null
+++ b/iree/tools/iree_translate_lib.h
@@ -0,0 +1,18 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_TOOLS_IREE_TRANSLATE_LIB_H
+#define IREE_TOOLS_IREE_TRANSLATE_LIB_H
+
+namespace mlir {
+namespace iree_compiler {
+
+int runIreeTranslateMain(int argc, char **argv);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_TOOLS_IREE_TRANSLATE_LIB_H
diff --git a/llvm-external-projects/iree-compiler-api/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/CMakeLists.txt
index 76781b9..9e9f53f 100644
--- a/llvm-external-projects/iree-compiler-api/CMakeLists.txt
+++ b/llvm-external-projects/iree-compiler-api/CMakeLists.txt
@@ -35,11 +35,12 @@
set(IREE_COMPILER_API_INTREE ON)
if(IREE_COMPILER_API_INTREE)
+ set(IREE_BUILD_TESTS OFF) # Conflicts with our tests if we are top-level.
set(LLVM_EXTERNAL_MLIR_IREE_DIALECTS_SOURCE_DIR "${IREE_COMPILER_API_SOURCE_DIR}/../iree-dialects")
set(IREE_MAIN_SOURCE_DIR "${IREE_COMPILER_API_SOURCE_DIR}/../..")
set(LLVM_MAIN_SRC_DIR "${IREE_COMPILER_API_SOURCE_DIR}/../../third_party/llvm-project/llvm")
set(LLVM_EXTERNAL_MLIR_HLO_SOURCE_DIR "${IREE_COMPILER_API_SOURCE_DIR}/../../third_party/mlir-hlo")
-
+ enable_testing()
else()
message(FATAL_ERROR "Non intree (source package) not yet supported")
endif()
@@ -140,3 +141,4 @@
include(HandleLLVMOptions)
add_subdirectory(lib)
add_subdirectory(python)
+add_subdirectory(unittests)
diff --git a/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Tools.h b/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Tools.h
new file mode 100644
index 0000000..ce2cf79
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Tools.h
@@ -0,0 +1,24 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_COMPILER_API_TOOLS_H
+#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_COMPILER_API_TOOLS_H
+
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/// Runs the IREE compiler main function. This is used to build ireec-like
+/// binaries that link against a common shared library.
+MLIR_CAPI_EXPORTED int ireeCompilerRunMain(int argc, char **argv);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_COMPILER_API_TOOLS_H
diff --git a/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt
index e5fcc0d..0d50421 100644
--- a/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt
+++ b/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_public_c_api_library(IREECompilerAPICompilerCAPI
Compiler.cpp
+ Tools.cpp
# TODO: If installing, complains about IREEVM not being in any export set.
DISABLE_INSTALL
LINK_COMPONENTS
@@ -12,6 +13,9 @@
# All HAL Targets.
iree::tools::init_targets
+
+ # Tools.
+ iree::tools::iree_translate_lib
)
# TODO: Fix upstream so there is a way to know what the actual compile target
diff --git a/llvm-external-projects/iree-compiler-api/lib/CAPI/Tools.cpp b/llvm-external-projects/iree-compiler-api/lib/CAPI/Tools.cpp
new file mode 100644
index 0000000..e3b8416
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/lib/CAPI/Tools.cpp
@@ -0,0 +1,13 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-compiler-c/Tools.h"
+
+#include "iree/tools/iree_translate_lib.h"
+
+int ireeCompilerRunMain(int argc, char **argv) {
+ return mlir::iree_compiler::runIreeTranslateMain(argc, argv);
+}
diff --git a/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt
index d471da7..896dd14 100644
--- a/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt
+++ b/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt
@@ -16,6 +16,17 @@
)
declare_mlir_python_sources(IREECompilerAPIPythonExtensions)
+declare_mlir_python_sources(IREECompilerAPIPythonTools
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler"
+ SOURCES
+ __init__.py
+ tf.py
+ tflite.py
+ xla.py
+ SOURCES_GLOB
+ tools/*.py
+)
+
################################################################################
# Extensions
################################################################################
@@ -39,6 +50,7 @@
# Local sources.
IREECompilerAPIPythonSources
IREECompilerAPIPythonExtensions
+ IREECompilerAPIPythonTools
# TODO: Core is now implicitly building/registering all dialects, increasing
# build burden by ~5x. Make it stop.
@@ -82,3 +94,18 @@
COMMON_CAPI_LINK_LIBS
IREECompilerAggregateCAPI
)
+
+
+# Build the ireec tool into _mlir_libs.
+add_executable(
+ IREECompilerIREECTool
+ IREECTool.c
+)
+target_link_libraries(IREECompilerIREECTool IREECompilerAggregateCAPI)
+set_target_properties(IREECompilerIREECTool
+ PROPERTIES
+ OUTPUT_NAME "ireec"
+ RUNTIME_OUTPUT_DIRECTORY "${IREE_COMPILER_API_BINARY_DIR}/python_package/iree/compiler/_mlir_libs"
+ BUILD_RPATH_USE_ORIGIN ON
+)
+add_dependencies(IREECompilerPythonModules IREECompilerIREECTool)
diff --git a/llvm-external-projects/iree-compiler-api/python/IREECTool.c b/llvm-external-projects/iree-compiler-api/python/IREECTool.c
new file mode 100644
index 0000000..5428add
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/IREECTool.c
@@ -0,0 +1,9 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-compiler-c/Tools.h"
+
+int main(int argc, char **argv) { return ireeCompilerRunMain(argc, argv); }
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/__init__.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/__init__.py
new file mode 100644
index 0000000..b49ab00
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/__init__.py
@@ -0,0 +1,9 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Re-export some legacy APIs from the tools package to this top-level.
+# TODO: Deprecate and remove these names once clients are migrated.
+from .tools import *
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tf.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tf.py
new file mode 100644
index 0000000..7eba07c
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tf.py
@@ -0,0 +1,10 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import sys
+from .tools import tf
+
+sys.modules[__name__] = tf
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tflite.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tflite.py
new file mode 100644
index 0000000..9a9a9cb
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tflite.py
@@ -0,0 +1,10 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import sys
+from .tools import tflite
+
+sys.modules[__name__] = tflite
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/__init__.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/__init__.py
new file mode 100644
index 0000000..2adf9e5
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/__init__.py
@@ -0,0 +1,10 @@
+# Lint-as: python3
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .core import *
+from .debugging import TempFileSaver
+from .binaries import CompilerToolError
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/binaries.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/binaries.py
new file mode 100644
index 0000000..d45da87
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/binaries.py
@@ -0,0 +1,283 @@
+# Lint-as: python3
+"""Utilities for locating and invoking compiler tool binaries."""
+
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import importlib
+import io
+import os
+import platform
+import subprocess
+import sys
+import textwrap
+import threading
+
+from typing import List, Optional, Union
+
+__all__ = [
+ "find_tool",
+ "invoke_immediate",
+ "invoke_pipeline",
+ "get_tool_path",
+ "CompilerToolError",
+]
+
+_BUILTIN_TOOLS = [
+ "ireec",
+]
+
+# In normal distribution circumstances, each named tool is associated with
+# a python module that provides a `get_tool` function for getting its absolute
+# path. This dictionary maps the tool name to the module.
+_TOOL_MODULE_MAP = {
+ "iree-import-tflite": "iree.tools.tflite",
+ "iree-import-xla": "iree.tools.xla",
+ "iree-import-tf": "iree.tools.tf",
+}
+
+# Map of tool module to package name as distributed to archives (used for
+# error messages).
+_TOOL_MODULE_PACKAGES = {
+ "iree.tools.tf": "iree-tools-tf",
+ "iree.tools.tflite": "iree-tools-tflite",
+ "iree.tools.xla": "iree-tools-xla",
+}
+
+# Environment variable holding directories to be searched for named tools.
+# Delimitted by os.pathsep.
+_TOOL_PATH_ENVVAR = "IREE_TOOL_PATH"
+
+
+class CompilerToolError(Exception):
+ """Compiler exception that preserves the command line and error output."""
+
+ def __init__(self, process: subprocess.CompletedProcess):
+ try:
+ errs = process.stderr.decode("utf-8")
+ except:
+ errs = str(process.stderr) # Decode error or other: best we can do.
+
+ tool_name = os.path.basename(process.args[0])
+ super().__init__(f"Error invoking IREE compiler tool {tool_name}\n"
+ f"Diagnostics:\n{errs}\n\n"
+ f"Invoked with:\n {tool_name} {' '.join(process.args)}")
+
+
+def get_tool_path() -> List[str]:
+ """Returns list of paths to search for tools."""
+ list_str = os.environ.get(_TOOL_PATH_ENVVAR)
+ if not list_str:
+ return []
+ return list_str.split(os.pathsep)
+
+
+def find_tool(exe_name: str) -> str:
+ """Finds a tool by its (extension-less) executable name.
+
+ Args:
+ exe_name: The name of the executable (extension-less).
+ Returns:
+ An absolute path to the tool.
+ Raises:
+ ValueError: If the tool is not known or not found.
+ """
+ is_builtin = exe_name in _BUILTIN_TOOLS
+ if not is_builtin and exe_name not in _TOOL_MODULE_MAP:
+ raise ValueError(f"IREE compiler tool '{exe_name}' is not a known tool")
+ # First search an explicit tool path.
+ tool_path = get_tool_path()
+ for path_entry in tool_path:
+ if not path_entry:
+ continue
+ candidate_exe = os.path.join(path_entry, exe_name)
+ if os.path.isfile(candidate_exe) and os.access(candidate_exe, os.X_OK):
+ return candidate_exe
+
+ if is_builtin:
+ # Get builtin tool.
+ candidate_exe = _get_builtin_tool(exe_name)
+ else:
+ # Attempt to load the tool module.
+ tool_module_name = _TOOL_MODULE_MAP[exe_name]
+ tool_module_package = _TOOL_MODULE_PACKAGES[tool_module_name]
+ try:
+ tool_module = importlib.import_module(tool_module_name)
+ except ModuleNotFoundError:
+ raise ValueError(
+ f"IREE compiler tool '{exe_name}' is not installed (it should have been "
+ f"found in the python module '{tool_module_name}', typically installed "
+ f"via the package {tool_module_package}).\n\n"
+ f"Either install the package or set the {_TOOL_PATH_ENVVAR} environment "
+ f"variable to contain the path of the tool executable "
+ f"(current {_TOOL_PATH_ENVVAR} = {repr(tool_path)})") from None
+
+ # Ask the module for its tool.
+ candidate_exe = tool_module.get_tool(exe_name)
+
+ if (not candidate_exe or not os.path.isfile(candidate_exe) or
+ not os.access(candidate_exe, os.X_OK)):
+ raise ValueError(
+ f"IREE compiler tool '{exe_name}' was located in module "
+ f"'{tool_module_name}' but the file was not found or not executable: "
+ f"{candidate_exe}")
+ return candidate_exe
+
+
+def _get_builtin_tool(exe_name: str) -> Optional[str]:
+ if platform.system() == "Windows":
+ exe_name = exe_name + ".exe"
+ this_path = os.path.dirname(__file__)
+ tool_path = os.path.join(this_path, "..", "_mlir_libs", exe_name)
+ return tool_path
+
+
+def invoke_immediate(command_line: List[str],
+ *,
+ input_file: Optional[bytes] = None,
+ immediate_input=None):
+ """Invokes an immediate command.
+
+ This is separate from invoke_pipeline as it is simpler and supports more
+ complex input redirection, using recommended facilities for sub-processes
+ (less magic).
+
+ Note that this differs from the usual way of using subprocess.run or
+ subprocess.Popen().communicate() because we need to pump all of the error
+ streams individually and only pump pipes not connected to a different stage.
+ Uses threads to pump everything that is required.
+ """
+ run_args = {}
+ input_file_handle = None
+ stderr_handle = sys.stderr
+ try:
+ # Redirect input.
+ if input_file is not None:
+ input_file_handle = open(input_file, "rb")
+ run_args["stdin"] = input_file_handle
+ elif immediate_input is not None:
+ run_args["input"] = immediate_input
+
+ # Capture output.
+ # TODO(#4131) python>=3.7: Use capture_output=True.
+ run_args["stdout"] = subprocess.PIPE
+ run_args["stderr"] = subprocess.PIPE
+ process = subprocess.run(command_line, **run_args)
+ if process.returncode != 0:
+ raise CompilerToolError(process)
+ # Emit stderr contents.
+ _write_binary_stderr(stderr_handle, process.stderr)
+ return process.stdout
+ finally:
+ if input_file_handle:
+ input_file_handle.close()
+
+
+def invoke_pipeline(command_lines: List[List[str]], immediate_input=None):
+ """Invoke a pipeline of commands.
+
+ The first stage of the pipeline will have its stdin set to DEVNULL and each
+ subsequent stdin will derive from the prior stdout. The final stdout will
+ be accumulated and returned. All stderr contents are accumulated and printed
+ to stderr on completion or the first failing stage of the pipeline will have
+ an exception raised with its stderr output.
+ """
+ stages = []
+ pipeline_input = (subprocess.DEVNULL
+ if immediate_input is None else subprocess.PIPE)
+ prev_out = pipeline_input
+ stderr_handle = sys.stderr
+
+ # Create all stages.
+ for i in range(len(command_lines)):
+ command_line = command_lines[i]
+ popen_args = {
+ "stdin": prev_out,
+ "stdout": subprocess.PIPE,
+ "stderr": subprocess.PIPE,
+ }
+ process = subprocess.Popen(command_line, **popen_args)
+ prev_out = process.stdout
+ capture_output = (i == (len(command_lines) - 1))
+ stages.append(_PipelineStage(process, capture_output))
+
+ # Start stages.
+ for stage in stages:
+ stage.start()
+
+ # Pump input.
+ pipe_success = True
+ if immediate_input is not None:
+ try:
+ pipe_success = False
+ stages[0].process.stdin.write(immediate_input)
+ pipe_success = True
+ finally:
+ stages[0].process.stdin.close()
+
+ # Join.
+ for stage in stages:
+ stage.join()
+
+ # Check for errors.
+ for stage in stages:
+ assert stage.completed
+ if stage.completed.returncode != 0:
+ raise CompilerToolError(stage.completed)
+
+ # Broken pipe.
+ if not pipe_success:
+ raise CompilerToolError(stages[0].completed)
+
+ # Print any stderr output.
+ for stage in stages:
+ _write_binary_stderr(stderr_handle, stage.errs)
+ return stages[-1].outs
+
+
+class _PipelineStage(threading.Thread):
+ """Wraps a process and pumps its handles, waiting for completion."""
+
+ def __init__(self, process, capture_output):
+ super().__init__()
+ self.process = process
+ self.capture_output = capture_output
+ self.completed: Optional[subprocess.CompletedProcess] = None
+ self.outs = None
+ self.errs = None
+
+ def pump_stderr(self):
+ self.errs = self.process.stderr.read()
+
+ def pump_stdout(self):
+ self.outs = self.process.stdout.read()
+
+ def run(self):
+ stderr_thread = threading.Thread(target=self.pump_stderr)
+ stderr_thread.start()
+ if self.capture_output:
+ stdout_thread = threading.Thread(target=self.pump_stdout)
+ stdout_thread.start()
+ self.process.wait()
+ stderr_thread.join()
+ if self.capture_output:
+ stdout_thread.join()
+ self.completed = subprocess.CompletedProcess(self.process.args,
+ self.process.returncode,
+ self.outs, self.errs)
+ self.process.stderr.close()
+ self.process.stdout.close()
+
+
+def _write_binary_stderr(out_handle, contents):
+ # Fast-paths buffered text-io (which stderr is by default) while allowing
+ # full decode for non buffered and binary io.
+ if hasattr(out_handle, "buffer"):
+ out_handle.buffer.write(contents)
+ elif isinstance(out_handle, io.TextIOBase):
+ out_handle.write(contents.decode("utf-8"))
+ else:
+ out_handle.write(contents)
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/core.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/core.py
new file mode 100644
index 0000000..20962b8
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/core.py
@@ -0,0 +1,252 @@
+# Lint-as: python3
+"""Core compiler interface."""
+
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# TODO(#4131) python>=3.7: Use postponed type annotations.
+
+from enum import Enum
+import subprocess
+from typing import Any, Dict, List, Optional, Sequence, Union
+
+from .debugging import TempFileSaver
+from .binaries import *
+
+__all__ = [
+ "DEFAULT_TESTING_BACKENDS",
+ "compile_file",
+ "compile_str",
+ "CompilerOptions",
+ "OutputFormat",
+]
+
+# Default testing backend for invoking the compiler.
+# TODO: Remove these. In the absence of default profiles, though, it is better
+# to centralize.
+DEFAULT_TESTING_BACKENDS = ["dylib-llvm-aot"]
+DEFAULT_TESTING_DRIVER = "dylib"
+
+
+class InputType(Enum):
+ """The type of input pipeline to run prior to the core compiler."""
+ NONE = "none"
+ MHLO = "mhlo"
+ TOSA = "tosa"
+
+ @staticmethod
+ def parse(spec: Union[str, "InputType"]) -> "InputType":
+ """Parses or returns an InputType.
+
+ Args:
+ spec: An InputType instance or the case-insensitive name of one of the
+ enum values.
+ Returns:
+ An InputType instance.
+ """
+ if isinstance(spec, InputType):
+ return spec
+ spec = spec.upper().replace("-", "_")
+ if spec not in InputType.__members__:
+ raise ValueError(f"For input_type= argument, expected one of: "
+ f"{', '.join(InputType.__members__.keys())}")
+ return InputType[spec]
+
+
+class OutputFormat(Enum):
+ """The output format of the compiler."""
+ FLATBUFFER_BINARY = "flatbuffer-binary"
+ FLATBUFFER_TEXT = "flatbuffer-text"
+ MLIR_TEXT = "mlir-text"
+
+ @staticmethod
+ def parse(spec: Union[str, "OutputFormat"]) -> "OutputFormat":
+ """Parses or returns an OutputFormat.
+
+ Args:
+ spec: An OutputFormat instance or the case-insensitive name of one of
+ the enum values.
+ Returns:
+ An OutputFormat instance.
+ """
+ if isinstance(spec, OutputFormat):
+ return spec
+ spec = spec.upper().replace("-", "_")
+ if spec not in OutputFormat.__members__:
+ raise ValueError(f"For output_format= argument, expected one of: "
+ f"{', '.join(OutputFormat.__members__.keys())}")
+ return OutputFormat[spec]
+
+
+# TODO(#4131) python>=3.7: Consider using a dataclass.
+class CompilerOptions:
+ """Options to the compiler backend.
+
+ Keyword options:
+ output_file: Optionally save the compiled binary to a file instead of
+ returning it.
+ target_backends: List of str names of target backends to compile into
+ the binary. The resulting binary will run on targets that match one
+ or more of the compiled backends.
+ input_type: The type of input legalization to perform prior to full
+ compilation. Defaults to none.
+ output_format: Override the output format. See the OutputFormat enum.
+ Values can either be an enum value or a case-insensitive name of
+ the option. Typically used for debugging
+ extra_args: Optional list of additional arguments to pass to the compiler.
+ Example: ["--print-ir-after-all"]
+ optimize: Whether to apply some default high level optimizations (default
+ True).
+ output_mlir_debuginfo: Include debuginfo (including paths) in any saved or
+ returned MLIR.
+ output_generic_mlir: Use the generic (and more portable) MLIR formatting for
+ any saved or returned MLIR instead of the per-dialect custom assembly.
+ extended_diagnostics: Outputs extended information on diagnostics,
+ potentially outputting very verbosely (defaults to False).
+ strip_debug_ops: Whether to strip high level operations used to aid
+ debugging.
+ strip_source_map: Whether to strip source map information (used to generate
+ better errors).
+ crash_reproducer_path: File name to output an MLIR crash dump to if there
+ is a compiler failure.
+ enable_tflite_bindings: Support the IREE TFLite runtime bindings API shim.
+ enable_benchmark: Whether to generate instrumented binaries suitable
+ for benchmarking.
+ """
+
+ def __init__(self,
+ *,
+ output_file: Optional[str] = None,
+ target_backends: Sequence[str] = (),
+ input_type: Union[InputType, str] = InputType.NONE,
+ output_format: Union[OutputFormat,
+ str] = OutputFormat.FLATBUFFER_BINARY,
+ extra_args: Sequence[str] = (),
+ optimize: bool = True,
+ output_mlir_debuginfo: bool = True,
+ output_generic_mlir: bool = False,
+ extended_diagnostics: bool = False,
+ strip_debug_ops: bool = False,
+ strip_source_map: bool = False,
+ crash_reproducer_path: Optional[str] = None,
+ enable_tflite_bindings: bool = False,
+ enable_benchmark: bool = False):
+ self.output_file = output_file
+ self.target_backends = target_backends
+ self.input_type = InputType.parse(input_type)
+ self.output_format = OutputFormat.parse(output_format)
+ self.extra_args = extra_args
+ self.optimize = optimize
+ self.output_mlir_debuginfo = output_mlir_debuginfo
+ self.output_generic_mlir = output_generic_mlir
+ self.extended_diagnostics = extended_diagnostics
+ self.strip_debug_ops = strip_debug_ops
+ self.strip_source_map = strip_source_map
+ self.crash_reproducer_path = crash_reproducer_path
+ self.enable_tflite_bindings = enable_tflite_bindings
+ self.enable_benchmark = enable_benchmark
+
+
+def build_compile_command_line(input_file: str, tfs: TempFileSaver,
+ options: CompilerOptions) -> List[str]:
+ """Builds a command line for invoking the compiler.
+
+ Args:
+ input_file: The input file name.
+ tfs: TempFileSaver.
+ options: Compiler options.
+ Returns:
+ List of strings of command line.
+ """
+ iree_translate = find_tool("ireec")
+ if not options.target_backends:
+ raise ValueError("Expected a non-empty list for 'target_backends'")
+
+ cl = [
+ iree_translate,
+ input_file,
+ f"--iree-input-type={options.input_type.value}",
+ f"--iree-vm-bytecode-module-output-format={options.output_format.value}",
+ ]
+ for target_backend in options.target_backends:
+ cl.append(f"--iree-hal-target-backends={target_backend}")
+
+ # Output file.
+ output_file = tfs.alloc_optional("core-output.bin",
+ export_as=options.output_file)
+ if output_file:
+ cl.append(f"-o={output_file}")
+
+ # Translation to perform.
+ cl.append("--iree-mlir-to-vm-bytecode-module")
+
+ # MLIR flags.
+ if options.output_mlir_debuginfo:
+ cl.append("--mlir-print-debuginfo")
+ if options.output_generic_mlir:
+ cl.append("--mlir-print-op-generic")
+ if options.extended_diagnostics:
+ # Note that different tools have different defaults, so be explicit.
+ cl.append("--mlir-print-op-on-diagnostic=true")
+ else:
+ cl.append("--mlir-print-op-on-diagnostic=false")
+
+ # Other options to set if specified.
+ if options.strip_debug_ops:
+ cl.append("--iree-vm-bytecode-module-strip-debug-ops")
+ if options.strip_source_map:
+ cl.append("--iree-vm-bytecode-module-strip-source-map")
+ crash_reproducer_path = tfs.alloc_optional(
+ "core-reproducer.mlir", export_as=options.crash_reproducer_path)
+ if crash_reproducer_path:
+ cl.append(f"--pass-pipeline-crash-reproducer={crash_reproducer_path}")
+ if options.enable_tflite_bindings:
+ cl.append("--iree-tflite-bindings-support")
+ if options.enable_benchmark:
+ cl.append("--iree-flow-export-benchmark-funcs")
+
+ cl.extend(options.extra_args)
+ return cl
+
+
+def compile_file(input_file: str, **kwargs):
+ """Invokes the IREE compiler on an input file.
+
+ Args:
+ input_file: File containing MLIR assembly to compile.
+ **kwargs: Keyword arguments corresponding to CompilerOptions.
+ Returns:
+ Either a byte buffer of the compiled content or None if output_file
+ was specified in the options.
+ """
+ with TempFileSaver.implicit() as tfs:
+ options = CompilerOptions(**kwargs)
+ cl = build_compile_command_line(input_file, tfs, options)
+ result = invoke_immediate(cl)
+ if options.output_file:
+ return None
+ return result
+
+
+def compile_str(input_str: Union[str, bytes], **kwargs):
+ """Invokes the IREE compiler with an input string.
+
+ Args:
+ input_str: MLIR assembly to parse/compile (str or bytes).
+ **kwargs: Keyword arguments corresponding to CompilerOptions.
+ Returns:
+ Either a byte buffer of the compiled content or None if output_file
+ was specified in the options.
+ """
+ with TempFileSaver.implicit() as tfs:
+ options = CompilerOptions(**kwargs)
+ cl = build_compile_command_line("-", tfs, options)
+ input_bytes = input_str.encode("utf-8") if isinstance(input_str,
+ str) else input_str
+ result = invoke_immediate(cl, immediate_input=input_bytes)
+ if options.output_file:
+ return None
+ return result
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/debugging.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/debugging.py
new file mode 100644
index 0000000..fc897dd
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/debugging.py
@@ -0,0 +1,176 @@
+"""Debugging support."""
+
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional
+
+import logging
+import os
+import shutil
+import sys
+import threading
+
+_thread_locals = threading.local()
+_invocation_id = 0
+
+
+def _get_temp_file_saver_stack():
+ try:
+ return _thread_locals.temp_file_saver_stack
+ except AttributeError:
+ stack = []
+ _thread_locals.temp_file_saver_stack = stack
+ return stack
+
+
+def _interpolate_path_pattern(path_pattern: str, *, invocation_id: str):
+ # We do not use str.format() because we do not know the providence of
+ # path_pattern. Instead, handle a fixed set of replacements.
+ path_pattern = path_pattern.replace("{id}", str(invocation_id))
+ path_pattern = path_pattern.replace("{pid}", str(os.getpid()))
+ path_pattern = path_pattern.replace("{main}", os.path.basename(sys.argv[0]))
+ return path_pattern
+
+
+class TempFileSaver:
+ """Manages the saving of temp files resulting from tool invocations.
+
+ The TempFileSaver is a thread-local context bound object. An attempt to
+ create a new one will return the most recent instance created and entered
+ as a context manager. This allows up-stack callers to establish the
+ policy for saving temporaries and deep implementations will inherit it.
+
+ Proper usage from users wishing to establish a saver context:
+ with TempFileSaver():
+ # Do things with temp files.
+
+ Proper usage for implementors wishing to use an established saver context
+ or set up a new one:
+ with TempFileSaver.implicit() as tfs:
+ # Do things with temp files.
+
+ The outer-most creator can customize it with explicit arguments to __init__
+ but these will be ignored if an instance is already thread bound.
+ """
+ TEMP_PATH_ENV_KEY = "IREE_SAVE_TEMPS"
+
+ @staticmethod
+ def implicit():
+ stack = _get_temp_file_saver_stack()
+ if stack:
+ return stack[-1]
+ return TempFileSaver()
+
+ def __init__(self,
+ temp_path_pattern: Optional[str] = None,
+ *,
+ invocation_id: Optional[str] = None):
+ self.retained = False
+ self._refcount = 0
+ if temp_path_pattern is None:
+ temp_path_pattern = os.environ.get(TempFileSaver.TEMP_PATH_ENV_KEY)
+ if temp_path_pattern is None:
+ return
+
+ global _invocation_id
+ if invocation_id is not None:
+ self.invocation_id = invocation_id
+ else:
+ self.invocation_id = _invocation_id
+ _invocation_id += 1
+
+ self.retained_path = _interpolate_path_pattern(
+ temp_path_pattern, invocation_id=self.invocation_id)
+ self.retained = True
+ self._retained_file_names = set()
+ self._copy_on_finalize = list() # Of (source_path, target_path)
+
+ def __enter__(self):
+ _get_temp_file_saver_stack().append(self)
+ self._refcount += 1
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ del _get_temp_file_saver_stack()[-1]
+ self._refcount -= 1
+ if self._refcount == 0:
+ self._finalize()
+
+ @staticmethod
+ def current():
+ try:
+ return _get_temp_file_saver_stack()[-1]
+ except KeyError:
+ raise RuntimeError("No current TempFileSaver")
+
+ def alloc_optional(self,
+ file_name: str,
+ *,
+ export_as: Optional[str] = None) -> Optional[str]:
+ """Allocates an optional temporary file.
+
+
+ When in non-retained mode, the return value is 'export_as', meaning that the
+ file is just some user specified output file.
+
+ When in retained mode, the output file will be an index-mangled variant
+ of 'file_name' under the temp_path. In addition, a mapping will be added
+ so that upon finalization, the file is also exported to 'export_as' if
+ specified.
+
+ Returns None if neither a user-specified 'export_as' is specified nor in
+ retained mode.
+
+ The distinction between retained temporaries and exports is to help in
+ cases for when the caller has requested that an artifact be written to
+ a specific place (i.e. an output file) but for debuggability, we also
+ want to save it as a temporary. In this case, we save it to the temporary
+ location and then conclude by moving artifacts to their final location
+ once the saver goes out of scope.
+ """
+ if not self.retained:
+ return export_as
+ alloced_path = self._alloc_retained_path(file_name)
+ if export_as:
+ self._copy_on_finalize.append((alloced_path, export_as))
+ return alloced_path
+
+ def _alloc_retained_path(self, file_name: str) -> str:
+ assert self.retained
+ index = 0
+ original_file_name = file_name
+ while True:
+ if file_name not in self._retained_file_names:
+ # First use of this name.
+ self._retained_file_names.add(file_name)
+ os.makedirs(self.retained_path, exist_ok=True)
+ return os.path.join(self.retained_path, file_name)
+ index += 1
+ stem, ext = os.path.splitext(original_file_name)
+ file_name = f"{stem}_{index}{ext}"
+
+ def _finalize(self):
+ if not self.retained:
+ return
+ # See which files were materialized.
+ was_materialized = []
+ for file_name in self._retained_file_names:
+ file_path = os.path.join(self.retained_path, file_name)
+ if os.path.exists(file_path):
+ was_materialized.append((file_name, file_path))
+ if was_materialized:
+ logging.info(
+ "**** IREE Compiler retained temporary files (%s)***:\n%s",
+ self.invocation_id, "\n".join([
+ f" * {file_name} : {file_path}"
+ for file_name, file_path in was_materialized
+ ]))
+ for source_path, target_path in self._copy_on_finalize:
+ if os.path.exists(source_path):
+ logging.info("Copy retained file to output: %s -> %s", source_path,
+ target_path)
+ shutil.copyfile(source_path, target_path)
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/__init__.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/__init__.py
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/ireec/__main__.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/ireec/__main__.py
new file mode 100644
index 0000000..2ded4b8
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/scripts/ireec/__main__.py
@@ -0,0 +1,21 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import subprocess
+import sys
+
+from iree.compiler.tools import binaries
+
+
+def main(args=None):
+ if args is None:
+ args = sys.argv[1:]
+ exe = binaries.find_tool("ireec")
+ return subprocess.call(args=[exe] + args)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tf.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tf.py
new file mode 100644
index 0000000..7056f33
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tf.py
@@ -0,0 +1,238 @@
+# Lint-as: python3
+"""TensorFlow compiler interface."""
+
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# TODO(#4131) python>=3.7: Use postponed type annotations.
+
+from enum import Enum
+import logging
+import tempfile
+from typing import List, Optional, Sequence, Set, Union
+
+from .core import CompilerOptions, DEFAULT_TESTING_BACKENDS, build_compile_command_line
+from .debugging import TempFileSaver
+from .binaries import find_tool, invoke_immediate, invoke_pipeline
+
+__all__ = [
+ "compile_saved_model",
+ "compile_module",
+ "is_available",
+ "DEFAULT_TESTING_BACKENDS",
+ "ImportOptions",
+ "ImportType",
+]
+
+_TF_IMPORT_TOOL = "iree-import-tf"
+
+
+def is_available():
+ """Determine if TensorFlow and the compiler are available."""
+ try:
+ import tensorflow as tf
+ except ModuleNotFoundError:
+ logging.warn("Unable to import tensorflow")
+ return False
+ try:
+ find_tool(_TF_IMPORT_TOOL)
+ except ValueError:
+ logging.warning("Unable to find IREE tool %s", _TF_IMPORT_TOOL)
+ return False
+ return True
+
+
+class ImportType(Enum):
+ """Import type of the model."""
+ OBJECT_GRAPH = "savedmodel_v2"
+ V2 = "savedmodel_v2"
+ SIGNATURE_DEF = "savedmodel_v1"
+ V1 = "savedmodel_v1"
+
+ @staticmethod
+ def parse(spec: Union[str, "ImportType"]) -> "ImportType":
+ """Parses or returns an ImportType.
+
+ Args:
+ spec: An ImportType instance or the case-insensitive name of one of
+ the enum values.
+ Returns:
+ An ImportType instance.
+ """
+ if isinstance(spec, ImportType):
+ return spec
+ spec = spec.upper()
+ if spec not in ImportType.__members__:
+ raise ValueError(f"For import_type= argument, expected one of: "
+ f"{', '.join(ImportType.__members__.keys())}")
+ return ImportType[spec]
+
+
+# TODO(#4131) python>=3.7: Consider using a dataclass.
+class ImportOptions(CompilerOptions):
+ """Import options layer on top of the backend compiler options."""
+
+ def __init__(self,
+ exported_names: Sequence[str] = (),
+ import_only: bool = False,
+ import_type: Union[ImportType, str] = ImportType.OBJECT_GRAPH,
+ saved_model_tags: Set[str] = set(),
+ import_extra_args: Sequence[str] = (),
+ save_temp_tf_input: Optional[str] = None,
+ save_temp_mid_level_input: Optional[str] = None,
+ save_temp_iree_input: Optional[str] = None,
+ use_tosa: bool = False,
+ **kwargs):
+ """Initialize options from keywords.
+
+ Args:
+ exported_names: Optional sequence representing the exported names to
+ keep (object graph/v2 models only).
+ import_only: Only import the module. If True, the result will be textual
+ MLIR that can be further fed to the IREE compiler. If False (default),
+ the result will be the fully compiled IREE binary. In both cases,
+ bytes-like output is returned. Note that if the output_file= is
+ specified and import_only=True, then the MLIR form will be written to
+ the output file.
+ import_type: Type of import to perform. See ImportType enum.
+ saved_model_tags: Set of tags to export (signature def/v1 saved models
+ only).
+ import_extra_args: Extra arguments to pass to the iree-import-tf tool.
+ save_temp_tf_input: Optionally save the IR that is input to the
+ TensorFlow pipeline.
+ save_temp_mid_level_input: Optionally save the IR that is input to the
+ mid level IR.
+ save_temp_iree_input: Optionally save the IR that is the result of the
+ import (ready to be passed to IREE).
+ """
+ super().__init__(**kwargs)
+ self.exported_names = exported_names
+ self.import_only = import_only
+ self.import_type = ImportType.parse(import_type)
+ self.saved_model_tags = saved_model_tags
+ self.import_extra_args = import_extra_args
+ self.save_temp_tf_input = save_temp_tf_input
+ self.save_temp_mid_level_input = save_temp_mid_level_input
+ self.save_temp_iree_input = save_temp_iree_input
+ self.use_tosa = use_tosa
+
+
+def build_import_command_line(input_path: str, tfs: TempFileSaver,
+ options: ImportOptions) -> List[str]:
+ """Builds a command line for invoking the import stage.
+
+ Args:
+ input_path: The input path.
+ tfs: TempFileSaver.
+ options: Import options.
+ Returns:
+ List of strings of command line.
+ """
+ tf_import = find_tool(_TF_IMPORT_TOOL)
+ cl = [
+ tf_import,
+ input_path,
+ f"--tf-import-type={options.import_type.value}",
+ f"--tf-savedmodel-exported-names={','.join(options.exported_names)}",
+ f"--tf-savedmodel-tags={','.join(options.saved_model_tags)}",
+ ]
+
+ if options.import_only and options.output_file:
+ # Import stage directly outputs.
+ output_file = tfs.alloc_optional("tf-output.mlir",
+ export_as=options.output_file)
+ cl.append(f"-o={output_file}")
+
+ # MLIR flags.
+ if options.output_mlir_debuginfo:
+ cl.append("--mlir-print-debuginfo")
+ if options.output_generic_mlir:
+ cl.append("--mlir-print-op-generic")
+
+ # Save temps flags.
+ save_tf_input = tfs.alloc_optional("tf-input.mlir",
+ export_as=options.save_temp_tf_input)
+ if save_tf_input:
+ cl.append(f"--save-temp-tf-input={save_tf_input}")
+ save_mid_level_input = tfs.alloc_optional(
+ "tf-mid-level-input.mlir", export_as=options.save_temp_mid_level_input)
+ if save_mid_level_input:
+ cl.append(f"--save-temp-mid-level-input={save_mid_level_input}")
+ save_iree_input = tfs.alloc_optional("tf-iree-input.mlir",
+ export_as=options.save_temp_iree_input)
+ if save_iree_input:
+ cl.append(f"--save-temp-iree-input={save_iree_input}")
+
+ if options.use_tosa:
+ cl.append(f"--use-tosa")
+
+ # Crash reproducer (locally qualified).
+ requested_crash_reproducer_path = options.crash_reproducer_path
+ if requested_crash_reproducer_path:
+ requested_crash_reproducer_path = (requested_crash_reproducer_path +
+ ".import-tf")
+ crash_reproducer_path = tfs.alloc_optional(
+ "tf-reproducer.mlir", export_as=requested_crash_reproducer_path)
+ if crash_reproducer_path:
+ cl.append(f"--pass-pipeline-crash-reproducer={crash_reproducer_path}")
+
+ # Extra args.
+ cl.extend(options.import_extra_args)
+ return cl
+
+
+def compile_saved_model(saved_model_dir: str, **kwargs):
+ """Compiles an on-disk saved model to an IREE binary.
+
+ Args:
+ saved_model_dir: Path to directory where the model was saved.
+ **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+ Returns:
+ A bytes-like object with the compiled output or None if output_file=
+ was specified.
+ """
+ with TempFileSaver.implicit() as tfs:
+ options = ImportOptions(**kwargs)
+ import_cl = build_import_command_line(saved_model_dir, tfs, options)
+ if options.import_only:
+ # One stage tool pipeline.
+ result = invoke_immediate(import_cl)
+ if options.output_file:
+ return None
+ return result
+
+ # Full compilation pipeline.
+ compile_cl = build_compile_command_line("-", tfs, options)
+ result = invoke_pipeline([import_cl, compile_cl])
+ if options.output_file:
+ return None
+ return result
+
+
+def compile_module(module, saved_model_dir: Optional[str] = None, **kwargs):
+ """Compiles a tf.Module to an IREE binary (by saving to disk).
+
+ Args:
+ module: The tf.Module instance to convert to MLIR
+ saved_model_dir: Optional path to save the tf.Module to. The module will not
+ be persisted on disk outside of this call if this is not provided.
+ **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+ Returns:
+ Same as compile_saved_model().
+ """
+ with TempFileSaver.implicit() as tfs:
+
+ def do_it(saved_model_dir):
+ import tensorflow as tf
+ options = tf.saved_model.SaveOptions(save_debug_info=True)
+ tf.saved_model.save(module, saved_model_dir, options=options)
+ return compile_saved_model(saved_model_dir, **kwargs)
+
+ if saved_model_dir:
+ return do_it(saved_model_dir)
+ else:
+ with tempfile.TemporaryDirectory(suffix=".sm") as td:
+ return do_it(td)
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tflite.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tflite.py
new file mode 100644
index 0000000..327becc
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/tflite.py
@@ -0,0 +1,198 @@
+# Lint-as: python3
+"""TFLite compiler interface."""
+
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# TODO(#4131) python>=3.7: Use postponed type annotations.
+
+from enum import Enum
+import logging
+import tempfile
+from typing import List, Optional, Sequence, Set, Union
+
+from .debugging import TempFileSaver
+from .binaries import find_tool, invoke_immediate, invoke_pipeline
+from .core import CompilerOptions, DEFAULT_TESTING_BACKENDS, build_compile_command_line
+
+__all__ = [
+ "compile_file",
+ "compile_str",
+ "is_available",
+ "DEFAULT_TESTING_BACKENDS",
+ "ImportOptions",
+]
+
+_IMPORT_TOOL = "iree-import-tflite"
+
+
+def is_available():
+ """Determine if the XLA frontend is available."""
+ try:
+ find_tool(_IMPORT_TOOL)
+ except ValueError:
+ logging.warning("Unable to find IREE tool %s", _IMPORT_TOOL)
+ return False
+ return True
+
+
+# TODO(#4131) python>=3.7: Consider using a dataclass.
+class ImportOptions(CompilerOptions):
+ """Import options layer on top of the backend compiler options."""
+
+ def __init__(self,
+ input_arrays: Sequence[str] = (),
+ output_arrays: Sequence[str] = (),
+ import_only: bool = False,
+ import_extra_args: Sequence[str] = (),
+ save_temp_tfl_input: Optional[str] = None,
+ save_temp_iree_input: Optional[str] = None,
+ input_type: Optional[str] = "tosa",
+ **kwargs):
+ """Initialize options from keywords.
+
+ Args:
+ input_arrays: Sequence of input array node names (if different from
+ default).
+ output_arrays: Sequence of output array node names (if different from
+ default).
+ import_only: Only import the module. If True, the result will be textual
+ MLIR that can be further fed to the IREE compiler. If False (default),
+ the result will be the fully compiled IREE binary. In both cases,
+ bytes-like output is returned. Note that if the output_file= is
+ specified and import_only=True, then the MLIR form will be written to
+ the output file.
+ import_extra_args: Extra arguments to pass to the iree-import-tf tool.
+ save_temp_tfl_input: Optionally save the IR that results from importing
+ the flatbuffer (prior to any further transformations).
+ save_temp_iree_input: Optionally save the IR that is the result of the
+ import (ready to be passed to IREE).
+ """
+ super().__init__(input_type=input_type, **kwargs)
+ self.input_arrays = input_arrays
+ self.output_arrays = output_arrays
+ self.import_only = import_only
+ self.import_extra_args = import_extra_args
+ self.save_temp_tfl_input = save_temp_tfl_input
+ self.save_temp_iree_input = save_temp_iree_input
+
+
+def build_import_command_line(input_path: str, tfs: TempFileSaver,
+ options: ImportOptions) -> List[str]:
+ """Builds a command line for invoking the import stage.
+
+ Args:
+ input_path: The input path.
+ tfs: TempFileSaver.
+ options: Import options.
+ Returns:
+ List of strings of command line.
+ """
+ import_tool = find_tool(_IMPORT_TOOL)
+ cl = [
+ import_tool,
+ input_path,
+ ]
+
+ if options.import_only and options.output_file:
+ # Import stage directly outputs.
+ output_file = tfs.alloc_optional("tflite-output.mlir",
+ export_as=options.output_file)
+ cl.append(f"-o={options.output_file}")
+
+ # Input arrays.
+ if options.input_arrays:
+ for input_array in options.input_arrays:
+ cl.append(f"--input-array={input_array}")
+ for output_array in options.output_arrays:
+ cl.append(f"--output-array={output_array}")
+
+ # MLIR flags.
+ if options.output_mlir_debuginfo:
+ cl.append("--mlir-print-debuginfo")
+ if options.output_generic_mlir:
+ cl.append("--mlir-print-op-generic")
+
+ # Save temps flags.
+ tfl_input = tfs.alloc_optional("tflite-input.mlir",
+ export_as=options.save_temp_tfl_input)
+ if tfl_input:
+ cl.append(f"--save-temp-tfl-input={tfl_input}")
+ iree_input = tfs.alloc_optional("tflite-iree-input.mlir",
+ export_as=options.save_temp_iree_input)
+ if iree_input:
+ cl.append(f"--save-temp-iree-input={iree_input}")
+
+ # Crash reproducer (locally qualified).
+ requested_crash_reproducer_path = options.crash_reproducer_path
+ if requested_crash_reproducer_path:
+ requested_crash_reproducer_path = (requested_crash_reproducer_path +
+ ".import-tflite")
+ crash_reproducer_path = tfs.alloc_optional(
+ "tflite-reproducer.mlir", export_as=requested_crash_reproducer_path)
+ if crash_reproducer_path:
+ cl.append(f"--pass-pipeline-crash-reproducer={crash_reproducer_path}")
+
+ # Extra args.
+ cl.extend(options.import_extra_args)
+ return cl
+
+
+def compile_file(fb_path: str, **kwargs):
+ """Compiles a TFLite flatbuffer file to an IREE binary.
+
+ Args:
+ fb_path: Path to the flatbuffer.
+ **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+ Returns:
+ A bytes-like object with the compiled output or None if output_file=
+ was specified.
+ """
+ with TempFileSaver.implicit() as tfs:
+ options = ImportOptions(**kwargs)
+ import_cl = build_import_command_line(fb_path, tfs, options)
+ if options.import_only:
+ # One stage tool pipeline.
+ result = invoke_immediate(import_cl)
+ if options.output_file:
+ return None
+ return result
+
+ # Full compilation pipeline.
+ compile_cl = build_compile_command_line("-", tfs, options)
+ result = invoke_pipeline([import_cl, compile_cl])
+ if options.output_file:
+ return None
+ return result
+
+
+def compile_str(fb_content: bytes, **kwargs):
+ """Compiles in-memory TFLite flatbuffer to an IREE binary.
+
+ Args:
+ xla_content: Flatbuffer content as bytes.
+ **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+ Returns:
+ A bytes-like object with the compiled output or None if output_file=
+ was specified.
+ """
+ with TempFileSaver.implicit() as tfs:
+ options = ImportOptions(**kwargs)
+ import_cl = build_import_command_line("-", tfs, options)
+ if options.import_only:
+ # One stage tool pipeline.
+ result = invoke_immediate(import_cl, immediate_input=fb_content)
+ if options.output_file:
+ return None
+ return result
+
+ # Full compilation pipeline.
+ compile_cl = build_compile_command_line("-", tfs, options)
+ result = invoke_pipeline([import_cl, compile_cl],
+ immediate_input=fb_content)
+ if options.output_file:
+ return None
+ return result
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/xla.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/xla.py
new file mode 100644
index 0000000..9b2b1c2
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/tools/xla.py
@@ -0,0 +1,212 @@
+# Lint-as: python3
+"""XLA compiler interface."""
+
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# TODO(#4131) python>=3.7: Use postponed type annotations.
+
+from enum import Enum
+import logging
+import tempfile
+from typing import List, Optional, Sequence, Set, Union
+
+from .core import CompilerOptions, DEFAULT_TESTING_BACKENDS, build_compile_command_line
+from .debugging import TempFileSaver
+from .binaries import find_tool, invoke_immediate, invoke_pipeline
+
+__all__ = [
+ "compile_file",
+ "compile_str",
+ "is_available",
+ "DEFAULT_TESTING_BACKENDS",
+ "ImportOptions",
+ "ImportFormat",
+]
+
+_IMPORT_TOOL = "iree-import-xla"
+
+
+def is_available():
+ """Determine if the XLA frontend is available."""
+ try:
+ find_tool(_IMPORT_TOOL)
+ except ValueError:
+ logging.warning("Unable to find IREE tool %s", _IMPORT_TOOL)
+ return False
+ return True
+
+
+class ImportFormat(Enum):
+ """Import type of the model."""
+ BINARY_PROTO = "binary_proto"
+ TEXT_PROTO = "text_proto"
+ HLO_TEXT = "hlo_text"
+
+ @staticmethod
+ def parse(spec: Union[str, "ImportFormat"]) -> "ImportFormat":
+ """Parses or returns an ImportFormat.
+
+ Args:
+ spec: An ImportFormat instance or the case-insensitive name of one of
+ the enum values.
+ Returns:
+ An ImportFormat instance.
+ """
+ if isinstance(spec, ImportFormat):
+ return spec
+ spec = spec.upper()
+ if spec not in ImportFormat.__members__:
+ raise ValueError(f"For import_format= argument, expected one of: "
+ f"{', '.join(ImportFormat.__members__.keys())}")
+ return ImportFormat[spec]
+
+
+# TODO(#4131) python>=3.7: Consider using a dataclass.
+class ImportOptions(CompilerOptions):
+ """Import options layer on top of the backend compiler options."""
+
+ def __init__(self,
+ import_only: bool = False,
+ import_format: Union[ImportFormat,
+ str] = ImportFormat.BINARY_PROTO,
+ import_extra_args: Sequence[str] = (),
+ save_temp_mhlo_input: Optional[str] = None,
+ save_temp_iree_input: Optional[str] = None,
+ **kwargs):
+ """Initialize options from keywords.
+
+ Args:
+ import_format: Format of the proto (text or binary).
+ save_temp_iree_input: Optionally save the IR that is the result of the
+ import (ready to be passed to IREE).
+ """
+ super().__init__(**kwargs)
+ self.import_only = import_only
+ self.import_format = ImportFormat.parse(import_format)
+ self.import_extra_args = import_extra_args
+ self.save_temp_mhlo_input = save_temp_mhlo_input
+ self.save_temp_iree_input = save_temp_iree_input
+
+
+def build_import_command_line(input_path: str, tfs: TempFileSaver,
+ options: ImportOptions) -> List[str]:
+ """Builds a command line for invoking the import stage.
+
+ Args:
+ input_path: The input path.
+ tfs: TempFileSaver.
+ options: Import options.
+ Returns:
+ List of strings of command line.
+ """
+ import_tool = find_tool(_IMPORT_TOOL)
+ cl = [
+ import_tool,
+ input_path,
+ f"--xla-format={options.import_format.value}",
+ ]
+
+ if options.import_only and options.output_file:
+ # Import stage directly outputs.
+ output_file = tfs.alloc_optional("xla-output.mlir",
+ export_as=options.output_file)
+ cl.append(f"-o={output_file}")
+
+ # MLIR flags.
+ if options.output_mlir_debuginfo:
+ cl.append("--mlir-print-debuginfo")
+ if options.output_generic_mlir:
+ cl.append("--mlir-print-op-generic")
+
+ # Save temps flags.
+ save_mhlo_input = tfs.alloc_optional("tf-mhlo.mlir",
+ export_as=options.save_temp_mhlo_input)
+ if save_mhlo_input:
+ cl.append(f"--save-temp-mhlo-input={save_mhlo_input}")
+ iree_input = tfs.alloc_optional("xla-iree-input.mlir",
+ export_as=options.save_temp_iree_input)
+ if iree_input:
+ cl.append(f"--save-temp-iree-input={iree_input}")
+
+ # Crash reproducer (locally qualified).
+ requested_crash_reproducer_path = options.crash_reproducer_path
+ if requested_crash_reproducer_path:
+ requested_crash_reproducer_path = (requested_crash_reproducer_path +
+ ".import-xla")
+ crash_reproducer_path = tfs.alloc_optional(
+ "xla-reproducer.mlir", export_as=requested_crash_reproducer_path)
+ if crash_reproducer_path:
+ cl.append(f"--pass-pipeline-crash-reproducer={crash_reproducer_path}")
+
+ # Extra args.
+ cl.extend(options.import_extra_args)
+ return cl
+
+
+def compile_file(xla_file_path: str, **kwargs):
+ """Compiles an on-disk XLA protocol buffer to an IREE binary.
+
+ Args:
+ xla_file_path: Path to the XLA proto file.
+ **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+ Returns:
+ A bytes-like object with the compiled output or None if output_file=
+ was specified.
+ """
+ with TempFileSaver.implicit() as tfs:
+ options = ImportOptions(**kwargs)
+ import_cl = build_import_command_line(xla_file_path, tfs, options)
+ if options.import_only:
+ # One stage tool pipeline.
+ result = invoke_immediate(import_cl)
+ if options.output_file:
+ return None
+ return result
+
+ # Full compilation pipeline.
+ compile_cl = build_compile_command_line("-", tfs, options)
+ result = invoke_pipeline([import_cl, compile_cl])
+ if options.output_file:
+ return None
+ return result
+
+
+def compile_str(xla_content: Union[bytes, str], **kwargs):
+ """Compiles in-memory XLA content to an IREE binary.
+
+ Args:
+ xla_content: Either bytes or str content (str is only valid for text
+ formats).
+ **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions.
+ Returns:
+ A bytes-like object with the compiled output or None if output_file=
+ was specified.
+ """
+ with TempFileSaver.implicit() as tfs:
+ options = ImportOptions(**kwargs)
+ if isinstance(xla_content, str):
+ if options.import_format not in [
+ ImportFormat.TEXT_PROTO, ImportFormat.HLO_TEXT
+ ]:
+ raise ValueError("If passing a string, ImportFormat must be TEXT_PROTO")
+ xla_content = xla_content.encode("utf-8")
+
+ import_cl = build_import_command_line("-", tfs, options)
+ if options.import_only:
+ # One stage tool pipeline.
+ result = invoke_immediate(import_cl, immediate_input=xla_content)
+ if options.output_file:
+ return None
+ return result
+
+ # Full compilation pipeline.
+ compile_cl = build_compile_command_line("-", tfs, options)
+ result = invoke_pipeline([import_cl, compile_cl],
+ immediate_input=xla_content)
+ if options.output_file:
+ return None
+ return result
diff --git a/llvm-external-projects/iree-compiler-api/python/iree/compiler/xla.py b/llvm-external-projects/iree-compiler-api/python/iree/compiler/xla.py
new file mode 100644
index 0000000..e66d5fe
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/python/iree/compiler/xla.py
@@ -0,0 +1,10 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import sys
+from .tools import xla
+
+sys.modules[__name__] = xla
diff --git a/llvm-external-projects/iree-compiler-api/unittests/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/unittests/CMakeLists.txt
new file mode 100644
index 0000000..f7c6b56
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(tools)
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/unittests/tools/CMakeLists.txt
new file mode 100644
index 0000000..9233c3f
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/CMakeLists.txt
@@ -0,0 +1,52 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+function(iree_compiler_api_py_test)
+ cmake_parse_arguments(
+ ARG
+ ""
+ "NAME;MAIN"
+ ""
+ ${ARGN}
+ )
+ set(TEST_NAME "iree-compiler-api-${ARG_NAME}")
+ add_test(
+ NAME
+ ${TEST_NAME}
+ WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND "${Python3_EXECUTABLE}" "${CMAKE_CURRENT_SOURCE_DIR}/${ARG_MAIN}"
+ )
+ set_tests_properties(${TEST_NAME} PROPERTIES
+ ENVIRONMENT PYTHONPATH=${IREE_COMPILER_API_BINARY_DIR}/python_package)
+endfunction()
+
+iree_compiler_api_py_test(
+ NAME
+ compiler_core_test
+ MAIN
+ "compiler_core_test.py"
+)
+
+iree_compiler_api_py_test(
+ NAME
+ compiler_tf_test
+ MAIN
+ "compiler_tf_test.py"
+)
+
+iree_compiler_api_py_test(
+ NAME
+ compiler_tflite_test
+ MAIN
+ "compiler_tflite_test.py"
+)
+
+iree_compiler_api_py_test(
+ NAME
+ compiler_xla_test
+ MAIN
+ "compiler_xla_test.py"
+)
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/README.md b/llvm-external-projects/iree-compiler-api/unittests/tools/README.md
new file mode 100644
index 0000000..e6e6c29
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/README.md
@@ -0,0 +1,8 @@
+# Python API Tests
+
+These tests are run in an environment where all available Python bindings
+are setup on the `PYTHONPATH`. Each will internally skip itself if optional
+components are not available.
+
+Note that IREE compiler tool locations can be overridden by specifying the
+`IREE_TOOL_PATH` environment variable.
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_core_test.py b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_core_test.py
new file mode 100644
index 0000000..996fc29
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_core_test.py
@@ -0,0 +1,216 @@
+# Lint as: python3
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import contextlib
+import logging
+import os
+import io
+import tempfile
+import unittest
+
+import iree.compiler.tools
+
+SIMPLE_MUL_ASM = """
+func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+"""
+
+
+class CompilerTest(unittest.TestCase):
+
+ def setUp(self):
+ if "IREE_SAVE_TEMPS" in os.environ:
+ del os.environ["IREE_SAVE_TEMPS"]
+
+ def testNoTargetBackends(self):
+ with self.assertRaisesRegex(
+ ValueError, "Expected a non-empty list for 'target_backends'"):
+ binary = iree.compiler.tools.compile_str(SIMPLE_MUL_ASM)
+
+ def testCompileStr(self):
+ binary = iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type="mhlo",
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+ logging.info("Flatbuffer size = %d", len(binary))
+ self.assertTrue(binary)
+
+ # Compiling the string form means that the compiler does not have a valid
+ # source file name, which can cause issues on the AOT side. Verify
+ # specifically. See: https://github.com/google/iree/issues/4439
+ def testCompileStrLLVMAOT(self):
+ binary = iree.compiler.tools.compile_str(SIMPLE_MUL_ASM,
+ input_type="mhlo",
+ target_backends=["dylib-llvm-aot"])
+ logging.info("Flatbuffer size = %d", len(binary))
+ self.assertTrue(binary)
+
+ # Verifies that multiple target_backends are accepted. Which two are not
+ # load bearing.
+ # See: https://github.com/google/iree/issues/4436
+ def testCompileMultipleBackends(self):
+ binary = iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type="mhlo",
+ target_backends=["dylib-llvm-aot", "vulkan-spirv"])
+ logging.info("Flatbuffer size = %d", len(binary))
+ self.assertTrue(binary)
+
+ def testCompileInputFile(self):
+ with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+ try:
+ f.write(SIMPLE_MUL_ASM)
+ f.close()
+ binary = iree.compiler.tools.compile_file(
+ f.name,
+ input_type="mhlo",
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+ finally:
+ os.remove(f.name)
+ logging.info("Flatbuffer size = %d", len(binary))
+ self.assertIn(b"simple_mul", binary)
+
+ def testCompileOutputFile(self):
+ with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+ try:
+ f.close()
+ output = iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type="mhlo",
+ output_file=f.name,
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+ self.assertIsNone(output)
+
+ with open(f.name, "rb") as f_read:
+ binary = f_read.read()
+ finally:
+ os.remove(f.name)
+ logging.info("Flatbuffer size = %d", len(binary))
+ self.assertIn(b"simple_mul", binary)
+
+ def testOutputFbText(self):
+ text = iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type="mhlo",
+ output_format=iree.compiler.tools.OutputFormat.FLATBUFFER_TEXT,
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode(
+ "utf-8")
+ # Just check for an arbitrary JSON-tag.
+ self.assertIn('"exported_functions"', text)
+
+ def testBadInputType(self):
+ with self.assertRaisesRegex(
+ ValueError, "For input_type= argument, expected one of: "
+ "NONE, MHLO, TOSA"):
+ _ = iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type="not-existing",
+ output_format="foobar",
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+
+ def testBadOutputFormat(self):
+ with self.assertRaisesRegex(
+ ValueError, "For output_format= argument, expected one of: "
+ "FLATBUFFER_BINARY, FLATBUFFER_TEXT, MLIR_TEXT"):
+ _ = iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type="mhlo",
+ output_format="foobar",
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+
+ def testOutputFbTextParsed(self):
+ text = iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type='mhlo',
+ output_format='flatbuffer_text',
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode(
+ "utf-8")
+ # Just check for an arbitrary JSON-tag.
+ self.assertIn('"exported_functions"', text)
+
+ def testOutputMlirText(self):
+ text = iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type="mhlo",
+ output_format=iree.compiler.tools.OutputFormat.MLIR_TEXT,
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode(
+ "utf-8")
+ # Just check for a textual op name.
+ self.assertIn("vm.module", text)
+
+ def testExtraArgsStderr(self):
+ # mlir-timing is not special: it just does something and emits to stderr.
+ with io.StringIO() as buf, contextlib.redirect_stderr(buf):
+ iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type="mhlo",
+ extra_args=["--mlir-timing"],
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+ stderr = buf.getvalue()
+ self.assertIn("Execution time report", stderr)
+
+ def testAllOptions(self):
+ binary = iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type="mhlo",
+ optimize=False,
+ strip_debug_ops=True,
+ strip_source_map=True,
+ crash_reproducer_path="foobar.txt",
+ # Re-enable when benchmarking pass is fixed: #6196
+ # enable_benchmark=True,
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+
+ def testException(self):
+ with self.assertRaisesRegex(iree.compiler.tools.CompilerToolError,
+ "Invoked with"):
+ _ = iree.compiler.tools.compile_str(
+ "I'm a little teapot but not a valid program",
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+
+ def testExplicitTempFileSaver(self):
+ temp_dir = tempfile.TemporaryDirectory()
+ output_file = tempfile.NamedTemporaryFile("wt")
+ output_file.close()
+ with iree.compiler.tools.TempFileSaver(temp_dir.name):
+ output = iree.compiler.tools.compile_str(
+ SIMPLE_MUL_ASM,
+ input_type="mhlo",
+ output_file=output_file.name,
+ target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS)
+ self.assertIsNone(output)
+
+ # There should be an output file and a core-output.bin in the temp dir.
+ self.assertTrue(os.path.exists(output_file.name))
+ expected_temp_file = os.path.join(temp_dir.name, "core-output.bin")
+ self.assertTrue(os.path.exists(expected_temp_file))
+
+ # And they should have the same contents.
+ with open(output_file.name, "rb") as f:
+ output_contents = f.read()
+ with open(expected_temp_file, "rb") as f:
+ temp_contents = f.read()
+ self.assertEqual(temp_contents, output_contents)
+ temp_dir.cleanup()
+
+ def testEnvTempFileSaver(self):
+ temp_dir = tempfile.TemporaryDirectory()
+ os.environ["IREE_SAVE_TEMPS"] = temp_dir.name
+ with iree.compiler.tools.TempFileSaver() as tfs:
+ self.assertTrue(tfs.retained)
+ self.assertEqual(tfs.retained_path, temp_dir.name)
+
+ def testTempFileSaverDisabled(self):
+ with iree.compiler.tools.TempFileSaver() as tfs:
+ self.assertFalse(tfs.retained)
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tf_test.py b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tf_test.py
new file mode 100644
index 0000000..fc1ee71
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tf_test.py
@@ -0,0 +1,83 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+import os
+import sys
+import tempfile
+import unittest
+
+# TODO: No idea why pytype cannot find names from this module.
+# pytype: disable=name-error
+import iree.compiler.tools.tf
+
+if not iree.compiler.tools.tf.is_available():
+ print(f"Skipping test {__file__} because the IREE TensorFlow compiler "
+ f"is not installed")
+ sys.exit(0)
+
+import tensorflow as tf
+
+
+class SimpleArithmeticModule(tf.Module):
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([4], tf.float32),
+ tf.TensorSpec([4], tf.float32)
+ ])
+ def simple_mul(self, a, b):
+ return a * b
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([128, 3072], tf.float32),
+ tf.TensorSpec([3072, 256], tf.float32),
+ ])
+ def simple_matmul(self, a, b):
+ return tf.matmul(a, b)
+
+
+# TODO(laurenzo): More test cases needed (may need additional files).
+# Specifically, figure out how to test v1 models.
+class TfCompilerTest(tf.test.TestCase):
+
+ def testImportSavedModel(self):
+ import_mlir = iree.compiler.tools.tf.compile_saved_model(
+ self.smdir, import_only=True, output_generic_mlir=True).decode("utf-8")
+ self.assertIn("sym_name = \"simple_matmul\"", import_mlir)
+
+ def testCompileSavedModel(self):
+ binary = iree.compiler.tools.tf.compile_saved_model(
+ self.smdir,
+ target_backends=iree.compiler.tools.tf.DEFAULT_TESTING_BACKENDS)
+ logging.info("Compiled len: %d", len(binary))
+ self.assertIn(b"simple_matmul", binary)
+ self.assertIn(b"simple_mul", binary)
+
+ def testCompileModule(self):
+ binary = iree.compiler.tools.tf.compile_module(
+ self.m, target_backends=iree.compiler.tools.tf.DEFAULT_TESTING_BACKENDS)
+ logging.info("Compiled len: %d", len(binary))
+ self.assertIn(b"simple_matmul", binary)
+ self.assertIn(b"simple_mul", binary)
+
+ @classmethod
+ def setUpClass(cls):
+ cls.m = SimpleArithmeticModule()
+ cls.tempdir = tempfile.TemporaryDirectory()
+ cls.smdir = os.path.join(cls.tempdir.name, "arith.sm")
+ tf.saved_model.save(
+ cls.m,
+ cls.smdir,
+ options=tf.saved_model.SaveOptions(save_debug_info=True))
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.tempdir.cleanup()
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ tf.test.main()
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tflite_test.py b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tflite_test.py
new file mode 100644
index 0000000..671f515
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_tflite_test.py
@@ -0,0 +1,102 @@
+# Lint as: python3
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+import os
+import sys
+import tempfile
+import unittest
+
+# TODO: No idea why pytype cannot find names from this module.
+# pytype: disable=name-error
+import iree.compiler.tools.tflite
+
+if not iree.compiler.tools.tflite.is_available():
+ print(f"Skipping test {__file__} because the IREE TFLite compiler "
+ f"is not installed")
+ sys.exit(0)
+
+
+class CompilerTest(unittest.TestCase):
+
+ def testImportBinaryPbFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ text = iree.compiler.tools.tflite.compile_file(
+ path, import_only=True).decode("utf-8")
+ logging.info("%s", text)
+ self.assertIn("tosa.mul", text)
+
+ def testCompileBinaryPbFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ binary = iree.compiler.tools.tflite.compile_file(
+ path,
+ target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS)
+ logging.info("Binary length = %d", len(binary))
+ self.assertIn(b"main", binary)
+
+ def testImportBinaryPbFileOutputFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+ try:
+ f.close()
+ output = iree.compiler.tools.tflite.compile_file(path,
+ import_only=True,
+ output_file=f.name)
+ self.assertIsNone(output)
+ with open(f.name, "rt") as f_read:
+ text = f_read.read()
+ finally:
+ os.remove(f.name)
+ logging.info("%s", text)
+ self.assertIn("tosa.mul", text)
+
+ def testCompileBinaryPbFileOutputFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+ try:
+ f.close()
+ output = iree.compiler.tools.tflite.compile_file(
+ path,
+ output_file=f.name,
+ target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS)
+ self.assertIsNone(output)
+ with open(f.name, "rb") as f_read:
+ binary = f_read.read()
+ finally:
+ os.remove(f.name)
+ logging.info("Binary length = %d", len(binary))
+ self.assertIn(b"main", binary)
+
+ def testImportBinaryPbBytes(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ with open(path, "rb") as f:
+ content = f.read()
+ text = iree.compiler.tools.tflite.compile_str(
+ content, import_only=True).decode("utf-8")
+ logging.info("%s", text)
+ self.assertIn("tosa.mul", text)
+
+ def testCompileBinaryPbBytes(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata",
+ "tflite_sample.fb")
+ with open(path, "rb") as f:
+ content = f.read()
+ binary = iree.compiler.tools.tflite.compile_str(
+ content,
+ target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS)
+ logging.info("Binary length = %d", len(binary))
+ self.assertIn(b"main", binary)
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_xla_test.py b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_xla_test.py
new file mode 100644
index 0000000..ad73db4
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/compiler_xla_test.py
@@ -0,0 +1,118 @@
+# Lint as: python3
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+import os
+import sys
+import tempfile
+import unittest
+
+import iree.compiler.tools.xla
+
+if not iree.compiler.tools.xla.is_available():
+ print(f"Skipping test {__file__} because the IREE XLA compiler "
+ f"is not installed")
+ sys.exit(0)
+
+
+class CompilerTest(unittest.TestCase):
+
+ def testImportBinaryPbFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+ text = iree.compiler.tools.xla.compile_file(
+ path, import_only=True).decode("utf-8")
+ logging.info("%s", text)
+ self.assertIn("linalg.generic", text)
+
+ def testCompileBinaryPbFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+ binary = iree.compiler.tools.xla.compile_file(
+ path, target_backends=iree.compiler.tools.xla.DEFAULT_TESTING_BACKENDS)
+ logging.info("Binary length = %d", len(binary))
+ self.assertIn(b"main", binary)
+
+ def testImportBinaryPbFileOutputFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+ with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+ try:
+ f.close()
+ output = iree.compiler.tools.xla.compile_file(path,
+ import_only=True,
+ output_file=f.name)
+ self.assertIsNone(output)
+ with open(f.name, "rt") as f_read:
+ text = f_read.read()
+ finally:
+ os.remove(f.name)
+ logging.info("%s", text)
+ self.assertIn("linalg.generic", text)
+
+ def testCompileBinaryPbFileOutputFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+ with tempfile.NamedTemporaryFile("wt", delete=False) as f:
+ try:
+ f.close()
+ output = iree.compiler.tools.xla.compile_file(
+ path,
+ output_file=f.name,
+ target_backends=iree.compiler.DEFAULT_TESTING_BACKENDS)
+ self.assertIsNone(output)
+ with open(f.name, "rb") as f_read:
+ binary = f_read.read()
+ finally:
+ os.remove(f.name)
+ logging.info("Binary length = %d", len(binary))
+ self.assertIn(b"main", binary)
+
+ def testImportBinaryPbBytes(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+ with open(path, "rb") as f:
+ content = f.read()
+ text = iree.compiler.tools.xla.compile_str(content,
+ import_only=True).decode("utf-8")
+ logging.info("%s", text)
+ self.assertIn("linalg.generic", text)
+
+ def testCompileBinaryPbBytes(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.pb")
+ with open(path, "rb") as f:
+ content = f.read()
+ binary = iree.compiler.tools.xla.compile_str(
+ content,
+ target_backends=iree.compiler.tools.xla.DEFAULT_TESTING_BACKENDS)
+ logging.info("Binary length = %d", len(binary))
+ self.assertIn(b"main", binary)
+
+ def testImportHloTextFile(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.hlo")
+ text = iree.compiler.tools.xla.compile_file(
+ path, import_only=True, import_format="hlo_text").decode("utf-8")
+ logging.info("%s", text)
+ self.assertIn("linalg.generic", text)
+
+ def testImportHloTextStr(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.hlo")
+ with open(path, "rt") as f:
+ content = f.read()
+ text = iree.compiler.tools.xla.compile_str(
+ content, import_only=True, import_format="hlo_text").decode("utf-8")
+ logging.info("%s", text)
+ self.assertIn("linalg.generic", text)
+
+ def testImportHloTextBytes(self):
+ path = os.path.join(os.path.dirname(__file__), "testdata", "xla_sample.hlo")
+ with open(path, "rb") as f:
+ content = f.read()
+ text = iree.compiler.tools.xla.compile_str(
+ content, import_only=True, import_format="hlo_text").decode("utf-8")
+ logging.info("%s", text)
+ self.assertIn("linalg.generic", text)
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_tflite.py b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_tflite.py
new file mode 100644
index 0000000..a325089
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_tflite.py
@@ -0,0 +1,29 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+
+import tensorflow as tf
+
+
+class Squared(tf.Module):
+
+ @tf.function
+ def __call__(self, x):
+ return tf.square(x)
+
+
+model = Squared()
+concrete_func = model.__call__.get_concrete_function(
+ tf.TensorSpec(shape=[4, 3], dtype=tf.float32))
+
+converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func],
+ model)
+tflite_model = converter.convert()
+
+this_dir = os.path.dirname(__file__)
+with open(os.path.join(this_dir, "tflite_sample.fb"), "wb") as f:
+ f.write(tflite_model)
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_xla.py b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_xla.py
new file mode 100644
index 0000000..c583291
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/generate_xla.py
@@ -0,0 +1,28 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+
+import numpy as np
+
+# Jax is the most accessible way to get at an xla_client.
+# python -m pip install --upgrade pip
+# python -m pip install --upgrade jax jaxlib
+from jaxlib import xla_client
+
+ops = xla_client.ops
+
+builder = xla_client.XlaBuilder("testbuilder")
+in_shape = np.array([4], dtype=np.float32)
+in_feed = ops.Parameter(builder, 0, xla_client.shape_from_pyval(in_shape))
+result = ops.Add(in_feed, ops.Constant(builder, np.float32(1.0)))
+xla_computation = builder.Build(result)
+
+this_dir = os.path.dirname(__file__)
+with open(os.path.join(this_dir, "xla_sample.pb"), "wb") as f:
+ f.write(xla_computation.as_serialized_hlo_module_proto())
+with open(os.path.join(this_dir, "xla_sample.hlo"), "wt") as f:
+ f.write(xla_computation.as_hlo_text())
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/tflite_sample.fb b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/tflite_sample.fb
new file mode 100644
index 0000000..52cb9e4
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/tflite_sample.fb
Binary files differ
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.hlo b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.hlo
new file mode 100644
index 0000000..2617141
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.hlo
@@ -0,0 +1,9 @@
+HloModule testbuilder.5
+
+ENTRY testbuilder.5 {
+ parameter.1 = f32[1] parameter(0)
+ constant.2 = f32[] constant(1)
+ broadcast.3 = f32[1]{0} broadcast(constant.2), dimensions={}
+ ROOT add.4 = f32[1]{0} add(parameter.1, broadcast.3)
+}
+
diff --git a/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.pb b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.pb
new file mode 100644
index 0000000..1f69a64
--- /dev/null
+++ b/llvm-external-projects/iree-compiler-api/unittests/tools/testdata/xla_sample.pb
Binary files differ