Remove pyiree.compiler and switch everything to pyiree.compiler2.
* Also remove the pyiree.tf.compiler and remove build support for pyiree.xla.compiler (will be transitioned after).
* Also removes bazel build support for python packages.
PiperOrigin-RevId: 346446168
diff --git a/bindings/python/BUILD b/bindings/python/BUILD
index 691f117..c44dcad 100644
--- a/bindings/python/BUILD
+++ b/bindings/python/BUILD
@@ -24,3 +24,8 @@
name = "pathsetup",
imports = ["."],
)
+
+filegroup(
+ name = "python_extension_headers",
+ srcs = glob(["**/*.h"]),
+)
diff --git a/bindings/python/pyiree/common/BUILD b/bindings/python/pyiree/common/BUILD
deleted file mode 100644
index 3c102c8..0000000
--- a/bindings/python/pyiree/common/BUILD
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "pybind_cc_library",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-pybind_cc_library(
- name = "common",
- srcs = [
- "status_utils.cc",
- ],
- hdrs = [
- "binding.h",
- "status_utils.h",
- ],
- deps = [
- "//iree/base:api",
- "//iree/base:status",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- ],
-)
diff --git a/bindings/python/pyiree/compiler/BUILD b/bindings/python/pyiree/compiler/BUILD
deleted file mode 100644
index 48917ed..0000000
--- a/bindings/python/pyiree/compiler/BUILD
+++ /dev/null
@@ -1,91 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "NUMPY_DEPS",
- "PYBIND_COPTS",
- "PYBIND_EXTENSION_COPTS",
- "PYBIND_FEATURES",
- "PYTHON_CPP_EXTRA_DEPS",
- "iree_py_extension",
- "iree_py_library",
- "iree_py_test",
- "pybind_cc_library",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_py_library(
- name = "compiler",
- srcs = [
- "__init__.py",
- ],
- srcs_version = "PY3",
- deps = PYTHON_CPP_EXTRA_DEPS + [
- ":binding",
- "//bindings/python:pathsetup", # build_cleaner: keep
- ],
-)
-
-iree_py_extension(
- name = "binding",
- srcs = [
- "initialize_module.cc",
- ],
- copts = PYBIND_COPTS + PYBIND_EXTENSION_COPTS,
- features = PYBIND_FEATURES,
- linkstatic = 1,
- win_def_file = "export.def",
- deps = [
- ":compiler_library",
- "//bindings/python/pyiree/common",
- ],
-)
-
-pybind_cc_library(
- name = "compiler_library",
- srcs = [
- "compiler.cc",
- ],
- hdrs = [
- "compiler.h",
- ],
- deps = [
- "//bindings/python/pyiree/common",
- "//iree/compiler/Dialect/VM/Target:init_targets",
- "//iree/compiler/Dialect/VM/Target/Bytecode",
- "//iree/tools:init_passes_and_dialects",
- "//iree/tools:init_targets",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Parser",
- "@llvm-project//mlir:Pass",
- ],
-)
-
-iree_py_test(
- name = "compiler_test",
- srcs = ["compiler_test.py"],
- python_version = "PY3",
- deps = NUMPY_DEPS + [
- "//bindings/python:pathsetup", # build_cleaner: keep
- "@absl_py//absl/testing:absltest",
- "//bindings/python/pyiree/compiler",
- ],
-)
diff --git a/bindings/python/pyiree/compiler/CMakeLists.txt b/bindings/python/pyiree/compiler/CMakeLists.txt
deleted file mode 100644
index ffc8a30..0000000
--- a/bindings/python/pyiree/compiler/CMakeLists.txt
+++ /dev/null
@@ -1,62 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-iree_py_library(
- NAME
- compiler
- SRCS
- "__init__.py"
- PYEXT_DEPS
- ::PyExtCompiler
-)
-
-iree_pyext_module(
- NAME
- PyExtCompiler
- MODULE_NAME
- binding
- UNIX_LINKER_SCRIPT
- "unix_version.lds"
- SRCS
- "initialize_module.cc"
- PYEXT_DEPS
- ::PyExtCompilerLib
- bindings::python::pyiree::common::PyextCommonLib
-)
-
-iree_pyext_library(
- NAME
- PyExtCompilerLib
- SRCS
- "compiler.h"
- "compiler.cc"
- PYEXT_DEPS
- bindings::python::pyiree::common::PyextCommonLib
- DEPS
- iree::compiler::Dialect::VM::Target::Bytecode
- iree::compiler::Dialect::VM::Target::init_targets
- iree::tools::init_passes_and_dialects
- iree::tools::init_targets
- LLVMSupport
- MLIRIR
- MLIRParser
- MLIRPass
-)
-
-iree_py_test(
- NAME
- compiler_test
- SRCS
- "compiler_test.py"
-)
diff --git a/bindings/python/pyiree/compiler/__init__.py b/bindings/python/pyiree/compiler/__init__.py
deleted file mode 100644
index 94d2f7c..0000000
--- a/bindings/python/pyiree/compiler/__init__.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Module init for the python bindings."""
-
-# pylint: disable=g-multiple-import
-# pylint: disable=g-bad-import-order
-# pylint: disable=g-import-not-at-top
-# pylint: disable=wildcard-import
-
-from . import binding as binding
-
-# Native aliases.
-llvm = binding.llvm
-Context = binding.CompilerContext
-Module = binding.CompilerModule
-CompileOptions = binding.CompileOptions
-OutputFormat = binding.OutputFormat
diff --git a/bindings/python/pyiree/compiler/compiler.cc b/bindings/python/pyiree/compiler/compiler.cc
deleted file mode 100644
index 94d96ff..0000000
--- a/bindings/python/pyiree/compiler/compiler.cc
+++ /dev/null
@@ -1,473 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "bindings/python/pyiree/compiler/compiler.h"
-
-#include <stdexcept>
-#include <string>
-
-#include "bindings/python/pyiree/common/binding.h"
-#include "bindings/python/pyiree/common/status_utils.h"
-#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
-#include "iree/compiler/Dialect/VM/Target/init_targets.h"
-#include "iree/tools/init_dialects.h"
-#include "iree/tools/init_passes.h"
-#include "iree/tools/init_targets.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/PrettyStackTrace.h"
-#include "llvm/Support/Signals.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/Location.h"
-#include "mlir/Parser.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Pass/PassRegistry.h"
-
-namespace py = pybind11;
-
-using namespace mlir;
-using namespace mlir::iree_compiler;
-
-using mlir::iree_compiler::IREE::HAL::TargetOptions;
-using mlir::iree_compiler::IREE::VM::BytecodeOutputFormat;
-using mlir::iree_compiler::IREE::VM::BytecodeTargetOptions;
-
-using llvm::MemoryBuffer;
-using llvm::MemoryBufferRef;
-using llvm::raw_ostream;
-using llvm::raw_string_ostream;
-using llvm::StringRef;
-
-namespace iree {
-namespace python {
-
-/* static */ std::mutex CompilerContextBundle::static_config_lock_;
-/* static */ absl::optional<std::string>
- CompilerContextBundle::default_crash_reproducer_path_;
-
-namespace {
-
-bool LLVMOnceInit() {
- // Enable LLVM's signal handler to get nice stack traces.
- llvm::sys::SetOneShotPipeSignalFunction(
- llvm::sys::DefaultOneShotPipeSignalHandler);
- llvm::sys::PrintStackTraceOnErrorSignal("pyiree");
-
- mlir::iree_compiler::registerAllPasses();
- mlir::iree_compiler::registerHALTargetBackends();
-
- // Register any pass manager command line options.
- mlir::registerPassManagerCLOptions();
-
- std::string program_name = "pyiree";
- std::vector<const char*> default_options = {program_name.c_str(), nullptr};
- llvm::cl::ParseCommandLineOptions(1, default_options.data());
- return true;
-}
-
-void SetupLLVMModule(pybind11::module m) {
- m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); });
- m.def(
- "add_option",
- [](std::string name, absl::optional<std::string> value) {
- auto options_map = llvm::cl::getRegisteredOptions();
- auto found_it = options_map.find(name);
- if (found_it == options_map.end()) {
- std::string message = "Unknown LLVM option: ";
- message.append(name);
- throw RaiseValueError(message.c_str());
- }
-
- std::string value_sr = value ? *value : "";
- found_it->getValue()->addOccurrence(1, name, value_sr);
- },
- py::arg("name"), py::arg("value") = absl::optional<std::string>());
- m.def(
- "reset_option",
- [](std::string name) {
- auto options_map = llvm::cl::getRegisteredOptions();
- auto found_it = options_map.find(name);
- if (found_it == options_map.end()) {
- std::string message = "Unknown LLVM option: ";
- message.append(name);
- throw RaiseValueError(message.c_str());
- }
- found_it->getValue()->setDefault();
- },
- py::arg("name"));
-}
-
-OwningModuleRef parseMLIRModuleFromString(StringRef contents,
- MLIRContext* context) {
- std::unique_ptr<MemoryBuffer> contents_buffer;
- if (contents.back() == 0) {
- // If it has a nul terminator, just use as-is.
- contents_buffer = MemoryBuffer::getMemBuffer(contents.drop_back());
- } else {
- // Otherwise, make a copy.
- contents_buffer = MemoryBuffer::getMemBufferCopy(contents, "EMBED");
- }
-
- llvm::SourceMgr source_mgr;
- source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc());
- OwningModuleRef mlir_module = parseSourceFile(source_mgr, context);
- return mlir_module;
-}
-
-} // namespace
-
-DiagnosticCapture::DiagnosticCapture(mlir::MLIRContext* mlir_context,
- DiagnosticCapture* parent)
- : mlir_context_(mlir_context), parent_(parent) {
- handler_id_ = mlir_context_->getDiagEngine().registerHandler(
- [&](Diagnostic& d) -> LogicalResult {
- diagnostics_.push_back(std::move(d));
- return success();
- });
-}
-DiagnosticCapture::~DiagnosticCapture() {
- if (mlir_context_) {
- mlir_context_->getDiagEngine().eraseHandler(handler_id_);
- if (parent_) {
- for (auto& d : diagnostics_) {
- parent_->diagnostics_.push_back(std::move(d));
- }
- }
- }
-}
-
-DiagnosticCapture::DiagnosticCapture(DiagnosticCapture&& other) {
- mlir_context_ = other.mlir_context_;
- parent_ = other.parent_;
- diagnostics_.swap(other.diagnostics_);
- handler_id_ = other.handler_id_;
- other.mlir_context_ = nullptr;
-}
-
-// Custom location printer that prints prettier, multi-line file output
-// suitable for human readable error messages. The standard printer just prints
-// a long nested expression not particularly human friendly). Note that there
-// is a location pretty printer in the MLIR AsmPrinter. It is private and
-// doesn't do any path shortening, which seems to make long Python stack traces
-// a bit easier to scan.
-void PrintLocation(Location loc, raw_ostream& out) {
- TypeSwitch<Location>(loc)
- .Case<OpaqueLoc>(
- [&](OpaqueLoc loc) { PrintLocation(loc.getFallbackLocation(), out); })
- .Case<UnknownLoc>([&](UnknownLoc) { out << " [unknown location]\n"; })
- .Case<FileLineColLoc>([&](FileLineColLoc line_col_loc) {
- StringRef this_filename = line_col_loc.getFilename();
- auto slash_pos = this_filename.find_last_of("/\\");
- // We print both the basename and extended names with a structure like
- // `foo.py:35:4`. Even though technically the line/col
- // information is redundant to include in both names, having it on both
- // makes it easier to paste the paths into an editor and jump to the
- // exact location.
- std::string line_col_suffix =
- ":" + std::to_string(line_col_loc.getLine()) + ":" +
- std::to_string(line_col_loc.getColumn());
- bool has_basename = false;
- StringRef basename = this_filename;
- if (slash_pos != StringRef::npos) {
- has_basename = true;
- basename = this_filename.substr(slash_pos + 1);
- }
- out << " at: " << basename << line_col_suffix;
- if (has_basename) {
- // When running through bazel, such as in our e2e test suite,
- // the paths involved can be quite large, and will have a very long
- // prefix before the sandboxed "runfiles" directory that the program
- // runs in. Trim off that long prefix. By convention, the path names
- // with this prefix dropped will correspond to the path in the source
- // directory, which is probably what we want anyway.
- StringRef kRunfiles(".runfiles/");
- StringRef extended_name = this_filename;
- auto runfiles_pos = extended_name.rfind(kRunfiles);
- if (runfiles_pos != StringRef::npos) {
- extended_name =
- extended_name.drop_front(runfiles_pos + kRunfiles.size());
- }
- // Print out two tabs, as basenames usually vary in length by more
- // than one tab width.
- out << "\t\t( " << extended_name << line_col_suffix << " )";
- }
- out << "\n";
- })
- .Case<NameLoc>([&](NameLoc name_loc) {
- out << " @'" << name_loc.getName() << "':\n";
- auto child_loc = name_loc.getChildLoc();
- if (!child_loc.isa<UnknownLoc>()) {
- out << "(...\n";
- PrintLocation(child_loc, out);
- out << ")\n";
- }
- })
- .Case<CallSiteLoc>([&](CallSiteLoc call_site) {
- PrintLocation(call_site.getCaller(), out);
- PrintLocation(call_site.getCallee(), out);
- });
-}
-
-std::string DiagnosticCapture::ConsumeDiagnosticsAsString(
- const char* error_message) {
- std::string s;
- raw_string_ostream sout(s);
- bool first = true;
- if (error_message) {
- sout << error_message;
- first = false;
- }
- for (auto& d : diagnostics_) {
- if (!first) {
- sout << "\n\n";
- } else {
- first = false;
- }
-
- switch (d.getSeverity()) {
- case DiagnosticSeverity::Note:
- sout << "[NOTE]";
- break;
- case DiagnosticSeverity::Warning:
- sout << "[WARNING]";
- break;
- case DiagnosticSeverity::Error:
- sout << "[ERROR]";
- break;
- case DiagnosticSeverity::Remark:
- sout << "[REMARK]";
- break;
- default:
- sout << "[UNKNOWN]";
- }
- // Message.
- sout << ": " << d << "\n";
- PrintLocation(d.getLocation(), sout);
- }
-
- diagnostics_.clear();
- return sout.str();
-}
-
-void DiagnosticCapture::ClearDiagnostics() { diagnostics_.clear(); }
-
-CompilerContextBundle::CompilerContextBundle()
- : default_capture_(&mlir_context_, nullptr) {
- mlir::iree_compiler::registerAllDialects(mlir_context_.getDialectRegistry());
-}
-CompilerContextBundle::~CompilerContextBundle() = default;
-
-CompilerModuleBundle CompilerContextBundle::ParseAsm(
- const std::string& asm_text) {
- // Arrange to get a view that includes a terminating null to avoid additional
- // copy.
- const char* asm_chars = asm_text.c_str();
- StringRef asm_sr(asm_chars, asm_text.size() + 1);
-
- auto diag_capture = CaptureDiagnostics();
- auto module_ref = parseMLIRModuleFromString(asm_sr, mlir_context());
- if (!module_ref) {
- throw RaiseValueError(
- diag_capture.ConsumeDiagnosticsAsString("Error parsing ASM").c_str());
- }
- return CompilerModuleBundle(shared_from_this(), module_ref.release());
-}
-
-std::string CompilerModuleBundle::ToAsm(bool enableDebugInfo, bool prettyForm,
- int64_t largeElementLimit) {
- // Print to asm.
- std::string asm_output;
- raw_string_ostream sout(asm_output);
- OpPrintingFlags print_flags;
- if (enableDebugInfo) {
- print_flags.enableDebugInfo(prettyForm);
- }
- if (largeElementLimit >= 0) {
- print_flags.elideLargeElementsAttrs(largeElementLimit);
- }
- module_op().print(sout, print_flags);
- return sout.str();
-}
-
-std::shared_ptr<OpaqueBlob> CompilerModuleBundle::Compile(
- BytecodeTargetOptions options, std::vector<std::string> target_backends) {
- mlir::PassManager pass_manager(context_->mlir_context());
- mlir::applyPassManagerCLOptions(pass_manager);
- auto crash_reproducer_path = context_->crash_reproducer_path();
- if (crash_reproducer_path) {
- pass_manager.enableCrashReproducerGeneration(*crash_reproducer_path, false);
- }
-
- mlir::iree_compiler::IREE::HAL::TargetOptions hal_target_options;
- if (target_backends.empty()) {
- hal_target_options.targets =
- mlir::iree_compiler::IREE::HAL::getRegisteredTargetBackends();
- } else {
- hal_target_options.targets = std::move(target_backends);
- }
-
- auto vm_target_options =
- mlir::iree_compiler::IREE::VM::getTargetOptionsFromFlags();
-
- mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline(pass_manager);
- mlir::iree_compiler::IREE::HAL::buildHALTransformPassPipeline(
- pass_manager, hal_target_options);
- mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline(
- pass_manager, vm_target_options);
-
- // Run primary passes.
- auto diag_capture = context_->CaptureDiagnostics();
- if (failed(pass_manager.run(module_op_))) {
- throw RaisePyError(
- PyExc_RuntimeError,
- diag_capture.ConsumeDiagnosticsAsString("Error compiling IREE module:")
- .c_str());
- }
-
- // Run serialization.
- std::string contents;
- raw_string_ostream out(contents);
- if (failed(mlir::iree_compiler::IREE::VM::translateModuleToBytecode(
- module_op_, options, out))) {
- throw RaisePyError(
- PyExc_RuntimeError,
- diag_capture
- .ConsumeDiagnosticsAsString("Error serializing to flatbuffer:")
- .c_str());
- }
-
- out.flush();
- return std::make_shared<OpaqueStringBlob>(std::move(out.str()));
-}
-
-void CompilerModuleBundle::RunPassPipeline(
- const std::vector<std::string>& pipelines) {
- mlir::PassManager pm(context_->mlir_context(),
- mlir::OpPassManager::Nesting::Implicit);
- mlir::applyPassManagerCLOptions(pm);
- auto crash_reproducer_path = context_->crash_reproducer_path();
- if (crash_reproducer_path) {
- pm.enableCrashReproducerGeneration(*crash_reproducer_path);
- }
-
- // Parse the pass pipelines.
- std::string error;
- raw_string_ostream error_stream(error);
- for (const auto& pipeline : pipelines) {
- if (failed(mlir::parsePassPipeline(pipeline, pm, error_stream))) {
- throw RaiseValueError(error_stream.str().c_str());
- }
- }
-
- // Run them.
- auto diag_capture = context_->CaptureDiagnostics();
- if (failed(pm.run(module_op_))) {
- throw RaisePyError(
- PyExc_RuntimeError,
- diag_capture.ConsumeDiagnosticsAsString("Error running pass pipelines:")
- .c_str());
- }
-}
-
-void SetupCommonCompilerBindings(pybind11::module m) {
- // Guard the once init to happen once per process (vs module, which in
- // mondo builds can happen multiple times).
- static bool llvm_init_baton = ([]() { return LLVMOnceInit(); })();
- (void)(llvm_init_baton);
-
- // llvm module
- auto llvm_m = m.def_submodule("llvm", "Global LLVM configuration");
- SetupLLVMModule(llvm_m);
-
- // OpaqueBlob class
- py::class_<OpaqueBlob, std::shared_ptr<OpaqueBlob>>(m, "OpaqueBlob",
- py::buffer_protocol())
- .def_buffer([](OpaqueBlob* self) -> py::buffer_info {
- return py::buffer_info(
- self->data(), // Pointer to buffer
- sizeof(uint8_t), // Size of one scalar
- py::format_descriptor<uint8_t>::value, // Python struct-style
- // format
- 1, // Number of dimensions
- {self->size()}, // Buffer dimensions
- {self->size()} // Strides
- );
- })
- .def_property_readonly("bytes",
- [](OpaqueBlob* self) -> py::bytes {
- return py::bytes(
- static_cast<const char*>(self->data()),
- self->size());
- })
- .def_property_readonly("text", [](OpaqueBlob* self) -> py::str {
- return py::str(static_cast<const char*>(self->data()), self->size());
- });
-
- // CompilerContext class
- py::class_<CompilerContextBundle, std::shared_ptr<CompilerContextBundle>>(
- m, "CompilerContext")
- .def(py::init<>([]() {
- // Need explicit make_shared to avoid UB with enable_shared_from_this.
- return std::make_shared<CompilerContextBundle>();
- }))
- .def("parse_asm", &CompilerContextBundle::ParseAsm)
- .def("get_diagnostics",
- &CompilerContextBundle::ConsumeDiagnosticsAsString)
- .def("clear_diagnostics", &CompilerContextBundle::ClearDiagnostics)
- .def_property_static(
- "default_crash_reproducer_path",
- [](py::object /* cls */) {
- return CompilerContextBundle::default_crash_reproducer_path();
- },
- [](py::object /* cls */, absl::optional<std::string> p) {
- CompilerContextBundle::set_default_crash_reproducer_path(
- std::move(p));
- })
- .def_property("crash_reproducer_path",
- &CompilerContextBundle::crash_reproducer_path,
- &CompilerContextBundle::set_crash_reproducer_path);
-
- // OutputFormat enum
- py::enum_<BytecodeOutputFormat>(m, "OutputFormat")
- .value("FLATBUFFER_BINARY", BytecodeOutputFormat::kFlatBufferBinary)
- .value("FLATBUFFER_TEXT", BytecodeOutputFormat::kFlatBufferText)
- .value("MLIR_TEXT", BytecodeOutputFormat::kMlirText)
- .export_values();
-
- // CompileOptions class
- py::class_<BytecodeTargetOptions>(m, "CompileOptions")
- .def(py::init<>())
- .def_readwrite("output_format", &BytecodeTargetOptions::outputFormat)
- .def_readwrite("optimize", &BytecodeTargetOptions::optimize)
- .def_readwrite("strip_debug_ops", &BytecodeTargetOptions::stripDebugOps)
- .def_readwrite("strip_source_map", &BytecodeTargetOptions::stripSourceMap)
- .def_readwrite("strip_symbols", &BytecodeTargetOptions::stripSymbols);
-
- // CompilerModule class
- py::class_<CompilerModuleBundle>(m, "CompilerModule")
- .def("to_asm", &CompilerModuleBundle::ToAsm,
- py::arg("debug_info") = false, py::arg("pretty") = false,
- py::arg("large_element_limit") = -1)
- .def("compile", &CompilerModuleBundle::Compile,
- py::arg("options") = BytecodeTargetOptions{},
- py::arg("target_backends") = std::vector<std::string>())
- .def("run_pass_pipeline", &CompilerModuleBundle::RunPassPipeline,
- py::arg("pipelines") = std::vector<std::string>());
-}
-
-} // namespace python
-} // namespace iree
diff --git a/bindings/python/pyiree/compiler/compiler.h b/bindings/python/pyiree/compiler/compiler.h
deleted file mode 100644
index d384170..0000000
--- a/bindings/python/pyiree/compiler/compiler.h
+++ /dev/null
@@ -1,219 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
-
-#include <mutex> // NOLINT
-#include <string>
-
-#include "bindings/python/pyiree/common/binding.h"
-#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/MLIRContext.h"
-
-namespace iree {
-namespace python {
-
-// Wrapper around a blob of memory.
-// Used to transport blobs back and forth between C++ and Python.
-class OpaqueBlob {
- public:
- OpaqueBlob() : data_(nullptr), size_(0) {}
- OpaqueBlob(void* data, size_t size) : data_(data), size_(size) {}
- virtual ~OpaqueBlob() = default;
-
- void* data() { return data_; }
- const void* data() const { return data_; }
- size_t size() const { return size_; }
-
- // Create a free function from the OpaqueBlob shared pointer.
- using BufferFreeFn = void (*)(void* self, iree_byte_span_t);
- static std::pair<BufferFreeFn, void*> CreateFreeFn(
- std::shared_ptr<OpaqueBlob> blob) {
- // Note that there are more efficient ways to write this which
- // don't bounce through an extra heap alloc, but this is not
- // intended to be a high impact code path.
- struct Holder {
- std::shared_ptr<OpaqueBlob> blob;
- };
- Holder* holder = new Holder{std::move(blob)};
- auto free_fn = +([](void* self, iree_byte_span_t) {
- Holder* self_holder = static_cast<Holder*>(self);
- delete self_holder;
- });
- return {free_fn, holder};
- }
-
- static iree_allocator_t CreateDeallocator(std::shared_ptr<OpaqueBlob> blob) {
- // Note that there are more efficient ways to write this which
- // don't bounce through an extra heap alloc, but this is not
- // intended to be a high impact code path.
- struct Holder {
- std::shared_ptr<OpaqueBlob> blob;
- };
- Holder* holder = new Holder{std::move(blob)};
- auto free_fn = +([](void* self, void*) {
- Holder* self_holder = static_cast<Holder*>(self);
- delete self_holder;
- });
- return {holder /* self */, nullptr /* alloc */, free_fn /* free */};
- }
-
- protected:
- void* data_;
- size_t size_;
-};
-
-// Opaque blob that owns a vector.
-class OpaqueByteVectorBlob : public OpaqueBlob {
- public:
- OpaqueByteVectorBlob(std::vector<uint8_t> v)
- : OpaqueBlob(), v_(std::move(v)) {
- data_ = v_.data();
- size_ = v_.size();
- }
-
- private:
- std::vector<uint8_t> v_;
-};
-
-class OpaqueStringBlob : public OpaqueBlob {
- public:
- OpaqueStringBlob(std::string s) : OpaqueBlob(), s_(std::move(s)) {
- data_ = &s_[0];
- size_ = s_.size();
- }
-
- private:
- std::string s_;
-};
-
-class CompilerContextBundle;
-class CompilerModuleBundle;
-
-// Wraps an MLIR module and its producing context.
-class CompilerModuleBundle {
- public:
- CompilerModuleBundle(std::shared_ptr<CompilerContextBundle> context,
- mlir::ModuleOp module_op)
- : context_(std::move(context)), module_op_(std::move(module_op)) {}
-
- mlir::ModuleOp& module_op() { return module_op_; }
- std::string ToAsm(bool enableDebugInfo, bool prettyForm,
- int64_t largeElementLimit);
-
- // Runs one or more pass pipelines (as is mlir::parsePassPipeline).
- void RunPassPipeline(const std::vector<std::string>& pipelines);
-
- // Compile to a VM module.
- std::shared_ptr<OpaqueBlob> Compile(
- mlir::iree_compiler::IREE::VM::BytecodeTargetOptions options,
- std::vector<std::string> target_backends);
-
- private:
- std::shared_ptr<CompilerContextBundle> context_;
- mlir::ModuleOp module_op_;
-};
-
-// Registers to receive diagnostics for a scope.
-// When this goes out of scope, any remaining diagnostics will be added to
-// the parent.
-class DiagnosticCapture {
- public:
- DiagnosticCapture(mlir::MLIRContext* mlir_context, DiagnosticCapture* parent);
- ~DiagnosticCapture();
- DiagnosticCapture(DiagnosticCapture&& other);
-
- std::vector<mlir::Diagnostic>& diagnostics() { return diagnostics_; }
-
- // Consumes/clears diagnostics.
- std::string ConsumeDiagnosticsAsString(const char* error_message);
- void ClearDiagnostics();
-
- private:
- mlir::MLIRContext* mlir_context_;
- DiagnosticCapture* parent_;
- std::vector<mlir::Diagnostic> diagnostics_;
- mlir::DiagnosticEngine::HandlerID handler_id_;
-};
-
-// Bundle of MLIRContext related things that facilitates interop with
-// Python.
-class CompilerContextBundle
- : public std::enable_shared_from_this<CompilerContextBundle> {
- public:
- CompilerContextBundle();
- ~CompilerContextBundle();
-
- mlir::MLIRContext* mlir_context() { return &mlir_context_; }
-
- CompilerModuleBundle ParseAsm(const std::string& asm_text);
-
- // Gets the default diagnostic capture.
- DiagnosticCapture& DefaultDiagnosticCapture() { return default_capture_; }
-
- // Creates a new diagnostic region.
- // Note that this only supports one deep at present.
- DiagnosticCapture CaptureDiagnostics() {
- return DiagnosticCapture(&mlir_context_, &default_capture_);
- }
-
- // Consumes/clears diagnostics.
- std::string ConsumeDiagnosticsAsString() {
- return default_capture_.ConsumeDiagnosticsAsString(nullptr);
- }
- void ClearDiagnostics() { default_capture_.ClearDiagnostics(); }
-
- // Default crash reproducer path.
- static absl::optional<std::string> default_crash_reproducer_path() {
- std::lock_guard<std::mutex> lock(static_config_lock_);
- return default_crash_reproducer_path_;
- }
- static void set_default_crash_reproducer_path(
- absl::optional<std::string> default_crash_reproducer_path) {
- std::lock_guard<std::mutex> lock(static_config_lock_);
- default_crash_reproducer_path_ = std::move(default_crash_reproducer_path);
- }
-
- // Crash reproducer (if not set, uses the static default).
- // If neither are set or are the empty string, then the crash reproducer
- // will not be used.
- absl::optional<std::string> crash_reproducer_path() const {
- if (crash_reproducer_path_) {
- return crash_reproducer_path_;
- }
- return default_crash_reproducer_path();
- }
- void set_crash_reproducer_path(
- absl::optional<std::string> crash_reproducer_path) {
- crash_reproducer_path_ = std::move(crash_reproducer_path);
- }
-
- private:
- static std::mutex static_config_lock_;
- static absl::optional<std::string> default_crash_reproducer_path_;
-
- mlir::MLIRContext mlir_context_;
- DiagnosticCapture default_capture_;
- absl::optional<std::string> crash_reproducer_path_;
-};
-
-void SetupCommonCompilerBindings(pybind11::module m);
-
-} // namespace python
-} // namespace iree
-
-#endif // IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
diff --git a/bindings/python/pyiree/compiler/compiler_test.py b/bindings/python/pyiree/compiler/compiler_test.py
deleted file mode 100644
index d3fd1f2..0000000
--- a/bindings/python/pyiree/compiler/compiler_test.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Lint as: python3
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from absl.testing import absltest
-from pyiree import compiler
-
-SIMPLE_MUL_ASM = """
-func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
- attributes { iree.module.export } {
- %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- return %0 : tensor<4xf32>
-}
-"""
-
-
-class CompilerTest(absltest.TestCase):
-
- def testParseError(self):
- ctx = compiler.Context()
- with self.assertRaisesRegex(ValueError, "custom op 'FOOBAR' is unknown"):
- ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""")
-
- def testParseAndCompileToFlatbuffer(self):
- ctx = compiler.Context()
- input_module = ctx.parse_asm(SIMPLE_MUL_ASM)
- binary = input_module.compile()
- b = binary.bytes
- print("Flatbuffer size =", len(b))
- self.assertTrue(binary.bytes)
-
- def testParseAndCompileToFlatbufferText(self):
- ctx = compiler.Context()
- input_module = ctx.parse_asm(SIMPLE_MUL_ASM)
- options = compiler.CompileOptions()
- options.output_format = compiler.OutputFormat.FLATBUFFER_TEXT
- blob = input_module.compile(options=options)
- text = blob.text
- self.assertTrue(text)
-
- def testParseAndCompileToMlirText(self):
- ctx = compiler.Context()
- input_module = ctx.parse_asm(SIMPLE_MUL_ASM)
- options = compiler.CompileOptions()
- options.output_format = compiler.OutputFormat.MLIR_TEXT
- blob = input_module.compile(options=options)
- text = blob.text
- self.assertTrue(text)
-
-
-if __name__ == "__main__":
- absltest.main()
diff --git a/bindings/python/pyiree/compiler/export.def b/bindings/python/pyiree/compiler/export.def
deleted file mode 100644
index 1f2a8c1..0000000
--- a/bindings/python/pyiree/compiler/export.def
+++ /dev/null
@@ -1,17 +0,0 @@
-;; Copyright 2020 Google LLC
-;;
-;; Licensed under the Apache License, Version 2.0 (the "License");
-;; you may not use this file except in compliance with the License.
-;; You may obtain a copy of the License at
-;;
-;; https://www.apache.org/licenses/LICENSE-2.0
-;;
-;; Unless required by applicable law or agreed to in writing, software
-;; distributed under the License is distributed on an "AS IS" BASIS,
-;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-;; See the License for the specific language governing permissions and
-;; limitations under the License.
-
-LIBRARY BINDING
-EXPORTS
- PyInit_binding @1
diff --git a/bindings/python/pyiree/compiler/initialize_module.cc b/bindings/python/pyiree/compiler/initialize_module.cc
deleted file mode 100644
index 786179d..0000000
--- a/bindings/python/pyiree/compiler/initialize_module.cc
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <mutex> // NOLINT
-
-#include "bindings/python/pyiree/common/binding.h"
-#include "bindings/python/pyiree/compiler/compiler.h"
-
-namespace iree {
-namespace python {
-
-PYBIND11_MODULE(binding, m) {
- m.doc() = "IREE Compiler Interface";
- SetupCommonCompilerBindings(m);
-}
-
-} // namespace python
-} // namespace iree
diff --git a/bindings/python/pyiree/compiler/unix_version.lds b/bindings/python/pyiree/compiler/unix_version.lds
deleted file mode 100644
index 68ef766..0000000
--- a/bindings/python/pyiree/compiler/unix_version.lds
+++ /dev/null
@@ -1,19 +0,0 @@
-/* Copyright 2020 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-{
- global: PyInit_binding;
- local: *;
-};
diff --git a/bindings/python/pyiree/compiler2/core.py b/bindings/python/pyiree/compiler2/core.py
index cdc7900..7ea85d5 100644
--- a/bindings/python/pyiree/compiler2/core.py
+++ b/bindings/python/pyiree/compiler2/core.py
@@ -173,11 +173,11 @@
return result
-def compile_str(input_str: str, **kwargs):
+def compile_str(input_str: Union[str, bytes], **kwargs):
"""Invokes the IREE compiler with an input string.
Args:
- input_str: MLIR assembly to parse/compile.
+ input_str: MLIR assembly to parse/compile (str or bytes).
**kwargs: Keyword arguments corresponding to CompilerOptions.
Returns:
Either a byte buffer of the compiled content or None if output_file
@@ -185,7 +185,9 @@
"""
options = CompilerOptions(**kwargs)
cl = build_compile_command_line("-", options)
- result = invoke_immediate(cl, immediate_input=input_str.encode("utf-8"))
+ input_bytes = input_str.encode("utf-8") if isinstance(input_str,
+ str) else input_str
+ result = invoke_immediate(cl, immediate_input=input_bytes)
if options.output_file:
return None
return result
diff --git a/bindings/python/pyiree/compiler2/tf.py b/bindings/python/pyiree/compiler2/tf.py
index c9608af..90423c5 100644
--- a/bindings/python/pyiree/compiler2/tf.py
+++ b/bindings/python/pyiree/compiler2/tf.py
@@ -85,6 +85,8 @@
import_type: Union[ImportType, str] = ImportType.OBJECT_GRAPH,
saved_model_tags: Set[str] = set(),
import_extra_args: Sequence[str] = (),
+ save_temp_tf_input: Optional[str] = None,
+ save_temp_iree_input: Optional[str] = None,
**kwargs):
"""Initialize options from keywords.
@@ -101,6 +103,10 @@
saved_model_tags: Set of tags to export (signature def/v1 saved models
only).
import_extra_args: Extra arguments to pass to the iree-tf-import tool.
+ save_temp_tf_input: Optionally save the IR that is input to the
+ TensorFlow pipeline.
+ save_temp_iree_input: Optionally save the IR that is the result of the
+ import (ready to be passed to IREE).
"""
super().__init__(**kwargs)
self.exported_names = exported_names
@@ -108,6 +114,8 @@
self.import_type = ImportType.parse(import_type)
self.saved_model_tags = saved_model_tags
self.import_extra_args = import_extra_args
+ self.save_temp_tf_input = save_temp_tf_input
+ self.save_temp_iree_input = save_temp_iree_input
def build_import_command_line(input_path: str,
@@ -132,6 +140,17 @@
# Import stage directly outputs.
if options.output_file:
cl.append(f"-o={options.output_file}")
+ # Save temps flags.
+ if options.save_temp_tf_input:
+ cl.append(f"--save-temp-tf-input={options.save_temp_tf_input}")
+ if options.save_temp_iree_input:
+ cl.append(f"--save-temp-iree-input={options.save_temp_iree_input}")
+ # Crash reproducer (locally qualified).
+ if options.crash_reproducer_path:
+ cl.append(
+ f"--pass-pipeline-crash-reproducer={options.crash_reproducer_path}"
+ f".import-tf")
+ # Extra args.
cl.extend(options.import_extra_args)
return cl
diff --git a/bindings/python/pyiree/compiler2/tools.py b/bindings/python/pyiree/compiler2/tools.py
index 480600b..6083840 100644
--- a/bindings/python/pyiree/compiler2/tools.py
+++ b/bindings/python/pyiree/compiler2/tools.py
@@ -23,7 +23,7 @@
import textwrap
import threading
-from typing import List, Optional
+from typing import List, Optional, Union
__all__ = [
"find_tool",
@@ -124,7 +124,7 @@
def invoke_immediate(command_line: List[str],
*,
- input_file: Optional[str] = None,
+ input_file: Optional[bytes] = None,
immediate_input=None):
"""Invokes an immediate command.
diff --git a/bindings/python/pyiree/rt/BUILD b/bindings/python/pyiree/rt/BUILD
deleted file mode 100644
index 6e64af5..0000000
--- a/bindings/python/pyiree/rt/BUILD
+++ /dev/null
@@ -1,149 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "NUMPY_DEPS",
- "PYBIND_COPTS",
- "PYBIND_EXTENSION_COPTS",
- "PYBIND_FEATURES",
- "iree_py_extension",
- "iree_py_library",
- "iree_py_test",
- "pybind_cc_library",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_py_library(
- name = "rt",
- srcs = [
- "__init__.py",
- "system_api.py",
- ],
- srcs_version = "PY3",
- deps = [
- ":binding",
- "//bindings/python:pathsetup", # build_cleaner: keep
- ],
-)
-
-iree_py_extension(
- name = "binding",
- srcs = [
- "initialize_module.cc",
- ],
- copts = PYBIND_COPTS + PYBIND_EXTENSION_COPTS,
- features = PYBIND_FEATURES,
- linkstatic = 1,
- win_def_file = "export.def",
- deps = [
- ":rt_library",
- "//bindings/python/pyiree/common",
- "//iree/hal/drivers",
- ],
-)
-
-pybind_cc_library(
- name = "rt_library",
- srcs = [
- "function_abi.cc",
- "hal.cc",
- "host_types.cc",
- "vm.cc",
- ],
- hdrs = [
- "function_abi.h",
- "hal.h",
- "host_types.h",
- "vm.h",
- ],
- deps = [
- "//bindings/python/pyiree/common",
- "//iree/base:api",
- "//iree/base:signature_mangle",
- "//iree/hal:api",
- "//iree/modules/hal",
- "//iree/modules/strings:strings_module",
- "//iree/modules/tensorlist:native_module",
- "//iree/vm",
- "//iree/vm:bytecode_module",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-iree_py_library(
- name = "system_api",
- srcs = ["system_api.py"],
- srcs_version = "PY3",
- deps = [
- ":binding",
- "//bindings/python:pathsetup", # build_cleaner: keep
- ],
-)
-
-iree_py_test(
- name = "function_abi_test",
- srcs = ["function_abi_test.py"],
- python_version = "PY3",
- deps = NUMPY_DEPS + [
- "//bindings/python:pathsetup", # build_cleaner: keep
- "@absl_py//absl/testing:absltest",
- "//bindings/python/pyiree/rt",
- ],
-)
-
-iree_py_test(
- name = "hal_test",
- srcs = ["hal_test.py"],
- python_version = "PY3",
- deps = NUMPY_DEPS + [
- "//bindings/python:pathsetup", # build_cleaner: keep
- "@absl_py//absl/testing:absltest",
- "//bindings/python/pyiree/rt",
- ],
-)
-
-iree_py_test(
- name = "system_api_test",
- srcs = ["system_api_test.py"],
- python_version = "PY3",
- deps = NUMPY_DEPS + [
- ":system_api",
- "//bindings/python:pathsetup", # build_cleaner: keep
- "@absl_py//absl/testing:absltest",
- "//bindings/python/pyiree/compiler",
- "//bindings/python/pyiree/rt",
- ],
-)
-
-iree_py_test(
- name = "vm_test",
- srcs = ["vm_test.py"],
- python_version = "PY3",
- deps = NUMPY_DEPS + [
- "//bindings/python:pathsetup", # build_cleaner: keep
- "@absl_py//absl/testing:absltest",
- "//bindings/python/pyiree/compiler",
- "//bindings/python/pyiree/rt",
- ],
-)
diff --git a/bindings/python/pyiree/rt/system_api_test.py b/bindings/python/pyiree/rt/system_api_test.py
index a27f3b0..185b496 100644
--- a/bindings/python/pyiree/rt/system_api_test.py
+++ b/bindings/python/pyiree/rt/system_api_test.py
@@ -20,13 +20,13 @@
from absl import logging
from absl.testing import absltest
import numpy as np
-from pyiree import compiler
+from pyiree import compiler2 as compiler
from pyiree import rt
def create_simple_mul_module():
- ctx = compiler.Context()
- input_module = ctx.parse_asm("""
+ binary = compiler.compile_str(
+ """
module @arithmetic {
func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
attributes { iree.module.export } {
@@ -34,8 +34,8 @@
return %0 : tensor<4xf32>
}
}
- """)
- binary = input_module.compile()
+ """,
+ target_backends=["vulkan-spirv"])
m = rt.VmModule.from_flatbuffer(binary)
return m
diff --git a/bindings/python/setup_tools_tf.py.in b/bindings/python/setup_tools_tf.py.in
index a1b292e..b25e7c5 100644
--- a/bindings/python/setup_tools_tf.py.in
+++ b/bindings/python/setup_tools_tf.py.in
@@ -68,9 +68,10 @@
],
python_requires=">=3.6",
package_dir={"": this_dir},
- packages=find_namespace_packages(where=this_dir,
- include=["pyiree.tools.tf"],
- exclude=["*.CMakeFiles"]),
+ packages=find_namespace_packages(
+ where=this_dir,
+ include=["pyiree.tools.tf", "pyiree.tf.support"],
+ exclude=["*.CMakeFiles"]),
# Matching the native extension as a data file keeps setuptools from
# "building" it (i.e. turning it into a static binary).
package_data={
diff --git a/colab/BUILD.bazel b/colab/BUILD.bazel
deleted file mode 100644
index 01c88f9..0000000
--- a/colab/BUILD.bazel
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-package(
- features = ["layering_check"],
- licenses = ["notice"],
-)
-
-py_binary(
- name = "everything_for_colab",
- srcs = ["dummy.py"],
- legacy_create_init = False,
- main = "dummy.py",
- python_version = "PY3",
- deps = [
- "//bindings/python:pathsetup", # build_cleaner: keep
- "//bindings/python/pyiree/compiler", # build_cleaner: keep
- "//bindings/python/pyiree/rt", # build_cleaner: keep
- ] + select({
- "//iree:enable_tensorflow": [
- "//integrations/tensorflow/bindings/python/pyiree/tf/compiler", # build_cleaner: keep
- "//integrations/tensorflow/bindings/python/pyiree/tf/support", # build_cleaner: keep
- ],
- "//conditions:default": [
- ],
- }),
-)
diff --git a/integrations/tensorflow/bindings/python/CMakeLists.txt b/integrations/tensorflow/bindings/python/CMakeLists.txt
index 31b61a2..bafd98b 100644
--- a/integrations/tensorflow/bindings/python/CMakeLists.txt
+++ b/integrations/tensorflow/bindings/python/CMakeLists.txt
@@ -20,3 +20,5 @@
endfunction()
_add_overlay_subdirectory(pyiree/tools/tf)
+# TODO: Find another place for the TF support library.
+_add_overlay_subdirectory(pyiree/tf/support)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
deleted file mode 100644
index c2682f6..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
+++ /dev/null
@@ -1,161 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "INTREE_TENSORFLOW_PY_DEPS",
- "NUMPY_DEPS",
- "PYBIND_COPTS",
- "PYBIND_EXTENSION_COPTS",
- "PYBIND_FEATURES",
- "iree_py_extension",
- "iree_py_library",
- "iree_py_test",
- "pybind_cc_library",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-config_setting(
- name = "disable_kernels",
- define_values = {"PYIREE_TF_DISABLE_KERNELS": "1"},
-)
-
-# Runtime deps needed to compile the tensorflow compiler.
-# As of 2020-04-13, robust constant folding is dependent on the legacy
-# TensorFlow executor, which requires kernels to operate on/simplify
-# the graph. This should become less necessary as more robust support
-# is implemented as part of the MLIR-based tf2xla bridge. This adds
-# about ~350 files to the system, many of which are quite heavy to
-# compile. Excluding them disables TensorFlow constant propagation,
-# which can cause non-optimized binaries (and tickle bugs and unimplemented
-# features). However, it is allowed, especially for development because it
-# is such a burden to build them. Disable kernels with this command line
-# options:
-# --define=PYIREE_TF_DISABLE_KERNELS=1
-# See: https://github.com/google/iree/issues/1506
-SAVED_MODEL_TF_RUNTIME_DEPS = [
- "@org_tensorflow//tensorflow/core:ops",
-] + select({
- ":disable_kernels": [],
- "//conditions:default": [
- "@org_tensorflow//tensorflow/core/kernels:array",
- "@org_tensorflow//tensorflow/core/kernels:math",
- ],
-})
-
-# TODO: Isolate SignatureDef SavedModel support into its own library to decrease buildtime cost.
-# While it would be nice to simply depend on TensorFlow, manually paring
-# down the dependencies significantly reduces the build time that this adds.
-#
-# Baseline: 449s
-# SignatureDef SavedModels: 546s – 22% increase in build time.
-# SignatureDef SavedModels + Deps for MobileBert: 572s – 27% increase in build time.
-# TF OpenSource: 664s – 49% increase in build time.
-SIGNATURE_DEF_SAVED_MODEL_TF_RUNTIME_DEPS = [
- # Deps for SignatureDef SavedModels:
- "@org_tensorflow//tensorflow/core:direct_session",
- "@org_tensorflow//tensorflow/core/kernels:resource_variable_ops", # VarHandleOp
- "@org_tensorflow//tensorflow/core/kernels:regex_full_match_op", # StaticRegexFullMatch
- "@org_tensorflow//tensorflow/core/kernels:string_join_op", # StringJoin
- "@org_tensorflow//tensorflow/core/kernels:save_op", # SharedFilename
- "@org_tensorflow//tensorflow/core/kernels:save_restore_v2_ops", # SaveV2
-
- # Deps for MobileBert:
- "@org_tensorflow//tensorflow/core/kernels:parameterized_truncated_normal_op", # TruncatedNormal
- "@org_tensorflow//tensorflow/core/kernels:state", # Assign.
- "@org_tensorflow//tensorflow/core/kernels:logging_ops", # Assert
- "@org_tensorflow//tensorflow/core/kernels:bias_op", # BiasAdd
- "@org_tensorflow//tensorflow/core/kernels:softmax_op", # Softmax
- "@org_tensorflow//tensorflow/core/kernels:relu_op", # Relu
-]
-
-TF_XLA_PASS_DEPS = [
- "//integrations/tensorflow/compiler:tensorflow",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_tf",
-]
-
-iree_py_library(
- name = "compiler",
- srcs = [
- "__init__.py",
- ],
- srcs_version = "PY3",
- deps = [
- ":binding",
- "//bindings/python:pathsetup", # build_cleaner: keep
- "//integrations/tensorflow/bindings/python:pathsetup", # build_cleaner: keep
- ],
-)
-
-iree_py_extension(
- name = "binding",
- srcs = [
- "initialize_module.cc",
- ],
- copts = PYBIND_COPTS + PYBIND_EXTENSION_COPTS,
- features = PYBIND_FEATURES,
- linkstatic = 1,
- win_def_file = "export.def",
- deps = [
- ":compiler_library",
- "//bindings/python/pyiree/common",
- "//bindings/python/pyiree/compiler:compiler_library",
- "//integrations/tensorflow/compiler:tensorflow",
- ],
-)
-
-pybind_cc_library(
- name = "compiler_library",
- srcs = [
- "register_tensorflow.cc",
- ],
- hdrs = [
- "register_tensorflow.h",
- ],
- deps = SAVED_MODEL_TF_RUNTIME_DEPS + TF_XLA_PASS_DEPS + SIGNATURE_DEF_SAVED_MODEL_TF_RUNTIME_DEPS + [
- "//bindings/python/pyiree/common",
- "//bindings/python/pyiree/compiler:compiler_library",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- "@org_tensorflow//tensorflow/cc/saved_model:loader_lite",
- "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
- "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
- "@org_tensorflow//tensorflow/core:core_cpu",
- ],
-)
-
-iree_py_test(
- name = "saved_model_test",
- srcs = ["saved_model_test.py"],
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "@absl_py//absl/testing:absltest",
- "//integrations/tensorflow/bindings/python/pyiree/tf/compiler",
- ],
-)
-
-iree_py_test(
- name = "signature_def_saved_model_test",
- srcs = ["signature_def_saved_model_test.py"],
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "@absl_py//absl/testing:absltest",
- "//integrations/tensorflow/bindings/python/pyiree/tf/compiler",
- ],
-)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
deleted file mode 100644
index 26285f9..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
+++ /dev/null
@@ -1,243 +0,0 @@
-# Lint-as: python3
-"""Module init for the python bindings."""
-
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# pylint: disable=g-multiple-import
-# pylint: disable=g-bad-import-order
-# pylint: disable=g-import-not-at-top
-# pylint: disable=wildcard-import
-
-__all__ = [
- # Common
- "Context",
- "Module",
- "CompileOptions",
- "OutputFormat",
- # TensorFlow
- "TF_IMPORT_PASS_PIPELINE",
- "tf_saved_model_to_compiler_module",
- "tf_signature_def_saved_model_to_compiler_module",
- "tf_module_to_compiler_module",
- "compile_tf_saved_model",
- "compile_tf_signature_def_saved_model",
- "compile_tf_module",
-]
-
-import tempfile
-from typing import Collection, Optional, Sequence, Set
-
-from . import binding as binding
-import tensorflow as tf
-
-# Native aliases (matches those in the generic compiler).
-llvm = binding.llvm
-Context = binding.CompilerContext
-Module = binding.CompilerModule
-CompileOptions = binding.CompileOptions
-OutputFormat = binding.OutputFormat
-
-# Pass pipeline that should run to lower a TF saved_model to a form suitable
-# for input to the IREE compiler.
-TF_IMPORT_PASS_PIPELINE = (
- # IREE-specific passes to prepare TF code for IREE compilation.
- # In particular, this eliminates tf_saved_model.
- "iree-tf-import-pipeline",
-)
-
-
-def tf_saved_model_to_compiler_module(
- saved_model_dir: str,
- exported_names: Sequence[str] = (),
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- compiler_context: Optional[Context] = None) -> Module:
- """Converts a TensorFlow SavedModel into a MLIR module.
-
- See also compile_tf_saved_model() for a one-shot API to load and compile.
-
- Args:
- saved_model_dir: Directory of the saved model.
- exported_names: Optional sequence representing the exported names to keep.
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- compiler_context: The pyiree.compiler.Context() backing the module.
-
- Returns:
- An MLIR Module suitable for compilation by the IREE compiler.
- This can be further compiled to an IREE blob by calling
- .compile_to_sequencer_blob.
- """
- if not compiler_context:
- compiler_context = Context()
- compiler_module = binding.load_saved_model(compiler_context,
- saved_model_dir,
- exported_names=exported_names)
- if pass_pipeline:
- compiler_module.run_pass_pipeline(pass_pipeline)
- return compiler_module
-
-
-def compile_tf_saved_model(
- saved_model_dir: str,
- exported_names: Sequence[str] = (),
- target_backends: Sequence[str] = (),
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
- """Compiles a TensorFlow SavedModel to IREE in one shot.
-
- Args:
- saved_model_dir: Directory of the saved model.
- exported_names: Optional sequence representing the exported names to keep.
- target_backends: Optional sequence of specific target backends to compile
- for (defaults to all compiled in targets).
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- compiler_context: The pyiree.compiler.Context() backing the module.
-
- Returns:
- An OpaqueBlob representing the compiled module.
- """
- compiler_module = tf_saved_model_to_compiler_module(saved_model_dir,
- exported_names,
- pass_pipeline,
- compiler_context)
- return compiler_module.compile(target_backends=target_backends)
-
-
-def tf_signature_def_saved_model_to_compiler_module(
- saved_model_dir: str,
- saved_model_tags: Set[str] = set(),
- exported_names: Sequence[str] = (),
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- compiler_context: Optional[Context] = None) -> Module:
- """Converts a TensorFlow SignatureDef SavedModel into a MLIR module.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- exported_names: Optional sequence representing the exported names to keep.
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- compiler_context: The pyiree.compiler.Context() backing the module.
-
- Returns:
- An MLIR Module suitable for compilation by the IREE compiler.
- This can be further compiled to an IREE blob by calling
- .compile_to_sequencer_blob.
- """
- if not compiler_context:
- compiler_context = Context()
- compiler_module = binding.load_signature_def_saved_model(
- compiler_context,
- saved_model_dir,
- saved_model_tags,
- exported_names=exported_names)
- if pass_pipeline:
- compiler_module.run_pass_pipeline(pass_pipeline)
- return compiler_module
-
-
-def compile_tf_signature_def_saved_model(
- saved_model_dir: str,
- saved_model_tags: Set[str] = set(),
- exported_names: Sequence[str] = (),
- target_backends: Sequence[str] = (),
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
- """Compiles a TensorFlow SignatureDef SavedModel to IREE in one shot.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- exported_names: Optional sequence representing the exported names to keep.
- target_backends: Optional sequence of specific target backends to compile
- for (defaults to all compiled in targets).
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- compiler_context: The pyiree.compiler.Context() backing the module.
-
- Returns:
- An OpaqueBlob representing the compiled module.
- """
- compiler_module = tf_signature_def_saved_model_to_compiler_module(
- saved_model_dir, saved_model_tags, exported_names, pass_pipeline,
- compiler_context)
- return compiler_module.compile(target_backends=target_backends)
-
-
-def tf_module_to_compiler_module(
- module: tf.Module,
- exported_names: Sequence[str] = (),
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- compiler_context: Optional[Context] = None,
- saved_model_dir: str = None) -> Module:
- """Converts a tf.Module instance into a MLIR module.
-
- Args:
- module: The tf.Module instance to convert to MLIR
- exported_names: Optional sequence representing the exported names to keep.
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- compiler_context: The pyiree.compiler.Context() backing the module.
- saved_model_dir: Optional path to save the tf.Module to. The module will not
- be saved on disk if this is not provided.
-
- Returns:
- An MLIR Module suitable for compilation by the IREE compiler.
- This can be further compiled to an IREE blob by calling
- .compile_to_sequencer_blob.
- """
-
- def _convert(saved_model_dir):
- options = tf.saved_model.SaveOptions(save_debug_info=True)
- tf.saved_model.save(module, saved_model_dir, options=options)
- return tf_saved_model_to_compiler_module(saved_model_dir, exported_names,
- pass_pipeline, compiler_context)
-
- if saved_model_dir is None:
- with tempfile.TemporaryDirectory() as saved_model_dir:
- compiler_module = _convert(saved_model_dir)
- else:
- compiler_module = _convert(saved_model_dir)
- return compiler_module
-
-
-def compile_tf_module(module: tf.Module,
- exported_names: Sequence[str] = (),
- target_backends: Sequence[str] = (),
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- compiler_context: Optional[Context] = None,
- saved_model_dir: str = None):
- """Compiles a tf.Module to IREE in one shot.
-
- Args:
- module: The tf.Module instance to convert to MLIR
- exported_names: Optional sequence representing the exported names to keep.
- target_backends: Optional sequence of specific target backends to compile
- for (defaults to all compiled in targets).
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- compiler_context: The pyiree.compiler.Context() backing the module.
- saved_model_dir: Optional path to save the tf.Module to. The module will not
- be saved on disk if this is not provided.
-
- Returns:
- An OpaqueBlob representing the compiled module.
- """
- compiler_module = tf_module_to_compiler_module(module, exported_names,
- pass_pipeline,
- compiler_context,
- saved_model_dir)
- return compiler_module.compile(target_backends=target_backends)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/export.def b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/export.def
deleted file mode 100644
index 1f2a8c1..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/export.def
+++ /dev/null
@@ -1,17 +0,0 @@
-;; Copyright 2020 Google LLC
-;;
-;; Licensed under the Apache License, Version 2.0 (the "License");
-;; you may not use this file except in compliance with the License.
-;; You may obtain a copy of the License at
-;;
-;; https://www.apache.org/licenses/LICENSE-2.0
-;;
-;; Unless required by applicable law or agreed to in writing, software
-;; distributed under the License is distributed on an "AS IS" BASIS,
-;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-;; See the License for the specific language governing permissions and
-;; limitations under the License.
-
-LIBRARY BINDING
-EXPORTS
- PyInit_binding @1
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/initialize_module.cc b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/initialize_module.cc
deleted file mode 100644
index 1eb7549..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/initialize_module.cc
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <mutex> // NOLINT
-
-#include "bindings/python/pyiree/common/binding.h"
-#include "bindings/python/pyiree/compiler/compiler.h"
-#include "integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.h"
-
-namespace iree {
-namespace python {
-
-PYBIND11_MODULE(binding, m) {
- m.doc() = "IREE TensorFlow Compiler Interface";
- SetupCommonCompilerBindings(m);
- SetupTensorFlowBindings(m);
-}
-
-} // namespace python
-} // namespace iree
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.cc b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.cc
deleted file mode 100644
index 990371a..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.cc
+++ /dev/null
@@ -1,131 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.h"
-
-#include <mutex>
-#include <string>
-#include <vector>
-
-#include "bindings/python/pyiree/common/status_utils.h"
-#include "bindings/python/pyiree/compiler/compiler.h"
-#include "integrations/tensorflow/compiler/Passes.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/MLIRContext.h"
-#include "tensorflow/cc/saved_model/loader.h"
-#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
-#include "tensorflow/core/public/session_options.h"
-
-using namespace mlir; // NOLINT
-
-using tensorflow::ConvertSavedModelToMlir;
-using tensorflow::ConvertSavedModelV1ToMlir;
-using tensorflow::LoadSavedModel;
-using tensorflow::RunOptions;
-using tensorflow::SavedModelBundle;
-using tensorflow::SavedModelV2Bundle;
-using tensorflow::SessionOptions;
-
-namespace iree {
-namespace python {
-
-namespace {
-
-static void initializeContextForTFImport(MLIRContext* context) {
- static std::once_flag flag;
- std::call_once(flag, [&]() {
- mlir::iree_compiler::TF::registerAllPasses();
- mlir::iree_compiler::TF::registerAllDialects(context->getDialectRegistry());
- });
-}
-
-CompilerModuleBundle LoadSavedModel(
- std::shared_ptr<CompilerContextBundle> context_bundle,
- const std::string& saved_model_dir,
- const std::vector<std::string>& exported_names) {
- initializeContextForTFImport(context_bundle->mlir_context());
-
- SavedModelV2Bundle bundle;
- auto load_status = SavedModelV2Bundle::Load(
- std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle);
- if (!load_status.ok()) {
- std::stringstream msg;
- msg << "Failed to load saved model '" << saved_model_dir
- << "': " << load_status;
- throw RaisePyError(PyExc_RuntimeError, msg.str().c_str());
- }
-
- // TODO(laurenzo): Fix the upstream ConvertSavedModelToMlir() to take a const
- // span of external names.
- std::vector<std::string> mutable_exported_names = exported_names;
- auto module_or =
- ConvertSavedModelToMlir(&bundle, context_bundle->mlir_context(),
- absl::MakeSpan(mutable_exported_names));
- if (!module_or.status().ok()) {
- std::stringstream msg;
- msg << "Failed to convert saved model to MLIR '" << saved_model_dir
- << "': " << module_or.status();
- throw RaisePyError(PyExc_RuntimeError, msg.str().c_str());
- }
- return CompilerModuleBundle(context_bundle,
- module_or.ConsumeValueOrDie().release());
-}
-
-CompilerModuleBundle LoadSignatureDefSavedModel(
- std::shared_ptr<CompilerContextBundle> context_bundle,
- const std::string& saved_model_dir,
- const std::unordered_set<std::string>& tags,
- const std::vector<std::string>& exported_names) {
- initializeContextForTFImport(context_bundle->mlir_context());
-
- SavedModelBundle bundle;
- auto load_status = LoadSavedModel(
- SessionOptions(), RunOptions(),
- std::string(saved_model_dir.data(), saved_model_dir.length()), tags,
- &bundle);
- if (!load_status.ok()) {
- std::stringstream msg;
- msg << "Failed to load saved model '" << saved_model_dir
- << "': " << load_status;
- throw RaisePyError(PyExc_RuntimeError, msg.str().c_str());
- }
- std::vector<std::string> mutable_exported_names = exported_names;
- auto module_or =
- ConvertSavedModelV1ToMlir(bundle, absl::MakeSpan(mutable_exported_names),
- context_bundle->mlir_context());
- if (!module_or.status().ok()) {
- std::stringstream msg;
- msg << "Failed to convert saved model to MLIR '" << saved_model_dir
- << "': " << module_or.status();
- throw RaisePyError(PyExc_RuntimeError, msg.str().c_str());
- }
- return CompilerModuleBundle(context_bundle,
- module_or.ConsumeValueOrDie().release());
-}
-
-} // namespace
-
-void SetupTensorFlowBindings(pybind11::module m) {
- m.def("load_saved_model", &LoadSavedModel, py::arg("compiler_context"),
- py::arg("saved_model_dir"),
- py::arg("exported_names") = std::vector<std::string>());
- m.def("load_signature_def_saved_model", &LoadSignatureDefSavedModel,
- py::arg("compiler_context"), py::arg("saved_model_dir"),
- py::arg("tags") = std::unordered_set<std::string>(),
- py::arg("exported_names") = std::vector<std::string>());
-}
-
-} // namespace python
-} // namespace iree
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.h b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.h
deleted file mode 100644
index 6df0a8c..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.h
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_TF_INTEROP_REGISTER_TENSORFLOW_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_TF_INTEROP_REGISTER_TENSORFLOW_H_
-
-#include <string>
-
-#include "bindings/python/pyiree/common/binding.h"
-
-namespace iree {
-namespace python {
-
-void SetupTensorFlowBindings(pybind11::module m);
-
-} // namespace python
-} // namespace iree
-
-#endif // IREE_BINDINGS_PYTHON_PYIREE_TF_INTEROP_REGISTER_TENSORFLOW_H_
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
deleted file mode 100644
index 10e55c3..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import importlib
-import os
-import sys
-import tempfile
-
-from pyiree.tf import compiler
-
-# Dynamically import tensorflow.
-try:
- # Use a dynamic import so as to avoid hermetic dependency analysis
- # (i.e. we only want the tensorflow from the environment).
- tf = importlib.import_module("tensorflow")
- # Just in case if linked against a pre-V2 defaulted version.
- if hasattr(tf, "enable_v2_behavior"):
- tf.enable_v2_behavior()
- tf = tf.compat.v2
-except ImportError:
- print("Not running tests because tensorflow is not available")
- sys.exit(0)
-
-
-class StatelessModule(tf.Module):
-
- def __init__(self):
- pass
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.float32),
- tf.TensorSpec([4], tf.float32)
- ])
- def add(self, a, b):
- return tf.tanh(a + b)
-
-
-class RuntimeTest(tf.test.TestCase):
-
- def testLoadSavedModelToXlaPipeline(self):
- """Tests that a basic saved model to XLA workflow grossly functions.
-
- This is largely here to verify that everything is linked in that needs to be
- and that there are not no-ops, etc.
- """
- with tempfile.TemporaryDirectory() as temp_dir:
- sm_dir = os.path.join(temp_dir, "simple.sm")
- print("Saving to:", sm_dir)
- my_module = StatelessModule()
- options = tf.saved_model.SaveOptions(save_debug_info=True)
- tf.saved_model.save(my_module, sm_dir, options=options)
-
- # Load it up.
- input_module = compiler.tf_saved_model_to_compiler_module(sm_dir)
- xla_asm = input_module.to_asm()
- print("XLA ASM:", xla_asm)
- self.assertRegex(xla_asm, "mhlo.tanh")
-
-
-if __name__ == "__main__":
- tf.test.main()
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
deleted file mode 100644
index e8d1e53..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import importlib
-import os
-import sys
-import tempfile
-
-from pyiree.tf import compiler
-
-# Dynamically import tensorflow.
-try:
- # Use a dynamic import so as to avoid hermetic dependency analysis
- # (i.e. we only want the tensorflow from the environment).
- tf = importlib.import_module("tensorflow")
- # Just in case if linked against a pre-V2 defaulted version.
- if hasattr(tf, "enable_v2_behavior"):
- tf.enable_v2_behavior()
- tf = tf.compat.v2
-except ImportError:
- print("Not running tests because tensorflow is not available")
- sys.exit(0)
-
-
-class RuntimeTest(tf.test.TestCase):
-
- def testLoadSignatureDefSavedModel(self):
- """Tests loading a SignatureDef saved model with a single variable."""
-
- with tempfile.TemporaryDirectory() as temp_dir:
- sm_dir = os.path.join(temp_dir, "simple.sm")
- print("Saving to:", sm_dir)
-
- with tf.Graph().as_default() as graph:
- v = tf.Variable(10)
- result = v.read_value()
- tensor_info = tf.compat.v1.saved_model.utils.build_tensor_info(result)
- sig = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
- inputs={}, outputs={"result": tensor_info}, method_name="foo")
- builder = tf.compat.v1.saved_model.Builder(sm_dir)
- with tf.compat.v1.Session(graph=graph) as sess:
- sess.run(v.initializer)
- builder.add_meta_graph_and_variables(sess, ["bar"], {"baz": sig},
- strip_default_attrs=True)
- builder.save()
-
- module = compiler.tf_signature_def_saved_model_to_compiler_module(
- sm_dir, saved_model_tags=set(["bar"]), exported_names=["baz"])
-
- module_asm = module.to_asm(large_element_limit=100)
- self.assertRegexpMatches(module_asm, "flow.variable @[^ ]* dense<10>")
-
-
-if __name__ == "__main__":
- tf.test.main()
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/test/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/test/BUILD
deleted file mode 100644
index 8a2a4f0..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/test/BUILD
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load(
- "//iree:build_defs.oss.bzl",
- "INTREE_FILECHECK_TARGET",
-)
-load(
- "//bindings/python:build_defs.oss.bzl",
- "INTREE_TENSORFLOW_PY_DEPS",
- "iree_py_test",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_py_test(
- name = "saved_model_adopt_exports",
- srcs = [
- "saved_model_adopt_exports.py",
- ],
- args = [
- "--filecheck_binary=$(rootpath %s)" % INTREE_FILECHECK_TARGET,
- ],
- data = [
- INTREE_FILECHECK_TARGET,
- ],
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/test/saved_model_adopt_exports.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/test/saved_model_adopt_exports.py
deleted file mode 100644
index 0eca3f4..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/test/saved_model_adopt_exports.py
+++ /dev/null
@@ -1,340 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests supported features of saved models."""
-
-# pylint: disable=invalid-name
-# pylint: disable=missing-docstring
-# pylint: disable=line-too-long
-
-from pyiree.tf.support import tf_test_driver
-import tensorflow.compat.v2 as tf
-
-SAVED_MODEL_IMPORT_PASSES = [
- "iree-tf-import-pipeline",
- "canonicalize",
-]
-
-
-# Tests that a simple example with flat args and a single result and no
-# captures imports properly.
-# CHECK-LABEL: RUN_TEST: T0001_FlatArgsResultsNoBoundGlobals
-# CHECK: module
-# CHECK-NOT: tf_saved_model.semantics
-# CHECK: @simple_mul_no_capture
-# CHECK: iree.module.export
-# CHECK: FINISH_TEST
-class T0001_FlatArgsResultsNoBoundGlobals(tf.Module):
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.float32),
- tf.TensorSpec([4], tf.float32)
- ])
- def simple_mul_no_capture(self, a, b):
- return a * b
-
-
-tf_test_driver.add_test(test_name="T0001_FlatArgsResultsNoBoundGlobals",
- tf_module_builder=T0001_FlatArgsResultsNoBoundGlobals,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-
-# T0002: Tests that bound global vars import properly.
-
-# CHECK-LABEL: RUN_TEST: T0002a_SimpleVarRead
-# CHECK: flow.variable @v mutable dense<0.000000e+00> : tensor<f32>
-# CHECK: func @f() -> tensor<f32>
-# CHECK: attributes
-# CHECK-SAME: iree.module.export
-# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I1!R3!_0"}
-# CHECK: flow.variable.load @v : tensor<f32>
-# CHECK: FINISH_TEST
-class T0002a_SimpleVarRead(tf.Module):
-
- def __init__(self):
- self.v = tf.Variable(0.)
-
- @tf.function(input_signature=[])
- def f(self):
- return self.v
-
-
-# CHECK-LABEL: RUN_TEST: T0002b_SimpleVarWrite
-# CHECK: flow.variable @v mutable dense<0.000000e+00> : tensor<f32>
-# CHECK: func @f(%arg0: tensor<f32> {tf._user_specified_name = "a"})
-# CHECK: attributes
-# CHECK-SAME: iree.module.export
-# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I8!S5!k0_0R1!"}
-# CHECK: flow.variable.store %arg0, @v : tensor<f32>
-# CHECK: FINISH_TEST
-class T0002b_SimpleVarWrite(tf.Module):
-
- def __init__(self):
- self.v = tf.Variable(0.)
-
- @tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
- def f(self, a):
- self.v.assign(a)
-
-
-# CHECK-LABEL: RUN_TEST: T0002c_SimpleConst
-# CHECK: flow.variable [[CONST:@.+]] dense<0.000000e+00> : tensor<f32>
-# CHECK: func @f() -> tensor<f32>
-# CHECK: attributes
-# CHECK-SAME: iree.module.export
-# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I1!R3!_0"}
-# NOTE: the constant variable gets inlined:
-# CHECK: = constant dense<0.000000e+00> : tensor<f32>
-# CHECK: FINISH_TEST
-class T0002c_SimpleConst(tf.Module):
-
- def __init__(self):
- self.c = tf.constant(0.)
-
- @tf.function(input_signature=[])
- def f(self):
- return self.c
-
-
-# CHECK-LABEL: RUN_TEST: T0002d_VarCompatibleShapeChange
-# CHECK: flow.variable @v mutable dense<0.000000e+00> : tensor<1xf32>
-# CHECK: func @f()
-# CHECK: attributes
-# CHECK-SAME: iree.module.export
-# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I1!R1!"}
-# CHECK-DAG: [[CONST_2xf32:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]>
-# CHECK-DAG: [[CONST_3xf32:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]>
-# CHECK-DAG: flow.variable.store [[CONST_2xf32]], @v : tensor<2xf32>
-# CHECK-DAG: flow.variable.store [[CONST_3xf32]], @v : tensor<3xf32>
-# CHECK: FINISH_TEST
-class T0002d_VarCompatibleShapeChange(tf.Module):
-
- def __init__(self):
- self.v = tf.Variable([0.], shape=[None])
-
- @tf.function(input_signature=[])
- def f(self):
- self.v.assign(tf.constant([0., 1.]))
- self.v.assign(tf.constant([0., 1., 2.]))
-
-
-# CHECK-LABEL: RUN_TEST: T0002e_Error_VarMultipleExportedNames
-# CHECK: [ERROR]: Multiple exported names for global tensor not supported yet
-# CHECK: FINISH_TEST
-class T0002e_Error_VarMultipleExportedNames(tf.Module):
-
- def __init__(self):
- self.v = tf.Variable(0.)
- self.v2 = self.v
-
-
-tf_test_driver.add_test(test_name="T0002a_SimpleVarRead",
- tf_module_builder=T0002a_SimpleVarRead,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-tf_test_driver.add_test(test_name="T0002b_SimpleVarWrite",
- tf_module_builder=T0002b_SimpleVarWrite,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-tf_test_driver.add_test(test_name="T0002c_SimpleConst",
- tf_module_builder=T0002c_SimpleConst,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-tf_test_driver.add_test(test_name="T0002d_VarCompatibleShapeChange",
- tf_module_builder=T0002d_VarCompatibleShapeChange,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-tf_test_driver.add_test(test_name="T0002e_Error_VarMultipleExportedNames",
- tf_module_builder=T0002e_Error_VarMultipleExportedNames,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True,
- expect_pass_failure=True)
-
-
-# Tests that a structured argument is handled properly.
-# NOTE: This is currently an error and needs to be implemented
-# CHECK-LABEL: RUN_TEST: T0003a_StructuredArgs
-# CHECK: func @simple_mul
-# CHECK: attributes
-# CHECK-SAME: iree.module.export
-# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I23!S19!k0D13!K2!x_0K2!y_1R3!_0"}
-# CHECK: FINISH_TEST
-class T0003a_StructuredArgs(tf.Module):
-
- @tf.function(input_signature=[{
- "x": tf.TensorSpec([4], tf.float32),
- "y": tf.TensorSpec([4], tf.float32)
- }])
- def simple_mul(self, d):
- return d["x"] * d["y"]
-
-
-tf_test_driver.add_test(test_name="T0003a_StructuredArgs",
- tf_module_builder=T0003a_StructuredArgs,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-
-
-# Tests that a structured argument is handled properly.
-# NOTE: This is currently an error and needs to be implemented
-# CHECK-LABEL: RUN_TEST: T0003b_StructuredMultipleDictResult
-# CHECK: func @simple_mul
-# CHECK: attributes
-# CHECK-SAME: iree.module.export
-# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I12!S9!k0_0k1_1R26!D22!K2!x_0K10!x_squared_1"}
-# CHECK: FINISH_TEST
-class T0003b_StructuredMultipleDictResult(tf.Module):
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.float32),
- tf.TensorSpec([4], tf.float32)
- ])
- def simple_mul(self, a, b):
- product = a * b
- return {"x": product, "x_squared": product * product}
-
-
-tf_test_driver.add_test(test_name="T0003b_StructuredMultipleDictResult",
- tf_module_builder=T0003b_StructuredMultipleDictResult,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-
-
-# Tests that a structured argument is handled properly.
-# NOTE: This is currently an error and needs to be implemented
-# CHECK-LABEL: RUN_TEST: T0003c_StructuredSingleDictResult
-# CHECK: func @simple_mul
-# CHECK: attributes
-# CHECK-SAME: iree.module.export
-# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I12!S9!k0_0k1_1R10!D7!K2!x_0"}
-# CHECK: FINISH_TEST
-class T0003c_StructuredSingleDictResult(tf.Module):
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.float32),
- tf.TensorSpec([4], tf.float32)
- ])
- def simple_mul(self, a, b):
- product = a * b
- return {"x": product}
-
-
-tf_test_driver.add_test(test_name="T0003c_StructuredSingleDictResult",
- tf_module_builder=T0003c_StructuredSingleDictResult,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-
-
-# Tests that a structured argument is handled properly.
-# NOTE: This is currently an error and needs to be implemented
-# CHECK-LABEL: RUN_TEST: T0003d_StructuredSingleResult
-# CHECK: func @simple_mul
-# CHECK: attributes
-# CHECK-SAME: iree.module.export
-# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I12!S9!k0_0k1_1R3!_0"}
-# CHECK: FINISH_TEST
-class T0003d_StructuredSingleResult(tf.Module):
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.float32),
- tf.TensorSpec([4], tf.float32)
- ])
- def simple_mul(self, a, b):
- product = a * b
- return product
-
-
-tf_test_driver.add_test(test_name="T0003d_StructuredSingleResult",
- tf_module_builder=T0003d_StructuredSingleResult,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-
-
-# Tests that a structured argument is handled properly.
-# NOTE: This is currently an error and needs to be implemented
-# CHECK-LABEL: RUN_TEST: T0003e_StructuredSequenceResult
-# CHECK: func @simple_mul
-# CHECK: attributes
-# CHECK-SAME: iree.module.export
-# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I12!S9!k0_0k1_1R17!S13!k0_0k1_1k2_2"}
-# CHECK: FINISH_TEST
-class T0003e_StructuredSequenceResult(tf.Module):
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.float32),
- tf.TensorSpec([4], tf.float32)
- ])
- def simple_mul(self, a, b):
- product = a * b
- return product, a, b
-
-
-tf_test_driver.add_test(test_name="T0003e_StructuredSequenceResult",
- tf_module_builder=T0003e_StructuredSequenceResult,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-
-
-# Tests that a structured argument is handled properly.
-# NOTE: This is currently an error and needs to be implemented
-# CHECK-LABEL: RUN_TEST: T0003f_StructuredNestedResult
-# CHECK: func @simple_mul
-# CHECK: attributes
-# CHECK-SAME: iree.module.export
-# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I12!S9!k0_0k1_1R27!S23!k0_0k1D13!K2!a_1K2!b_2"}
-# CHECK: FINISH_TEST
-class T0003f_StructuredNestedResult(tf.Module):
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.float32),
- tf.TensorSpec([4], tf.float32)
- ])
- def simple_mul(self, a, b):
- product = a * b
- return product, {"a": a, "b": b}
-
-
-tf_test_driver.add_test(test_name="T0003f_StructuredNestedResult",
- tf_module_builder=T0003f_StructuredNestedResult,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True)
-
-
-# Tests that a structured argument is handled properly.
-# NOTE: This is currently an error and needs to be implemented
-# CHECK-LABEL: RUN_TEST: T0005_MultipleExportedFuncNames
-# CHECK: [ERROR]: Multiple exported names not supported yet
-# CHECK: FINISH_TEST_WITH_EXCEPTION
-class T0005_MultipleExportedFuncNames(tf.Module):
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.float32),
- tf.TensorSpec([4], tf.float32)
- ])
- def simple_mul(self, a, b):
- product = a * b
- return {"x": product}
-
-
-# Force a function alias.
-T0005_MultipleExportedFuncNames.another_copy = (
- T0005_MultipleExportedFuncNames.simple_mul)
-
-tf_test_driver.add_test(test_name="T0005_MultipleExportedFuncNames",
- tf_module_builder=T0005_MultipleExportedFuncNames,
- passes=SAVED_MODEL_IMPORT_PASSES,
- print_input_module=True,
- expect_pass_failure=True)
-
-if __name__ == "__main__":
- tf_test_driver.run_tests(__file__, with_filecheck=True)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
deleted file mode 100644
index 09fe138..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
+++ /dev/null
@@ -1,98 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "INTREE_TENSORFLOW_PY_DEPS",
- "iree_py_library",
- "iree_py_test",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_py_library(
- name = "support",
- srcs = [
- "__init__.py",
- "module_utils.py",
- "tf_test_driver.py",
- "tf_test_utils.py",
- "tf_utils.py",
- "trace_utils.py",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + [
- "//integrations/tensorflow/bindings/python:pathsetup", # build_cleaner: keep
- "//integrations/tensorflow/bindings/python/pyiree/tf/compiler",
- "//bindings/python/pyiree/rt",
- ],
-)
-
-iree_py_test(
- name = "module_utils_test",
- srcs = [
- "module_utils.py",
- "module_utils_test.py",
- ],
- python_version = "PY3",
- tags = [
- "driver=llvm",
- "driver=vmla",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_py_test(
- name = "tf_test_utils_test",
- srcs = [
- "tf_test_utils.py",
- "tf_test_utils_test.py",
- ],
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_py_test(
- name = "tf_utils_test",
- srcs = [
- "tf_utils.py",
- "tf_utils_test.py",
- ],
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_py_test(
- name = "trace_utils_test",
- srcs = [
- "trace_utils.py",
- "trace_utils_test.py",
- ],
- python_version = "PY3",
- tags = [
- "driver=vmla",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
diff --git a/bindings/python/pyiree/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/support/CMakeLists.txt
similarity index 71%
rename from bindings/python/pyiree/BUILD
rename to integrations/tensorflow/bindings/python/pyiree/tf/support/CMakeLists.txt
index 01d988d..627c40c 100644
--- a/bindings/python/pyiree/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright 2019 Google LLC
+# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,8 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
+iree_py_library(
+ NAME
+ support
+ SRCS
+ "__init__.py"
+ "module_utils.py"
+ "tf_test_driver.py"
+ "tf_test_utils.py"
+ "tf_utils.py"
+ "trace_utils.py"
)
+
+# TODO: Add tests.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
index fb55501..5a799c6 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
@@ -17,13 +17,13 @@
import collections
import os
import tempfile
-from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union
from absl import flags
from absl import logging
import numpy as np
from pyiree import rt
-from pyiree.tf import compiler
+from pyiree.compiler2 import tf as tf_compiler
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -40,108 +40,64 @@
return "TEST_TMPDIR" in os.environ
-def _setup_mlir_crash_reproducer(
- function: Any, # pytype doesn't support arbitrary Callable[*args, **kwargs]
- artifacts_dir: str,
- backend_id: str,
-) -> Any: # Callable[Any, Any]
- """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
+def _get_tf_import_output_kwargs(artifacts_dir: str,
+ backend_id: str,
+ *,
+ needs_temp_saved_model_dir: bool = False):
+ """Gets output kwargs dict to pass to tf.compile() for output generation.
- Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
-
- Args:
- function: The callable to decorate.
- artifacts_dir: The directory to write the reproducer to.
- backend_id: The unique backend name to use when writting the reproducer.
-
- Returns:
- A function with the same API as the passed function.
- """
-
- def decorator(*args, **kwargs):
- # Set up a crash reproducer for debugging.
- if artifacts_dir is not None:
- compiler.Context.default_crash_reproducer_path = os.path.join(
- artifacts_dir, f"reproducer__{backend_id}.mlir")
- try:
- results = function(*args, **kwargs)
- except Exception: # pylint: disable=broad-except
- # Disable the crash reproducer (to avoid inadvertently overwriting it).
- if artifacts_dir is not None:
- compiler.Context.default_crash_reproducer_path = None
- raise
- return results
-
- return decorator
-
-
-def _incrementally_lower_compiler_module(
- compiler_module: compiler.Module,
- backend_info: "BackendInfo",
- artifacts_dir: str,
-) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
- """Lowers a MLIR compiler module incrementally and saves its outputs.
-
- If artifacts_dir is provided then the following artifacts will be saved:
+ When artifacts_dir is set, writes:
tf_input.mlir:
MLIR for the module in TF's input dialect.
iree_input.mlir:
The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
backend_id/compiled.vmfb:
A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
+ `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
Args:
- compiler_module: A compiler.Module to lower.
- backend_info: BackendInfo with the details for lowering compiler_module to
- IREE.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
+ artifacts_dir: The artifacts directory.
+ backend_id: The backend id (for artifacts that are backend dependent).
+ needs_temp_saved_model_dir: Whether a temporary 'saved_model_dir' directory
+ needs to be set.
+
+ Returns:
+ A dict of output kwargs.
"""
- if artifacts_dir is not None:
- os.makedirs(artifacts_dir, exist_ok=True)
- tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
- logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
- with open(tf_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
+ kwargs = {}
+ backend_dir = os.path.join(artifacts_dir, backend_id)
+ os.makedirs(backend_dir, exist_ok=True)
+ kwargs["output_file"] = os.path.join(backend_dir, "compiled.vmfb")
+ if needs_temp_saved_model_dir:
+ kwargs["saved_model_dir"] = os.path.join(artifacts_dir,
+ "tfmodule.saved_model")
+ kwargs["save_temp_tf_input"] = os.path.join(artifacts_dir, "tf_input.mlir")
+ kwargs["save_temp_iree_input"] = os.path.join(artifacts_dir,
+ "iree_input.mlir")
- # Manually run the passes that tf_module_to_compiler_module usually would.
- compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
-
- if artifacts_dir is not None:
- iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
- logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
- with open(iree_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
-
- compiled_module = compiler_module.compile(
- target_backends=backend_info.compiler_targets)
-
- compiled_path = None
- if artifacts_dir is not None:
- backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
- os.makedirs(backend_dir, exist_ok=True)
- compiled_path = os.path.join(backend_dir, "compiled.vmfb")
- logging.info("Saving compiled IREE module to: %s", compiled_path)
- with open(compiled_path, "wb") as f:
- f.write(compiled_module)
- return compiled_module, compiled_path
+ # Avoid the crash reproducer under tests or if the flag is false.
+ if (FLAGS.capture_crash_reproducer):
+ kwargs["crash_reproducer_path"] = os.path.join(
+ artifacts_dir, f"reproducer__{backend_id}.mlir")
+ else:
+ logging.info("Crash reproducer suppressed")
+ logging.info(
+ "Outputting intermediate artifacts (--artifacts_dir is set):\n%s",
+ "\n".join(f" {k}: {v}" for k, v in kwargs.items()))
+ return kwargs
def _incrementally_compile_tf_module(
module: Type[tf.Module],
backend_info: "BackendInfo",
exported_names: Sequence[str] = (),
- artifacts_dir: str = None,
-) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
+ artifacts_dir: Optional[str] = None,
+) -> Tuple[bytes, Optional[str]]:
"""Compile a TensorFlow tf.Module and optionally save compilation artifacts.
The module blob this creates is not callable. See IreeCompiledModule for an
API that returns a module that can be called without any further steps.
- See _incrementally_lower_compiler_module's docstring for details about which
- artifacts will be saved.
-
Args:
module: A tf.Module.
backend_info: BackendInfo with the details for compiling this module.
@@ -154,22 +110,23 @@
A compiled IREE module blob and the path to the compiled VM FlatBuffer if
artifacts_dir is provided.
"""
+ output_kwargs = (
+ _get_tf_import_output_kwargs(
+ artifacts_dir,
+ backend_info.backend_id,
+ needs_temp_saved_model_dir=True,
+ ) if artifacts_dir else {})
+ immediate_result = tf_compiler.compile_module(
+ module,
+ target_backends=backend_info.compiler_targets,
+ exported_names=exported_names,
+ **output_kwargs)
- def _compile_module(module, backend_info, exported_names, artifacts_dir):
- compiler_module = compiler.tf_module_to_compiler_module(module,
- exported_names,
- pass_pipeline=())
- return _incrementally_lower_compiler_module(compiler_module, backend_info,
- artifacts_dir)
-
- # Avoid the crash reproducer under tests or if the flag is false.
- # Developers can run tests outside of the test runner (e.g. `bazel run`) to
- # use the crash reproducer.
- if (FLAGS.capture_crash_reproducer and not _running_bazel_test()):
- _compile_module = _setup_mlir_crash_reproducer(_compile_module,
- artifacts_dir,
- backend_info.backend_id)
- return _compile_module(module, backend_info, exported_names, artifacts_dir)
+ output_file = output_kwargs.get("output_file")
+ if output_file:
+ with open(output_file, "rb") as f:
+ immediate_result = f.read()
+ return immediate_result, output_file
def _incrementally_compile_tf_signature_def_saved_model(
@@ -180,9 +137,6 @@
The module blob this creates is not callable. See IreeCompiledModule for an
API that returns a module that can be called without any further steps.
- See _incrementally_lower_compiler_module's docstring for details about which
- artifacts will be saved.
-
Args:
saved_model_dir: Directory of the saved model.
saved_model_tags: Optional set of tags to use when loading the model.
@@ -197,24 +151,22 @@
A compiled IREE module blob and the path to the compiled VM FlatBuffer if
artifacts_dir is provided.
"""
+ output_kwargs = (
+ _get_tf_import_output_kwargs(artifacts_dir, backend_info.backend_id)
+ if artifacts_dir else {})
+ immediate_result = tf_compiler.compile_saved_model(
+ saved_model_dir,
+ import_type="SIGNATURE_DEF",
+ target_backends=backend_info.compiler_targets,
+ exported_names=[exported_name],
+ saved_model_tags=saved_model_tags,
+ **output_kwargs)
- def _compile_module(saved_model_dir, saved_model_tags, backend_info,
- exported_name, artifacts_dir):
- # Convert the tf_module into raw TF input MLIR.
- compiler_module = compiler.tf_signature_def_saved_model_to_compiler_module(
- saved_model_dir, saved_model_tags, [exported_name], pass_pipeline=())
- return _incrementally_lower_compiler_module(compiler_module, backend_info,
- artifacts_dir)
-
- # Avoid the crash reproducer under tests or if the flag is false.
- # Developers can run tests outside of the test runner (e.g. `bazel run`) to
- # use the crash reproducer.
- if (FLAGS.capture_crash_reproducer and not _running_bazel_test()):
- _compile_module = _setup_mlir_crash_reproducer(_compile_module,
- artifacts_dir,
- backend_info.backend_id)
- return _compile_module(saved_model_dir, saved_model_tags, backend_info,
- exported_name, artifacts_dir)
+ output_file = output_kwargs.get("output_file")
+ if output_file:
+ with open(output_file, "rb") as f:
+ immediate_result = f.read()
+ return immediate_result, output_file
class _FunctionWrapper(object):
diff --git a/integrations/tensorflow/bindings/python/pyiree/xla/compiler/BUILD b/integrations/tensorflow/bindings/python/pyiree/xla/compiler/BUILD
deleted file mode 100644
index 9c9e486..0000000
--- a/integrations/tensorflow/bindings/python/pyiree/xla/compiler/BUILD
+++ /dev/null
@@ -1,103 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "INTREE_TENSORFLOW_PY_DEPS",
- "NUMPY_DEPS",
- "PYBIND_COPTS",
- "PYBIND_EXTENSION_COPTS",
- "PYBIND_FEATURES",
- "iree_py_extension",
- "iree_py_library",
- "iree_py_test",
- "pybind_cc_library",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_py_library(
- name = "compiler",
- srcs = [
- "__init__.py",
- ],
- srcs_version = "PY3",
- tags = [
- "nokokoro",
- ],
- deps = [
- ":binding",
- "//bindings/python:pathsetup", # build_cleaner: keep
- "//integrations/tensorflow/bindings/python:pathsetup", # build_cleaner: keep
- ],
-)
-
-iree_py_extension(
- name = "binding",
- srcs = [
- "initialize_module.cc",
- ],
- copts = PYBIND_COPTS + PYBIND_EXTENSION_COPTS,
- features = PYBIND_FEATURES,
- linkstatic = 1,
- tags = [
- "nokokoro",
- ],
- win_def_file = "export.def",
- deps = [
- ":compiler_library",
- "//bindings/python/pyiree/common",
- "//bindings/python/pyiree/compiler:compiler_library",
- ],
-)
-
-pybind_cc_library(
- name = "compiler_library",
- srcs = [
- "register_xla.cc",
- ],
- hdrs = [
- "register_xla.h",
- ],
- tags = [
- "nokokoro",
- ],
- deps = [
- "//bindings/python/pyiree/common",
- "//bindings/python/pyiree/compiler:compiler_library",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo_module_importer",
- "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation",
- ],
-)
-
-iree_py_test(
- name = "xla_module_proto_test",
- srcs = ["xla_module_proto_test.py"],
- python_version = "PY3",
- tags = [
- "manual",
- "nokokoro",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "@absl_py//absl/testing:absltest",
- "//integrations/tensorflow/bindings/python/pyiree/xla/compiler",
- "@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
- ],
-)
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
index ce92df0..b5c8552 100644
--- a/integrations/tensorflow/compiler/BUILD
+++ b/integrations/tensorflow/compiler/BUILD
@@ -104,6 +104,7 @@
"@org_tensorflow//tensorflow/compiler/mlir:init_mlir",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+ "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", # Note: always_links in random foo that makes constant folding work
"@org_tensorflow//tensorflow/core/platform:errors",
],
)
diff --git a/integrations/tensorflow/compiler/iree-tf-import-main.cc b/integrations/tensorflow/compiler/iree-tf-import-main.cc
index edeb093..62631b8 100644
--- a/integrations/tensorflow/compiler/iree-tf-import-main.cc
+++ b/integrations/tensorflow/compiler/iree-tf-import-main.cc
@@ -37,7 +37,6 @@
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/core/platform/errors.h"
-
using namespace llvm;
using namespace mlir;
@@ -56,13 +55,14 @@
tensorflow::SavedModelV2Bundle bundle;
auto loadStatus = tensorflow::SavedModelV2Bundle::Load(inputPath, &bundle);
if (!loadStatus.ok()) {
- std::cerr << "TensorFlow reported error loading saved model:\n "
- << loadStatus << "\n\n";
+ llvm::errs() << "TensorFlow reported error loading saved model:\n "
+ << loadStatus.ToString() << "\n\n";
if (!tensorflow::errors::IsNotFound(loadStatus)) {
- std::cerr << "Note: Attempted to load V2 SavedModel. Double check that "
- "this is correct "
- << "and adjust via the flag "
- "--tf-import-type=savedmodel_v1|savedmodel_v2\n";
+ llvm::errs()
+ << "Note: Attempted to load V2 SavedModel. Double check that "
+ "this is correct "
+ << "and adjust via the flag "
+ "--tf-import-type=savedmodel_v1|savedmodel_v2\n";
}
return nullptr;
}
@@ -72,9 +72,9 @@
auto loadedModule = tensorflow::ConvertSavedModelToMlir(
&bundle, &context, absl::MakeSpan(exportedNamesVector));
if (!loadedModule.ok()) {
- std::cerr << "Error performing initial import from SavedModel to MLIR. "
- << "Reported error below (and see diagnostics):\n"
- << " " << loadedModule.status() << "\n";
+ llvm::errs() << "Error performing initial import from SavedModel to MLIR. "
+ << "Reported error below (and see diagnostics):\n"
+ << " " << loadedModule.status().ToString() << "\n";
return nullptr;
}
@@ -96,13 +96,14 @@
tensorflow::LoadSavedModel(session_options,
/*run_options=*/{}, inputPath, tags, &bundle);
if (!loadStatus.ok()) {
- std::cerr << "TensorFlow reported error loading saved model:\n "
- << loadStatus << "\n\n";
+ llvm::errs() << "TensorFlow reported error loading saved model:\n "
+ << loadStatus.ToString() << "\n\n";
if (!tensorflow::errors::IsNotFound(loadStatus)) {
- std::cerr << "Note: Attempted to load V1 SavedModel. Double check that "
- "this is correct "
- << "and adjust via the flag "
- "--tf-import-type=savedmodel_v1|savedmodel_v2\n";
+ llvm::errs()
+ << "Note: Attempted to load V1 SavedModel. Double check that "
+ "this is correct "
+ << "and adjust via the flag "
+ "--tf-import-type=savedmodel_v1|savedmodel_v2\n";
}
return nullptr;
}
@@ -115,9 +116,9 @@
/*upgrade_legacy=*/false);
if (!loadedModule.ok()) {
- std::cerr << "Error performing initial import from SavedModel to MLIR. "
- << "Reported error below (and see diagnostics):\n"
- << " " << loadedModule.status() << "\n";
+ llvm::errs() << "Error performing initial import from SavedModel to MLIR. "
+ << "Reported error below (and see diagnostics):\n"
+ << " " << loadedModule.status().ToString() << "\n";
return nullptr;
}
@@ -150,6 +151,15 @@
llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
"separated by ','"),
llvm::cl::init("serve"));
+ static llvm::cl::opt<std::string> saveTempTfInput(
+ "save-temp-tf-input",
+ llvm::cl::desc("Save the TF pipeline input to this file"),
+ llvm::cl::init(""));
+ static llvm::cl::opt<std::string> saveTempIreeImport(
+ "save-temp-iree-input",
+ llvm::cl::desc("Save the resultant IR to this file (useful for saving an "
+ "intermediate in a pipeline)"),
+ llvm::cl::init(""));
// Register any command line options.
registerAsmPrinterCLOptions();
@@ -164,6 +174,20 @@
OwningModuleRef module;
registry.loadAll(&context);
+ auto saveToFile = [&](llvm::StringRef savePath) -> LogicalResult {
+ auto outputFile = openOutputFile(savePath);
+ if (!outputFile) {
+ llvm::errs() << "Could not open output file: " << savePath << "\n";
+ return failure();
+ }
+ OpPrintingFlags printFlags;
+ printFlags.enableDebugInfo();
+ module->print(outputFile->os(), printFlags);
+ outputFile->os() << "\n";
+ outputFile->keep();
+ return success();
+ };
+
// First stage import.
switch (importType) {
case savedmodel_v2:
@@ -178,25 +202,28 @@
}
if (!module) return 1;
+ // Save temp output.
+ if (!saveTempTfInput.empty()) {
+ if (failed(saveToFile(saveTempTfInput))) return 10;
+ }
+
// Run passes.
PassManager pm(&context, PassManager::Nesting::Implicit);
+ applyPassManagerCLOptions(pm);
+
iree_compiler::TF::buildTFImportPassPipeline(pm);
if (failed(pm.run(*module))) {
- std::cerr
+ llvm::errs()
<< "Running iree-tf-import pass pipeline failed (see diagnostics)\n";
return 2;
}
- // Output.
- auto outputFile = openOutputFile(outputFilename);
- if (!outputFile) {
- return 3;
+ // Save temp output.
+ if (!saveTempIreeImport.empty()) {
+ if (failed(saveToFile(saveTempIreeImport))) return 10;
}
- OpPrintingFlags printFlags;
- printFlags.enableDebugInfo();
- module->print(outputFile->os(), printFlags);
- outputFile->os() << "\n";
- outputFile->keep();
+ // Save output.
+ if (failed(saveToFile(outputFilename))) return 3;
return 0;
}
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 28f69d1..2f61a33 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -43,7 +43,7 @@
main = src,
python_version = "PY3",
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ "//third_party/py/pyiree:pylib_tf_support",
],
)
for src in glob(["*_test.py"])
@@ -87,6 +87,7 @@
"mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
"ring_buffer_test.py", # TODO(b/148747011)
"strings_test.py",
+ "tensorlist_test.py", # TODO(suderman): Re-enable once dependencies resolved
]
# keep sorted
@@ -109,6 +110,7 @@
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
"strings_test.py",
+ "tensorlist_test.py", # TODO(suderman): Re-enable once dependencies resolved
]
# keep sorted
@@ -131,6 +133,7 @@
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
"strings_test.py",
+ "tensorlist_test.py", # TODO(suderman): Re-enable once dependencies resolved
]
TF_PASSING = glob(
@@ -169,7 +172,7 @@
},
reference_backend = "tf",
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ "//third_party/py/pyiree:pylib_tf_support",
],
)
@@ -189,7 +192,7 @@
"notap",
],
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ "//third_party/py/pyiree:pylib_tf_support",
],
)
@@ -213,6 +216,6 @@
"notap",
],
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ "//third_party/py/pyiree:pylib_tf_support",
],
)
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
deleted file mode 100644
index 9a8b540..0000000
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ /dev/null
@@ -1,216 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. Coverage tables generated from this file can be viewed here:
-# https://google.github.io/iree/tensorflow-coverage
-# Updates made to test suite names should also be reflected here:
-# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "INTREE_TENSORFLOW_PY_DEPS",
- "NUMPY_DEPS",
- "iree_py_binary",
-)
-load(
- "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
- "iree_e2e_cartesian_product_test_suite",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-[
- iree_py_binary(
- name = src.replace(".py", "_manual"),
- srcs = [src],
- main = src,
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for src in glob(
- ["*_test.py"],
- exclude = ["keyword_spotting_streaming_test.py"],
- )
-]
-
-# Keyword Spotting Tests:
-KEYWORD_SPOTTING_MODELS = [
- "svdf",
- "svdf_resnet",
- "ds_cnn",
- "gru",
- "lstm",
- "cnn_stride",
- "cnn",
- "tc_resnet",
- "crnn",
- "dnn",
- "att_rnn",
- "att_mh_rnn",
- "mobilenet",
- "mobilenet_v2",
- "xception",
- "inception",
- "inception_resnet",
- "ds_tc_resnet",
-]
-
-NON_STREAMING_KEYWORD_SPOTTING_MODELS = [
- "att_mh_rnn",
- "att_rnn",
- "ds_cnn",
- "inception",
- "inception_resnet",
- "mobilenet",
- "mobilenet_v2",
- "svdf_resnet",
- "tc_resnet",
- "xception",
-]
-
-iree_e2e_cartesian_product_test_suite(
- name = "keyword_spotting_tests",
- srcs = ["keyword_spotting_streaming_test.py"],
- failing_configurations = [
- {
- # TODO(#4065): VMLA raises INVALID_ARGUMENT errors after DeviceQueue failure.
- "model": [
- "att_mh_rnn",
- "att_rnn",
- "crnn",
- "gru",
- "lstm",
- ],
- "target_backends": "iree_vmla",
- },
- {
- # Timing out on SwiftShader
- "model": [
- "att_mh_rnn",
- "att_rnn",
- ],
- "target_backends": "iree_vulkan",
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "mode": "non_streaming",
- "model": KEYWORD_SPOTTING_MODELS,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "keyword_spotting_streaming_test.py",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- "@kws_streaming//:models_lib",
- "@kws_streaming//:train_lib",
- ],
-)
-
-iree_e2e_cartesian_product_test_suite(
- name = "keyword_spotting_internal_streaming_tests",
- srcs = ["keyword_spotting_streaming_test.py"],
- failing_configurations = [
- {
- # TFLite cannot compile variables.
- "target_backends": "tflite",
- },
- {
- # These models do not currently support streaming.
- "model": NON_STREAMING_KEYWORD_SPOTTING_MODELS,
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "mode": "internal_streaming",
- "model": KEYWORD_SPOTTING_MODELS,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "keyword_spotting_streaming_test.py",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- "@kws_streaming//:models_lib",
- "@kws_streaming//:train_lib",
- ],
-)
-
-iree_e2e_cartesian_product_test_suite(
- name = "keyword_spotting_external_streaming_tests",
- srcs = ["keyword_spotting_streaming_test.py"],
- failing_configurations = [
- {
- # A bug in keras causes the external steraming conversion to fail
- # when TensorFlow 2.x is used.
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- {
- # These models do not currently support streaming.
- "model": NON_STREAMING_KEYWORD_SPOTTING_MODELS,
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "mode": "external_streaming",
- "model": KEYWORD_SPOTTING_MODELS,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "keyword_spotting_streaming_test.py",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- "@kws_streaming//:models_lib",
- "@kws_streaming//:train_lib",
- ],
-)
-
-iree_py_binary(
- name = "keyword_spotting_streaming_test_manual",
- srcs = ["keyword_spotting_streaming_test.py"],
- main = "keyword_spotting_streaming_test.py",
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- "@kws_streaming//:models_lib",
- "@kws_streaming//:train_lib",
- ],
-)
diff --git a/integrations/tensorflow/e2e/keras/applications/BUILD b/integrations/tensorflow/e2e/keras/applications/BUILD
deleted file mode 100644
index a9df330..0000000
--- a/integrations/tensorflow/e2e/keras/applications/BUILD
+++ /dev/null
@@ -1,318 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. Coverage tables generated from this file can be viewed here:
-# https://google.github.io/iree/tensorflow-coverage/vision-coverage
-# Updates made to test suite names should also be reflected here:
-# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "INTREE_TENSORFLOW_PY_DEPS",
- "NUMPY_DEPS",
- "iree_py_binary",
-)
-load(
- "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
- "iree_e2e_cartesian_product_test_suite",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-# @unused
-DOC = """
-applications_test_manual is for manual testing of all keras vision models.
-Test will run only manually with all parameters specified manually, for example:
-bazel run -c opt integrations/tensorflow/e2e/keras/applications:applications_test_manual -- \
---target_backends=tf,iree_vmla \
---data=imagenet \
---url=https://storage.googleapis.com/iree_models/ \
---model=ResNet50
-
-Command arguments description:
---target_backends: can be combination of these: tf,iree_vmla
---data: can be 'imagenet' or 'cifar10'.
- imagenet - input image size (1, 224, 224, 3)
- cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
- and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
---include_top: Whether or not to include the final (top) layers of the model.
---url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
- imagenet pretrained weights url is specified by keras
---model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
- ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
- InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
- DenseNet201, NASNetMobile, NASNetLarge
- All above models works with 'imagenet' data sets.
- ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
-"""
-
-[
- iree_py_binary(
- name = src.replace(".py", "_manual"),
- srcs = [src],
- main = src,
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for src in glob(
- ["*_test.py"],
- )
-]
-
-KERAS_APPLICATIONS_MODELS = [
- "DenseNet121",
- "DenseNet169",
- "DenseNet201",
- "EfficientNetB0",
- "EfficientNetB1",
- "EfficientNetB2",
- "EfficientNetB3",
- "EfficientNetB4",
- "EfficientNetB5",
- "EfficientNetB6",
- "EfficientNetB7",
- "InceptionResNetV2",
- "InceptionV3",
- "MobileNet",
- "MobileNetV2",
- "MobileNetV3Large",
- "MobileNetV3Small",
- "NASNetLarge",
- "NASNetMobile",
- "ResNet101",
- "ResNet101V2",
- "ResNet152",
- "ResNet152V2",
- "ResNet50",
- "ResNet50V2",
- "VGG16",
- "VGG19",
-]
-
-iree_e2e_cartesian_product_test_suite(
- name = "large_cifar10_tests",
- size = "large",
- srcs = ["applications_test.py"],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "cifar10",
- "model": [
- # All models with runtime shorter than ResNet50.
- "MobileNet", # Max: Vulkan 61.0s
- "MobileNetV2", # Max: LLVM 96.3s
- "ResNet50", # Max: LLVM 145.6s
- "VGG16", # Max: LLVM 89.5s
- "VGG19", # Max: LLVM 94.7s
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "applications_test.py",
- tags = ["manual"],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_e2e_cartesian_product_test_suite(
- name = "enormous_cifar10_tests",
- size = "enormous",
- srcs = ["applications_test.py"],
- failing_configurations = [
- {
- # Failing on vmla with negative inputs.
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- ],
- "target_backends": "iree_vmla",
- },
- {
- # Failing on llvm and vulkan:
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50V2",
- "ResNet101V2",
- "ResNet152V2",
- ],
- "target_backends": [
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "cifar10",
- "model": [
- "DenseNet121",
- "DenseNet169",
- "DenseNet201",
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50V2",
- "ResNet101",
- "ResNet101V2",
- "ResNet152",
- "ResNet152V2",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "applications_test.py",
- tags = [
- "guitar",
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# 'non_hermetic' tests use real model weights to test numerical correctness.
-iree_e2e_cartesian_product_test_suite(
- name = "cifar10_non_hermetic_tests",
- size = "large",
- srcs = ["applications_test.py"],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "cifar10",
- "url": "https://storage.googleapis.com/iree_models/",
- "use_external_weights": True,
- "model": [
- "MobileNet",
- "MobileNetV2",
- "ResNet50",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "applications_test.py",
- tags = [
- "external",
- "guitar",
- "manual",
- "no-remote",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# 'non_hermetic' tests use real model weights to test numerical correctness.
-iree_e2e_cartesian_product_test_suite(
- name = "imagenet_non_hermetic_tests",
- size = "enormous",
- srcs = ["applications_test.py"],
- failing_configurations = [
- {
- # Failing on vmla with negative inputs.
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- ],
- "target_backends": "iree_vmla",
- },
- {
- # Failing vulkan:
- "model": [
- "InceptionResNetV2",
- "InceptionV3",
- ],
- "target_backends": [
- "iree_vulkan",
- ],
- },
- {
- # Failing llvm and vulkan:
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50V2",
- "ResNet101V2",
- "ResNet152V2",
- "Xception",
- ],
- "target_backends": [
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "imagenet",
- "use_external_weights": True,
- "model": KERAS_APPLICATIONS_MODELS,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "applications_test.py",
- tags = [
- "external",
- "guitar",
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# It is used to produce weights for keras vision models with input image size
-# 32x32. These models are not optimized for accuracy or latency (they are for
-# debugging only). They have the same neural net topology with keras vision
-# models trained on imagenet data sets
-iree_py_binary(
- name = "train_vision_models_on_cifar",
- srcs = ["train_vision_models_on_cifar.py"],
- python_version = "PY3",
- srcs_version = "PY2AND3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
diff --git a/integrations/tensorflow/e2e/keras/layers/BUILD b/integrations/tensorflow/e2e/keras/layers/BUILD
deleted file mode 100644
index b3e74e5..0000000
--- a/integrations/tensorflow/e2e/keras/layers/BUILD
+++ /dev/null
@@ -1,549 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. Coverage tables generated from this file can be viewed here:
-# https://google.github.io/iree/tensorflow-coverage/tf-keras-coverage
-# Updates made to test suite names should also be reflected here:
-# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "INTREE_TENSORFLOW_PY_DEPS",
- "NUMPY_DEPS",
- "iree_py_binary",
-)
-load(
- "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
- "iree_e2e_cartesian_product_test_suite",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-[
- iree_py_binary(
- name = src.replace(".py", "_manual"),
- srcs = [src],
- main = src,
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for src in glob(
- ["*_test.py"],
- exclude = ["keyword_spotting_streaming_test.py"],
- )
-]
-
-# These layers were selected by:
-# 1. Getting all subclasses of `tf.keras.layers.Layer`
-# 2. Removing deperacated layers based on the tf.keras docs
-# 3. Removing irrelevant layers
-# 4. Removing layers that don't fit in the testing framework (Wrappers, DenseFeatures, ...)
-LAYERS = [
- "Activation",
- "ActivityRegularization",
- "Add",
- "AdditiveAttention",
- "AlphaDropout",
- "Attention",
- "Average",
- "AveragePooling1D",
- "AveragePooling2D",
- "AveragePooling3D",
- "BatchNormalization",
- "Concatenate",
- "Conv1D",
- "Conv1DTranspose",
- "Conv2D",
- "Conv2DTranspose",
- "Conv3D",
- "Conv3DTranspose",
- # "ConvLSTM2D", # TODO(meadowlark): Debug flakiness.
- "Cropping1D",
- "Cropping2D",
- "Cropping3D",
- "Dense",
- "DepthwiseConv2D",
- "Dot",
- "Dropout",
- "ELU",
- "Embedding",
- "Flatten",
- "GRU",
- "GaussianDropout",
- "GaussianNoise",
- "GlobalAveragePooling1D",
- "GlobalAveragePooling2D",
- "GlobalAveragePooling3D",
- "GlobalMaxPool1D",
- "GlobalMaxPool2D",
- "GlobalMaxPool3D",
- "InputLayer",
- "LSTM",
- "Lambda",
- "LayerNormalization",
- "LeakyReLU",
- "LocallyConnected1D",
- "LocallyConnected2D",
- "Masking",
- "MaxPool1D",
- "MaxPool2D",
- "MaxPool3D",
- "Maximum",
- "Minimum",
- "MultiHeadAttention",
- "Multiply",
- "PReLU",
- "Permute",
- "ReLU",
- "RepeatVector",
- "Reshape",
- "SeparableConv1D",
- "SeparableConv2D",
- # "SimpleRNN", # TODO(meadowlark): Debug flakiness.
- "Softmax",
- "SpatialDropout1D",
- "SpatialDropout2D",
- "SpatialDropout3D",
- "Subtract",
- "ThresholdedReLU",
- "UpSampling1D",
- "UpSampling2D",
- "UpSampling3D",
- "ZeroPadding1D",
- "ZeroPadding2D",
- "ZeroPadding3D",
-]
-
-FAILING_STATIC = [
- {
- # Failing on TFLite
- "layer": [
- "AveragePooling3D",
- "Conv3DTranspose",
- "Conv3D",
- "ConvLSTM2D",
- "LayerNormalization",
- "Softmax",
- "MaxPool3D",
- "ZeroPadding3D",
- ],
- "target_backends": "tflite",
- },
- {
- # Failing on IREE
- "layer": [
- "ConvLSTM2D",
- "GRU",
- "LSTM", # Failing unless 'return_sequences = True'
- "LayerNormalization",
- "LeakyReLU",
- "LocallyConnected2D", # TODO(#4065): VMLA raises INVALID_ARGUMENT errors after DeviceQueue failure.
- "MultiHeadAttention",
- "UpSampling2D",
- ],
- "target_backends": [
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- {
- # Failing on VMLA
- "layer": [
- "Conv3DTranspose",
- "Conv3D",
- ],
- "target_backends": "iree_vmla",
- },
- {
- # Failing on LLVM and Vulkan
- "layer": [
- "Lambda",
- "Masking",
- "MaxPool1D",
- "MaxPool2D",
- "MaxPool3D",
- ],
- "target_backends": [
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- {
- # Failing on Vulkan
- "layer": [
- "Attention",
- "AdditiveAttention",
- "AveragePooling1D",
- "AveragePooling2D",
- "AveragePooling3D",
- "ThresholdedReLU",
- ],
- "target_backends": "iree_vulkan",
- },
-]
-
-iree_e2e_cartesian_product_test_suite(
- name = "layers_tests",
- srcs = ["layers_test.py"],
- failing_configurations = FAILING_STATIC,
- flags_to_values = {
- "reference_backend": "tf",
- "layer": LAYERS,
- "dynamic_dims": False,
- "training": False,
- "test_default_kwargs_only": True,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "layers_test.py",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# A list of all layers with non-default api tests can be generated by running:
-# bazel run integrations/tensorflow/e2e/keras/layers:layers_test_manual -- \
-# --list_layers_with_full_api_tests
-LAYERS_WITH_FULL_API_TESTS = [
- "ActivityRegularization",
- "AdditiveAttention",
- "Attention",
- "AveragePooling1D",
- "AveragePooling2D",
- "AveragePooling3D",
- "BatchNormalization",
- "Concatenate",
- "Conv1D",
- "Conv1DTranspose",
- "Conv2D",
- "Conv2DTranspose",
- "Conv3D",
- "Conv3DTranspose",
- # "ConvLSTM2D", # TODO(meadowlark): Debug flakiness.
- "Cropping1D",
- "Cropping2D",
- "Cropping3D",
- "DepthwiseConv2D",
- "GRU",
- "LSTM",
- "LocallyConnected1D",
- "LocallyConnected2D",
- "MaxPool1D",
- "MaxPool2D",
- "MaxPool3D",
- "SeparableConv1D",
- "SeparableConv2D",
- "SimpleRNN",
- # "SimpleRNN", # TODO(meadowlark): Debug flakiness.
-]
-
-FAILING_FULL_API = [
- {
- # Failing on TFLite
- "layer": [
- "AveragePooling3D",
- "Conv2DTranspose",
- "Conv3D",
- "Conv3DTranspose",
- "ConvLSTM2D",
- "DepthwiseConv2D",
- "GRU",
- "LocallyConnected1D",
- "LocallyConnected2D",
- "LSTM",
- "MaxPool1D",
- "MaxPool3D",
- "SeparableConv1D", # Failing on Kokoro.
- "SeparableConv2D",
- "SimpleRNN",
- ],
- "target_backends": "tflite",
- },
- {
- # Failing on IREE
- "layer": [
- "Conv2DTranspose",
- "Conv3DTranspose",
- "ConvLSTM2D",
- "GRU",
- "LocallyConnected1D",
- "LocallyConnected2D",
- "LSTM",
- "SimpleRNN",
- ],
- "target_backends": [
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- {
- "layer": "Conv3D",
- "target_backends": "iree_vmla",
- },
- {
- # Failing on LLVM and Vulakn
- "layer": [
- "AdditiveAttention",
- "Attention",
- "AveragePooling1D",
- "AveragePooling2D",
- "AveragePooling3D",
- "Conv1DTranspose",
- "MaxPool1D",
- "MaxPool2D",
- "MaxPool3D",
- ],
- "target_backends": [
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
-]
-
-iree_e2e_cartesian_product_test_suite(
- name = "layers_full_api_tests",
- srcs = ["layers_test.py"],
- failing_configurations = FAILING_FULL_API,
- flags_to_values = {
- "reference_backend": "tf",
- "layer": LAYERS_WITH_FULL_API_TESTS,
- "dynamic_dims": False,
- "training": False,
- "test_default_kwargs_only": False,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "layers_test.py",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-FAILING_DYNAMIC = [
- {
- # TFLite does not support dynamic shapes.
- "target_backends": "tflite",
- },
- {
- # Failing on IREE
- "layer": [
- "AdditiveAttention",
- "AveragePooling1D",
- "AveragePooling2D",
- "AveragePooling3D",
- "BatchNormalization",
- "Concatenate",
- "Conv1D",
- "Conv1DTranspose",
- "Conv2D",
- "Conv2DTranspose",
- "Conv3D",
- "Conv3DTranspose",
- "ConvLSTM2D",
- "Cropping1D",
- "Cropping2D",
- "Cropping3D",
- "Dense",
- "DepthwiseConv2D",
- "Dot",
- "ELU",
- "Flatten",
- "GRU",
- "LayerNormalization",
- "LeakyReLU",
- "LocallyConnected1D",
- "LocallyConnected2D",
- "LSTM", # TODO(silvasean): Get this test working on IREE.
- "Masking",
- "MaxPool1D",
- "MaxPool2D",
- "MaxPool3D",
- "MultiHeadAttention",
- "RepeatVector",
- "Reshape",
- "SeparableConv1D",
- "SeparableConv2D",
- "SimpleRNN",
- "ThresholdedReLU",
- "UpSampling1D",
- "UpSampling2D",
- "UpSampling3D",
- "ZeroPadding1D",
- "ZeroPadding2D",
- "ZeroPadding3D",
- ],
- "target_backends": [
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- {
- # Failing on LLVM and Vulkan
- "layer": [
- "Activation",
- "Add",
- "Attention",
- "Average",
- "GlobalAveragePooling1D",
- "GlobalAveragePooling2D",
- "GlobalAveragePooling3D",
- "Lambda",
- "Maximum",
- "Minimum",
- "Multiply",
- "PReLU",
- "ReLU",
- "Softmax",
- "Subtract",
- ],
- "target_backends": [
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- {
- # Failing on Vulkan
- "layer": "Embedding",
- "target_backends": "iree_vulkan",
- },
-]
-
-iree_e2e_cartesian_product_test_suite(
- name = "layers_dynamic_dims_tests",
- srcs = ["layers_test.py"],
- failing_configurations = FAILING_DYNAMIC,
- flags_to_values = {
- "reference_backend": "tf",
- "layer": LAYERS,
- "dynamic_dims": True,
- "training": False,
- "test_default_kwargs_only": True,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "layers_test.py",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# Layers that mention a training kwarg in their doc.
-LAYERS_WITH_TRAINING_BEHAVIOR = [
- "AdditiveAttention",
- "AlphaDropout",
- "Attention",
- "BatchNormalization",
- # "ConvLSTM2D", # TODO(meadowlark): Debug flakiness.
- "Dropout",
- "GRU",
- "GaussianDropout",
- "GaussianNoise",
- "LSTM",
- "MultiHeadAttention",
- # "SimpleRNN", # TODO(meadowlark): Debug flakiness.
- "SpatialDropout1D",
- "SpatialDropout2D",
- "SpatialDropout3D",
-]
-
-FAILING_TRAINING = [
- {
- # Failing on TFLite:
- "layer": [
- "AlphaDropout",
- "BatchNormalization",
- "ConvLSTM2D",
- "GaussianDropout",
- "GaussianNoise",
- "GRU",
- "LSTM",
- "SimpleRNN",
- ],
- "target_backends": "tflite",
- },
- {
- # Failing on IREE
- "layer": [
- "AdditiveAttention",
- "AlphaDropout",
- "Attention",
- "BatchNormalization",
- "ConvLSTM2D",
- "Dropout",
- "GaussianDropout",
- "GaussianNoise",
- "GRU",
- "LSTM",
- "MultiHeadAttention",
- "SimpleRNN",
- "SpatialDropout1D",
- "SpatialDropout2D",
- "SpatialDropout3D",
- ],
- "target_backends": [
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
-]
-
-iree_e2e_cartesian_product_test_suite(
- name = "layers_training_tests",
- srcs = ["layers_test.py"],
- failing_configurations = FAILING_TRAINING,
- flags_to_values = {
- "reference_backend": "tf",
- "layer": LAYERS_WITH_TRAINING_BEHAVIOR,
- "dynamic_dims": False,
- "training": True,
- "test_default_kwargs_only": True,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "layers_test.py",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
diff --git a/integrations/tensorflow/e2e/keras/train/BUILD b/integrations/tensorflow/e2e/keras/train/BUILD
deleted file mode 100644
index 7dc0bd6..0000000
--- a/integrations/tensorflow/e2e/keras/train/BUILD
+++ /dev/null
@@ -1,128 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "INTREE_TENSORFLOW_PY_DEPS",
- "NUMPY_DEPS",
- "iree_py_binary",
-)
-load(
- "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
- "iree_e2e_cartesian_product_test_suite",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-[
- iree_py_binary(
- name = src.replace(".py", "_manual"),
- srcs = [src],
- main = src,
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for src in glob(["*_test.py"])
-]
-
-iree_e2e_cartesian_product_test_suite(
- name = "classification_training_tests",
- srcs = ["classification_training_test.py"],
- failing_configurations = [
- {
- # TFLite doesn't support training.
- "target_backends": "tflite",
- },
- {
- "target_backends": [
- "tflite",
- "iree_vmla", # TODO(b/157581521)
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "optimizer": [
- "adadelta",
- "adagrad",
- "adam",
- "adamax",
- "ftrl",
- "nadam",
- "rmsprop",
- "sgd",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "classification_training_test.py",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_e2e_cartesian_product_test_suite(
- name = "regression_training_tests",
- srcs = ["regression_training_test.py"],
- failing_configurations = [
- {
- # TFLite doesn't support training.
- "target_backends": "tflite",
- },
- {
- "target_backends": [
- "iree_vmla", # TODO(b/157581521)
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "optimizer": [
- "adadelta",
- "adagrad",
- "adam",
- "adamax",
- "ftrl",
- "nadam",
- "rmsprop",
- "sgd",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "regression_training_test.py",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
diff --git a/integrations/tensorflow/e2e/math/BUILD b/integrations/tensorflow/e2e/math/BUILD
deleted file mode 100644
index 4440b70..0000000
--- a/integrations/tensorflow/e2e/math/BUILD
+++ /dev/null
@@ -1,995 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. Coverage tables generated from this file can be viewed here:
-# https://google.github.io/iree/tensorflow-coverage/tf-base-coverage
-# Updates made to test suite names should also be reflected here:
-# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "INTREE_TENSORFLOW_PY_DEPS",
- "NUMPY_DEPS",
- "iree_py_binary",
- "iree_py_test",
-)
-load(
- "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
- "set_difference",
-)
-load(
- "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
- "iree_e2e_cartesian_product_test_suite",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-[
- iree_py_binary(
- name = src.replace(".py", "_manual"),
- srcs = [src],
- main = src,
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for src in glob(
- ["*_test.py"],
- exclude = ["keyword_spotting_streaming_test.py"],
- )
-]
-
-# These functions were selected using all of the funcions in the tf.math docs:
-# https://www.tensorflow.org/api_docs/python/tf/math
-TF_MATH_FUNCTIONS = [
- "abs",
- "accumulate_n",
- "acos",
- "acosh",
- "add",
- "add_n",
- "angle",
- "argmax",
- "argmin",
- "asin",
- "asinh",
- "atan",
- "atan2",
- "atanh",
- "bessel_i0",
- "bessel_i0e",
- "bessel_i1",
- "bessel_i1e",
- "betainc",
- "bincount",
- "ceil",
- "confusion_matrix",
- "cos",
- "cosh",
- "count_nonzero",
- "cumprod",
- "cumsum",
- "cumulative_logsumexp",
- "digamma",
- "divide",
- "divide_no_nan",
- "equal",
- "erf",
- "erfc",
- "erfinv",
- "exp",
- "expm1",
- "floor",
- "floordiv",
- "floormod",
- "greater",
- "greater_equal",
- "igamma",
- "igammac",
- "imag",
- "in_top_k",
- "invert_permutation",
- "is_finite",
- "is_inf",
- "is_nan",
- "is_non_decreasing",
- "is_strictly_increasing",
- "lbeta",
- "less",
- "less_equal",
- "lgamma",
- "log",
- "log1p",
- "log_sigmoid",
- "log_softmax",
- "logical_and",
- "logical_not",
- "logical_or",
- "logical_xor",
- "maximum",
- "minimum",
- "mod",
- "multiply",
- "multiply_no_nan",
- "ndtri",
- "negative",
- "nextafter",
- "not_equal",
- "polygamma",
- "polyval",
- "pow",
- "real",
- "reciprocal",
- "reciprocal_no_nan",
- "reduce_all",
- "reduce_any",
- "reduce_euclidean_norm",
- "reduce_logsumexp",
- "reduce_max",
- "reduce_mean",
- "reduce_min",
- "reduce_prod",
- "reduce_std",
- "reduce_sum",
- "reduce_variance",
- "rint",
- "round",
- "rsqrt",
- "scalar_mul",
- "segment_max",
- "segment_mean",
- "segment_min",
- "segment_prod",
- "segment_sum",
- "sigmoid",
- "sign",
- "sin",
- "sinh",
- "sobol_sample",
- "softmax",
- "softplus",
- "softsign",
- "sqrt",
- "square",
- "squared_difference",
- "subtract",
- "tan",
- "tanh",
- # "top_k", # TODO(meadowlark): Enable once list outputs are supported.
- "truediv",
- "unsorted_segment_max",
- "unsorted_segment_mean",
- "unsorted_segment_min",
- "unsorted_segment_prod",
- "unsorted_segment_sqrt_n",
- "unsorted_segment_sum",
- "xdivy",
- "xlog1py",
- "xlogy",
- "zero_fraction",
- "zeta",
-]
-
-# keep sorted
-TFLITE_FAILING = [
- "abs", # Failing for integer inputs.
- "acos",
- "acosh",
- "asin",
- "asinh",
- "atan",
- "atan2",
- "atanh",
- "bessel_i0",
- "bessel_i0e",
- "bessel_i1",
- "bessel_i1e",
- "betainc",
- "bincount",
- "confusion_matrix",
- "conj",
- "cosh",
- "cumprod",
- "cumulative_logsumexp",
- "digamma",
- "divide", # Failing for integer inputs.
- "erf",
- "erfc",
- "erfinv",
- "expm1",
- "igamma",
- "igammac",
- "in_top_k",
- "invert_permutation",
- "is_finite",
- "is_non_decreasing",
- "is_strictly_increasing",
- "l2_normalize",
- "lbeta",
- "lgamma",
- "log1p",
- "log_sigmoid",
- "ndtri",
- "nextafter",
- "polygamma",
- "polyval",
- "pow", # Failing for integer inputs.
- "reduce_all",
- "reduce_euclidean_norm",
- "reduce_logsumexp",
- "reduce_mean",
- "reduce_std",
- "reduce_variance",
- "rint",
- "segment_max",
- "segment_mean",
- "segment_min",
- "segment_prod",
- "sign",
- "sinh",
- "sobol_sample",
- "softmax",
- "softplus",
- "softsign",
- "tan",
- "unsorted_segment_max",
- "unsorted_segment_mean",
- "unsorted_segment_min",
- "unsorted_segment_prod",
- "unsorted_segment_sqrt_n",
- "unsorted_segment_sum",
- "xdivy",
- "xlog1py",
- "xlogy",
- "zeta",
-]
-
-# Note: The VMLA_FAILING_DYNAMIC specification extends this list. Newly-passing
-# functions removed from this list may need to be added to VMLA_FAILING_DYNAMIC.
-# keep sorted
-VMLA_FAILING = [
- "acosh",
- "argmax",
- "argmin",
- "asin",
- "asinh",
- "atan2",
- "atanh",
- "bessel_i0",
- "bessel_i0e",
- "bessel_i1",
- "bessel_i1e",
- "betainc",
- "bincount",
- "confusion_matrix",
- "cosh",
- "count_nonzero",
- "cumprod",
- "cumulative_logsumexp",
- "digamma",
- "divide", # Failing for integer inputs because iree doesn't output 'f64'.
- "erf",
- "erfc",
- "erfinv",
- "expm1",
- "igamma",
- "igammac",
- "in_top_k",
- "invert_permutation",
- "is_nan", # TODO(#4061): tf.math.is_nan miscompiles with static shapes.
- "is_non_decreasing",
- "is_strictly_increasing",
- "ndtri",
- "nextafter",
- "polygamma",
- "pow", # Failing for integer inputs.
- "reduce_euclidean_norm",
- "reduce_prod",
- "rint",
- "segment_max",
- "segment_mean",
- "segment_min",
- "segment_prod",
- "segment_sum",
- "sign",
- "sobol_sample",
- "softsign",
- "unsorted_segment_max",
- "unsorted_segment_mean",
- "unsorted_segment_min",
- "unsorted_segment_prod",
- "unsorted_segment_sqrt_n",
- "unsorted_segment_sum",
- "xdivy",
- "xlog1py",
- "xlogy",
- "zeta",
-]
-
-# Note: The LLVM_FAILING_DYNAMIC specification extends this list. Newly-passing
-# functions removed from this list may need to be added to LLVM_FAILING_DYNAMIC.
-# keep sorted
-LLVM_FAILING = [
- "acos",
- "acosh",
- "argmax",
- "argmin",
- "asin",
- "asinh",
- "atan",
- "atan2",
- "atanh",
- "bessel_i0",
- "bessel_i0e",
- "bessel_i1",
- "bessel_i1e",
- "betainc",
- "bincount",
- "confusion_matrix",
- "cosh",
- "count_nonzero",
- "cumprod",
- "cumulative_logsumexp",
- "digamma",
- "divide", # Failing for integer inputs because iree doesn't output 'f64'.
- "erf",
- "erfc",
- "erfinv",
- "expm1",
- "igamma",
- "igammac",
- "in_top_k",
- "invert_permutation",
- "is_nan",
- "is_non_decreasing",
- "is_strictly_increasing",
- "l2_normalize",
- "logical_or",
- "logical_xor",
- "ndtri",
- "nextafter",
- "polygamma",
- "pow",
- "reduce_all",
- "reduce_any",
- "reduce_euclidean_norm",
- "reduce_logsumexp",
- "reduce_max",
- "reduce_mean",
- "reduce_min",
- "reduce_prod",
- "reduce_std",
- "reduce_sum",
- "reduce_variance",
- "rint",
- "segment_max",
- "segment_mean",
- "segment_min",
- "segment_prod",
- "segment_sum",
- "sign",
- "sobol_sample",
- "softsign",
- "unsorted_segment_max",
- "unsorted_segment_mean",
- "unsorted_segment_min",
- "unsorted_segment_prod",
- "unsorted_segment_sqrt_n",
- "unsorted_segment_sum",
- "xdivy",
- "xlog1py",
- "xlogy",
- "zeta",
-]
-
-# Note: The VULKAN_FAILING_DYNAMIC specification extends this list.
-# Newly-passing functions removed from this list may need to be added to
-# VULKAN_FAILING_DYNAMIC.
-# keep sorted
-VULKAN_FAILING = [
- "acos",
- "acosh",
- "argmax",
- "argmin",
- "asin",
- "asinh",
- "atan",
- "atan2",
- "atanh",
- "bessel_i0",
- "bessel_i0e",
- "bessel_i1",
- "bessel_i1e",
- "betainc",
- "bincount",
- "confusion_matrix",
- "cosh",
- "count_nonzero",
- "cumprod",
- "cumsum",
- "cumulative_logsumexp",
- "digamma",
- "divide", # Failing for integer inputs because iree doesn't output 'f64'.
- "erf",
- "erfc",
- "erfinv",
- "expm1",
- "igamma",
- "igammac",
- "in_top_k",
- "invert_permutation",
- "is_nan",
- "is_non_decreasing",
- "is_strictly_increasing",
- "l2_normalize",
- "logical_and",
- "logical_not",
- "logical_or",
- "logical_xor",
- "mod", # Passes with swiftshader, but fails on Turing GPU
- "ndtri",
- "nextafter",
- "polygamma",
- "pow",
- "reduce_all",
- "reduce_any",
- "reduce_euclidean_norm",
- "reduce_logsumexp",
- "reduce_max",
- "reduce_mean",
- "reduce_min",
- "reduce_prod",
- "reduce_std",
- "reduce_sum",
- "reduce_variance",
- "rint",
- "segment_max",
- "segment_mean",
- "segment_min",
- "segment_prod",
- "segment_sum",
- "sign",
- "sobol_sample",
- "softsign",
- "unsorted_segment_max",
- "unsorted_segment_mean",
- "unsorted_segment_min",
- "unsorted_segment_prod",
- "unsorted_segment_sqrt_n",
- "unsorted_segment_sum",
- "xdivy",
- "xlog1py",
- "xlogy",
- "zeta",
-]
-
-# ---- INDIVIDUAL STATIC TESTS ----------------------------------------------- #
-
-# These tests allow us to generate coverage tables and give a finer-grained view
-# of the coverage, but are very slow due to bazel overhead, so they are not
-# run on the internal or OSS CI.
-iree_e2e_cartesian_product_test_suite(
- name = "math_tests",
- srcs = ["math_test.py"],
- failing_configurations = [
- {
- # Failing on TFLite.
- "functions": TFLITE_FAILING,
- "target_backends": "tflite",
- },
- {
- # Failing on vmla.
- "functions": VMLA_FAILING,
- "target_backends": "iree_vmla",
- },
- {
- # Failing on llvm.
- "functions": LLVM_FAILING,
- "target_backends": "iree_llvmjit",
- },
- {
- # Failing on vulkan.
- "functions": VULKAN_FAILING,
- "target_backends": "iree_vulkan",
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "functions": TF_MATH_FUNCTIONS,
- "dynamic_dims": False,
- "test_complex": False,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "math_test.py",
- tags = [
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# ---- MULTIPLE STATIC TESTS ------------------------------------------------ #
-
-# These tests compile all functions in tf.math at once for testing so that
-# we can run them on the CI with 5 additional targets instead of 640. The tests
-# are run sharded such that about 5 functions run per shard. This is a
-# reasonable tradeoff between shard startup overhead and critical path test
-# latency.
-
-# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
-VMLA_FAILING_MULTIPLE = VMLA_FAILING + ["multiply"]
-
-# TODO(#3810) Including 'square' causes error: Recieved signal 11.
-# @unused
-LLVM_FAILING_MULTIPLE = LLVM_FAILING + ["square"]
-
-# TODO(#3810) Including 'square' causes error: Recieved signal 11.
-VULKAN_FAILING_MULTIPLE = VULKAN_FAILING + ["square"]
-
-[
- iree_py_test(
- name = "math_tests_multiple__{}".format(target_backend),
- srcs = ["math_test.py"],
- args = [
- "--reference_backend=tf",
- "--target_backends={}".format(target_backend),
- "--functions={}".format(",".join(functions)),
- "--dynamic_dims=False",
- "--test_complex=False",
- ],
- main = "math_test.py",
- python_version = "PY3",
- shard_count = len(functions) // 5,
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for target_backend, functions in dict(
- # TODO(#2673): enable LLVM AOT for these tests. JIT is deprecated.
- # iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_MULTIPLE),
- iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_MULTIPLE),
- iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_MULTIPLE),
- tf = TF_MATH_FUNCTIONS,
- tflite = set_difference(TF_MATH_FUNCTIONS, TFLITE_FAILING),
- ).items()
-]
-
-# ---- INDIVIDUAL DYNAMIC TESTS ---------------------------------------------- #
-
-# keep sorted
-VMLA_FAILING_DYNAMIC = VMLA_FAILING + [
- "angle",
- "cumsum",
- "divide_no_nan",
- "floordiv", # TODO(#4065): VMLA raises INVALID_ARGUMENT errors after DeviceQueue failure.
- "floormod", # TODO(#4065): VMLA raises INVALID_ARGUMENT errors after DeviceQueue failure.
- "lbeta",
- "lgamma",
- "log_sigmoid",
- "log1p",
- "logical_and",
- "logical_not",
- "logical_or",
- "logical_xor",
- "mod", # TODO(#4065): VMLA raises INVALID_ARGUMENT errors after DeviceQueue failure.
- "multiply_no_nan",
- "round",
- "reciprocal_no_nan",
- "reduce_all",
- "reduce_any",
- "reduce_logsumexp",
- "reduce_max",
- "reduce_min",
- "reduce_sum",
- "reduce_mean",
- "reduce_std",
- "reduce_variance",
- "softplus",
- "zero_fraction",
-]
-
-# TODO(#4061): tf.math.is_nan miscompiles with static shapes.
-VMLA_FAILING_DYNAMIC.remove("is_nan")
-
-# keep sorted
-LLVM_FAILING_DYNAMIC = LLVM_FAILING + [
- "accumulate_n",
- "add",
- "add_n",
- "angle",
- "cumsum",
- "divide",
- "divide_no_nan",
- "equal",
- "floordiv",
- "floormod",
- "greater",
- "greater_equal",
- "is_finite",
- "is_inf",
- "lbeta",
- "less",
- "less_equal",
- "lgamma",
- "log_sigmoid",
- "log_softmax",
- "log1p",
- "logical_and",
- "logical_not",
- "maximum",
- "minimum",
- "mod",
- "multiply",
- "multiply_no_nan",
- "not_equal",
- "polyval",
- "reciprocal",
- "reciprocal_no_nan",
- "reduce_mean",
- "scalar_mul",
- "sigmoid",
- "sinh",
- "softmax",
- "softplus",
- "square",
- "squared_difference",
- "subtract",
- "round",
- "tan",
- "truediv",
- "zero_fraction",
-]
-
-# keep sorted
-VULKAN_FAILING_DYNAMIC = VULKAN_FAILING + [
- "abs",
- "accumulate_n",
- "add",
- "add_n",
- "angle",
- "ceil",
- "cos",
- "divide",
- "divide_no_nan",
- "equal",
- "exp",
- "floor",
- "floordiv",
- "floormod",
- "greater",
- "greater_equal",
- "imag",
- "is_finite",
- "is_inf",
- "lbeta",
- "less",
- "round",
- "less_equal",
- "lgamma",
- "log",
- "log_sigmoid",
- "log_softmax",
- "log1p",
- "maximum",
- "minimum",
- "mod",
- "multiply",
- "multiply_no_nan",
- "negative",
- "not_equal",
- "polyval",
- "reciprocal",
- "reciprocal_no_nan",
- "reduce_max",
- "reduce_mean",
- "reduce_min",
- "reduce_sum",
- "rsqrt",
- "scalar_mul",
- "sigmoid",
- "sin",
- "sinh",
- "softmax",
- "softplus",
- "sqrt",
- "square",
- "squared_difference",
- "subtract",
- "tan",
- "tanh",
- "truediv",
- "zero_fraction",
-]
-
-# These tests allow us to generate coverage tables and give a finer-grained view
-# of the coverage, but are very slow due to bazel overhead, so they are not
-# run on the internal or OSS CI.
-iree_e2e_cartesian_product_test_suite(
- name = "math_dynamic_dims_tests",
- srcs = ["math_test.py"],
- failing_configurations = [
- {
- # TFLite does not support dynamic shapes.
- "target_backends": "tflite",
- },
- {
- # Failing on vmla.
- "functions": VMLA_FAILING_DYNAMIC,
- "target_backends": "iree_vmla",
- },
- {
- # Failing on llvm.
- "functions": LLVM_FAILING_DYNAMIC,
- "target_backends": "iree_llvmjit",
- },
- {
- # Failing on vulkan.
- "functions": VULKAN_FAILING_DYNAMIC,
- "target_backends": "iree_vulkan",
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "functions": TF_MATH_FUNCTIONS,
- "dynamic_dims": True,
- "test_complex": False,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "math_test.py",
- tags = [
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# ---- MULTIPLE DYNAMIC TESTS ----------------------------------------------- #
-
-# These tests compile all functions in tf.math at once for testing so that
-# we can run them on the CI with 4 additional targets instead of 512. The tests
-# are run sharded such that about 5 functions run per shard. This is a
-# reasonable tradeoff between shard startup overhead and critical path test
-# latency.
-
-# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
-VMLA_FAILING_DYNAMIC_MULTIPLE = VMLA_FAILING_DYNAMIC + ["multiply"]
-
-[
- iree_py_test(
- name = "math_dynamic_dims_tests_multiple__{}".format(target_backend),
- srcs = ["math_test.py"],
- args = [
- "--reference_backend=tf",
- "--target_backends={}".format(target_backend),
- "--functions={}".format(",".join(functions)),
- "--dynamic_dims=True",
- "--test_complex=False",
- ],
- main = "math_test.py",
- python_version = "PY3",
- shard_count = len(functions) // 5,
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for target_backend, functions in dict(
- # TODO(#2673): enable LLVM AOT for these tests. JIT is deprecated.
- # iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_DYNAMIC),
- iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_DYNAMIC_MULTIPLE),
- iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_DYNAMIC),
- tf = TF_MATH_FUNCTIONS,
- ).items()
-]
-
-# ---- INDIVIDUAL COMPLEX TESTS ---------------------------------------------- #
-
-# This list was generated by running:
-# bazel run integrations/tensorflow/e2e/math:math_test_manual -- --list_functions_with_complex_tests
-COMPLEX_FUNCTIONS = [
- "abs",
- "add",
- "angle",
- "asinh",
- "atanh",
- "conj",
- "cos",
- "cosh",
- "count_nonzero",
- "cumprod",
- "cumsum",
- "divide",
- "divide_no_nan",
- "exp",
- "expm1",
- "imag",
- "l2_normalize",
- "log",
- "log1p",
- "multiply",
- "multiply_no_nan",
- "negative",
- "pow",
- "real",
- "reciprocal",
- "reciprocal_no_nan",
- "reduce_euclidean_norm",
- "reduce_std",
- "reduce_variance",
- "rsqrt",
- "sigmoid",
- "sign",
- "sin",
- "sinh",
- "sqrt",
- "square",
- "squared_difference",
- "subtract",
- "tan",
- "tanh",
- "truediv",
- "xdivy",
- "xlog1py",
- "xlogy",
- "zero_fraction",
-]
-
-# keep sorted
-FAILING_COMPLEX = [
- "angle",
- "cos",
- "cumsum",
- "divide_no_nan",
- "log",
- "log1p",
- "multiply_no_nan",
- "negative",
- "reciprocal",
- "reciprocal_no_nan",
- "reduce_std",
- "reduce_variance",
- "rsqrt",
- "sigmoid",
- "sin",
- "sinh",
- "sqrt",
- "tan",
- "tanh",
- "zero_fraction",
-]
-
-VMLA_FAILING_COMPLEX = VMLA_FAILING + FAILING_COMPLEX
-
-LLVM_FAILING_COMPLEX = LLVM_FAILING + FAILING_COMPLEX
-
-VULKAN_FAILING_COMPLEX = VULKAN_FAILING + FAILING_COMPLEX
-
-# These tests allow us to generate coverage tables and give a finer-grained view
-# of the coverage, but are very slow due to bazel overhead, so they are not
-# run on the internal or OSS CI.
-iree_e2e_cartesian_product_test_suite(
- name = "math_complex_tests",
- srcs = ["math_test.py"],
- failing_configurations = [
- {
- # TFLite does not support complex numbers.
- "target_backends": "tflite",
- },
- {
- # Failing on vmla.
- "functions": VMLA_FAILING_COMPLEX,
- "target_backends": "iree_vmla",
- },
- {
- # Failing on llvm.
- "functions": LLVM_FAILING_COMPLEX,
- "target_backends": "iree_llvmjit",
- },
- {
- # Failing on vulkan.
- "functions": VULKAN_FAILING_COMPLEX,
- "target_backends": "iree_vulkan",
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "functions": COMPLEX_FUNCTIONS,
- "dynamic_dims": False,
- "test_complex": True,
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "math_test.py",
- tags = [
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# ---- MULTIPLE COMPLEX TESTS ----------------------------------------------- #
-
-# These tests compile all functions in tf.math at once for testing so that
-# we can run them on the CI with 4 additional targets instead of 512. The tests
-# are run sharded such that about 5 functions run per shard. This is a
-# reasonable tradeoff between shard startup overhead and critical path test
-# latency.
-
-# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
-VMLA_FAILING_COMPLEX_MULTIPLE = VMLA_FAILING_COMPLEX + ["multiply"]
-
-# TODO(#3810) Including 'square' causes error: Recieved signal 11.
-# @unused
-LLVM_FAILING_COMPLEX_MULTIPLE = LLVM_FAILING_COMPLEX + ["square"]
-
-# TODO(#3810) Including 'square' causes error: Recieved signal 11.
-VULKAN_FAILING_COMPLEX_MULTIPLE = VULKAN_FAILING_COMPLEX + ["square"]
-
-[
- iree_py_test(
- name = "math_complex_tests_multiple__{}".format(target_backend),
- srcs = ["math_test.py"],
- args = [
- "--reference_backend=tf",
- "--target_backends={}".format(target_backend),
- "--functions={}".format(",".join(functions)),
- "--dynamic_dims=False",
- "--test_complex=True",
- ],
- main = "math_test.py",
- python_version = "PY3",
- shard_count = len(functions) // 5,
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for target_backend, functions in dict(
- # TODO(#2673): enable LLVM AOT for these tests. JIT is deprecated.
- # iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_COMPLEX_MULTIPLE),
- iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_COMPLEX_MULTIPLE),
- iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_COMPLEX_MULTIPLE),
- tf = TF_MATH_FUNCTIONS,
- ).items()
-]
diff --git a/integrations/tensorflow/e2e/mobile_bert_squad_test.py b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
index 67278ae..d0d8726 100644
--- a/integrations/tensorflow/e2e/mobile_bert_squad_test.py
+++ b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
@@ -25,7 +25,6 @@
from absl import flags
import numpy as np
from pyiree.tf.support import tf_test_utils
-from pyiree.tf import compiler
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD
deleted file mode 100644
index f871fe2..0000000
--- a/integrations/tensorflow/e2e/slim_vision_models/BUILD
+++ /dev/null
@@ -1,176 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. Coverage tables generated from this file can be viewed here:
-# https://google.github.io/iree/tensorflow-coverage/vision-coverage
-# Updates made to test suite names should also be reflected here:
-# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
-
-load(
- "//bindings/python:build_defs.oss.bzl",
- "INTREE_TENSORFLOW_PY_DEPS",
- "INTREE_TF_HUB_DEPS",
- "NUMPY_DEPS",
- "iree_py_binary",
-)
-load(
- "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
- "iree_e2e_cartesian_product_test_suite",
-)
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-# Create binaries for all test srcs to allow them to be run manually.
-iree_py_binary(
- name = "slim_vision_model_test_manual",
- srcs = ["slim_vision_model_test.py"],
- args = ["--tf_hub_url=https://tfhub.dev/google/imagenet/"],
- main = "slim_vision_model_test.py",
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + INTREE_TF_HUB_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_e2e_cartesian_product_test_suite(
- name = "slim_vision_tests",
- size = "enormous",
- srcs = ["slim_vision_model_test.py"],
- failing_configurations = [
- {
- # SavedModelV2 (classification/4) not available.
- "model": "amoebanet_a_n18_f448",
- },
- {
- # Failing on vmla with negative inputs.
- "model": [
- "nasnet_large",
- "nasnet_mobile",
- ],
- "target_backends": "iree_vmla",
- },
- {
- # Failing llvmjit and vulkan:
- "model": [
- "nasnet_mobile",
- "nasnet_large",
- "pnasnet_large",
- "resnet_v2_50",
- "resnet_v2_101",
- "resnet_v2_152",
- ],
- "target_backends": [
- "iree_vulkan",
- ],
- },
- {
- # Failing vulkan:
- "model": [
- # [ERROR]: cannot separate Linalg/Parallel ops into multiple kernels
- "inception_v1",
- "inception_v2",
- "inception_v3",
- "inception_resnet_v2",
- ],
- "target_backends": [
- "iree_vulkan",
- ],
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "tf_hub_url": ["https://tfhub.dev/google/imagenet/"],
- "model": [
- "amoebanet_a_n18_f448",
- "inception_resnet_v2",
- "inception_v1",
- "inception_v2",
- "inception_v3",
- # MobileNetV1
- "mobilenet_v1_025_128",
- "mobilenet_v1_025_160",
- "mobilenet_v1_025_192",
- "mobilenet_v1_025_224",
- "mobilenet_v1_050_128",
- "mobilenet_v1_050_160",
- "mobilenet_v1_050_192",
- "mobilenet_v1_050_224",
- "mobilenet_v1_075_128",
- "mobilenet_v1_075_160",
- "mobilenet_v1_075_192",
- "mobilenet_v1_075_224",
- "mobilenet_v1_100_128",
- "mobilenet_v1_100_160",
- "mobilenet_v1_100_192",
- "mobilenet_v1_100_224",
- # MobileNetV2:
- "mobilenet_v2_035_96",
- "mobilenet_v2_035_128",
- "mobilenet_v2_035_160",
- "mobilenet_v2_035_192",
- "mobilenet_v2_035_224",
- "mobilenet_v2_050_96",
- "mobilenet_v2_050_128",
- "mobilenet_v2_050_160",
- "mobilenet_v2_050_192",
- "mobilenet_v2_050_224",
- "mobilenet_v2_075_96",
- "mobilenet_v2_075_128",
- "mobilenet_v2_075_160",
- "mobilenet_v2_075_192",
- "mobilenet_v2_075_224",
- "mobilenet_v2_100_96",
- "mobilenet_v2_100_128",
- "mobilenet_v2_100_160",
- "mobilenet_v2_100_192",
- "mobilenet_v2_100_224",
- "mobilenet_v2_130_224",
- "mobilenet_v2_140_224",
- "nasnet_mobile",
- "nasnet_large",
- "pnasnet_large",
- # ResNetV1
- "resnet_v1_50",
- "resnet_v1_101",
- "resnet_v1_152",
- # ResNetV2
- "resnet_v2_50",
- "resnet_v2_101",
- "resnet_v2_152",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_vulkan",
- ],
- },
- main = "slim_vision_model_test.py",
- tags = [
- "external",
- "guitar",
- "manual",
- "no-remote",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + INTREE_TF_HUB_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)