Update CMake min to 3.12 STRING JOIN was introduced in 3.12 TEST=builds with cmake 3.12 fails with 3.10.2 (stock in ubuntu) Latest Ubuntu PPAs are here: https://apt.kitware.com/ -- e65d4dfd0278185962071d30359fc0ce1232be9e by Anush Elangovan <anush@nod-labs.com>: Fix suprious reference to DLOG TEST:builds Closes #89 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/89 from powderluv:fix-cmake-ver e65d4dfd0278185962071d30359fc0ce1232be9e PiperOrigin-RevId: 276298047
diff --git a/bindings/python/BUILD b/bindings/python/BUILD new file mode 100644 index 0000000..793b93d --- /dev/null +++ b/bindings/python/BUILD
@@ -0,0 +1,25 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//:build_defs.google.bzl", "iree_py_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_py_library( + name = "pathsetup", + imports = ["."], +)
diff --git a/bindings/python/README.md b/bindings/python/README.md new file mode 100644 index 0000000..4307803 --- /dev/null +++ b/bindings/python/README.md
@@ -0,0 +1,19 @@ +# IREE Python Sandbox + +This directory contains various integration-oriented Python utilities that are +not intended to be a public API. They are, however, useful for lower level +compiler interop work. And of course, they are useful since we presently lack a +real API :) + +We're still untangling build support, jupyter integration, etc for OSS builds. +Stand by. + +## Issues: + +* This is called `pyiree` vs `iree` to avoid pythonpath collisions that tend + to arise when an iree directory is inside of an iree directory. +* The above could be solved in the bazel build by making iree/bindings/python + its own sub-workspace. +* However, doing so presently breaks both flatbuffer and tablegen generation + because of fixes needed to those build rules so that they are sub-worksapce + aware.
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD new file mode 100644 index 0000000..151f5ee --- /dev/null +++ b/bindings/python/pyiree/BUILD
@@ -0,0 +1,103 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//:build_defs.google.bzl", "NUMPY_DEPS", "iree_py_extension") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +COMPILER_DEPS = [ + "///compiler/Translation/Sequencer", + "///compiler/Translation/Interpreter", + "///compiler/Translation/SPIRV", +] + +DRIVER_DEPS = [ + "///hal/interpreter:interpreter_driver_module", + "///hal/vulkan:vulkan_driver_module", +] + +iree_py_extension( + name = "binding", + srcs = [ + "binding.cc", + "binding.h", + "compiler.cc", + "compiler.h", + "hal.cc", + "hal.h", + "initialize.cc", + "initialize.h", + "rt.cc", + "rt.h", + "status_utils.cc", + "status_utils.h", + "vm.cc", + "vm.h", + ], + copts = [ + "-fexceptions", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "///base:api", + "///base:init", + "///base:status", + "///hal:api", + "///rt:api", + "///schemas", + "///vm:api", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Parser", + "@iree_pybind11//:pybind11", + ] + COMPILER_DEPS + DRIVER_DEPS, +) + +py_test( + name = "compiler_test", + srcs = ["compiler_test.py"], + python_version = "PY3", + deps = [ + ":binding", + "///bindings/python:pathsetup", + "@absl_py//absl/testing:absltest", + ], +) + +py_test( + name = "hal_test", + srcs = ["hal_test.py"], + python_version = "PY3", + deps = [ + ":binding", + "///bindings/python:pathsetup", + "@absl_py//absl/testing:absltest", + ], +) + +py_test( + name = "runtime_test", + srcs = ["runtime_test.py"], + python_version = "PY3", + deps = NUMPY_DEPS + [ + ":binding", + "@absl_py//absl/testing:absltest", + ], +)
diff --git a/bindings/python/pyiree/binding.cc b/bindings/python/pyiree/binding.cc new file mode 100644 index 0000000..1cbdef7 --- /dev/null +++ b/bindings/python/pyiree/binding.cc
@@ -0,0 +1,46 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/binding.h" + +#include "bindings/python/pyiree/compiler.h" +#include "bindings/python/pyiree/hal.h" +#include "bindings/python/pyiree/initialize.h" +#include "bindings/python/pyiree/rt.h" +#include "bindings/python/pyiree/status_utils.h" +#include "bindings/python/pyiree/vm.h" + +namespace iree { +namespace python { + +PYBIND11_MODULE(binding, m) { + m.doc() = "IREE Binding Backend Helpers"; + py::class_<OpaqueBlob, std::shared_ptr<OpaqueBlob>>(m, "OpaqueBlob"); + m.def("initialize_extension", &InitializeExtension); + + auto compiler_m = m.def_submodule("compiler", "IREE compiler support"); + SetupCompilerBindings(compiler_m); + + auto hal_m = m.def_submodule("hal", "IREE HAL support"); + SetupHalBindings(hal_m); + + auto rt_m = m.def_submodule("rt", "IREE RT api"); + SetupRtBindings(rt_m); + + auto vm_m = m.def_submodule("vm", "IREE VM api"); + SetupVmBindings(vm_m); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/binding.h b/bindings/python/pyiree/binding.h new file mode 100644 index 0000000..e470ee6 --- /dev/null +++ b/bindings/python/pyiree/binding.h
@@ -0,0 +1,147 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_ + +#include <vector> + +#include "absl/types/optional.h" +#include "base/api.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace pybind11 { +namespace detail { +#if !defined(ABSL_HAVE_STD_OPTIONAL) +// Make absl::optional act like the future C++17 optional for pybind11. +// If ABSL_HAVE_STD_OPTIONAL is defined then absl::optional == std::optional +// and the default type caster is sufficient. +template <typename T> +struct type_caster<absl::optional<T>> : optional_caster<absl::optional<T>> {}; +#endif +} // namespace detail +} // namespace pybind11 + +namespace iree { +namespace python { + +namespace py = pybind11; + +// Wrapper around a blob of memory. +// Used to transport blobs back and forth between C++ and Python. +class OpaqueBlob { + public: + OpaqueBlob() : data_(nullptr), size_(0) {} + OpaqueBlob(void* data, size_t size) : data_(data), size_(size) {} + virtual ~OpaqueBlob() = default; + + void* data() { return data_; } + const void* data() const { return data_; } + size_t size() const { return size_; } + + // Create a free function from the OpaqueBlob shared pointer. + using BufferFreeFn = void (*)(void* self, iree_byte_span_t); + static std::pair<BufferFreeFn, void*> CreateFreeFn( + std::shared_ptr<OpaqueBlob> blob) { + // Note that there are more efficient ways to write this which + // don't bounce through an extra heap alloc, but this is not + // intended to be a high impact code path. + struct Holder { + std::shared_ptr<OpaqueBlob> blob; + }; + Holder* holder = new Holder{std::move(blob)}; + auto free_fn = +([](void* self, iree_byte_span_t) { + Holder* self_holder = static_cast<Holder*>(self); + delete self_holder; + }); + return {free_fn, holder}; + } + + protected: + void* data_; + size_t size_; +}; + +// Opaque blob that owns a vector. +class OpaqueByteVectorBlob : public OpaqueBlob { + public: + OpaqueByteVectorBlob(std::vector<uint8_t> v) + : OpaqueBlob(), v_(std::move(v)) { + data_ = v_.data(); + size_ = v_.size(); + } + + private: + std::vector<uint8_t> v_; +}; + +template <typename T> +struct ApiPtrAdapter {}; + +template <typename Self, typename T> +class ApiRefCounted { + public: + ApiRefCounted() : instance_(nullptr) {} + ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) { + other.instance_ = nullptr; + } + void operator=(const ApiRefCounted&) = delete; + + ~ApiRefCounted() { Release(); } + + // Creates an instance of the ref counted wrapper based on an instance + // that has already been retained. Ownership is transferred to the + // wrapper. + static Self CreateRetained(T* retained_inst) { + auto self = Self(); + self.instance_ = retained_inst; + return self; + } + + // Creates a new instance, retaining the underlying object. + static Self RetainAndCreate(T* non_retained_inst) { + auto self = Self(); + self.instance_ = non_retained_inst; + if (non_retained_inst) { + ApiPtrAdapter<T>::Retain(non_retained_inst); + } + return self; + } + + T* raw_ptr() { + if (!instance_) { + throw std::invalid_argument("API object is null"); + } + return instance_; + } + void Retain() { + if (instance_) { + ApiPtrAdapter<T>::Retain(instance_); + } + } + void Release() { + if (instance_) { + ApiPtrAdapter<T>::Release(instance_); + } + } + + private: + T* instance_; +}; + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
diff --git a/bindings/python/pyiree/compiler.cc b/bindings/python/pyiree/compiler.cc new file mode 100644 index 0000000..562f1c9 --- /dev/null +++ b/bindings/python/pyiree/compiler.cc
@@ -0,0 +1,93 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/compiler.h" + +#include <stdexcept> + +#include "bindings/python/pyiree/binding.h" +#include "bindings/python/pyiree/initialize.h" +#include "bindings/python/pyiree/status_utils.h" +#include "compiler/Translation/Sequencer/SequencerModuleTranslation.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "schemas/module_def_generated.h" + +namespace py = pybind11; + +using namespace mlir; +using namespace mlir::iree_compiler; + +using llvm::MemoryBuffer; +using llvm::MemoryBufferRef; +using llvm::StringRef; + +namespace iree { +namespace python { + +namespace { + +OwningModuleRef parseMLIRModuleFromString(StringRef contents, + MLIRContext* context) { + std::unique_ptr<MemoryBuffer> contents_buffer; + if (contents.back() == 0) { + // If it has a nul terminator, just use as-is. + contents_buffer = MemoryBuffer::getMemBuffer(contents.drop_back()); + } else { + // Otherwise, make a copy. + contents_buffer = MemoryBuffer::getMemBufferCopy(contents, "EMBED"); + } + + llvm::SourceMgr source_mgr; + source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc()); + OwningModuleRef mlir_module = parseSourceFile(source_mgr, context); + return mlir_module; +} + +} // namespace + +std::shared_ptr<OpaqueBlob> CompileModuleFromAsm(const std::string& moduleAsm) { + InitializeExtension({}); + + MLIRContext context; + + // Arrange to get a view that includes a terminating null to avoid additional + // copy. + const char* moduleAsmChars = moduleAsm.c_str(); + StringRef moduleAsmSr(moduleAsmChars, moduleAsm.size() + 1); + + // TODO(laurenzo): This error handling is super hoaky. Hook into the MLIR + // error reporter and plumb through properly. + OwningModuleRef mlirModule = parseMLIRModuleFromString(moduleAsmSr, &context); + if (!mlirModule) { + throw std::runtime_error("Failed to parse MLIR asm"); + } + + auto moduleBlob = + mlir::iree_compiler::translateMlirToIreeSequencerModule(mlirModule.get()); + if (moduleBlob.empty()) { + throw std::runtime_error("Failed to translate MLIR module"); + } + return std::make_shared<OpaqueByteVectorBlob>(std::move(moduleBlob)); +} + +void SetupCompilerBindings(pybind11::module m) { + m.def("compile_module_from_asm", CompileModuleFromAsm); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/compiler.h b/bindings/python/pyiree/compiler.h new file mode 100644 index 0000000..0bd6624 --- /dev/null +++ b/bindings/python/pyiree/compiler.h
@@ -0,0 +1,30 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_ + +#include <string> + +#include "bindings/python/pyiree/binding.h" + +namespace iree { +namespace python { + +void SetupCompilerBindings(pybind11::module m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
diff --git a/bindings/python/pyiree/compiler_test.py b/bindings/python/pyiree/compiler_test.py new file mode 100644 index 0000000..5cd3de0 --- /dev/null +++ b/bindings/python/pyiree/compiler_test.py
@@ -0,0 +1,39 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest + +from pyiree import binding as binding + + +class CompilerTest(absltest.TestCase): + + def testModuleCompileAndIntrospectFromAsm(self): + + m = binding.compiler.compile_module_from_asm(""" + func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> + attributes { iree.module.export } { + %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> + } + """) + self.assertTrue(m) + + +if __name__ == '__main__': + absltest.main()
diff --git a/bindings/python/pyiree/hal.cc b/bindings/python/pyiree/hal.cc new file mode 100644 index 0000000..e7a59a2 --- /dev/null +++ b/bindings/python/pyiree/hal.cc
@@ -0,0 +1,135 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/hal.h" + +#include "hal/api.h" + +namespace iree { +namespace python { + +namespace { + +class HalMappedMemory { + public: + HalMappedMemory(iree_hal_mapped_memory_t mapped_memory, + iree_hal_buffer_view_t* bv) + : mapped_memory_(mapped_memory), bv_(bv) { + iree_hal_buffer_view_retain(bv_); + } + ~HalMappedMemory() { + if (bv_) { + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_); + CHECK_EQ(iree_hal_buffer_unmap(buffer, &mapped_memory_), IREE_STATUS_OK); + iree_hal_buffer_view_release(bv_); + } + } + HalMappedMemory(HalMappedMemory&& other) + : mapped_memory_(other.mapped_memory_), bv_(other.bv_) { + other.bv_ = nullptr; + } + + static HalMappedMemory Create(HalBufferView& bv) { + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv.raw_ptr()); + iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer); + iree_hal_mapped_memory_t mapped_memory; + CheckApiStatus(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ, + 0 /* element_offset */, byte_length, + &mapped_memory), + "Could not map memory"); + return HalMappedMemory(mapped_memory, bv.raw_ptr()); + } + + py::buffer_info ToBufferInfo() { + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_); + iree_shape_t shape; + CheckApiStatus(iree_hal_buffer_view_shape(bv_, &shape), + "Error getting buffer view shape"); + int8_t element_size = iree_hal_buffer_view_element_size(bv_); + iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer); + absl::InlinedVector<ssize_t, IREE_SHAPE_MAX_RANK> dims; + dims.resize(shape.rank); + for (int i = 0; i < shape.rank; ++i) { + dims[i] = shape.dims[i]; + } + absl::InlinedVector<ssize_t, IREE_SHAPE_MAX_RANK> strides; + strides.resize(shape.rank); + for (int i = 1; i < shape.rank; ++i) { + strides[i - 1] = shape.dims[i] * element_size; + } + if (!strides.empty()) { + strides.back() = 1 * element_size; + } + + // TODO(laurenzo): We need to figure out how to propagate dtype in the + // buffer view. + return py::buffer_info( + mapped_memory_.contents.data, element_size, + py::format_descriptor<float>::format(), // TODO(laurenzo): DTYPE! + shape.rank, dims, strides); + } + + private: + iree_hal_mapped_memory_t mapped_memory_; + iree_hal_buffer_view_t* bv_; +}; + +} // namespace + +void SetupHalBindings(pybind11::module m) { + // Enums. + py::enum_<iree_hal_memory_type_t>(m, "MemoryType") + .value("NONE", IREE_HAL_MEMORY_TYPE_NONE) + .value("TRANSIENT", IREE_HAL_MEMORY_TYPE_TRANSIENT) + .value("HOST_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) + .value("HOST_COHERENT", IREE_HAL_MEMORY_TYPE_HOST_COHERENT) + .value("HOST_CACHED", IREE_HAL_MEMORY_TYPE_HOST_CACHED) + .value("HOST_LOCAL", IREE_HAL_MEMORY_TYPE_HOST_LOCAL) + .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) + .value("DEVICE_LOCAL", IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL) + .export_values(); + py::enum_<iree_hal_buffer_usage_t>(m, "BufferUsage") + .value("NONE", IREE_HAL_BUFFER_USAGE_NONE) + .value("CONSTANT", IREE_HAL_BUFFER_USAGE_CONSTANT) + .value("TRANSFER", IREE_HAL_BUFFER_USAGE_TRANSFER) + .value("MAPPING", IREE_HAL_BUFFER_USAGE_MAPPING) + .value("DISPATCH", IREE_HAL_BUFFER_USAGE_DISPATCH) + .value("ALL", IREE_HAL_BUFFER_USAGE_ALL) + .export_values(); + py::enum_<iree_hal_memory_access_t>(m, "MemoryAccess") + .value("NONE", IREE_HAL_MEMORY_ACCESS_NONE) + .value("READ", IREE_HAL_MEMORY_ACCESS_READ) + .value("WRITE", IREE_HAL_MEMORY_ACCESS_WRITE) + .value("DISCARD", IREE_HAL_MEMORY_ACCESS_DISCARD) + .value("DISCARD_WRITE", IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE) + .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL) + .export_values(); + + py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector)); + py::class_<HalBufferView>(m, "BufferView") + .def("map", HalMappedMemory::Create); + py::class_<HalMappedMemory>(m, "MappedMemory", py::buffer_protocol()) + .def_buffer(&HalMappedMemory::ToBufferInfo); + py::class_<HalBuffer>(m, "Buffer") + .def_static("allocate_heap", &HalBuffer::AllocateHeapBuffer, + py::arg("memory_type"), py::arg("usage"), + py::arg("allocation_size")) + .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"), + py::arg("byte_length")) + .def("create_view", &HalBuffer::CreateView, py::arg("shape"), + py::arg("element_size")); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/hal.h b/bindings/python/pyiree/hal.h new file mode 100644 index 0000000..d26bcf0 --- /dev/null +++ b/bindings/python/pyiree/hal.h
@@ -0,0 +1,97 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_HAL_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_HAL_H_ + +#include "bindings/python/pyiree/binding.h" +#include "bindings/python/pyiree/status_utils.h" +#include "hal/api.h" + +namespace iree { +namespace python { + +template <> +struct ApiPtrAdapter<iree_hal_buffer_t> { + static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); } + static void Release(iree_hal_buffer_t* b) { iree_hal_buffer_release(b); } +}; + +template <> +struct ApiPtrAdapter<iree_hal_buffer_view_t> { + static void Retain(iree_hal_buffer_view_t* bv) { + iree_hal_buffer_view_retain(bv); + } + static void Release(iree_hal_buffer_view_t* bv) { + iree_hal_buffer_view_release(bv); + } +}; + +struct HalShape { + public: + static HalShape FromIntVector(std::vector<int32_t> indices) { + if (indices.size() > IREE_SHAPE_MAX_RANK) { + throw RaiseValueError("Shape exceeded maximum rank"); + } + HalShape s; + s.s.rank = indices.size(); + for (size_t i = 0, e = indices.size(); i < e; ++i) { + s.s.dims[i] = indices[i]; + } + return s; + } + + iree_shape_t s; +}; + +class HalBufferView + : public ApiRefCounted<HalBufferView, iree_hal_buffer_view_t> { + public: +}; + +class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> { + public: + static HalBuffer AllocateHeapBuffer(int32_t memory_type, int32_t usage, + iree_host_size_t allocation_size) { + iree_hal_buffer_t* buffer = nullptr; + CheckApiStatus( + iree_hal_heap_buffer_allocate( + static_cast<iree_hal_memory_type_t>(memory_type), + static_cast<iree_hal_buffer_usage_t>(usage), allocation_size, + IREE_ALLOCATOR_DEFAULT, IREE_ALLOCATOR_DEFAULT, &buffer), + "Error allocating heap buffer"); + return HalBuffer::CreateRetained(buffer); + } + + void FillZero(iree_device_size_t byte_offset, + iree_device_size_t byte_length) { + CheckApiStatus(iree_hal_buffer_zero(raw_ptr(), byte_offset, byte_length), + "Error zero filling buffer"); + } + + HalBufferView CreateView(HalShape& shape, size_t element_size) { + iree_hal_buffer_view_t* bv; + CheckApiStatus(iree_hal_buffer_view_create(raw_ptr(), shape.s, element_size, + IREE_ALLOCATOR_DEFAULT, &bv), + "Error creating buffer view"); + return HalBufferView::CreateRetained(bv); + } +}; + +void SetupHalBindings(pybind11::module m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
diff --git a/bindings/python/pyiree/hal_test.py b/bindings/python/pyiree/hal_test.py new file mode 100644 index 0000000..469e7f3 --- /dev/null +++ b/bindings/python/pyiree/hal_test.py
@@ -0,0 +1,42 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest + +from pyiree import binding as binding + + +class HalTest(absltest.TestCase): + + def testEnums(self): + print("MemoryType =", binding.hal.MemoryType) + print("HOST_VISIBLE =", int(binding.hal.MemoryType.HOST_VISIBLE)) + + def testAllocateHeap(self): + b = binding.hal.Buffer.allocate_heap( + memory_type=int(binding.hal.MemoryType.HOST_LOCAL), + usage=int(binding.hal.BufferUsage.ALL), + allocation_size=4096) + self.assertIsNot(b, None) + b.fill_zero(0, 4096) + shape = binding.hal.Shape([1, 1024]) + unused_bv = b.create_view(shape, 4) + + +if __name__ == "__main__": + absltest.main()
diff --git a/bindings/python/pyiree/initialize.cc b/bindings/python/pyiree/initialize.cc new file mode 100644 index 0000000..acf5cf0 --- /dev/null +++ b/bindings/python/pyiree/initialize.cc
@@ -0,0 +1,53 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/initialize.h" + +#include <string.h> + +#include <mutex> // NOLINT + +#include "base/init.h" + +namespace iree { +namespace python { + +namespace { + +void InternalInitialize(const std::vector<std::string>& arguments) { + int argc = arguments.size() + 1; // plus one for program name. + char** argv = static_cast<char**>( + malloc(sizeof(char*) * (argc + 1))); // plus one for null terminator. + char** orig_argv = argv; + argv[0] = strdup("<python_extension>"); + for (int i = 1; i < argc; ++i) { + argv[i] = strdup(arguments[i - 1].c_str()); + } + argv[argc] = nullptr; + InitializeEnvironment(&argc, &argv); + for (int i = 0; i < argc; ++i) { + free(argv[i]); + } + free(orig_argv); +} + +} // namespace + +void InitializeExtension(const std::vector<std::string>& arguments) { + static std::once_flag init_once; + std::call_once(init_once, InternalInitialize, arguments); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/initialize.h b/bindings/python/pyiree/initialize.h new file mode 100644 index 0000000..38dff00 --- /dev/null +++ b/bindings/python/pyiree/initialize.h
@@ -0,0 +1,34 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_INITIALIZE_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_INITIALIZE_H_ + +#include <vector> + +namespace iree { +namespace python { + +// Performs once-only initialization of the extension, which is required +// prior to any use of the runtime. Optionally, arguments can be provided. +// If automatic initialization has already taken place, then does nothing. +// In the future, it would be nice to have more of the process level init +// happen automatically and rely less on this kind of init the world +// function. +void InitializeExtension(const std::vector<std::string>& arguments); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_INITIALIZE_H_
diff --git a/bindings/python/pyiree/rt.cc b/bindings/python/pyiree/rt.cc new file mode 100644 index 0000000..b683d02 --- /dev/null +++ b/bindings/python/pyiree/rt.cc
@@ -0,0 +1,150 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/rt.h" + +#include "base/api.h" +#include "bindings/python/pyiree/status_utils.h" +#include "hal/api.h" + +namespace iree { +namespace python { + +HalBufferView RtContext::WrapPyBufferForInput(py::buffer py_buffer) { + auto py_buffer_info = py_buffer.request(false /* writable */); + if (py_buffer_info.ndim > IREE_SHAPE_MAX_RANK || py_buffer_info.ndim < 0) { + RaiseValueError("Unsupported buffer rank"); + } + if (py_buffer_info.size < 0) { + RaiseValueError("Illegal buffer size"); + } + + // For the moment, allocate a device visible buffer of equivalent size and + // copy into it. + // TODO(laurenzo): Once sequencer is in place, switch to HeapBuffer, wrap + // and retain the original buffer. + iree_host_size_t byte_size = py_buffer_info.size * py_buffer_info.itemsize; + HalBuffer buffer = + AllocateDeviceVisible(byte_size, IREE_HAL_BUFFER_USAGE_CONSTANT | + IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH); + CheckApiStatus(iree_hal_buffer_write_data(buffer.raw_ptr(), 0, + py_buffer_info.ptr, byte_size), + "Error writing to input buffer"); + + // Create the buffer view. + // TODO(laurenzo): This does no validation on dtype and only cares if the + // elementsize matches. Figure out where to enforce actual dtype. + iree_shape_t shape; + shape.rank = py_buffer_info.ndim; + + // Verify strides are row-major. + // TODO(laurenzo): Test this with rank>1. + for (int i = 1; i < shape.rank; ++i) { + if ((py_buffer_info.strides[i - 1] * py_buffer_info.itemsize) != + py_buffer_info.shape[i]) { + RaiseValueError("Expected row-major layout"); + } + } + if (!py_buffer_info.strides.empty()) { + if (py_buffer_info.strides.back() != 1) { + RaiseValueError("Expected row-major layout"); + } + } + + // Populate shape. + for (int i = 0; i < shape.rank; ++i) { + ssize_t dim = py_buffer_info.shape[i]; + if (dim < 0) { + RaiseValueError("Unsupported negative dim"); + } + shape.dims[i] = dim; + } + + iree_hal_buffer_view_t* bv; + CheckApiStatus(iree_hal_buffer_view_create(buffer.raw_ptr(), shape, + py_buffer_info.itemsize, + IREE_ALLOCATOR_DEFAULT, &bv), + "Error allocating buffer view"); + + return HalBufferView::CreateRetained(bv); +} + +void SetupRtBindings(pybind11::module m) { + // BufferPlacement. + py::enum_<BufferPlacement>(m, "BufferPlacement") + .value("HEAP", BufferPlacement::kHeap) + .value("DEVICE_VISIBLE", BufferPlacement::kDeviceVisible) + .value("DEVICE_LOCAL", BufferPlacement::kDeviceLocal) + .export_values(); + + // RtModule. + py::class_<RtModule>(m, "Module") + .def_property_readonly("name", &RtModule::name) + .def("lookup_function_by_ordinal", &RtModule::lookup_function_by_ordinal) + .def("lookup_function_by_name", &RtModule::lookup_function_by_name); + // RtFunction. + py::class_<RtFunction>(m, "Function") + .def_property_readonly("name", &RtFunction::name) + .def_property_readonly("signature", &RtFunction::signature); + py::class_<iree_rt_function_signature_t>(m, "FunctionSignature") + .def_readonly("argument_count", + &iree_rt_function_signature_t::argument_count) + .def_readonly("result_count", + &iree_rt_function_signature_t::result_count); + + // RtPolicy. + py::class_<RtPolicy>(m, "Policy").def(py::init(&RtPolicy::Create)); + + // RtInstance. + py::class_<RtInstance>(m, "Instance") + .def(py::init(&RtInstance::Create), + py::arg_v("driver_name", absl::optional<std::string>())); + + // RtContext. + py::class_<RtContext>(m, "Context") + .def(py::init(&RtContext::Create), py::arg("instance"), py::arg("policy")) + .def_property_readonly("context_id", &RtContext::context_id) + .def("register_modules", &RtContext::RegisterModules, py::arg("modules")) + .def("register_module", &RtContext::RegisterModule, py::arg("module")) + .def("lookup_module_by_name", &RtContext::LookupModuleByName, + py::arg("name")) + .def("resolve_function", &RtContext::ResolveFunction, + py::arg("full_name")) + .def("allocate", &RtContext::Allocate, py::arg("allocation_size"), + py::arg("placement") = BufferPlacement::kHeap, + py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL) + .def("allocate_device_local", &RtContext::AllocateDeviceLocal, + py::arg("allocation_size"), + py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL) + .def("allocate_device_visible", &RtContext::AllocateDeviceVisible, + py::arg("allocation_size"), + py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL) + .def("wrap_for_input", &RtContext::WrapPyBufferForInput, py::arg("v")) + .def("invoke", &RtContext::Invoke, py::arg("f"), py::arg("policy"), + py::arg("arguments"), + py::arg("results") = absl::optional<std::vector<HalBufferView*>>()); + + // RtInvocation. + py::class_<RtInvocation>(m, "Invocation") + .def("query_status", &RtInvocation::QueryStatus) + .def("await", &RtInvocation::Await, + py::arg("deadline") = IREE_TIME_INFINITE_FUTURE) + .def("await_optional", &RtInvocation::AwaitOptional, + py::arg("deadline") = IREE_TIME_INFINITE_FUTURE) + .def_property_readonly("results", &RtInvocation::ConsumeResults); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/rt.h b/bindings/python/pyiree/rt.h new file mode 100644 index 0000000..85e85da --- /dev/null +++ b/bindings/python/pyiree/rt.h
@@ -0,0 +1,390 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_RT_H_ + +#include "absl/container/inlined_vector.h" +#include "base/api.h" +#include "bindings/python/pyiree/binding.h" +#include "bindings/python/pyiree/hal.h" +#include "bindings/python/pyiree/initialize.h" +#include "bindings/python/pyiree/status_utils.h" +#include "hal/api.h" +#include "rt/api.h" + +namespace iree { +namespace python { + +// When creating a buffer via the context, switch between the different +// allocation entry-points via an enum (these are separate functions in the +// C API). +enum class BufferPlacement { + kHeap, + kDeviceVisible, + kDeviceLocal, +}; + +// Adapts API pointer access to retain/release API calls. +template <> +struct ApiPtrAdapter<iree_rt_module_t> { + static void Retain(iree_rt_module_t* m) { iree_rt_module_retain(m); } + static void Release(iree_rt_module_t* m) { iree_rt_module_release(m); } +}; + +template <> +struct ApiPtrAdapter<iree_rt_instance_t> { + static void Retain(iree_rt_instance_t* inst) { + iree_rt_instance_retain(inst); + } + static void Release(iree_rt_instance_t* inst) { + iree_rt_instance_release(inst); + } +}; + +template <> +struct ApiPtrAdapter<iree_rt_policy_t> { + static void Retain(iree_rt_policy_t* p) { iree_rt_policy_retain(p); } + static void Release(iree_rt_policy_t* p) { iree_rt_policy_release(p); } +}; + +template <> +struct ApiPtrAdapter<iree_rt_context_t> { + static void Retain(iree_rt_context_t* c) { iree_rt_context_retain(c); } + static void Release(iree_rt_context_t* c) { iree_rt_context_release(c); } +}; + +template <> +struct ApiPtrAdapter<iree_rt_invocation_t> { + static void Retain(iree_rt_invocation_t* inv) { + iree_rt_invocation_retain(inv); + } + static void Release(iree_rt_invocation_t* inv) { + iree_rt_invocation_release(inv); + } +}; + +// Wrapper classes. These mirror the Python declarations. +class RtFunction { + public: + // Note that this will retain the module. + RtFunction(iree_rt_function_t function) : function_(function) { + iree_rt_module_retain(function_.module); + } + ~RtFunction() { + if (function_.module) iree_rt_module_release(function_.module); + } + RtFunction(RtFunction&& other) : function_(other.function_) { + other.function_.module = nullptr; + } + void operator=(const RtFunction&) = delete; + + std::string name() { + auto sv = iree_rt_function_name(&function_); + return std::string(sv.data, sv.size); + } + + iree_rt_function_signature_t signature() { + iree_rt_function_signature_t sig; + CheckApiStatus(iree_rt_function_signature(&function_, &sig), + "Error getting function signature"); + return sig; + } + + iree_rt_function_t& raw_function() { return function_; } + + private: + iree_rt_function_t function_; +}; + +class RtModule : public ApiRefCounted<RtModule, iree_rt_module_t> { + public: + std::string name() { + auto sv = iree_rt_module_name(raw_ptr()); + return std::string(sv.data, sv.size); + } + + absl::optional<RtFunction> lookup_function_by_ordinal(int32_t ordinal) { + iree_rt_function_t f; + // TODO(laurenzo): Support an optional linkage argument. + auto module = raw_ptr(); + auto status = iree_rt_module_lookup_function_by_ordinal( + module, IREE_RT_FUNCTION_LINKAGE_EXPORT, ordinal, &f); + if (status == IREE_STATUS_NOT_FOUND) { + return absl::optional<RtFunction>(); + } + CheckApiStatus(status, "Error looking up function"); + return RtFunction(f); + } + + absl::optional<RtFunction> lookup_function_by_name(const std::string& name) { + iree_rt_function_t f; + // TODO(laurenzo): Support an optional linkage argument. + auto module = raw_ptr(); + iree_string_view_t name_sv{name.data(), name.size()}; + auto status = iree_rt_module_lookup_function_by_name( + module, IREE_RT_FUNCTION_LINKAGE_EXPORT, name_sv, &f); + if (status == IREE_STATUS_NOT_FOUND) { + return absl::optional<RtFunction>(); + } + CheckApiStatus(status, "Error looking up function"); + return RtFunction(f); + } +}; + +class RtInstance : public ApiRefCounted<RtInstance, iree_rt_instance_t> { + public: + // TODO(laurenzo): Support optional allocator argument. + static RtInstance Create(absl::optional<std::string> driver_name) { + InitializeExtension({}); + iree_rt_instance_t* raw_inst; + CheckApiStatus(iree_rt_instance_create(IREE_ALLOCATOR_DEFAULT, &raw_inst), + "Error creating instance"); + RtInstance inst = RtInstance::CreateRetained(raw_inst); + + if (!driver_name) { + driver_name = "interpreter"; + } + CheckApiStatus(iree_rt_instance_register_driver_ex( + raw_inst, iree_string_view_t{driver_name->c_str(), + driver_name->size()}), + "Error registering drivers"); + + return inst; + } +}; + +class RtPolicy : public ApiRefCounted<RtPolicy, iree_rt_policy_t> { + public: + // TODO(laurenzo): Support optional allocator argument. + static RtPolicy Create() { + iree_rt_policy_t* policy; + CheckApiStatus(iree_rt_policy_create(IREE_ALLOCATOR_DEFAULT, &policy), + "Error creating policy"); + return RtPolicy::CreateRetained(policy); + } +}; + +class RtInvocation : public ApiRefCounted<RtInvocation, iree_rt_invocation_t> { + public: + // Returns whether ready. + // Raises exception on error. + bool QueryStatus() { + auto status = iree_rt_invocation_query_status(raw_ptr()); + if (status == IREE_STATUS_OK) { + return true; + } else if (status == IREE_STATUS_UNAVAILABLE) { + return false; + } else { + CheckApiStatus(status, "Error in function invocation"); + return false; + } + } + + // TODO(laurenzo): Convert to the pybind chrono support. + // Returns whether the invocation is ready. + bool AwaitOptional(iree_time_t epoch_nanos_deadline) { + auto status = iree_rt_invocation_await(raw_ptr(), epoch_nanos_deadline); + if (status == IREE_STATUS_OK) { + return true; + } else if (status == IREE_STATUS_DEADLINE_EXCEEDED) { + return false; + } else { + CheckApiStatus(status, "Error in invocation"); + return false; + } + } + + // Similar to AwaitOptional but will raise an error unless if the status + // is ready. + void Await(iree_time_t epoch_nanos_deadline) { + if (!AwaitOptional(epoch_nanos_deadline)) { + RaiseValueError("Deadline expired"); + } + } + + std::vector<HalBufferView> ConsumeResults() { + static constexpr size_t kInlineSize = 8; + iree_host_size_t result_count; + absl::InlinedVector<iree_hal_buffer_view_t*, kInlineSize> result_bvs; + result_bvs.resize(kInlineSize); + auto status = iree_rt_invocation_consume_results( + raw_ptr(), kInlineSize, IREE_ALLOCATOR_DEFAULT, &result_bvs[0], + &result_count); + if (status == IREE_STATUS_OUT_OF_RANGE) { + // Resize/retry. + result_bvs.resize(result_count); + status = iree_rt_invocation_consume_results( + raw_ptr(), result_count, IREE_ALLOCATOR_DEFAULT, &result_bvs[0], + &result_count); + } + CheckApiStatus(status, "Error consuming invocation results"); + result_bvs.resize(result_count); + std::vector<HalBufferView> results; + for (auto* raw_bv : result_bvs) { + results.push_back(HalBufferView::CreateRetained(raw_bv)); + } + return results; + } +}; + +class RtContext : public ApiRefCounted<RtContext, iree_rt_context_t> { + public: + static RtContext Create(RtInstance* instance, RtPolicy* policy) { + iree_rt_context_t* context; + // TODO(laurenzo): Support optional allocator argument. + CheckApiStatus( + iree_rt_context_create(instance->raw_ptr(), policy->raw_ptr(), + IREE_ALLOCATOR_DEFAULT, &context), + "Error creating instance"); + return RtContext::CreateRetained(context); + } + + int context_id() { return iree_rt_context_id(raw_ptr()); } + + void RegisterModules(std::vector<RtModule*> modules) { + std::vector<iree_rt_module_t*> module_raw_ptrs; + module_raw_ptrs.resize(modules.size()); + for (size_t i = 0, e = modules.size(); i < e; ++i) { + auto module_raw_ptr = modules[i]->raw_ptr(); + module_raw_ptrs[i] = module_raw_ptr; + } + CheckApiStatus( + iree_rt_context_register_modules(raw_ptr(), module_raw_ptrs.data(), + module_raw_ptrs.size()), + "Error registering modules"); + } + + void RegisterModule(RtModule* module) { + iree_rt_module_t* module_raw_ptr = module->raw_ptr(); + CheckApiStatus( + iree_rt_context_register_modules(raw_ptr(), &module_raw_ptr, 1), + "Error registering module"); + } + + absl::optional<RtModule> LookupModuleByName(const std::string& name) { + iree_rt_module_t* module = iree_rt_context_lookup_module_by_name( + raw_ptr(), {name.data(), name.size()}); + if (!module) { + return absl::optional<RtModule>(); + } + return RtModule::RetainAndCreate(module); + } + + absl::optional<RtFunction> ResolveFunction(const std::string& full_name) { + iree_rt_function_t f; + auto status = iree_rt_context_resolve_function( + raw_ptr(), {full_name.data(), full_name.size()}, &f); + if (status == IREE_STATUS_NOT_FOUND) { + return absl::optional<RtFunction>(); + } + CheckApiStatus(status, "Error resolving function"); + return RtFunction(f); + } + + // Convenience method to allocate host, device-visible or device-local + // buffers. + HalBuffer Allocate(iree_host_size_t allocation_size, + BufferPlacement placement, int32_t usage) { + iree_hal_buffer_t* raw_buffer = nullptr; + switch (placement) { + case BufferPlacement::kHeap: + // Even though allocating a heap buffer does not require the context, + // provide it here to make the API easier to navigate. + CheckApiStatus( + iree_hal_heap_buffer_allocate( + IREE_HAL_MEMORY_TYPE_HOST_LOCAL, + static_cast<iree_hal_buffer_usage_t>(usage), allocation_size, + IREE_ALLOCATOR_DEFAULT, IREE_ALLOCATOR_DEFAULT, &raw_buffer), + "Error allocating heap buffer"); + break; + case BufferPlacement::kDeviceLocal: + CheckApiStatus( + iree_rt_context_allocate_device_local_buffer( + raw_ptr(), static_cast<iree_hal_buffer_usage_t>(usage), + allocation_size, IREE_ALLOCATOR_DEFAULT, &raw_buffer), + "Error allocating device local buffer"); + break; + case BufferPlacement::kDeviceVisible: + CheckApiStatus( + iree_rt_context_allocate_device_visible_buffer( + raw_ptr(), static_cast<iree_hal_buffer_usage_t>(usage), + allocation_size, IREE_ALLOCATOR_DEFAULT, &raw_buffer), + "Error allocating device visible buffer"); + break; + default: + throw RaiseValueError("Unknown BufferPlacement"); + } + + return HalBuffer::CreateRetained(raw_buffer); + } + + HalBuffer AllocateHeap(iree_host_size_t allocation_size, int32_t usage) { + return Allocate(allocation_size, BufferPlacement::kHeap, usage); + } + + HalBuffer AllocateDeviceLocal(iree_host_size_t allocation_size, + int32_t usage) { + return Allocate(allocation_size, BufferPlacement::kDeviceLocal, usage); + } + + HalBuffer AllocateDeviceVisible(iree_host_size_t allocation_size, + int32_t usage) { + return Allocate(allocation_size, BufferPlacement::kDeviceVisible, usage); + } + + // One stop convenience method for wrapping a python buffer protocol buffer + // for input to a function. At the runtime's discretion, this may make a copy + // or do something smarter, meaning the data in the backing python buffer + // will either be accessed immediately or at some future point. + HalBufferView WrapPyBufferForInput(py::buffer py_buffer); + + RtInvocation Invoke(RtFunction& f, RtPolicy& policy, + std::vector<HalBufferView*> arguments, + absl::optional<std::vector<HalBufferView*>> results) { + absl::InlinedVector<iree_hal_buffer_view_t*, 8> raw_arguments; + raw_arguments.resize(arguments.size()); + for (size_t i = 0, e = arguments.size(); i < e; ++i) { + auto inst = arguments[i]; + CheckApiNotNull(inst, "Argument buffer view cannot be None"); + raw_arguments[i] = inst->raw_ptr(); + } + absl::InlinedVector<iree_hal_buffer_view_t*, 8> raw_results; + if (results) { + raw_results.resize(results->size()); + for (size_t i = 0, e = results->size(); i < e; ++i) { + auto inst = (*results)[i]; + CheckApiNotNull(inst, "Result buffer view cannot be None"); + raw_results[i] = inst->raw_ptr(); + } + } + + iree_rt_invocation_t* invocation; + CheckApiStatus(iree_rt_invocation_create( + raw_ptr(), &f.raw_function(), policy.raw_ptr(), + nullptr /* dependencies */, raw_arguments.data(), + raw_arguments.size(), raw_results.data(), + raw_results.size(), IREE_ALLOCATOR_DEFAULT, &invocation), + "Error invoking function"); + + return RtInvocation::CreateRetained(invocation); + } +}; + +void SetupRtBindings(pybind11::module m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_H_
diff --git a/bindings/python/pyiree/runtime_test.py b/bindings/python/pyiree/runtime_test.py new file mode 100644 index 0000000..6935b87 --- /dev/null +++ b/bindings/python/pyiree/runtime_test.py
@@ -0,0 +1,128 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +import numpy as np +from pyiree import binding as binding + + +def create_simple_mul_module(): + blob = binding.compiler.compile_module_from_asm(""" + func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> + attributes { iree.module.export } { + %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> + } + """) + m = binding.vm.create_module_from_blob(blob) + return m + + +def create_host_buffer_view(context): + b = context.allocate_device_visible(16) + b.fill_zero(0, 16) + bv = b.create_view(binding.hal.Shape([4]), 4) + print("BUFFER VIEW:", bv) + return bv + + +class RuntimeTest(absltest.TestCase): + + def testModuleAndFunction(self): + m = create_simple_mul_module() + print("Module:", m) + print("Module name:", m.name) + self.assertEqual("module", m.name) + + # Function 0. + f = m.lookup_function_by_ordinal(0) + print("Function 0:", f) + self.assertEqual("simple_mul", f.name) + sig = f.signature + self.assertEqual(2, sig.argument_count) + self.assertEqual(1, sig.result_count) + + # Function 1. + f = m.lookup_function_by_ordinal(1) + self.assertIs(f, None) + + # By name. + f = m.lookup_function_by_name("simple_mul") + self.assertEqual("simple_mul", f.name) + sig = f.signature + self.assertEqual(2, sig.argument_count) + self.assertEqual(1, sig.result_count) + + # By name not found. + f = m.lookup_function_by_name("not_here") + self.assertIs(f, None) + + def testInitialization(self): + policy = binding.rt.Policy() + print("policy =", policy) + instance = binding.rt.Instance() + print("instance =", instance) + context = binding.rt.Context(instance=instance, policy=policy) + print("context =", context) + context_id = context.context_id + print("context_id =", context.context_id) + self.assertGreater(context_id, 0) + + def testRegisterModule(self): + policy = binding.rt.Policy() + instance = binding.rt.Instance() + context = binding.rt.Context(instance=instance, policy=policy) + m = create_simple_mul_module() + context.register_module(m) + self.assertIsNot(context.lookup_module_by_name("module"), None) + self.assertIs(context.lookup_module_by_name("nothere"), None) + f = context.resolve_function("module.simple_mul") + self.assertIsNot(f, None) + print("Resolved function:", f.name) + self.assertIs(context.resolve_function("module.nothere"), None) + + def testInvoke(self): + policy = binding.rt.Policy() + instance = binding.rt.Instance() + context = binding.rt.Context(instance=instance, policy=policy) + m = create_simple_mul_module() + context.register_module(m) + f = context.resolve_function("module.simple_mul") + print("INVOKE F:", f) + arg0 = context.wrap_for_input(np.array([1., 2., 3., 4.], dtype=np.float32)) + arg1 = context.wrap_for_input(np.array([4., 5., 6., 7.], dtype=np.float32)) + + inv = context.invoke(f, policy, [arg0, arg1]) + print("Status:", inv.query_status()) + inv.await() + results = inv.results + print("Results:", results) + result = results[0].map() + print("Mapped result:", result) + result_ary = np.array(result, copy=False) + print("NP result:", result_ary) + self.assertEqual(4., result_ary[0]) + self.assertEqual(10., result_ary[1]) + self.assertEqual(18., result_ary[2]) + self.assertEqual(28., result_ary[3]) + + +if __name__ == "__main__": + # Uncomment to initialize the extension with custom flags. + binding.initialize_extension(["--logtostderr"]) + absltest.main()
diff --git a/bindings/python/pyiree/status_utils.cc b/bindings/python/pyiree/status_utils.cc new file mode 100644 index 0000000..63f2131 --- /dev/null +++ b/bindings/python/pyiree/status_utils.cc
@@ -0,0 +1,72 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/status_utils.h" + +#include "absl/strings/str_cat.h" + +namespace iree { +namespace python { + +namespace { + +PyObject* StatusToPyExcClass(const Status& status) { + switch (status.code()) { + case StatusCode::kInvalidArgument: + return PyExc_ValueError; + case StatusCode::kOutOfRange: + return PyExc_IndexError; + case StatusCode::kUnimplemented: + return PyExc_NotImplementedError; + default: + return PyExc_RuntimeError; + } +} + +PyObject* ApiStatusToPyExcClass(iree_status_t status) { + switch (status) { + case IREE_STATUS_INVALID_ARGUMENT: + return PyExc_ValueError; + case IREE_STATUS_OUT_OF_RANGE: + return PyExc_IndexError; + case IREE_STATUS_UNIMPLEMENTED: + return PyExc_NotImplementedError; + default: + return PyExc_RuntimeError; + } +} + +} // namespace + +pybind11::error_already_set StatusToPyExc(const Status& status) { + assert(!status.ok()); + PyErr_SetString(StatusToPyExcClass(status), status.error_message().c_str()); + return pybind11::error_already_set(); +} + +pybind11::error_already_set ApiStatusToPyExc(iree_status_t status, + const char* message) { + assert(status != IREE_STATUS_OK); + auto full_message = absl::StrCat(message, ": ", static_cast<int>(status)); + PyErr_SetString(ApiStatusToPyExcClass(status), full_message.c_str()); + return pybind11::error_already_set(); +} + +pybind11::error_already_set RaiseValueError(const char* message) { + PyErr_SetString(PyExc_ValueError, message); + return pybind11::error_already_set(); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/status_utils.h b/bindings/python/pyiree/status_utils.h new file mode 100644 index 0000000..ef89ba8 --- /dev/null +++ b/bindings/python/pyiree/status_utils.h
@@ -0,0 +1,67 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_ + +#include "base/api.h" +#include "base/status.h" +#include "pybind11/pytypes.h" + +namespace iree { +namespace python { + +// Converts a failing status to a throwable exception, setting Python +// error information. +// Correct usage is something like: +// if (!status.ok()) { +// throw StatusToPyExc(status); +// } +pybind11::error_already_set StatusToPyExc(const Status& status); + +// Raises a value error with the given message. +// Correct usage: +// throw RaiseValueError("Foobar'd"); +pybind11::error_already_set RaiseValueError(const char* message); + +// Consumes a StatusOr<T>, returning an rvalue reference to the T if the +// status is ok(). Otherwise, throws an exception. +template <typename T> +T&& PyConsumeStatusOr(iree::StatusOr<T>&& sor) { + if (sor.ok()) { + return std::move(*sor); + } + throw StatusToPyExc(sor.status()); +} + +pybind11::error_already_set ApiStatusToPyExc(iree_status_t status, + const char* message); + +static void CheckApiStatus(iree_status_t status, const char* message) { + if (status == IREE_STATUS_OK) { + return; + } + throw ApiStatusToPyExc(status, message); +} + +static void CheckApiNotNull(const void* p, const char* message) { + if (!p) { + throw RaiseValueError(message); + } +} + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/vm.cc new file mode 100644 index 0000000..3d03d97 --- /dev/null +++ b/bindings/python/pyiree/vm.cc
@@ -0,0 +1,37 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/vm.h" + +#include "bindings/python/pyiree/status_utils.h" + +namespace iree { +namespace python { + +RtModule CreateModuleFromBlob(std::shared_ptr<OpaqueBlob> blob) { + iree_rt_module_t* module; + auto free_fn = OpaqueBlob::CreateFreeFn(blob); + auto status = iree_vm_bytecode_module_create_from_buffer( + {static_cast<const uint8_t*>(blob->data()), blob->size()}, free_fn.first, + free_fn.second, IREE_ALLOCATOR_DEFAULT, &module); + CheckApiStatus(status, "Error creating vm module from blob"); + return RtModule::CreateRetained(module); +} + +void SetupVmBindings(pybind11::module m) { + m.def("create_module_from_blob", CreateModuleFromBlob); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/vm.h b/bindings/python/pyiree/vm.h new file mode 100644 index 0000000..e7338cc --- /dev/null +++ b/bindings/python/pyiree/vm.h
@@ -0,0 +1,30 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_VM_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_VM_H_ + +#include "bindings/python/pyiree/binding.h" +#include "bindings/python/pyiree/rt.h" +#include "vm/api.h" + +namespace iree { +namespace python { + +void SetupVmBindings(pybind11::module m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_VM_H_