Wire saved_model import in properly. Refactors compiler bindings to preserve diagnostics. Adds ability to run pass pipelines. Adds e2e unit test. Updates colabs to demonstrate. Closes #106 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/106 from stellaraccident:tfxla2 2106756d76af81ec57f94acff37f5d4891d227b3 PiperOrigin-RevId: 277208555
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD index 6ff4b59..e6025d3 100644 --- a/bindings/python/pyiree/BUILD +++ b/bindings/python/pyiree/BUILD
@@ -67,36 +67,48 @@ cc_library( name = "base", + srcs = [ + "compiler.cc", + "hal.cc", + "initialize.cc", + "rt.cc", + "status_utils.cc", + "vm.cc", + ], hdrs = [ "binding.h", + "compiler.h", + "hal.h", + "initialize.h", + "rt.h", "status_utils.h", + "vm.h", ], copts = DEFAULT_COPTS, features = DEFAULT_FEATURES, deps = [ + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@iree_pybind11//:pybind11", + "@local_config_mlir//:IR", "//iree/base:api", "//iree/base:status", - ] + PYTHON_HEADERS_DEPS, + "//iree/base:init", + "//iree/hal:api", + "//iree/rt:api", + "//iree/schemas", + "//iree/vm:api", + "@llvm//:support", + "@local_config_mlir//:Parser", + "@local_config_mlir//:Pass", + "@iree_pybind11//:pybind11", + ] + COMPILER_DEPS + DRIVER_DEPS + PYTHON_HEADERS_DEPS, ) iree_py_extension( name = "binding", srcs = [ - "compiler.cc", - "compiler.h", - "hal.cc", - "hal.h", - "initialize.cc", - "initialize.h", "initialize_module.cc", - "rt.cc", - "rt.h", - "status_utils.cc", - "vm.cc", - "vm.h", ], copts = DEFAULT_COPTS, features = DEFAULT_FEATURES, @@ -104,21 +116,8 @@ win_def_file = "export.def", deps = [ ":base", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "//bindings/python/pyiree/tensorflow", - "//iree/base:api", - "//iree/base:init", - "//iree/hal:api", - "//iree/rt:api", - "//iree/schemas", - "//iree/vm:api", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Parser", - "@iree_pybind11//:pybind11", - ] + COMPILER_DEPS + DRIVER_DEPS + PYTHON_HEADERS_DEPS, + "//bindings/python/pyiree/tf_interop", + ], ) py_test(
diff --git a/bindings/python/pyiree/compiler.cc b/bindings/python/pyiree/compiler.cc index 238973b..7f207df 100644 --- a/bindings/python/pyiree/compiler.cc +++ b/bindings/python/pyiree/compiler.cc
@@ -23,9 +23,9 @@ #include "iree/schemas/module_def_generated.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" #include "mlir/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" namespace py = pybind11; @@ -60,33 +60,122 @@ } // namespace -std::shared_ptr<OpaqueBlob> CompileModuleFromAsm(const std::string& moduleAsm) { +CompilerContextBundle::CompilerContextBundle() { InitializeExtension({}); + // Setup a diagnostic handler. + mlir_context()->getDiagEngine().registerHandler( + [this](mlir::Diagnostic& d) { diagnostics_.push_back(std::move(d)); }); +} +CompilerContextBundle::~CompilerContextBundle() = default; - MLIRContext context; +std::string CompilerContextBundle::ConsumeDiagnosticsAsString() { + std::string s; + llvm::raw_string_ostream sout(s); + bool first = true; + for (auto& d : diagnostics_) { + if (!first) { + sout << "\n\n"; + } else { + first = false; + } + switch (d.getSeverity()) { + case DiagnosticSeverity::Note: + sout << "[NOTE]"; + break; + case DiagnosticSeverity::Warning: + sout << "[WARNING]"; + break; + case DiagnosticSeverity::Error: + sout << "[ERROR]"; + break; + case DiagnosticSeverity::Remark: + sout << "[REMARK]"; + break; + default: + sout << "[UNKNOWN]"; + } + // Message. + sout << ": " << d << "\n\t"; + + // Location. + d.getLocation().print(sout); + } + + diagnostics_.clear(); + return sout.str(); +} + +void CompilerContextBundle::ClearDiagnostics() { diagnostics_.clear(); } + +CompilerModuleBundle CompilerContextBundle::ParseAsm( + const std::string& asm_text) { // Arrange to get a view that includes a terminating null to avoid additional // copy. - const char* moduleAsmChars = moduleAsm.c_str(); - StringRef moduleAsmSr(moduleAsmChars, moduleAsm.size() + 1); + const char* asm_chars = asm_text.c_str(); + StringRef asm_sr(asm_chars, asm_text.size() + 1); - // TODO(laurenzo): This error handling is super hoaky. Hook into the MLIR - // error reporter and plumb through properly. - OwningModuleRef mlirModule = parseMLIRModuleFromString(moduleAsmSr, &context); - if (!mlirModule) { - throw std::runtime_error("Failed to parse MLIR asm"); + auto module_ref = parseMLIRModuleFromString(asm_sr, mlir_context()); + if (!module_ref) { + throw RaiseValueError("Failed to parse MLIR asm"); } + return CompilerModuleBundle(shared_from_this(), module_ref.release()); +} - auto moduleBlob = - mlir::iree_compiler::translateMlirToIreeSequencerModule(mlirModule.get()); - if (moduleBlob.empty()) { +std::string CompilerModuleBundle::ToAsm() { + // Print to asm. + std::string asm_output; + llvm::raw_string_ostream sout(asm_output); + OpPrintingFlags print_flags; + module_op().print(sout, print_flags); + return sout.str(); +} + +std::shared_ptr<OpaqueBlob> CompilerModuleBundle::CompileToSequencerBlob() { + auto module_blob = + mlir::iree_compiler::translateMlirToIreeSequencerModule(module_op()); + if (module_blob.empty()) { throw std::runtime_error("Failed to translate MLIR module"); } - return std::make_shared<OpaqueByteVectorBlob>(std::move(moduleBlob)); + return std::make_shared<OpaqueByteVectorBlob>(std::move(module_blob)); +} + +void CompilerModuleBundle::RunPassPipeline( + const std::vector<std::string>& pipelines) { + mlir::PassManager pm(context_->mlir_context()); + + // Parse the pass pipelines. + std::string error; + llvm::raw_string_ostream error_stream(error); + for (const auto& pipeline : pipelines) { + if (failed(mlir::parsePassPipeline(pipeline, pm, error_stream))) { + throw RaiseValueError(error_stream.str().c_str()); + } + } + + // Run them. + if (failed(pm.run(module_op_))) { + throw RaisePyError(PyExc_RuntimeError, + "Error running pass pipelines (see diagnostics)"); + } } void SetupCompilerBindings(pybind11::module m) { - m.def("compile_module_from_asm", CompileModuleFromAsm); + py::class_<CompilerContextBundle, std::shared_ptr<CompilerContextBundle>>( + m, "CompilerContext") + .def(py::init<>([]() { + // Need explicit make_shared to avoid UB with enable_shared_from_this. + return std::make_shared<CompilerContextBundle>(); + })) + .def("parse_asm", &CompilerContextBundle::ParseAsm) + .def("get_diagnostics", + &CompilerContextBundle::ConsumeDiagnosticsAsString) + .def("clear_diagnostics", &CompilerContextBundle::ClearDiagnostics); + py::class_<CompilerModuleBundle>(m, "CompilerModule") + .def("to_asm", &CompilerModuleBundle::ToAsm) + .def("compile_to_sequencer_blob", + &CompilerModuleBundle::CompileToSequencerBlob) + .def("run_pass_pipeline", &CompilerModuleBundle::RunPassPipeline); } } // namespace python
diff --git a/bindings/python/pyiree/compiler.h b/bindings/python/pyiree/compiler.h index 0bd6624..d40c057 100644 --- a/bindings/python/pyiree/compiler.h +++ b/bindings/python/pyiree/compiler.h
@@ -18,10 +18,58 @@ #include <string> #include "bindings/python/pyiree/binding.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" namespace iree { namespace python { +class CompilerContextBundle; +class CompilerModuleBundle; + +// Wraps an MLIR module and its producing context. +class CompilerModuleBundle { + public: + CompilerModuleBundle(std::shared_ptr<CompilerContextBundle> context, + mlir::ModuleOp module_op) + : context_(std::move(context)), module_op_(std::move(module_op)) {} + + mlir::ModuleOp& module_op() { return module_op_; } + std::string ToAsm(); + + // Runs one or more pass pipelines (as is mlir::parsePassPipeline). + void RunPassPipeline(const std::vector<std::string>& pipelines); + + // Compiles the MLIR module to an IREE sequencer module. + std::shared_ptr<OpaqueBlob> CompileToSequencerBlob(); + + private: + std::shared_ptr<CompilerContextBundle> context_; + mlir::ModuleOp module_op_; +}; + +// Bundle of MLIRContext related things that facilitates interop with +// Python. +class CompilerContextBundle + : public std::enable_shared_from_this<CompilerContextBundle> { + public: + CompilerContextBundle(); + ~CompilerContextBundle(); + + mlir::MLIRContext* mlir_context() { return &mlir_context_; } + + CompilerModuleBundle ParseAsm(const std::string& asm_text); + + // Consumes/clears diagnostics. + std::string ConsumeDiagnosticsAsString(); + void ClearDiagnostics(); + + private: + mlir::MLIRContext mlir_context_; + std::vector<mlir::Diagnostic> diagnostics_; +}; + void SetupCompilerBindings(pybind11::module m); } // namespace python
diff --git a/bindings/python/pyiree/compiler_test.py b/bindings/python/pyiree/compiler_test.py index 5cd3de0..b90b2c4 100644 --- a/bindings/python/pyiree/compiler_test.py +++ b/bindings/python/pyiree/compiler_test.py
@@ -23,16 +23,24 @@ class CompilerTest(absltest.TestCase): - def testModuleCompileAndIntrospectFromAsm(self): + def testParseError(self): + ctx = binding.compiler.CompilerContext() + with self.assertRaises(ValueError): + ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""") + diag_str = ctx.get_diagnostics() + self.assertRegex(diag_str, "custom op 'FOOBAR' is unknown") - m = binding.compiler.compile_module_from_asm(""" - func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> - attributes { iree.module.export } { - %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> - } - """) - self.assertTrue(m) + def testParseAndCompileToSequencer(self): + ctx = binding.compiler.CompilerContext() + input_module = ctx.parse_asm(""" + func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> + attributes { iree.module.export } { + %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> + } + """) + binary = input_module.compile_to_sequencer_blob() + self.assertTrue(binary) if __name__ == '__main__':
diff --git a/bindings/python/pyiree/initialize_module.cc b/bindings/python/pyiree/initialize_module.cc index 021f6a5..5a2ebdb 100644 --- a/bindings/python/pyiree/initialize_module.cc +++ b/bindings/python/pyiree/initialize_module.cc
@@ -18,7 +18,7 @@ #include "bindings/python/pyiree/initialize.h" #include "bindings/python/pyiree/rt.h" #include "bindings/python/pyiree/status_utils.h" -#include "bindings/python/pyiree/tensorflow/register_tensorflow.h" +#include "bindings/python/pyiree/tf_interop/register_tensorflow.h" #include "bindings/python/pyiree/vm.h" namespace iree {
diff --git a/bindings/python/pyiree/runtime_test.py b/bindings/python/pyiree/runtime_test.py index 3d7d480..a6d3db6 100644 --- a/bindings/python/pyiree/runtime_test.py +++ b/bindings/python/pyiree/runtime_test.py
@@ -22,14 +22,16 @@ def create_simple_mul_module(): - blob = binding.compiler.compile_module_from_asm(""" + ctx = binding.compiler.CompilerContext() + input_module = ctx.parse_asm(""" func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> attributes { iree.module.export } { %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } """) - m = binding.vm.create_module_from_blob(blob) + binary = input_module.compile_to_sequencer_blob() + m = binding.vm.create_module_from_blob(binary) return m
diff --git a/bindings/python/pyiree/status_utils.cc b/bindings/python/pyiree/status_utils.cc index 63f2131..9e3e91d 100644 --- a/bindings/python/pyiree/status_utils.cc +++ b/bindings/python/pyiree/status_utils.cc
@@ -63,8 +63,9 @@ return pybind11::error_already_set(); } -pybind11::error_already_set RaiseValueError(const char* message) { - PyErr_SetString(PyExc_ValueError, message); +pybind11::error_already_set RaisePyError(PyObject* exc_class, + const char* message) { + PyErr_SetString(exc_class, message); return pybind11::error_already_set(); }
diff --git a/bindings/python/pyiree/status_utils.h b/bindings/python/pyiree/status_utils.h index feda6cb..34d77a6 100644 --- a/bindings/python/pyiree/status_utils.h +++ b/bindings/python/pyiree/status_utils.h
@@ -32,8 +32,16 @@ // Raises a value error with the given message. // Correct usage: +// throw RaiseValueError(PyExc_ValueError, "Foobar'd"); +pybind11::error_already_set RaisePyError(PyObject* exc_class, + const char* message); + +// Raises a value error with the given message. +// Correct usage: // throw RaiseValueError("Foobar'd"); -pybind11::error_already_set RaiseValueError(const char* message); +inline pybind11::error_already_set RaiseValueError(const char* message) { + return RaisePyError(PyExc_ValueError, message); +} // Consumes a StatusOr<T>, returning an rvalue reference to the T if the // status is ok(). Otherwise, throws an exception.
diff --git a/bindings/python/pyiree/tensorflow/register_tensorflow.cc b/bindings/python/pyiree/tensorflow/register_tensorflow.cc deleted file mode 100644 index b2e09b0..0000000 --- a/bindings/python/pyiree/tensorflow/register_tensorflow.cc +++ /dev/null
@@ -1,62 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "bindings/python/pyiree/tensorflow/register_tensorflow.h" - -#include <string> -#include <vector> - -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" - -using namespace mlir; // NOLINT - -namespace iree { -namespace python { - -namespace { - -std::string ImportSavedModelToMlirAsm(const std::string& saved_model_dir, - std::vector<std::string> exported_names, - std::vector<std::string> tags) { - std::unordered_set<std::string> tags_set; - for (const auto& tag : tags) { - tags_set.insert(tag); - } - - MLIRContext context; - auto module = tensorflow::SavedModelToMlirImport( - saved_model_dir, tags_set, absl::MakeSpan(exported_names), &context); - - // Print to asm. - std::string asm_output; - llvm::raw_string_ostream sout(asm_output); - OpPrintingFlags print_flags; - module->print(sout, print_flags); - return sout.str(); -} - -} // namespace - -void SetupTensorFlowBindings(pybind11::module m) { - m.def("import_saved_model_to_mlir_asm", &ImportSavedModelToMlirAsm, - py::arg("saved_model_dir"), - py::arg("exported_names") = std::vector<std::string>(), - py::arg("tags") = std::vector<std::string>({std::string("serve")})); -} - -} // namespace python -} // namespace iree
diff --git a/bindings/python/pyiree/tensorflow/BUILD b/bindings/python/pyiree/tf_interop/BUILD similarity index 73% rename from bindings/python/pyiree/tensorflow/BUILD rename to bindings/python/pyiree/tf_interop/BUILD index 2838997..33a8f1f 100644 --- a/bindings/python/pyiree/tensorflow/BUILD +++ b/bindings/python/pyiree/tf_interop/BUILD
@@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +load( + "//iree:build_defs.bzl", + "INTREE_TENSORFLOW_PY_DEPS", +) + package( default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 @@ -25,28 +30,21 @@ "-use_header_modules", # Incompatible with exceptions builds. ] -config_setting( - name = "enable", - define_values = { - "iree_tensorflow": "true", - }, -) - cc_library( - name = "tensorflow", + name = "tf_interop", hdrs = [ "register_tensorflow.h", ], copts = DEFAULT_COPTS, defines = select({ - ":enable": [ + "//iree:enable_tensorflow": [ "IREE_TENSORFLOW_ENABLED", ], "//conditions:default": [], }), features = DEFAULT_FEATURES, deps = select({ - ":enable": [ + "//iree:enable_tensorflow": [ ":tensorflow_impl", ], "//conditions:default": [ @@ -80,6 +78,10 @@ "@org_tensorflow//tensorflow/core/kernels:state", ] +TF_XLA_PASS_DEPS = [ + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_tf", +] + cc_library( name = "tensorflow_impl", srcs = [ @@ -92,12 +94,14 @@ features = DEFAULT_FEATURES, visibility = ["//visibility:private"], deps = [ - "@local_config_mlir//:IR", - "@llvm//:support", - "@org_tensorflow//tensorflow/compiler/mlir/tensorflow", - "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//bindings/python/pyiree:base", - ] + SAVED_MODEL_TF_RUNTIME_DEPS + SAVED_MODEL_REQUIRED_KERNEL_DEPS, + "@local_config_mlir//:IR", + "@llvm//:support", + "@org_tensorflow//tensorflow/cc/saved_model:loader", + "@org_tensorflow//tensorflow/compiler/mlir/tensorflow", + "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//bindings/python/pyiree:base", + ] + SAVED_MODEL_TF_RUNTIME_DEPS + SAVED_MODEL_REQUIRED_KERNEL_DEPS + + TF_XLA_PASS_DEPS, ) cc_library( @@ -114,3 +118,16 @@ "//bindings/python/pyiree:base", ], ) + +py_test( + name = "saved_model_test", + srcs = ["saved_model_test.py"], + python_version = "PY3", + tags = [ + "oss_only", + ], + deps = [ + "//bindings/python/pyiree:binding", + "//bindings/python:pathsetup", # build_cleaner: keep + ] + INTREE_TENSORFLOW_PY_DEPS, +)
diff --git a/bindings/python/pyiree/tf_interop/register_tensorflow.cc b/bindings/python/pyiree/tf_interop/register_tensorflow.cc new file mode 100644 index 0000000..d0e85a7 --- /dev/null +++ b/bindings/python/pyiree/tf_interop/register_tensorflow.cc
@@ -0,0 +1,84 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/tf_interop/register_tensorflow.h" + +#include <string> +#include <vector> + +#include "bindings/python/pyiree/compiler.h" +#include "bindings/python/pyiree/status_utils.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" + +using namespace mlir; // NOLINT + +using tensorflow::ConvertSavedModelToMlir; +using tensorflow::RunOptions; +using tensorflow::SavedModelBundle; +using tensorflow::SessionOptions; + +namespace iree { +namespace python { + +namespace { + +CompilerModuleBundle LoadSavedModel( + std::shared_ptr<CompilerContextBundle> context_bundle, + const std::string& saved_model_dir, + const std::vector<std::string>& exported_names) { + SessionOptions session_options; + RunOptions run_options; + SavedModelBundle bundle; + std::unordered_set<std::string> tags{"serve"}; + auto load_status = LoadSavedModel( + session_options, run_options, + std::string(saved_model_dir.data(), saved_model_dir.length()), tags, + &bundle); + if (!load_status.ok()) { + std::stringstream msg; + msg << "Failed to load saved model '" << saved_model_dir + << "': " << load_status; + throw RaisePyError(PyExc_RuntimeError, msg.str().c_str()); + } + + // TODO(laurenzo): Fix the upstream ConvertSavedModelToMlir() to take a const + // span of external names. + std::vector<std::string> mutable_exported_names = exported_names; + auto module_or = + ConvertSavedModelToMlir(bundle, context_bundle->mlir_context(), + absl::MakeSpan(mutable_exported_names)); + if (!module_or.status().ok()) { + std::stringstream msg; + msg << "Failed to load saved model '" << saved_model_dir + << "': " << load_status; + throw RaisePyError(PyExc_RuntimeError, msg.str().c_str()); + } + return CompilerModuleBundle(context_bundle, + module_or.ConsumeValueOrDie().release()); +} + +} // namespace + +void SetupTensorFlowBindings(pybind11::module m) { + m.def("load_saved_model", &LoadSavedModel, py::arg("compiler_context"), + py::arg("saved_model_dir"), + py::arg("exported_names") = std::vector<std::string>()); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/tensorflow/register_tensorflow.h b/bindings/python/pyiree/tf_interop/register_tensorflow.h similarity index 80% rename from bindings/python/pyiree/tensorflow/register_tensorflow.h rename to bindings/python/pyiree/tf_interop/register_tensorflow.h index 95dbf3e..7a54404 100644 --- a/bindings/python/pyiree/tensorflow/register_tensorflow.h +++ b/bindings/python/pyiree/tf_interop/register_tensorflow.h
@@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef IREE_BINDINGS_PYTHON_PYIREE_TENSORFLOW_REGISTER_TENSORFLOW_H_ -#define IREE_BINDINGS_PYTHON_PYIREE_TENSORFLOW_REGISTER_TENSORFLOW_H_ +#ifndef IREE_BINDINGS_PYTHON_PYIREE_TF_INTEROP_REGISTER_TENSORFLOW_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_TF_INTEROP_REGISTER_TENSORFLOW_H_ #include <string> @@ -27,4 +27,4 @@ } // namespace python } // namespace iree -#endif // IREE_BINDINGS_PYTHON_PYIREE_TENSORFLOW_REGISTER_TENSORFLOW_H_ +#endif // IREE_BINDINGS_PYTHON_PYIREE_TF_INTEROP_REGISTER_TENSORFLOW_H_
diff --git a/bindings/python/pyiree/tensorflow/register_tensorflow_noop.cc b/bindings/python/pyiree/tf_interop/register_tensorflow_noop.cc similarity index 91% rename from bindings/python/pyiree/tensorflow/register_tensorflow_noop.cc rename to bindings/python/pyiree/tf_interop/register_tensorflow_noop.cc index 41e0eed..e5f83aa 100644 --- a/bindings/python/pyiree/tensorflow/register_tensorflow_noop.cc +++ b/bindings/python/pyiree/tf_interop/register_tensorflow_noop.cc
@@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "bindings/python/pyiree/tensorflow/register_tensorflow.h" +#include "bindings/python/pyiree/tf_interop/register_tensorflow.h" namespace iree { namespace python {
diff --git a/bindings/python/pyiree/tf_interop/saved_model_test.py b/bindings/python/pyiree/tf_interop/saved_model_test.py new file mode 100644 index 0000000..4d0ee59 --- /dev/null +++ b/bindings/python/pyiree/tf_interop/saved_model_test.py
@@ -0,0 +1,103 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import os +import sys +import tempfile + +from pyiree import binding as binding + +# Determine if compiled with tf_interop support. +if not hasattr(binding, "tf_interop"): + print("Not running tests because tf_interop support not compiled") + sys.exit(0) + +# Dynamically import tensorflow. +try: + # Use a dynamic import so as to avoid hermetic dependency analysis + # (i.e. we only want the tensorflow from the environment). + tf = importlib.import_module("tensorflow") + # Just in case if linked against a pre-V2 defaulted version. + tf.enable_v2_behavior() + tf = tf.compat.v2 +except ImportError: + print("Not running tests because tensorflow is not available") + sys.exit(0) + + +class StatefulModule(tf.Module): + + def __init__(self): + self.v = tf.Variable([4], dtype=tf.float32) + + @tf.function(input_signature=[ + tf.TensorSpec([4], tf.float32), + tf.TensorSpec([4], tf.float32) + ]) + def add(self, a, b): + return tf.tanh(self.v * a + b) + + +class RuntimeTest(tf.test.TestCase): + + def testLoadSavedModelToXlaPipeline(self): + """Tests that a basic saved model to XLA workflow grossly functions. + + This is largely here to verify that everything is linked in that needs to be + and that there are not no-ops, etc. + """ + with tempfile.TemporaryDirectory() as temp_dir: + sm_dir = os.path.join(temp_dir, "simple.sm") + print("Saving to:", sm_dir) + my_module = StatefulModule() + options = tf.saved_model.SaveOptions(save_debug_info=True) + tf.saved_model.save(my_module, sm_dir, options=options) + + # Load it up. + ctx = binding.compiler.CompilerContext() + input_module = binding.tf_interop.load_saved_model(ctx, sm_dir) + input_asm = input_module.to_asm() + print("LOADED ASM:\n", input_asm) + # Should have out exported name and have executor islands. + self.assertRegex(input_asm, + r"""tf_saved_model.exported_names = \["add"\]""") + self.assertRegex(input_asm, r"""tf_executor\.island""") + + # Run the necessary lowering passes. Makes sure that these are linked in. + input_module.run_pass_pipeline([ + "tf-executor-graph-pruning", + "tf-standard-pipeline", + "canonicalize", + ]) + lowered_asm = input_module.to_asm() + print("LOWERED ASM:\n", lowered_asm) + # Should have collapsed all executor islands. + self.assertNotRegex(lowered_asm, r"""tf_executor\.island""") + + # And legalize to XLA. + input_module.run_pass_pipeline([ + "xla-legalize-tf", + ]) + xla_asm = input_module.to_asm() + print("XLA ASM:", xla_asm) + self.assertRegex(xla_asm, "xla_hlo.tanh") + + +if __name__ == "__main__": + tf.test.main()
diff --git a/colab/low_level_invoke_function.ipynb b/colab/low_level_invoke_function.ipynb index b0757bc..bcc7bf5 100644 --- a/colab/low_level_invoke_function.ipynb +++ b/colab/low_level_invoke_function.ipynb
@@ -65,25 +65,27 @@ "metadata": { "id": "rxaiDxiq96SD", "colab_type": "code", + "outputId": "a1304fa3-b15a-4fab-eaaf-1467dc867191", "colab": { "base_uri": "https://localhost:8080/", "height": 51 - }, - "outputId": "c7c254b5-2dd4-478e-bdc4-8386a1d29953" + } }, "source": [ "# Compile a module.\n", - "blob = binding.compiler.compile_module_from_asm(\"\"\"\n", + "ctx = binding.compiler.CompilerContext()\n", + "input_module = ctx.parse_asm(\"\"\"\n", " func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>\n", " attributes { iree.module.export } {\n", " %0 = \"xla_hlo.mul\"(%arg0, %arg1) {name = \"mul.1\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n", " return %0 : tensor<4xf32>\n", " }\n", " \"\"\")\n", + "blob = input_module.compile_to_sequencer_blob()\n", "m = binding.vm.create_module_from_blob(blob)\n", "dump_module(m)" ], - "execution_count": 4, + "execution_count": 7, "outputs": [ { "output_type": "stream", @@ -100,11 +102,11 @@ "metadata": { "id": "aH6VdaoXD4hV", "colab_type": "code", + "outputId": "d109d4e7-83bf-4038-c0d1-c643dbd10c8e", "colab": { "base_uri": "https://localhost:8080/", "height": 102 - }, - "outputId": "c4e8be2c-f144-41c8-952c-9aca66300899" + } }, "source": [ "# Initialize the runtime and register the module.\n", @@ -137,15 +139,15 @@ "result_ary = np.array(result, copy=False)\n", "print(\"NP result:\", result_ary)\n" ], - "execution_count": 5, + "execution_count": 8, "outputs": [ { "output_type": "stream", "text": [ "INVOKE F: simple_mul\n", "Status: True\n", - "Results: [<pyiree.binding.hal.BufferView object at 0x7fc7697105e0>]\n", - "Mapped result: <pyiree.binding.hal.MappedMemory object at 0x7fc769710618>\n", + "Results: [<pyiree.binding.hal.BufferView object at 0x00000179E51410D8>]\n", + "Mapped result: <pyiree.binding.hal.MappedMemory object at 0x00000179E51412D0>\n", "NP result: [ 4. 10. 18. 28.]\n" ], "name": "stdout"
diff --git a/colab/simple_tensorflow_module_import.ipynb b/colab/simple_tensorflow_module_import.ipynb index ed4766e..11befd7 100644 --- a/colab/simple_tensorflow_module_import.ipynb +++ b/colab/simple_tensorflow_module_import.ipynb
@@ -58,11 +58,11 @@ "metadata": { "id": "6YGqN2uqP_7P", "colab_type": "code", + "outputId": "4cc03b22-bee5-4ffd-e4cf-6fa3055ea886", "colab": { "base_uri": "https://localhost:8080/", - "height": 411 - }, - "outputId": "4e8ba182-c7ee-402b-b6e9-15590e8617c5" + "height": 853 + } }, "source": [ "class MyModule(tf.Module):\n", @@ -80,23 +80,35 @@ "options = tf.saved_model.SaveOptions(save_debug_info=True)\n", "tf.saved_model.save(my_mod, os.path.join(SAVE_PATH, \"simple.sm\"), options=options)\n", "\n", - "mlir_asm = binding.tf_interop.import_saved_model_to_mlir_asm(os.path.join(SAVE_PATH, \"simple.sm\"))\n", - "print(mlir_asm)" + "ctx = binding.compiler.CompilerContext()\n", + "input_module = binding.tf_interop.load_saved_model(ctx, os.path.join(SAVE_PATH, \"simple.sm\"))\n", + "print('LOADED ASM:', input_module.to_asm())\n", + "\n", + "# Canonicalize the TF import.\n", + "input_module.run_pass_pipeline([\n", + " \"tf-executor-graph-pruning\",\n", + " \"tf-standard-pipeline\",\n", + " \"canonicalize\",\n", + "])\n", + "print(\"LOWERED TF ASM:\", input_module.to_asm())\n", + "\n", + "# Legalize to XLA (high-level).\n", + "input_module.run_pass_pipeline([\n", + " \"xla-legalize-tf\",\n", + "])\n", + "print(\"XLA ASM:\", input_module.to_asm())" ], - "execution_count": 2, + "execution_count": 5, "outputs": [ { "output_type": "stream", "text": [ - "WARNING:tensorflow:From c:\\users\\laurenzo\\scoop\\apps\\python36\\current\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1785: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "If using Keras pass *_constraint arguments to layers.\n", "INFO:tensorflow:Assets written to: C:\\Users\\laurenzo\\saved_models\\simple.sm\\assets\n", - "\n", + "LOADED ASM: \n", "\n", "module attributes {tf_saved_model.semantics} {\n", " \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n", - " func @__inference_add_160(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n", + " func @__inference_add_2620(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n", " attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n", " %0 = tf_executor.graph {\n", " %1:2 = tf_executor.island wraps \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n", @@ -109,6 +121,35 @@ " return %0 : tensor<4xf32>\n", " }\n", "}\n", + "\n", + "LOWERED TF ASM: \n", + "\n", + "module attributes {tf_saved_model.semantics} {\n", + " \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n", + " func @__inference_add_2620(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n", + " attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n", + " %0 = \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n", + " %1 = \"tf.Mul\"(%0, %arg0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"mul\"} : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n", + " %2 = \"tf.AddV2\"(%1, %arg1) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"add\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n", + " %3 = \"tf.Tanh\"(%2) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Tanh\"} : (tensor<4xf32>) -> tensor<4xf32>\n", + " %4 = \"tf.Identity\"(%3) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Identity\"} : (tensor<4xf32>) -> tensor<4xf32>\n", + " return %4 : tensor<4xf32>\n", + " }\n", + "}\n", + "\n", + "XLA ASM: \n", + "\n", + "module attributes {tf_saved_model.semantics} {\n", + " \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n", + " func @__inference_add_2620(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n", + " attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n", + " %0 = \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n", + " %1 = \"xla_hlo.mul\"(%0, %arg0) : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n", + " %2 = xla_hlo.add %1, %arg1 : tensor<4xf32>\n", + " %3 = \"xla_hlo.tanh\"(%2) : (tensor<4xf32>) -> tensor<4xf32>\n", + " return %3 : tensor<4xf32>\n", + " }\n", + "}\n", "\n" ], "name": "stdout"
diff --git a/iree/BUILD.bazel b/iree/BUILD.bazel index d21f74c..eed7ab7 100644 --- a/iree/BUILD.bazel +++ b/iree/BUILD.bazel
@@ -29,3 +29,10 @@ name = "target_config", defines = ["IREE_UNSPECIFIED_TARGET=1"], ) + +config_setting( + name = "enable_tensorflow", + define_values = { + "iree_tensorflow": "true", + }, +)
diff --git a/iree/build_defs.bzl b/iree/build_defs.bzl index 31fe1fb..d1bce9f 100644 --- a/iree/build_defs.bzl +++ b/iree/build_defs.bzl
@@ -8,6 +8,10 @@ NUMPY_DEPS = [] PYTHON_HEADERS_DEPS = ["@iree_native_python//:python_headers"] +# Optional deps to enable an intree TensorFlow python. This build configuration +# defaults to getting TensorFlow from the python environment (empty). +INTREE_TENSORFLOW_PY_DEPS = [] + def platform_trampoline_deps(basename, path = "base"): """Produce a list of deps for the given `basename` platform target.