Update CMake min to 3.12
STRING JOIN was introduced in 3.12
TEST=builds with cmake 3.12 fails with 3.10.2 (stock in ubuntu)
Latest Ubuntu PPAs are here: https://apt.kitware.com/
--
e65d4dfd0278185962071d30359fc0ce1232be9e by Anush Elangovan <anush@nod-labs.com>:
Fix suprious reference to DLOG
TEST:builds
Closes #89
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/89 from powderluv:fix-cmake-ver e65d4dfd0278185962071d30359fc0ce1232be9e
PiperOrigin-RevId: 276298047
diff --git a/BUILD.bazel b/BUILD.bazel
deleted file mode 100644
index 2b86f73..0000000
--- a/BUILD.bazel
+++ /dev/null
@@ -1,4 +0,0 @@
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
diff --git a/BUILD.bazel.oss b/BUILD.bazel.oss
new file mode 100644
index 0000000..f1f9d53
--- /dev/null
+++ b/BUILD.bazel.oss
@@ -0,0 +1,3 @@
+package(
+ licenses = ["notice"], # Apache 2.0
+)
diff --git a/BUILD.oss b/BUILD.oss
new file mode 100644
index 0000000..f5c6905
--- /dev/null
+++ b/BUILD.oss
@@ -0,0 +1,30 @@
+# Main IREE build file.
+# Note that project-wide, bazel repo aliases are used:
+# "//third_party/absl/python"
+# "//third_party/absl"
+# "//third_party/benchmark"
+# "//third_party/llvm/llvm/projects/google_mlir"
+# "//third_party/llvm/llvm"
+# "//third_party/flatbuffers"
+# "//third_party/tensorflow"
+#
+# Various scripts and helpers operate on these prefixes textually, so
+# avoid doing any systematic construction that would break the matching.
+
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+# Enables the debug service and other profiling features.
+# $ bazel build --define=IREE_DEBUG=1 :some_target
+config_setting(
+ name = "debug",
+ define_values = {"IREE_DEBUG": "1"},
+)
+
+# Marker library which can be extended to provide flags for things that
+# need to know the platform target.
+cc_library(
+ name = "target_config",
+ defines = ["IREE_UNSPECIFIED_TARGET=1"],
+)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index fbf93d0..ac138f0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-cmake_minimum_required(VERSION 3.5)
+cmake_minimum_required(VERSION 3.12)
cmake_policy(SET CMP0077 NEW)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
diff --git a/base/BUILD b/base/BUILD
new file mode 100644
index 0000000..80bce9f
--- /dev/null
+++ b/base/BUILD
@@ -0,0 +1,362 @@
+# Common types and utilities used in the IREE codebase.
+
+load("//:build_defs.google.bzl", "platform_trampoline_deps")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "api",
+ srcs = ["api.cc"],
+ hdrs = ["api.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":api_hdrs",
+ ":api_util",
+ ":file_mapping",
+ ":tracing",
+ ],
+)
+
+cc_library(
+ name = "api_hdrs",
+ hdrs = ["api.h"],
+)
+
+cc_library(
+ name = "api_util",
+ hdrs = ["api_util.h"],
+ deps = [
+ ":api_hdrs",
+ ":logging",
+ ":shape",
+ ":status",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "arena",
+ srcs = ["arena.cc"],
+ hdrs = ["arena.h"],
+ deps = [
+ ":logging",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_test(
+ name = "arena_test",
+ srcs = ["arena_test.cc"],
+ deps = [
+ ":arena",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "bitfield",
+ hdrs = ["bitfield.h"],
+ deps = [
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_test(
+ name = "bitfield_test",
+ srcs = ["bitfield_test.cc"],
+ deps = [
+ ":bitfield",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "file_io",
+ hdrs = ["file_io.h"],
+ deps = [
+ ":status",
+ ":target_platform",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ] + platform_trampoline_deps("file_io"),
+)
+
+cc_library(
+ name = "file_io_hdrs",
+ hdrs = ["file_io.h"],
+ deps = [":status"],
+)
+
+cc_library(
+ name = "file_mapping",
+ hdrs = ["file_mapping.h"],
+ deps = [
+ ":ref_ptr",
+ ":status",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ] + platform_trampoline_deps("file_mapping"),
+)
+
+cc_library(
+ name = "file_mapping_hdrs",
+ hdrs = ["file_mapping.h"],
+ deps = [
+ ":ref_ptr",
+ ":status",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "file_path",
+ srcs = ["file_path.cc"],
+ hdrs = ["file_path.h"],
+ deps = [
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "flatbuffer_util",
+ srcs = ["flatbuffer_util.cc"],
+ hdrs = ["flatbuffer_util.h"],
+ deps = [
+ ":file_mapping",
+ ":memory",
+ ":source_location",
+ ":status",
+ ":tracing",
+ "@com_github_google_flatbuffers//:flatbuffers",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "init",
+ hdrs = ["init.h"],
+ deps = platform_trampoline_deps("init"),
+)
+
+cc_library(
+ name = "intrusive_list",
+ hdrs = [
+ "intrusive_list.h",
+ "intrusive_list_ref_ptr.inc",
+ "intrusive_list_unique_ptr.inc",
+ ],
+ deps = [
+ ":logging",
+ ":ref_ptr",
+ ],
+)
+
+cc_test(
+ name = "intrusive_list_test",
+ srcs = [
+ "intrusive_list_ref_ptr_test.cc",
+ "intrusive_list_test.cc",
+ "intrusive_list_unique_ptr_test.cc",
+ ],
+ deps = [
+ ":intrusive_list",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "logging",
+ hdrs = ["logging.h"],
+ deps = platform_trampoline_deps("logging"),
+)
+
+cc_library(
+ name = "math",
+ hdrs = ["math.h"],
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
+cc_library(
+ name = "memory",
+ hdrs = ["memory.h"],
+ deps = [
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "ref_ptr",
+ hdrs = ["ref_ptr.h"],
+ deps = [
+ ":logging",
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
+cc_test(
+ name = "ref_ptr_test",
+ size = "small",
+ srcs = ["ref_ptr_test.cc"],
+ deps = [
+ ":ref_ptr",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "shape",
+ srcs = ["shape.cc"],
+ hdrs = ["shape.h"],
+ deps = [
+ ":logging",
+ ":source_location",
+ ":status",
+ "@com_google_absl//absl/meta:type_traits",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_test(
+ name = "shape_test",
+ srcs = ["shape_test.cc"],
+ deps = [
+ ":shape",
+ ":status",
+ ":status_matchers",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "source_location",
+ hdrs = ["source_location.h"],
+ deps = platform_trampoline_deps("source_location"),
+)
+
+cc_library(
+ name = "status",
+ hdrs = ["status.h"],
+ deps = [
+ ":source_location",
+ ] + platform_trampoline_deps("status"),
+)
+
+cc_library(
+ name = "status_matchers",
+ testonly = 1,
+ hdrs = ["status_matchers.h"],
+ deps = platform_trampoline_deps("status_matchers"),
+)
+
+cc_library(
+ name = "target_platform",
+ hdrs = ["target_platform.h"],
+)
+
+cc_library(
+ name = "time",
+ hdrs = ["time.h"],
+ deps = [
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "tracing",
+ hdrs = ["tracing.h"],
+ deps = [
+ "//:target_config",
+ "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
+ ] + select({
+ "@com_google_tracing_framework_cpp//:wtf_enable": [":tracing_enabled"],
+ "//conditions:default": [":tracing_disabled"],
+ }),
+)
+
+cc_library(
+ name = "tracing_disabled",
+ srcs = [
+ "tracing.h",
+ "tracing_disabled.cc",
+ ],
+ visibility = ["//visibility:private"],
+ deps = [
+ ":init",
+ ":logging",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "tracing_enabled",
+ srcs = [
+ "tracing.cc",
+ "tracing.h",
+ ],
+ visibility = ["//visibility:private"],
+ deps = [
+ ":file_io",
+ ":file_path",
+ ":init",
+ ":logging",
+ ":status",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
+ ],
+ alwayslink = 1,
+)
+
+# Dependent code has been removed and wait_handle is currently incompatible
+# with Windows, so excluding entirely.
+# See google/iree/65
+# cc_library(
+# name = "wait_handle",
+# srcs = ["wait_handle.cc"],
+# hdrs = ["wait_handle.h"],
+# deps = [
+# ":logging",
+# ":ref_ptr",
+# ":source_location",
+# ":status",
+# ":time",
+# "@com_google_absl//absl/base:core_headers",
+# "@com_google_absl//absl/container:fixed_array",
+# "@com_google_absl//absl/strings",
+# "@com_google_absl//absl/time",
+# "@com_google_absl//absl/types:span",
+# ],
+# )
+
+# cc_test(
+# name = "wait_handle_test",
+# srcs = ["wait_handle_test.cc"],
+# deps = [
+# ":status",
+# ":status_matchers",
+# ":wait_handle",
+# "@com_google_absl//absl/time",
+# "@com_google_googletest//:gtest_main",
+# ],
+# )
diff --git a/iree/base/CMakeLists.txt b/base/CMakeLists.txt
similarity index 100%
rename from iree/base/CMakeLists.txt
rename to base/CMakeLists.txt
diff --git a/base/api.cc b/base/api.cc
new file mode 100644
index 0000000..c187a34
--- /dev/null
+++ b/base/api.cc
@@ -0,0 +1,124 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/api.h"
+
+#include <cstdlib>
+#include <string>
+
+#include "base/api_util.h"
+#include "base/file_mapping.h"
+#include "base/tracing.h"
+
+namespace iree {
+
+//===----------------------------------------------------------------------===//
+// iree Core API
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_api_version_check(iree_api_version_t expected_version,
+ iree_api_version_t* out_actual_version) {
+ iree_api_version_t actual_version = IREE_API_VERSION_0;
+ *out_actual_version = actual_version;
+ return expected_version == actual_version ? IREE_STATUS_OK
+ : IREE_STATUS_OUT_OF_RANGE;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_allocator_alloc(void* self, iree_host_size_t byte_length, void** out_ptr) {
+ IREE_TRACE_SCOPE0("iree_allocator_alloc");
+
+ if (!out_ptr) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_ptr = nullptr;
+
+ if (byte_length <= 0) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ *out_ptr = std::malloc(byte_length);
+ if (!*out_ptr) {
+ return IREE_STATUS_RESOURCE_EXHAUSTED;
+ }
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_allocator_free(void* self,
+ void* ptr) {
+ IREE_TRACE_SCOPE0("iree_allocator_free");
+ if (ptr) {
+ std::free(ptr);
+ }
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::FileMapping
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_file_mapping_open_read(iree_string_view_t path, iree_allocator_t allocator,
+ iree_file_mapping_t** out_file_mapping) {
+ IREE_TRACE_SCOPE0("iree_file_mapping_open_read");
+
+ if (!out_file_mapping) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_file_mapping = nullptr;
+
+ IREE_API_ASSIGN_OR_RETURN(
+ auto file_mapping,
+ FileMapping::OpenRead(std::string(path.data, path.size)));
+
+ *out_file_mapping =
+ reinterpret_cast<iree_file_mapping_t*>(file_mapping.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_file_mapping_retain(iree_file_mapping_t* file_mapping) {
+ IREE_TRACE_SCOPE0("iree_file_mapping_retain");
+ auto* handle = reinterpret_cast<FileMapping*>(file_mapping);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_file_mapping_release(iree_file_mapping_t* file_mapping) {
+ IREE_TRACE_SCOPE0("iree_file_mapping_release");
+ auto* handle = reinterpret_cast<FileMapping*>(file_mapping);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_byte_span_t IREE_API_CALL
+iree_file_mapping_data(iree_file_mapping_t* file_mapping) {
+ IREE_TRACE_SCOPE0("iree_file_mapping_data");
+ auto* handle = reinterpret_cast<FileMapping*>(file_mapping);
+ CHECK(handle) << "NULL file_mapping handle";
+ auto data = handle->data();
+ return {const_cast<uint8_t*>(data.data()), data.size()};
+}
+
+} // namespace iree
diff --git a/iree/base/api.h b/base/api.h
similarity index 100%
rename from iree/base/api.h
rename to base/api.h
diff --git a/base/api_util.h b/base/api_util.h
new file mode 100644
index 0000000..1f750ad
--- /dev/null
+++ b/base/api_util.h
@@ -0,0 +1,125 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_API_UTIL_H_
+#define IREE_BASE_API_UTIL_H_
+
+#include "absl/base/macros.h"
+#include "absl/time/time.h"
+#include "base/api.h"
+#include "base/logging.h"
+#include "base/shape.h"
+#include "base/status.h"
+
+namespace iree {
+
+inline iree_status_t ToApiStatus(Status status) {
+ return static_cast<iree_status_t>(status.code());
+}
+
+inline Status FromApiStatus(iree_status_t status_code, SourceLocation loc) {
+ return StatusBuilder(static_cast<StatusCode>(status_code), loc);
+}
+
+// Internal helper for concatenating macro values.
+#define IREE_API_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
+#define IREE_API_STATUS_MACROS_IMPL_CONCAT_(x, y) \
+ IREE_API_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y)
+
+// clang-format off
+#define IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT
+// clang-format on
+
+namespace status_macro_internal {
+class StatusAdaptorForApiMacros {
+ public:
+ StatusAdaptorForApiMacros(const Status& status) : status_(status) {}
+ StatusAdaptorForApiMacros(Status&& status) : status_(std::move(status)) {}
+ StatusAdaptorForApiMacros(const StatusAdaptorForApiMacros&) = delete;
+ StatusAdaptorForApiMacros& operator=(const StatusAdaptorForApiMacros&) =
+ delete;
+ explicit operator bool() const { return ABSL_PREDICT_TRUE(status_.ok()); }
+ Status&& Consume() { return std::move(status_); }
+
+ private:
+ Status status_;
+};
+} // namespace status_macro_internal
+
+#define IREE_API_RETURN_IF_ERROR(expr) \
+ IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
+ if (::iree::status_macro_internal::StatusAdaptorForApiMacros \
+ status_adaptor = {expr}) { \
+ } else /* NOLINT */ \
+ return ::iree::ToApiStatus(status_adaptor.Consume())
+
+#define IREE_API_RETURN_IF_API_ERROR(expr) \
+ IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
+ if (iree_status_t status = (expr)) { \
+ return status; \
+ }
+
+#define IREE_API_ASSIGN_OR_RETURN(...) \
+ IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_( \
+ (__VA_ARGS__, IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \
+ IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \
+ (__VA_ARGS__)
+
+#define IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, \
+ ...) \
+ NAME
+#define IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_(args) \
+ IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args
+
+#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \
+ IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, std::move(_))
+#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, \
+ error_expression) \
+ IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \
+ IREE_API_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, \
+ rexpr, error_expression)
+#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \
+ error_expression) \
+ auto statusor = (rexpr); \
+ if (ABSL_PREDICT_FALSE(!statusor.ok())) { \
+ return ::iree::ToApiStatus(std::move(statusor).status()); \
+ } \
+ lhs = std::move(statusor).ValueOrDie()
+
+// Converts an iree_time_t to its equivalent absl::Time.
+inline absl::Time ToAbslTime(iree_time_t time) {
+ if (time == IREE_TIME_INFINITE_PAST) {
+ return absl::InfinitePast();
+ } else if (time == IREE_TIME_INFINITE_FUTURE) {
+ return absl::InfiniteFuture();
+ } else {
+ return absl::FromUnixNanos(time);
+ }
+}
+
+// Converts a Shape to an iree_shape_t.
+inline iree_status_t ToApiShape(const Shape& shape, iree_shape_t* out_shape) {
+ out_shape->rank = shape.size();
+ if (shape.size() > ABSL_ARRAYSIZE(out_shape->dims)) {
+ return IREE_STATUS_OUT_OF_RANGE;
+ }
+ for (int i = 0; i < out_shape->rank; ++i) {
+ out_shape->dims[i] = shape[i];
+ }
+ return IREE_STATUS_OK;
+}
+
+} // namespace iree
+
+#endif // IREE_BASE_API_UTIL_H_
diff --git a/base/arena.cc b/base/arena.cc
new file mode 100644
index 0000000..e51aab4
--- /dev/null
+++ b/base/arena.cc
@@ -0,0 +1,125 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/arena.h"
+
+#include <memory>
+
+#include "absl/base/attributes.h"
+#include "base/logging.h"
+
+namespace iree {
+
+namespace {
+
+// Rounds up to the next alignment value, if it is not already aligned.
+template <typename T>
+ABSL_ATTRIBUTE_ALWAYS_INLINE constexpr T RoundToAlignment(
+ T value, T alignment) noexcept {
+ return ((value + alignment - 1) / alignment) * alignment;
+}
+
+} // namespace
+
+Arena::Arena(size_t block_size) : block_size_(block_size) {}
+
+Arena::~Arena() { Clear(); }
+
+void Arena::Clear() {
+ // Deallocate all memory.
+ auto block_header = block_list_head_;
+ while (block_header) {
+ auto next_block = block_header->next_block;
+ std::free(block_header);
+ block_header = next_block;
+ }
+ block_list_head_ = nullptr;
+ block_header = unused_block_list_head_;
+ while (block_header) {
+ auto next_block = block_header->next_block;
+ std::free(block_header);
+ block_header = next_block;
+ }
+ unused_block_list_head_ = nullptr;
+
+ bytes_allocated_ = 0;
+ block_bytes_allocated_ = 0;
+}
+
+void Arena::Reset() {
+ // Move all blocks to the unused list and reset allocation count only.
+ auto block_header = block_list_head_;
+ while (block_header) {
+ auto next_block = block_header->next_block;
+ block_header->bytes_allocated = 0;
+ block_header->next_block = unused_block_list_head_;
+ unused_block_list_head_ = block_header;
+ block_header = next_block;
+ }
+ block_list_head_ = nullptr;
+
+ bytes_allocated_ = 0;
+}
+
+uint8_t* Arena::AllocateBytes(size_t length) {
+ if (!length) {
+ // Guarantee zero-length allocations return nullptr.
+ return nullptr;
+ }
+
+ // Pad length allocated so we are machine word aligned.
+ // This ensures the next allocation starts at the right boundary.
+ size_t aligned_length = RoundToAlignment(length, sizeof(uintptr_t));
+
+ if (aligned_length > block_size_) {
+ // This allocation is larger than an entire block. That's bad.
+ // We could allocate this with malloc (and then keep track of those to free
+ // things), but for now let's just die.
+ CHECK(false);
+ return nullptr;
+ }
+
+ if (!block_list_head_ ||
+ block_list_head_->bytes_allocated + aligned_length > block_size_) {
+ // Check to see if we have an existing unused block we can use.
+ if (unused_block_list_head_) {
+ // Move block from unused list to main list.
+ auto block_header = unused_block_list_head_;
+ unused_block_list_head_ = block_header->next_block;
+ block_header->next_block = block_list_head_;
+ block_header->bytes_allocated = 0;
+ block_list_head_ = block_header;
+ } else {
+ // Allocate a new block.
+ auto block_ptr = reinterpret_cast<uint8_t*>(
+ std::malloc(sizeof(BlockHeader) + block_size_));
+ auto block_header = reinterpret_cast<BlockHeader*>(block_ptr);
+ block_header->next_block = block_list_head_;
+ block_header->bytes_allocated = 0;
+ block_list_head_ = block_header;
+ block_bytes_allocated_ += sizeof(BlockHeader) + block_size_;
+ }
+ }
+
+ BlockHeader* target_block = block_list_head_;
+ auto data_ptr = reinterpret_cast<uint8_t*>(target_block) +
+ sizeof(BlockHeader) + target_block->bytes_allocated;
+ target_block->bytes_allocated += aligned_length;
+
+ bytes_allocated_ += length;
+
+ return data_ptr;
+}
+
+} // namespace iree
diff --git a/iree/base/arena.h b/base/arena.h
similarity index 100%
rename from iree/base/arena.h
rename to base/arena.h
diff --git a/base/arena_test.cc b/base/arena_test.cc
new file mode 100644
index 0000000..95b557f
--- /dev/null
+++ b/base/arena_test.cc
@@ -0,0 +1,148 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/arena.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace {
+
+// Tests basic block allocations.
+TEST(ArenaTest, BasicAllocation) {
+ Arena arena(64);
+ EXPECT_EQ(64, arena.block_size());
+ EXPECT_EQ(0, arena.bytes_allocated());
+ EXPECT_EQ(0, arena.block_bytes_allocated());
+
+ // Zero byte allocations should return nullptr and not allocate bytes.
+ auto zero_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(0));
+ EXPECT_EQ(0, zero_ptr);
+ EXPECT_EQ(0, arena.bytes_allocated());
+ EXPECT_EQ(0, arena.block_bytes_allocated());
+
+ arena.Clear();
+
+ // Allocations must be machine word aligned.
+ auto one_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(1));
+ EXPECT_NE(0, one_ptr);
+ EXPECT_EQ(0, one_ptr % sizeof(uintptr_t));
+ one_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(1));
+ EXPECT_NE(0, one_ptr);
+ EXPECT_EQ(0, one_ptr % sizeof(uintptr_t));
+ EXPECT_EQ(2, arena.bytes_allocated());
+ EXPECT_LT(2, arena.block_bytes_allocated());
+
+ arena.Clear();
+ EXPECT_EQ(0, arena.bytes_allocated());
+ EXPECT_EQ(0, arena.block_bytes_allocated());
+}
+
+// Tests typed allocations.
+TEST(ArenaTest, TypedAllocations) {
+ Arena arena(64);
+
+ EXPECT_NE(nullptr, arena.Allocate<int>());
+ EXPECT_EQ(4, arena.bytes_allocated());
+ EXPECT_EQ(64 + Arena::kBlockOverhead, arena.block_bytes_allocated());
+ arena.Clear();
+ EXPECT_EQ(0, arena.bytes_allocated());
+ EXPECT_EQ(0, arena.block_bytes_allocated());
+
+ struct MyType {
+ MyType() {}
+ explicit MyType(int initial_value) : value(initial_value) {}
+
+ int value = 5;
+ };
+ auto my_type_ptr = arena.Allocate<MyType>();
+ EXPECT_NE(nullptr, my_type_ptr);
+ EXPECT_EQ(sizeof(MyType), arena.bytes_allocated());
+ EXPECT_EQ(5, my_type_ptr->value); // Default ctor must be called.
+ arena.Clear();
+ EXPECT_EQ(0, arena.bytes_allocated());
+ EXPECT_EQ(0, arena.block_bytes_allocated());
+
+ my_type_ptr = arena.Allocate<MyType>(10);
+ EXPECT_NE(nullptr, my_type_ptr);
+ EXPECT_EQ(sizeof(MyType), arena.bytes_allocated());
+ EXPECT_EQ(10, my_type_ptr->value); // Ctor should have been called.
+ arena.Clear();
+ EXPECT_EQ(0, arena.bytes_allocated());
+ EXPECT_EQ(0, arena.block_bytes_allocated());
+}
+
+// Tests multiple blocks.
+TEST(ArenaTest, MultipleBlocks) {
+ Arena arena(16);
+ EXPECT_EQ(0, arena.bytes_allocated());
+ EXPECT_EQ(0, arena.block_bytes_allocated());
+
+ // Allocate one entire block.
+ EXPECT_NE(nullptr, arena.AllocateBytes(16));
+ EXPECT_EQ(16, arena.bytes_allocated());
+ EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated());
+
+ // Allocate into the next block.
+ EXPECT_NE(nullptr, arena.AllocateBytes(16));
+ EXPECT_EQ(32, arena.bytes_allocated());
+ EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
+
+ // Clear.
+ arena.Clear();
+ EXPECT_EQ(0, arena.bytes_allocated());
+ EXPECT_EQ(0, arena.block_bytes_allocated());
+
+ // Allocate again.
+ EXPECT_NE(nullptr, arena.AllocateBytes(16));
+ EXPECT_EQ(16, arena.bytes_allocated());
+ EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated());
+ EXPECT_NE(nullptr, arena.AllocateBytes(16));
+ EXPECT_EQ(32, arena.bytes_allocated());
+ EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
+}
+
+// Tests fast reset.
+TEST(ArenaTest, FastReset) {
+ Arena arena(16);
+ EXPECT_EQ(0, arena.bytes_allocated());
+ EXPECT_EQ(0, arena.block_bytes_allocated());
+
+ // Allocate one entire block.
+ EXPECT_NE(nullptr, arena.AllocateBytes(16));
+ EXPECT_EQ(16, arena.bytes_allocated());
+ EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated());
+
+ // Allocate into the next block.
+ EXPECT_NE(nullptr, arena.AllocateBytes(16));
+ EXPECT_EQ(32, arena.bytes_allocated());
+ EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
+
+ // Reset (without deallocating).
+ arena.Reset();
+ EXPECT_EQ(0, arena.bytes_allocated());
+ EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
+
+ // Allocate again.
+ EXPECT_NE(nullptr, arena.AllocateBytes(16));
+ EXPECT_EQ(16, arena.bytes_allocated());
+ EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
+ EXPECT_NE(nullptr, arena.AllocateBytes(16));
+ EXPECT_EQ(32, arena.bytes_allocated());
+ EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
+}
+
+} // namespace
+} // namespace iree
diff --git a/iree/base/bitfield.h b/base/bitfield.h
similarity index 100%
rename from iree/base/bitfield.h
rename to base/bitfield.h
diff --git a/base/bitfield_test.cc b/base/bitfield_test.cc
new file mode 100644
index 0000000..80ead1a
--- /dev/null
+++ b/base/bitfield_test.cc
@@ -0,0 +1,82 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/bitfield.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+namespace iree {
+
+// NOTE: define here so that we don't get internal linkage warnings.
+enum class MyValue : uint32_t {
+ kNone = 0,
+ kA = 1 << 0,
+ kB = 1 << 1,
+ kAll = kA | kB,
+};
+IREE_BITFIELD(MyValue);
+
+namespace {
+
+// Tests general usage.
+TEST(BitfieldTest, FormatBitfieldValue) {
+ std::vector<std::pair<MyValue, const char *>> mappings = {
+ {MyValue::kA, "kA"},
+ {MyValue::kB, "kB"},
+ };
+ EXPECT_EQ("",
+ FormatBitfieldValue(MyValue::kNone, absl::MakeConstSpan(mappings)));
+ EXPECT_EQ("kA",
+ FormatBitfieldValue(MyValue::kA, absl::MakeConstSpan(mappings)));
+ EXPECT_EQ("kA|kB", FormatBitfieldValue(MyValue::kA | MyValue::kB,
+ absl::MakeConstSpan(mappings)));
+}
+
+// Tests that empty mapping tables are fine.
+TEST(BitfieldTest, FormatBitfieldValueEmpty) {
+ EXPECT_EQ("", FormatBitfieldValue(MyValue::kNone, {}));
+}
+
+// Tests that values not found in the mappings are still displayed.
+TEST(BitfieldTest, FormatBitfieldValueUnhandledValues) {
+ EXPECT_EQ("kA|2h", FormatBitfieldValue(MyValue::kA | MyValue::kB,
+ {
+ {MyValue::kA, "kA"},
+ }));
+}
+
+// Tests priority order in the mapping table.
+TEST(BitfieldTest, FormatBitfieldValuePriority) {
+ // No priority, will do separate.
+ EXPECT_EQ("kA|kB", FormatBitfieldValue(MyValue::kA | MyValue::kB,
+ {
+ {MyValue::kA, "kA"},
+ {MyValue::kB, "kB"},
+ {MyValue::kAll, "kAll"},
+ }));
+
+ // Priority on the combined flag, use that instead.
+ EXPECT_EQ("kAll", FormatBitfieldValue(MyValue::kA | MyValue::kB,
+ {
+ {MyValue::kAll, "kAll"},
+ {MyValue::kA, "kA"},
+ {MyValue::kB, "kB"},
+ }));
+}
+
+} // namespace
+} // namespace iree
diff --git a/base/file_io.h b/base/file_io.h
new file mode 100644
index 0000000..4bf1b72
--- /dev/null
+++ b/base/file_io.h
@@ -0,0 +1,48 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_FILE_IO_H_
+#define IREE_BASE_FILE_IO_H_
+
+#include <string>
+
+#include "base/status.h"
+
+namespace iree {
+namespace file_io {
+
+// Checks if a file exists at the provided path.
+//
+// Returns an OK status if the file definitely exists.
+// Errors can include PermissionDeniedError, NotFoundError, etc.
+Status FileExists(const std::string& path);
+
+// Synchronously reads a file's contents into a string.
+StatusOr<std::string> GetFileContents(const std::string& path);
+
+// Deletes the file at the provided path.
+Status DeleteFile(const std::string& path);
+
+// Moves a file from 'source_path' to 'destination_path'.
+//
+// This may simply rename the file, but may fall back to a full copy and delete
+// of the original if renaming is not possible (for example when moving between
+// physical storage locations).
+Status MoveFile(const std::string& source_path,
+ const std::string& destination_path);
+
+} // namespace file_io
+} // namespace iree
+
+#endif // IREE_BASE_FILE_IO_H_
diff --git a/base/file_mapping.h b/base/file_mapping.h
new file mode 100644
index 0000000..23a2769
--- /dev/null
+++ b/base/file_mapping.h
@@ -0,0 +1,51 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_FILE_MAPPING_H_
+#define IREE_BASE_FILE_MAPPING_H_
+
+#include <cstdint>
+#include <string>
+
+#include "absl/types/span.h"
+#include "base/ref_ptr.h"
+#include "base/status.h"
+
+namespace iree {
+
+// A memory-mapped file handle.
+class FileMapping : public RefObject<FileMapping> {
+ public:
+ // Opens a file and maps it into the calling process memory.
+ // The file will be opened for shared read access.
+ static StatusOr<ref_ptr<FileMapping>> OpenRead(std::string path);
+
+ virtual ~FileMapping() = default;
+
+ // Read-only contents of the file.
+ inline absl::Span<const uint8_t> data() const noexcept { return data_; }
+
+ protected:
+ explicit FileMapping(absl::Span<const uint8_t> data) : data_(data) {}
+
+ absl::Span<const uint8_t> data_;
+
+ private:
+ FileMapping(const FileMapping&) = delete;
+ FileMapping& operator=(const FileMapping&) = delete;
+};
+
+} // namespace iree
+
+#endif // IREE_BASE_FILE_MAPPING_H_
diff --git a/base/file_path.cc b/base/file_path.cc
new file mode 100644
index 0000000..60384f7
--- /dev/null
+++ b/base/file_path.cc
@@ -0,0 +1,83 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/file_path.h"
+
+#include "absl/strings/str_cat.h"
+
+namespace iree {
+namespace file_path {
+
+namespace {
+
+std::pair<absl::string_view, absl::string_view> SplitPath(
+ absl::string_view path) {
+ size_t pos = path.find_last_of('/');
+ // Handle the case with no '/' in 'path'.
+ if (pos == absl::string_view::npos) {
+ return std::make_pair(path.substr(0, 0), path);
+ }
+ // Handle the case with a single leading '/' in 'path'.
+ if (pos == 0) {
+ return std::make_pair(path.substr(0, 1), absl::ClippedSubstr(path, 1));
+ }
+ return std::make_pair(path.substr(0, pos),
+ absl::ClippedSubstr(path, pos + 1));
+}
+
+// Return the parts of the basename of path, split on the final ".".
+// If there is no "." in the basename or "." is the final character in the
+// basename, the second value will be empty.
+std::pair<absl::string_view, absl::string_view> SplitBasename(
+ absl::string_view path) {
+ path = Basename(path);
+ size_t pos = path.find_last_of('.');
+ if (pos == absl::string_view::npos)
+ return std::make_pair(path, absl::ClippedSubstr(path, path.size(), 0));
+ return std::make_pair(path.substr(0, pos),
+ absl::ClippedSubstr(path, pos + 1));
+}
+
+} // namespace
+
+std::string JoinPaths(absl::string_view path1, absl::string_view path2) {
+ if (path1.empty()) return std::string(path2);
+ if (path2.empty()) return std::string(path1);
+ if (path1.back() == '/') {
+ if (path2.front() == '/')
+ return absl::StrCat(path1, absl::ClippedSubstr(path2, 1));
+ } else {
+ if (path2.front() != '/') return absl::StrCat(path1, "/", path2);
+ }
+ return absl::StrCat(path1, path2);
+}
+
+absl::string_view DirectoryName(absl::string_view path) {
+ return SplitPath(path).first;
+}
+
+absl::string_view Basename(absl::string_view path) {
+ return SplitPath(path).second;
+}
+
+absl::string_view Stem(absl::string_view path) {
+ return SplitBasename(path).first;
+}
+
+absl::string_view Extension(absl::string_view path) {
+ return SplitBasename(path).second;
+}
+
+} // namespace file_path
+} // namespace iree
diff --git a/iree/base/file_path.h b/base/file_path.h
similarity index 100%
rename from iree/base/file_path.h
rename to base/file_path.h
diff --git a/base/flatbuffer_util.cc b/base/flatbuffer_util.cc
new file mode 100644
index 0000000..72a9e7c
--- /dev/null
+++ b/base/flatbuffer_util.cc
@@ -0,0 +1,145 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/flatbuffer_util.h"
+
+#include <cerrno>
+#include <cstring>
+
+#include "absl/memory/memory.h"
+#include "base/file_mapping.h"
+#include "base/memory.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+
+namespace iree {
+
+FlatBufferFileBase::~FlatBufferFileBase() {
+ if (deleter_) {
+ deleter_();
+ deleter_ = []() {};
+ }
+}
+
+Status FlatBufferFileBase::Create(const void* root_ptr,
+ std::function<void()> deleter) {
+ IREE_TRACE_SCOPE0("FlatBufferFileBase::Create");
+
+ root_ptr_ = root_ptr;
+ deleter_ = std::move(deleter);
+
+ return OkStatus();
+}
+
+Status FlatBufferFileBase::CreateWithBackingBuffer(
+ const void* root_ptr, ::flatbuffers::DetachedBuffer backing_buffer) {
+ IREE_TRACE_SCOPE0("FlatBufferFileBase::Create");
+
+ root_ptr_ = root_ptr;
+
+ // Pass along the buffer provided so we keep it alive until the
+ // FlatBufferFileBase is destructed.
+ auto backing_buffer_baton = IreeMoveToLambda(backing_buffer);
+ deleter_ = [backing_buffer_baton]() { (void)backing_buffer_baton.value; };
+
+ return OkStatus();
+}
+
+Status FlatBufferFileBase::Wrap(const void* root_ptr) {
+ IREE_TRACE_SCOPE0("FlatBufferFileBase::Wrap");
+ return Create(root_ptr, []() {});
+}
+
+Status FlatBufferFileBase::FromBuffer(Identifier identifier,
+ absl::Span<const uint8_t> buffer_data,
+ std::function<void()> deleter,
+ size_t root_type_size,
+ VerifierFn verifier_fn) {
+ IREE_TRACE_SCOPE("FlatBufferFileBase::FromBuffer:size", int)
+ (static_cast<int>(buffer_data.size()));
+
+ // Sanity check buffer for the minimum size as FlatBuffers doesn't.
+ if (buffer_data.size() < 16) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Provided serialized flatbuffer buffer is too small to be legit "
+ "at size="
+ << buffer_data.size();
+ }
+
+ // Ensure the buffer has the BIPE magic bytes.
+ if (identifier.has_value() && !::flatbuffers::BufferHasIdentifier(
+ buffer_data.data(), identifier.value())) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Provided serialized buffer does not contain the expected type; "
+ "magic bytes mismatch (expected "
+ << identifier.value() << ")";
+ }
+
+ // Verify the FlatBuffer contains valid offsets and won't try to read out of
+ // bounds of the buffer. We inline a bit of VerifyBufferFromStart so this code
+ // can stay generic.
+ {
+ IREE_TRACE_SCOPE0("FlatBufferFileBase::FromBufferVerification");
+ ::flatbuffers::Verifier verifier{buffer_data.data(), buffer_data.size()};
+ if (!verifier_fn(identifier.value_or(nullptr), &verifier)) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "FlatBuffer failed to verify as expected type; possibly "
+ "corrupt input";
+ }
+ }
+
+ // Resolve the root pointer in the buffer.
+ // This is GetMutableRoot such that we don't need to know T.
+ root_ptr_ = buffer_data.data() +
+ ::flatbuffers::EndianScalar(
+ *reinterpret_cast<const ::flatbuffers::uoffset_t*>(
+ buffer_data.data()));
+ if (!root_ptr_) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Unable to resolve root table";
+ }
+ deleter_ = std::move(deleter);
+
+ return OkStatus();
+}
+
+Status FlatBufferFileBase::WrapBuffer(Identifier identifier,
+ absl::Span<const uint8_t> buffer_data,
+ size_t root_type_size,
+ VerifierFn verifier_fn) {
+ IREE_TRACE_SCOPE0("FlatBufferFileBase::WrapBuffer");
+ return FromBuffer(
+ identifier, buffer_data, []() {}, root_type_size, verifier_fn);
+}
+
+Status FlatBufferFileBase::LoadFile(Identifier identifier, std::string path,
+ size_t root_type_size,
+ VerifierFn verifier_fn) {
+ IREE_TRACE_SCOPE0("FlatBufferFileBase::LoadFile");
+
+ ASSIGN_OR_RETURN(auto file_mapping, FileMapping::OpenRead(path));
+ auto buffer_data = file_mapping->data();
+
+ auto handle_baton = IreeMoveToLambda(file_mapping);
+ return FromBuffer(
+ identifier, buffer_data,
+ [handle_baton]() {
+ // Keeping the mmap handle alive.
+ (void)handle_baton.value;
+ },
+ root_type_size, verifier_fn);
+}
+
+} // namespace iree
diff --git a/base/flatbuffer_util.h b/base/flatbuffer_util.h
new file mode 100644
index 0000000..79b8fe4
--- /dev/null
+++ b/base/flatbuffer_util.h
@@ -0,0 +1,321 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_FLATBUFFER_UTIL_H_
+#define IREE_BASE_FLATBUFFER_UTIL_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "base/memory.h"
+#include "base/status.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace iree {
+
+// Wraps a FlatBuffer String in an absl::string_view.
+// Returns empty-string ("") for nullptr values.
+inline absl::string_view WrapString(const ::flatbuffers::String* value) {
+ return value ? absl::string_view{value->data(), value->size()} : "";
+}
+
+// Base type for FlatBufferFile<T>. See below.
+class FlatBufferFileBase {
+ public:
+ using Identifier = absl::optional<const char*>;
+
+ virtual ~FlatBufferFileBase();
+
+ protected:
+ template <typename T>
+ friend class FlatBufferFile;
+
+ using VerifierFn = bool (*)(const char* identifier,
+ ::flatbuffers::Verifier* verifier);
+
+ FlatBufferFileBase() = default;
+
+ const void* root_ptr() const { return root_ptr_; }
+
+ // Redirections of template static methods on FlatBufferFile so we can put the
+ // implementations in a shared compilation unit.
+ // See FlatBufferFile<T> for doc comments.
+ Status Create(const void* root_ptr, std::function<void()> deleter);
+ Status CreateWithBackingBuffer(const void* root_ptr,
+ ::flatbuffers::DetachedBuffer backing_buffer);
+ Status Wrap(const void* root);
+ Status FromBuffer(Identifier identifier,
+ absl::Span<const uint8_t> buffer_data,
+ std::function<void()> deleter, size_t root_type_size,
+ VerifierFn verifier_fn);
+ // Initializes from an STL byte based container (string and vector of
+ // char/byte should be compatible).
+ template <typename Container>
+ Status FromContainer(Identifier identifier, Container container,
+ size_t root_type_size, VerifierFn verifier_fn);
+ Status WrapBuffer(Identifier identifier,
+ absl::Span<const uint8_t> buffer_data,
+ size_t root_type_size, VerifierFn verifier_fn);
+ Status LoadFile(Identifier identifier, std::string path,
+ size_t root_type_size, VerifierFn verifier_fn);
+
+ private:
+ const void* root_ptr_ = nullptr;
+ std::function<void()> deleter_;
+};
+
+// Immutable root FlatBuffer type wrapper with support for loading and backing
+// buffer management.
+//
+// Immutable and thread-safe.
+template <typename T>
+class FlatBufferFile final : public FlatBufferFileBase {
+ public:
+ // Creates a FlatBufferFile from an in-memory root pointer.
+ // The provided |deleter| will be called when the FlatBufferFile is destructed
+ // and can be used to deallocate/clean up resources.
+ //
+ // This assumes that the root pointer has already been verified as valid.
+ // If verification is required instead use FromBuffer on the original buffer.
+ static StatusOr<std::unique_ptr<FlatBufferFile<T>>> Create(
+ const T* root, std::function<void()> deleter);
+
+ // Creates a FlatBufferFile from an in-memory root pointer and the detached
+ // backing buffer storing it.
+ //
+ // Example:
+ // FlatBufferBuilder fbb;
+ // MyTypeBuilder mtb(fbb);
+ // fbb.Finish(mtb.Finish());
+ // auto my_type = FlatBufferFile<MyType>::CreateWithBackingBuffer(
+ // fbb.Release());
+ // my_type->foo();
+ static StatusOr<std::unique_ptr<FlatBufferFile<T>>> CreateWithBackingBuffer(
+ ::flatbuffers::DetachedBuffer backing_buffer);
+
+ // Wraps a caller-owned in-memory root pointer.
+ // The provided |root| must remain valid for the lifetime of the returned
+ // FlatBufferFile.
+ //
+ // This assumes that the root pointer has already been verified as valid.
+ // If verification is required instead use FromBuffer on the original buffer.
+ static StatusOr<std::unique_ptr<FlatBufferFile<T>>> Wrap(const T* root);
+
+ // Creates a FlatBufferFile wrapping an external data buffer with a deleter
+ // function that will be called when the FlatBufferFile is destructed.
+ static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromBuffer(
+ Identifier identifier, absl::Span<const uint8_t> buffer_data,
+ std::function<void()> deleter);
+
+ // Creates a FlatBufferFile from a serialized data buffer.
+ // The FlatBufferFile takes ownership of the vector.
+ static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromBuffer(
+ Identifier identifier, std::vector<uint8_t> buffer_data);
+
+ // Loads a FlatBufferFile from an external buffer owned by the caller.
+ // The buffer must remain valid until the Pipeline is destroyed.
+ static StatusOr<std::unique_ptr<FlatBufferFile<T>>> WrapBuffer(
+ Identifier identifier, absl::Span<const uint8_t> buffer_data);
+
+ // Loads the FlatBufferFile from a serialized byte-based STL container.
+ template <typename Container>
+ static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromContainer(
+ Identifier identifier, Container buffer_data);
+
+ // Loads a FlatBufferFile from a serialized string.
+ // The FlatBufferFile takes ownership of the string.
+ static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromString(
+ Identifier identifier, std::string buffer_data) {
+ return FromContainer(identifier, std::move(buffer_data));
+ }
+
+ // Loads a FlatBufferFile from a serialized byte vector.
+ // The FlatBufferFile takes ownership of the vector.
+ static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromVector(
+ Identifier identifier, std::vector<uint8_t> buffer_data) {
+ return FromContainer(identifier, std::move(buffer_data));
+ }
+
+ // Loads a FlatBufferFile from a serialized file on the file system.
+ // This will attempt to mmap the file and is the preferred way of loading as
+ // only those pages that contain requested tables will be read.
+ static StatusOr<std::unique_ptr<FlatBufferFile<T>>> LoadFile(
+ Identifier identifier, std::string path);
+
+ // Returns a vector of file references that share the same underlying data
+ // buffer. The buffer will be kept alive until the last file is released.
+ static StatusOr<std::vector<std::unique_ptr<FlatBufferFile<T>>>>
+ CreateShareGroup(std::unique_ptr<FlatBufferFile<T>> file, int count);
+
+ ~FlatBufferFile() override = default;
+
+ // Typed root pointer of the file.
+ const T* root() const { return reinterpret_cast<const T*>(root_ptr()); }
+
+ private:
+ FlatBufferFile() = default;
+
+ // Conforms to VerifierFn.
+ static bool VerifierFnT(const char* identifier,
+ ::flatbuffers::Verifier* verifier) {
+ return verifier->VerifyBuffer<T>(identifier);
+ }
+};
+
+template <typename Container>
+Status FlatBufferFileBase::FromContainer(Identifier identifier,
+ Container container,
+ size_t root_type_size,
+ VerifierFn verifier_fn) {
+ static_assert(sizeof(*container.data()) == 1,
+ "Expected container of byte sized elements");
+ auto buffer_data = absl::MakeConstSpan(
+ // Double static_cast through void is safer than reinterpret_cast.
+ static_cast<const uint8_t*>(static_cast<const void*>(container.data())),
+ container.size());
+ // Use a baton to keep the container alive until the FlatBufferFileBase is
+ // destroyed.
+ auto buffer_data_baton = IreeMoveToLambda(container);
+ return FromBuffer(
+ identifier, buffer_data,
+ [buffer_data_baton]() {
+ // Keeping the container alive.
+ (void)buffer_data_baton.value;
+ },
+ root_type_size, verifier_fn);
+}
+
+// static
+template <typename T>
+StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Create(
+ const T* root, std::function<void()> deleter) {
+ std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
+ auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
+ RETURN_IF_ERROR(base_file->Create(root, std::move(deleter)));
+ return std::move(flat_buffer_file);
+}
+
+// static
+template <typename T>
+StatusOr<std::unique_ptr<FlatBufferFile<T>>>
+FlatBufferFile<T>::CreateWithBackingBuffer(
+ ::flatbuffers::DetachedBuffer backing_buffer) {
+ std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
+ auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
+ auto* root_ptr = ::flatbuffers::GetRoot<T>(backing_buffer.data());
+ RETURN_IF_ERROR(
+ base_file->CreateWithBackingBuffer(root_ptr, std::move(backing_buffer)));
+ return std::move(flat_buffer_file);
+}
+
+// static
+template <typename T>
+StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Wrap(
+ const T* root) {
+ std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
+ auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
+ RETURN_IF_ERROR(base_file->Wrap(root));
+ return std::move(flat_buffer_file);
+}
+
+// static
+template <typename T>
+StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer(
+ Identifier identifier, absl::Span<const uint8_t> buffer_data,
+ std::function<void()> deleter) {
+ std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
+ auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
+ RETURN_IF_ERROR(base_file->FromBuffer(
+ identifier, buffer_data, std::move(deleter), sizeof(T), VerifierFnT));
+ return std::move(flat_buffer_file);
+}
+
+// static
+template <typename T>
+StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer(
+ Identifier identifier, std::vector<uint8_t> buffer_data) {
+ auto* buffer_data_ptr = new decltype(buffer_data);
+ (*buffer_data_ptr) = std::move(buffer_data);
+ return FromBuffer(identifier, absl::MakeConstSpan(*buffer_data_ptr),
+ [buffer_data_ptr]() { delete buffer_data_ptr; });
+}
+
+// static
+template <typename T>
+StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::WrapBuffer(
+ Identifier identifier, absl::Span<const uint8_t> buffer_data) {
+ std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
+ auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
+ RETURN_IF_ERROR(
+ base_file->WrapBuffer(identifier, buffer_data, sizeof(T), VerifierFnT));
+ return std::move(flat_buffer_file);
+}
+
+// static
+template <typename T>
+template <typename Container>
+StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromContainer(
+ Identifier identifier, Container buffer_data) {
+ std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
+ auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
+ RETURN_IF_ERROR(base_file->FromContainer(identifier, std::move(buffer_data),
+ sizeof(T), VerifierFnT));
+ return std::move(flat_buffer_file);
+}
+
+// static
+template <typename T>
+StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::LoadFile(
+ Identifier identifier, std::string path) {
+ std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
+ auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
+ RETURN_IF_ERROR(
+ base_file->LoadFile(identifier, std::move(path), sizeof(T), VerifierFnT));
+ return std::move(flat_buffer_file);
+}
+
+// static
+template <typename T>
+StatusOr<std::vector<std::unique_ptr<FlatBufferFile<T>>>>
+FlatBufferFile<T>::CreateShareGroup(std::unique_ptr<FlatBufferFile<T>> file,
+ int count) {
+ // Create a shared_ptr wrapper for the base file that will be.
+ std::shared_ptr<FlatBufferFile<T>> shared_file{file.release()};
+
+ // Create N files. We wrap and keep the shared_ptr alive in the deleter
+ // capture. By wrapping we avoid reverifying the entire buffer.
+ std::vector<std::unique_ptr<FlatBufferFile<T>>> list;
+ for (int i = 0; i < count; ++i) {
+ ASSIGN_OR_RETURN(auto new_file, FlatBufferFile<T>::Create(
+ shared_file->root(), [shared_file]() {
+ // Each new file keeps a reference to
+ // the shared file to keep it alive.
+ (void)shared_file;
+ }));
+ list.push_back(std::move(new_file));
+ }
+ return std::move(list);
+}
+
+} // namespace iree
+
+#endif // IREE_BASE_FLATBUFFER_UTIL_H_
diff --git a/base/init.h b/base/init.h
new file mode 100644
index 0000000..dab07f4
--- /dev/null
+++ b/base/init.h
@@ -0,0 +1,52 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INIT_H_
+#define IREE_BASE_INIT_H_
+
+// Initializer macros are defined in separate files:
+// IREE_DECLARE_MODULE_INITIALIZER(name)
+// IREE_REGISTER_MODULE_INITIALIZER(name, body)
+// IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2)
+// IREE_REQUIRE_MODULE_INITIALIZED(name)
+// IREE_RUN_MODULE_INITIALIZERS()
+// IREE_REQUIRE_MODULE_LINKED(name)
+//
+// These macros allow for arranging pieces of initialization code to be
+// executed at a well-defined time and in a well-defined order.
+//
+// Initialization happens automatically during InitializeEnvironment(), which
+// should be called early in main(), before other code runs.
+
+#ifdef IREE_CONFIG_GOOGLE_INTERNAL
+#include "base/google/init_google.h"
+#else
+#include "base/internal/init_internal.h"
+#endif // IREE_CONFIG_GOOGLE_INTERNAL
+
+namespace iree {
+
+// Initializes the system environment in a binary.
+//
+// This first parses command line flags, then resolves module initializers
+// by calling IREE_RUN_MODULE_INITIALIZERS().
+//
+// 'argc' and 'argv' are the command line flags to parse.
+//
+// This should typically be called early in main(), before other code runs.
+void InitializeEnvironment(int* argc, char*** argv);
+
+} // namespace iree
+
+#endif // IREE_BASE_INIT_H_
diff --git a/base/internal/BUILD b/base/internal/BUILD
new file mode 100644
index 0000000..3e91f20
--- /dev/null
+++ b/base/internal/BUILD
@@ -0,0 +1,118 @@
+# Implementations for iree/base/
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "file_handle_win32",
+ srcs = ["file_handle_win32.cc"],
+ hdrs = ["file_handle_win32.h"],
+ deps = [
+ "///base:status",
+ "///base:target_platform",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "file_io_internal",
+ srcs = [
+ "file_io_posix.cc",
+ "file_io_win32.cc",
+ ],
+ deps = [
+ ":file_handle_win32",
+ "///base:file_io_hdrs",
+ "///base:status",
+ "///base:target_platform",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "file_mapping_internal",
+ srcs = [
+ "file_mapping_posix.cc",
+ "file_mapping_win32.cc",
+ ],
+ deps = [
+ ":file_handle_win32",
+ "///base:file_mapping_hdrs",
+ "///base:target_platform",
+ "///base:tracing",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "init_internal",
+ srcs = ["init_internal.cc"],
+ hdrs = ["init_internal.h"],
+ deps = [
+ "///base:target_platform",
+ "@com_google_absl//absl/flags:parse",
+ ],
+)
+
+cc_library(
+ name = "logging_internal",
+ srcs = ["logging.cc"],
+ hdrs = ["logging.h"],
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
+ ],
+)
+
+cc_library(
+ name = "source_location_internal",
+ hdrs = ["source_location.h"],
+)
+
+cc_library(
+ name = "status_internal",
+ srcs = [
+ "status.cc",
+ "status_builder.cc",
+ "status_errno.cc",
+ "status_errors.cc",
+ "status_win32_errors.cc",
+ "statusor.cc",
+ ],
+ hdrs = [
+ "status.h",
+ "status_builder.h",
+ "status_errno.h",
+ "status_errors.h",
+ "status_macros.h",
+ "status_win32_errors.h",
+ "statusor.h",
+ ],
+ deps = [
+ ":logging_internal",
+ "///base:source_location",
+ "///base:target_platform",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/debugging:stacktrace",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "status_matchers_internal",
+ testonly = 1,
+ hdrs = ["status_matchers.h"],
+ deps = [
+ "///base:status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/base/internal/file_handle_win32.cc b/base/internal/file_handle_win32.cc
new file mode 100644
index 0000000..37c3352
--- /dev/null
+++ b/base/internal/file_handle_win32.cc
@@ -0,0 +1,55 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/internal/file_handle_win32.h"
+
+#include "absl/memory/memory.h"
+#include "base/target_platform.h"
+
+#if defined(IREE_PLATFORM_WINDOWS)
+
+#include <windows.h>
+
+namespace iree {
+
+// static
+StatusOr<std::unique_ptr<FileHandle>> FileHandle::OpenRead(std::string path,
+ DWORD file_flags) {
+ HANDLE handle = ::CreateFileA(
+ /*lpFileName=*/path.c_str(), /*dwDesiredAccess=*/GENERIC_READ,
+ /*dwShareMode=*/FILE_SHARE_READ, /*lpSecurityAttributes=*/nullptr,
+ /*dwCreationDisposition=*/OPEN_EXISTING,
+ /*dwFlagsAndAttributes=*/FILE_ATTRIBUTE_NORMAL | file_flags,
+ /*hTemplateFile=*/nullptr);
+ if (handle == INVALID_HANDLE_VALUE) {
+ return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
+ << "Unable to open file " << path;
+ }
+
+ BY_HANDLE_FILE_INFORMATION file_info;
+ if (::GetFileInformationByHandle(handle, &file_info) == FALSE) {
+ return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
+ << "Unable to query file info for " << path;
+ }
+
+ uint64_t file_size = (static_cast<uint64_t>(file_info.nFileSizeHigh) << 32) |
+ file_info.nFileSizeLow;
+ return absl::make_unique<FileHandle>(handle, file_size);
+}
+
+FileHandle::~FileHandle() { ::CloseHandle(handle_); }
+
+} // namespace iree
+
+#endif // IREE_PLATFORM_WINDOWS
diff --git a/base/internal/file_handle_win32.h b/base/internal/file_handle_win32.h
new file mode 100644
index 0000000..099e16b
--- /dev/null
+++ b/base/internal/file_handle_win32.h
@@ -0,0 +1,57 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_
+#define IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "base/status.h"
+#include "base/target_platform.h"
+
+#if defined(IREE_PLATFORM_WINDOWS)
+
+#include <windows.h>
+
+namespace iree {
+
+class FileHandle {
+ public:
+ static StatusOr<std::unique_ptr<FileHandle>> OpenRead(std::string path,
+ DWORD file_flags);
+
+ FileHandle(HANDLE handle, size_t size) : handle_(handle), size_(size) {}
+ ~FileHandle();
+
+ absl::string_view path() const { return path_; }
+ HANDLE handle() const { return handle_; }
+ size_t size() const { return size_; }
+
+ private:
+ FileHandle(const FileHandle&) = delete;
+ FileHandle& operator=(const FileHandle&) = delete;
+
+ std::string path_;
+ HANDLE handle_;
+ size_t size_;
+};
+
+} // namespace iree
+
+#endif // IREE_PLATFORM_WINDOWS
+
+#endif // IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_
diff --git a/base/internal/file_io_posix.cc b/base/internal/file_io_posix.cc
new file mode 100644
index 0000000..95de126
--- /dev/null
+++ b/base/internal/file_io_posix.cc
@@ -0,0 +1,89 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cstdio>
+
+#include "base/file_io.h"
+#include "base/status.h"
+#include "base/target_platform.h"
+
+#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_APPLE) || \
+ defined(IREE_PLATFORM_LINUX)
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+namespace iree {
+namespace file_io {
+
+Status FileExists(const std::string& path) {
+ struct stat stat_buf;
+ return stat(path.c_str(), &stat_buf) == 0 ? OkStatus()
+ : NotFoundErrorBuilder(IREE_LOC);
+}
+
+StatusOr<std::string> GetFileContents(const std::string& path) {
+ std::unique_ptr<FILE, void (*)(FILE*)> file = {std::fopen(path.c_str(), "r"),
+ +[](FILE* file) {
+ if (file) fclose(file);
+ }};
+ if (file == nullptr) {
+ return ErrnoToCanonicalStatusBuilder(errno, "Failed to open file",
+ IREE_LOC);
+ }
+ if (std::fseek(file.get(), 0, SEEK_END) == -1) {
+ return ErrnoToCanonicalStatusBuilder(errno, "Failed to seek file",
+ IREE_LOC);
+ }
+ size_t file_size = std::ftell(file.get());
+ if (file_size == -1L) {
+ return ErrnoToCanonicalStatusBuilder(errno, "Failed to read file length",
+ IREE_LOC);
+ }
+ if (std::fseek(file.get(), 0, SEEK_SET) == -1) {
+ return ErrnoToCanonicalStatusBuilder(errno, "Failed to seek file",
+ IREE_LOC);
+ }
+ std::string contents;
+ contents.resize(file_size);
+ if (std::fread(const_cast<char*>(contents.data()), file_size, 1,
+ file.get()) != file_size) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Unable to read entire file contents";
+ }
+ return contents;
+}
+
+Status DeleteFile(const std::string& path) {
+ if (::remove(path.c_str()) == -1) {
+ return ErrnoToCanonicalStatusBuilder(errno, "Failed to delete file",
+ IREE_LOC);
+ }
+ return OkStatus();
+}
+
+Status MoveFile(const std::string& source_path,
+ const std::string& destination_path) {
+ if (::rename(source_path.c_str(), destination_path.c_str()) == -1) {
+ return ErrnoToCanonicalStatusBuilder(errno, "Failed to rename file",
+ IREE_LOC);
+ }
+ return OkStatus();
+}
+
+} // namespace file_io
+} // namespace iree
+
+#endif // IREE_PLATFORM_*
diff --git a/base/internal/file_io_win32.cc b/base/internal/file_io_win32.cc
new file mode 100644
index 0000000..fc16c20
--- /dev/null
+++ b/base/internal/file_io_win32.cc
@@ -0,0 +1,76 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "base/file_io.h"
+#include "base/internal/file_handle_win32.h"
+#include "base/target_platform.h"
+
+#if defined(IREE_PLATFORM_WINDOWS)
+
+#include <windows.h>
+
+namespace iree {
+namespace file_io {
+
+Status FileExists(const std::string& path) {
+ DWORD attrs = ::GetFileAttributesA(path.c_str());
+ if (attrs == INVALID_FILE_ATTRIBUTES) {
+ return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
+ << "Unable to find/access file: " << path;
+ }
+ return OkStatus();
+}
+
+StatusOr<std::string> GetFileContents(const std::string& path) {
+ ASSIGN_OR_RETURN(auto file, FileHandle::OpenRead(std::move(path),
+ FILE_FLAG_SEQUENTIAL_SCAN));
+ std::string result;
+ result.resize(file->size());
+ DWORD bytes_read = 0;
+ if (::ReadFile(file->handle(), const_cast<char*>(result.data()),
+ result.size(), &bytes_read, nullptr) == FALSE) {
+ return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
+ << "Unable to read file span of " << result.size() << " bytes";
+ } else if (bytes_read != file->size()) {
+ return ResourceExhaustedErrorBuilder(IREE_LOC)
+ << "Unable to read all " << file->size()
+ << " bytes from the file (got " << bytes_read << ")";
+ }
+ return result;
+}
+
+Status DeleteFile(const std::string& path) {
+ if (::DeleteFileA(path.c_str()) == FALSE) {
+ return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
+ << "Unable to delete/access file: " << path;
+ }
+ return OkStatus();
+}
+
+Status MoveFile(const std::string& source_path,
+ const std::string& destination_path) {
+ if (::MoveFileA(source_path.c_str(), destination_path.c_str()) == FALSE) {
+ return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
+ << "Unable to move file " << source_path << " to "
+ << destination_path;
+ }
+ return OkStatus();
+}
+
+} // namespace file_io
+} // namespace iree
+
+#endif // IREE_PLATFORM_WINDOWS
diff --git a/base/internal/file_mapping_posix.cc b/base/internal/file_mapping_posix.cc
new file mode 100644
index 0000000..fd4e1e6
--- /dev/null
+++ b/base/internal/file_mapping_posix.cc
@@ -0,0 +1,106 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/file_mapping.h"
+#include "base/target_platform.h"
+#include "base/tracing.h"
+
+#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_APPLE) || \
+ defined(IREE_PLATFORM_LINUX)
+
+#include <fcntl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <cerrno>
+
+namespace iree {
+
+namespace {
+
+class FileDescriptor {
+ public:
+ static StatusOr<std::unique_ptr<FileDescriptor>> OpenRead(std::string path) {
+ struct stat buf;
+ if (::lstat(path.c_str(), &buf) == -1) {
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Unable to stat file " << path << ": " << ::strerror(errno);
+ }
+ uint64_t file_size = static_cast<size_t>(buf.st_size);
+
+ int fd = ::open(path.c_str(), O_RDONLY);
+ if (fd == -1) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Unable to open file " << path << ": " << ::strerror(errno);
+ }
+
+ return absl::make_unique<FileDescriptor>(std::move(path), fd, file_size);
+ }
+
+ FileDescriptor(std::string path, int fd, size_t size)
+ : path_(std::move(path)), fd_(fd), size_(size) {}
+ ~FileDescriptor() { ::close(fd_); }
+
+ absl::string_view path() const { return path_; }
+ int fd() const { return fd_; }
+ size_t size() const { return size_; }
+
+ private:
+ FileDescriptor(const FileDescriptor&) = delete;
+ FileDescriptor& operator=(const FileDescriptor&) = delete;
+
+ std::string path_;
+ int fd_;
+ size_t size_;
+};
+
+class MMapMapping : public FileMapping {
+ public:
+ MMapMapping(void* data, size_t data_size)
+ : FileMapping(
+ absl::MakeSpan(reinterpret_cast<uint8_t*>(data), data_size)) {}
+
+ ~MMapMapping() override {
+ if (::munmap(const_cast<uint8_t*>(data_.data()), data_.size()) != 0) {
+ LOG(WARNING) << "Unable to unmap file: " << strerror(errno);
+ }
+ }
+};
+
+} // namespace
+
+// static
+StatusOr<ref_ptr<FileMapping>> FileMapping::OpenRead(std::string path) {
+ IREE_TRACE_SCOPE0("FileMapping::Open");
+
+ // Open the file for reading. Note that we only need to keep it open long
+ // enough to map it and we can close the descriptor after that.
+ ASSIGN_OR_RETURN(auto file, FileDescriptor::OpenRead(std::move(path)));
+
+ // Map the file from the file descriptor.
+ void* data =
+ ::mmap(nullptr, file->size(), PROT_READ, MAP_SHARED, file->fd(), 0);
+ if (data == MAP_FAILED) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Mapping failed on file (ensure uncompressed): " << file->path();
+ }
+
+ return make_ref<MMapMapping>(data, file->size());
+}
+
+} // namespace iree
+
+#endif // IREE_PLATFORM_*
diff --git a/base/internal/file_mapping_win32.cc b/base/internal/file_mapping_win32.cc
new file mode 100644
index 0000000..9cdead8
--- /dev/null
+++ b/base/internal/file_mapping_win32.cc
@@ -0,0 +1,98 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "base/file_mapping.h"
+#include "base/internal/file_handle_win32.h"
+#include "base/target_platform.h"
+#include "base/tracing.h"
+
+#if defined(IREE_PLATFORM_WINDOWS)
+
+#include <windows.h>
+
+namespace iree {
+
+namespace {
+
+class Win32FileMapping : public FileMapping {
+ public:
+ Win32FileMapping(HANDLE mapping_handle, void* data, size_t data_size)
+ : FileMapping(
+ absl::MakeSpan(reinterpret_cast<uint8_t*>(data), data_size)),
+ mapping_handle_(mapping_handle) {}
+
+ ~Win32FileMapping() override {
+ if (!data_.empty()) {
+ if (::UnmapViewOfFile(data_.data()) == FALSE) {
+ LOG(WARNING) << "Unable to unmap file: " << GetLastError();
+ }
+ data_ = {};
+ }
+ if (mapping_handle_) {
+ ::CloseHandle(mapping_handle_);
+ mapping_handle_ = nullptr;
+ }
+ }
+
+ private:
+ HANDLE mapping_handle_;
+};
+
+} // namespace
+
+// static
+StatusOr<ref_ptr<FileMapping>> FileMapping::OpenRead(std::string path) {
+ IREE_TRACE_SCOPE0("FileMapping::Open");
+
+ // Open the file for reading. Note that we only need to keep it open long
+ // enough to map it and we can close the descriptor after that.
+ ASSIGN_OR_RETURN(auto file, FileHandle::OpenRead(std::move(path),
+ FILE_FLAG_RANDOM_ACCESS));
+
+ HANDLE mapping_handle = ::CreateFileMappingA(
+ /*hFile=*/file->handle(), /*lpFileMappingAttributes=*/nullptr,
+ /*flProtect=*/PAGE_READONLY, /*dwMaximumSizeHigh=*/0,
+ /*dwMaximumSizeLow=*/0, /*lpName=*/nullptr);
+ if (!mapping_handle) {
+ return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
+ << "Failed to create mapping on file (ensure uncompressed): "
+ << file->path();
+ }
+
+ void* data =
+ ::MapViewOfFileEx(/*hFileMappingObject=*/mapping_handle,
+ /*dwDesiredAccess=*/FILE_MAP_READ,
+ /*dwFileOffsetHigh=*/0, /*dwFileOffsetLow=*/0,
+ /*dwNumberOfBytesToMap=*/0, /*lpBaseAddress=*/nullptr);
+ if (!data) {
+ DWORD map_view_error = GetLastError();
+ ::CloseHandle(mapping_handle);
+ return Win32ErrorToCanonicalStatusBuilder(map_view_error, IREE_LOC)
+ << "Failed to map view of file: " << file->path();
+ }
+
+ auto result = make_ref<Win32FileMapping>(mapping_handle, data, file->size());
+
+ // NOTE: file mappings hold references to the file, so we don't need to keep
+ // the file around any longer than this function.
+ file.reset();
+
+ return result;
+}
+
+} // namespace iree
+
+#endif // IREE_PLATFORM_WINDOWS
diff --git a/base/internal/init_internal.cc b/base/internal/init_internal.cc
new file mode 100644
index 0000000..0f0ddeb
--- /dev/null
+++ b/base/internal/init_internal.cc
@@ -0,0 +1,110 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/internal/init_internal.h"
+
+#include <string.h>
+
+#include <set>
+
+#include "absl/flags/parse.h"
+
+namespace iree {
+
+static Initializer::NameMap* static_name_map = nullptr;
+
+struct Initializer::InitializerData {
+ Initializer* initializer_obj;
+ std::set<std::string> dependency_names;
+
+ InitializerData() : initializer_obj(nullptr) {}
+ explicit InitializerData(Initializer* i) : initializer_obj(i) {}
+};
+
+Initializer::DependencyRegisterer::DependencyRegisterer(
+ const char* name, Initializer* initializer, const Dependency& dependency) {
+ NameMap* name_map = InitializerNameMap();
+
+ // Insert 'dependency' into the 'dependency_names' set for 'initializer'.
+ InitializerData* initializer_data = &(*name_map)[name];
+ initializer_data->dependency_names.insert(dependency.name);
+
+ // Ensure that 'dependency' exists in the map.
+ InitializerData* dependency_data = &(*name_map)[dependency.name];
+ dependency_data->initializer_obj = dependency.initializer;
+}
+
+Initializer::Initializer(const char* name, InitializerFunc function)
+ : name_(name), function_(function), done_(false) {
+ // Register this Initializer instance (wrapped by an InitializerData) within
+ // the static name map.
+ NameMap* name_map = InitializerNameMap();
+ InitializerData* initializer_data = &(*name_map)[name];
+ initializer_data->initializer_obj = this;
+}
+
+void Initializer::RunInitializers() {
+ // Run each registered Initializer, in lexicographic order of their names.
+ // Initializer dependencies will be run first as needed.
+ NameMap* name_map = InitializerNameMap();
+ for (auto& p : *name_map) {
+ RunInitializer(&p.second);
+ }
+}
+
+void Initializer::Require() {
+ NameMap* name_map = InitializerNameMap();
+ InitializerData* initializer_data = &(name_map->find(name_)->second);
+ RunInitializer(initializer_data);
+}
+
+Initializer::NameMap* Initializer::InitializerNameMap() {
+ if (static_name_map == nullptr) {
+ static_name_map = new Initializer::NameMap;
+ }
+ return static_name_map;
+}
+
+void Initializer::RunInitializer(InitializerData* initializer_data) {
+ if (initializer_data->initializer_obj->done_) {
+ return;
+ }
+
+ // Run Initializer dependencies first.
+ NameMap* name_map = InitializerNameMap();
+ for (const auto& dependency_name : initializer_data->dependency_names) {
+ auto dep_init = name_map->find(dependency_name);
+ RunInitializer(&dep_init->second);
+ }
+
+ // Finally run the Initializer itself.
+ initializer_data->initializer_obj->function_();
+ initializer_data->initializer_obj->done_ = true;
+}
+
+void InitializeEnvironment(int* argc, char*** argv) {
+ auto positional_args = absl::ParseCommandLine(*argc, *argv);
+ if (positional_args.size() < *argc) {
+ // Edit the passed argument refs to only include positional args.
+ *argc = positional_args.size();
+ for (int i = 0; i < *argc; ++i) {
+ (*argv)[i] = positional_args[i];
+ }
+ (*argv)[*argc + 1] = nullptr;
+ }
+
+ IREE_RUN_MODULE_INITIALIZERS();
+}
+
+} // namespace iree
diff --git a/base/internal/init_internal.h b/base/internal/init_internal.h
new file mode 100644
index 0000000..bbf2591
--- /dev/null
+++ b/base/internal/init_internal.h
@@ -0,0 +1,110 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INTERNAL_INIT_INTERNAL_H_
+#define IREE_BASE_INTERNAL_INIT_INTERNAL_H_
+
+#include <map>
+#include <string>
+
+#include "base/target_platform.h"
+
+namespace iree {
+
+// A static instance of this class is declared for each piece of initialization
+// code using the initializer macros.
+class Initializer {
+ public:
+ typedef void (*InitializerFunc)();
+
+ Initializer(const char* name, InitializerFunc function);
+
+ // Runs all registered initializers that have not yet run.
+ // The initializers are invoked in lexicographically increasing order by name,
+ // except as necessary to satisfy dependencies.
+ //
+ // This is normally called by InitializeEnvironment(), so application code
+ // typically should not call it directly.
+ static void RunInitializers();
+
+ // Runs this initializer if it has not yet run, including any dependencies.
+ void Require();
+
+ struct Dependency {
+ Dependency(const char* n, Initializer* i) : name(n), initializer(i) {}
+ const char* const name;
+ Initializer* const initializer;
+ };
+
+ // A static instance of this class is declared for each piece of
+ // initializer ordering definition.
+ struct DependencyRegisterer {
+ DependencyRegisterer(const char* name, Initializer* initializer,
+ const Dependency& dependency);
+ };
+
+ struct InitializerData;
+ typedef std::map<std::string, InitializerData> NameMap;
+
+ private:
+ static NameMap* InitializerNameMap();
+ static void RunInitializer(InitializerData* initializer_data);
+
+ const std::string name_;
+ InitializerFunc function_;
+ bool done_;
+};
+
+// In iree/base/init.h:
+void InitializeEnvironment(int* argc, char*** argv);
+
+} // namespace iree
+
+#define IREE_DECLARE_MODULE_INITIALIZER(name) \
+ extern ::iree::Initializer iree_initializer_##name
+
+#define IREE_REGISTER_MODULE_INITIALIZER(name, body) \
+ static void iree_init_##name() { body; } \
+ ::iree::Initializer iree_initializer_##name(#name, iree_init_##name)
+
+#define IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) \
+ namespace { \
+ static ::iree::Initializer::DependencyRegisterer \
+ iree_initializer_dependency_##name1##_##name2( \
+ #name2, &iree_initializer_##name2, \
+ ::iree::Initializer::Dependency(#name1, &iree_initializer_##name1)); \
+ }
+
+#define IREE_REQUIRE_MODULE_INITIALIZED(name) \
+ do { \
+ IREE_DECLARE_MODULE_INITIALIZER(name); \
+ iree_initializer_##name.Require(); \
+ } while (0)
+
+#define IREE_RUN_MODULE_INITIALIZERS() \
+ do { \
+ ::iree::Initializer::RunInitializers(); \
+ } while (0)
+
+#if !defined(IREE_COMPILER_MSVC)
+#define IREE_ATTRIBUTE_USED __attribute__((used))
+#else
+#define IREE_ATTRIBUTE_USED
+#endif // IREE_COMPILER_MSVC
+
+#define IREE_REQUIRE_MODULE_LINKED(name) \
+ IREE_ATTRIBUTE_USED static ::iree::Initializer* iree_module_ref_##name = \
+ &iree_initializer_##name
+
+#endif // IREE_BASE_INTERNAL_INIT_INTERNAL_H_
diff --git a/base/internal/logging.cc b/base/internal/logging.cc
new file mode 100644
index 0000000..5e4c52e
--- /dev/null
+++ b/base/internal/logging.cc
@@ -0,0 +1,106 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/internal/logging.h"
+
+#include <string>
+
+#include "absl/flags/flag.h"
+
+ABSL_FLAG(int, iree_minloglevel, 0,
+ "Minimum logging level. 0 = INFO and above.");
+ABSL_FLAG(int, iree_v, 0,
+ "Verbosity level maximum. 1 = VLOG(0-1), 2 = VLOG(0-2).");
+ABSL_FLAG(bool, iree_logtostderr, false, "Logs to stderr instead of stdout");
+
+namespace iree {
+namespace internal {
+
+namespace {
+
+// Parse log level (int64_t) from environment variable (char*).
+// Returns true if the value was present and parsed successfully.
+bool LogLevelStrToInt(const char* iree_env_var_val, int64_t* out_level) {
+ *out_level = 0;
+ if (iree_env_var_val == nullptr) {
+ return false;
+ }
+
+ std::string min_log_level(iree_env_var_val);
+ std::istringstream ss(min_log_level);
+ int64_t level;
+ if (!(ss >> level)) {
+ // Invalid vlog level setting, set level to default (0).
+ return false;
+ }
+
+ *out_level = level;
+ return true;
+}
+
+int64_t MinLogLevelFromEnv() {
+ const char* iree_env_var_val = getenv("IREE_MIN_LOG_LEVEL");
+ int64_t level = 0;
+ if (LogLevelStrToInt(iree_env_var_val, &level)) {
+ return level;
+ }
+ return absl::GetFlag(FLAGS_iree_minloglevel);
+}
+
+int64_t MinVLogLevelFromEnv() {
+ const char* iree_env_var_val = getenv("IREE_MIN_VLOG_LEVEL");
+ int64_t level = 0;
+ if (LogLevelStrToInt(iree_env_var_val, &level)) {
+ return level;
+ }
+ return absl::GetFlag(FLAGS_iree_v);
+}
+
+} // namespace
+
+LogMessage::LogMessage(const char* file_name, int line, int severity)
+ : file_name_(file_name), line_(line), severity_(severity) {}
+
+LogMessage::~LogMessage() {
+ // Read the min log level once during the first call to logging.
+ static int64_t min_log_level = MinLogLevelFromEnv();
+ if (ABSL_PREDICT_TRUE(severity_ >= min_log_level)) {
+ EmitLogMessage();
+ }
+}
+
+int64_t LogMessage::MinVLogLevel() {
+ static int64_t min_vlog_level = MinVLogLevelFromEnv();
+ return min_vlog_level;
+}
+
+void LogMessage::EmitLogMessage() {
+ // TODO(scotttodd): Include current system time
+ fprintf(absl::GetFlag(FLAGS_iree_logtostderr) ? stderr : stdout,
+ "%c %s:%d] %s\n", "IWEF"[severity_], file_name_, line_,
+ str().c_str());
+}
+
+LogMessageFatal::LogMessageFatal(const char* file, int line)
+ : LogMessage(file, line, FATAL) {}
+
+LogMessageFatal::~LogMessageFatal() {
+ EmitLogMessage();
+
+ // abort() ensures we don't return (as promised via ATTRIBUTE_NORETURN).
+ abort();
+}
+
+} // namespace internal
+} // namespace iree
diff --git a/iree/base/internal/logging.h b/base/internal/logging.h
similarity index 100%
rename from iree/base/internal/logging.h
rename to base/internal/logging.h
diff --git a/iree/base/internal/source_location.h b/base/internal/source_location.h
similarity index 100%
rename from iree/base/internal/source_location.h
rename to base/internal/source_location.h
diff --git a/base/internal/status.cc b/base/internal/status.cc
new file mode 100644
index 0000000..447ab52
--- /dev/null
+++ b/base/internal/status.cc
@@ -0,0 +1,178 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/internal/status.h"
+
+#include <atomic>
+#include <memory>
+
+#include "absl/base/attributes.h"
+#include "absl/debugging/stacktrace.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+
+ABSL_FLAG(bool, iree_status_save_stack_trace, false,
+ "Save and display the full stack trace of the point of error")
+ .OnUpdate([]() {
+ iree::StatusSavesStackTrace(
+ absl::GetFlag(FLAGS_iree_status_save_stack_trace));
+ });
+
+namespace iree {
+
+namespace status_internal {
+
+ABSL_CONST_INIT std::atomic<bool> iree_save_stack_trace{false};
+
+} // namespace status_internal
+
+bool DoesStatusSaveStackTrace() {
+ return status_internal::iree_save_stack_trace.load(std::memory_order_relaxed);
+}
+void StatusSavesStackTrace(bool on_off) {
+ status_internal::iree_save_stack_trace.store(on_off,
+ std::memory_order_relaxed);
+}
+
+std::string StatusCodeToString(StatusCode code) {
+ switch (code) {
+ case StatusCode::kOk:
+ return "OK";
+ case StatusCode::kCancelled:
+ return "CANCELLED";
+ case StatusCode::kUnknown:
+ return "UNKNOWN";
+ case StatusCode::kInvalidArgument:
+ return "INVALID_ARGUMENT";
+ case StatusCode::kDeadlineExceeded:
+ return "DEADLINE_EXCEEDED";
+ case StatusCode::kNotFound:
+ return "NOT_FOUND";
+ case StatusCode::kAlreadyExists:
+ return "ALREADY_EXISTS";
+ case StatusCode::kPermissionDenied:
+ return "PERMISSION_DENIED";
+ case StatusCode::kUnauthenticated:
+ return "UNAUTHENTICATED";
+ case StatusCode::kResourceExhausted:
+ return "RESOURCE_EXHAUSTED";
+ case StatusCode::kFailedPrecondition:
+ return "FAILED_PRECONDITION";
+ case StatusCode::kAborted:
+ return "ABORTED";
+ case StatusCode::kOutOfRange:
+ return "OUT_OF_RANGE";
+ case StatusCode::kUnimplemented:
+ return "UNIMPLEMENTED";
+ case StatusCode::kInternal:
+ return "INTERNAL";
+ case StatusCode::kUnavailable:
+ return "UNAVAILABLE";
+ case StatusCode::kDataLoss:
+ return "DATA_LOSS";
+ default:
+ return "";
+ }
+}
+
+Status::Status() {}
+
+Status::Status(StatusCode code, absl::string_view message) {
+ state_ = absl::make_unique<State>();
+ state_->code = code;
+ state_->message = std::string(message);
+}
+
+Status::Status(const Status& x) {
+ if (x.ok()) return;
+
+ state_ = absl::make_unique<State>();
+ state_->code = x.state_->code;
+ state_->message = x.state_->message;
+}
+
+Status& Status::operator=(const Status& x) {
+ if (x.ok()) {
+ state_ = nullptr;
+ } else {
+ state_ = absl::make_unique<State>();
+ state_->code = x.state_->code;
+ state_->message = x.state_->message;
+ }
+ return *this;
+}
+
+Status::~Status() {}
+
+bool Status::ok() const { return state_ == nullptr; }
+
+StatusCode Status::code() const {
+ return ok() ? StatusCode::kOk : state_->code;
+}
+
+absl::string_view Status::message() const {
+ return ok() ? absl::string_view() : absl::string_view(state_->message);
+}
+
+std::string Status::ToString() const {
+ if (ok()) {
+ return "OK";
+ }
+
+ std::string text;
+ absl::StrAppend(&text, StatusCodeToString(state_->code), ": ",
+ state_->message);
+ // TODO(scotttodd): Payloads (stack traces)
+ return text;
+}
+
+void Status::IgnoreError() const {
+ // no-op
+}
+
+bool Status::EqualsSlow(const Status& a, const Status& b) {
+ if (a.code() != b.code()) return false;
+ if (a.message() != b.message()) return false;
+ // TODO(scotttodd): Payloads
+ return true;
+}
+
+bool operator==(const Status& lhs, const Status& rhs) {
+ return lhs.state_ == rhs.state_ || Status::EqualsSlow(lhs, rhs);
+}
+
+bool operator!=(const Status& lhs, const Status& rhs) { return !(lhs == rhs); }
+
+std::ostream& operator<<(std::ostream& os, const Status& x) {
+ os << x.ToString();
+ return os;
+}
+
+Status OkStatus() { return Status(); }
+
+Status Annotate(const Status& s, absl::string_view msg) {
+ if (s.ok() || msg.empty()) return s;
+
+ absl::string_view new_msg = msg;
+ std::string annotated;
+ if (!s.message().empty()) {
+ absl::StrAppend(&annotated, s.message(), "; ", msg);
+ new_msg = annotated;
+ }
+ Status result(s.code(), new_msg);
+ // TODO(scotttodd): Copy payload(s) into the new Status
+ return result;
+}
+
+} // namespace iree
diff --git a/base/internal/status.h b/base/internal/status.h
new file mode 100644
index 0000000..cc1d206
--- /dev/null
+++ b/base/internal/status.h
@@ -0,0 +1,130 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INTERNAL_STATUS_H_
+#define IREE_BASE_INTERNAL_STATUS_H_
+
+#include <atomic>
+#include <string>
+
+#include "absl/base/attributes.h"
+#include "absl/flags/flag.h"
+#include "absl/strings/string_view.h"
+#include "base/internal/logging.h"
+
+ABSL_DECLARE_FLAG(bool, iree_status_save_stack_trace);
+
+namespace iree {
+
+// True if Status objects will capture stack traces on init for non-ok Statuses.
+bool DoesStatusSaveStackTrace();
+
+// Enables/disables status stack trace saving. This is global for the process.
+// While useful for debugging, stack traces can impact performance severely.
+void StatusSavesStackTrace(bool on_off);
+
+enum class StatusCode : int {
+ kOk = 0,
+ kCancelled = 1,
+ kUnknown = 2,
+ kInvalidArgument = 3,
+ kDeadlineExceeded = 4,
+ kNotFound = 5,
+ kAlreadyExists = 6,
+ kPermissionDenied = 7,
+ kResourceExhausted = 8,
+ kFailedPrecondition = 9,
+ kAborted = 10,
+ kOutOfRange = 11,
+ kUnimplemented = 12,
+ kInternal = 13,
+ kUnavailable = 14,
+ kDataLoss = 15,
+ kUnauthenticated = 16,
+ kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20
+};
+
+std::string StatusCodeToString(StatusCode code);
+
+class ABSL_MUST_USE_RESULT Status;
+
+// A Status value can be either OK or not-OK
+// * OK indicates that the operation succeeded.
+// * A not-OK value indicates that the operation failed and contains details
+// about the error.
+class Status final {
+ public:
+ // Creates an OK status with no message.
+ Status();
+
+ // Creates a status with the specified code and error message.
+ Status(StatusCode code, absl::string_view message);
+
+ Status(const Status&);
+ Status& operator=(const Status& x);
+
+ ~Status();
+
+ // Returns true if the Status is OK.
+ ABSL_MUST_USE_RESULT bool ok() const;
+
+ // Returns the error code.
+ StatusCode code() const;
+
+ // Returns the error message. Note: prefer ToString() for debug logging.
+ // This message rarely describes the error code. It is not unusual for the
+ // error message to be the empty string.
+ absl::string_view message() const;
+
+ // Return a combination of the error code name and message.
+ std::string ToString() const;
+
+ // Compatibility with upstream API. Equiv to ToString().
+ std::string error_message() const { return ToString(); }
+
+ friend bool operator==(const Status&, const Status&);
+ friend bool operator!=(const Status&, const Status&);
+
+ // Ignores any errors, potentially suppressing complaints from any tools.
+ void IgnoreError() const;
+
+ private:
+ static bool EqualsSlow(const Status& a, const Status& b);
+
+ struct State {
+ StatusCode code;
+ std::string message;
+ };
+ // OK status has a nullptr state_. Otherwise, 'state_' points to
+ // a 'State' structure containing the error code and message(s).
+ std::unique_ptr<State> state_;
+};
+
+// Returns an OK status, equivalent to a default constructed instance.
+Status OkStatus();
+
+// Prints a human-readable representation of `x` to `os`.
+std::ostream& operator<<(std::ostream& os, const Status& x);
+
+// Returns a Status that is identical to `s` except that the message()
+// has been augmented by adding `msg` to the end of the original message.
+Status Annotate(const Status& s, absl::string_view msg);
+
+#define CHECK_OK(val) CHECK_EQ(::iree::OkStatus(), (val))
+#define QCHECK_OK(val) QCHECK_EQ(::iree::OkStatus(), (val))
+#define DCHECK_OK(val) DCHECK_EQ(::iree::OkStatus(), (val))
+
+} // namespace iree
+
+#endif // IREE_BASE_INTERNAL_STATUS_H_
diff --git a/base/internal/status_builder.cc b/base/internal/status_builder.cc
new file mode 100644
index 0000000..d5bf924
--- /dev/null
+++ b/base/internal/status_builder.cc
@@ -0,0 +1,140 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/internal/status_builder.h"
+
+#include <cstdio>
+
+#include "base/internal/status_errors.h"
+
+namespace iree {
+
+StatusBuilder::StatusBuilder(const Status& original_status,
+ SourceLocation location)
+ : status_(original_status), loc_(location) {}
+
+StatusBuilder::StatusBuilder(Status&& original_status, SourceLocation location)
+ : status_(original_status), loc_(location) {}
+
+StatusBuilder::StatusBuilder(const StatusBuilder& sb)
+ : status_(sb.status_), loc_(sb.loc_), message_(sb.message_) {}
+
+StatusBuilder::StatusBuilder(StatusCode code, SourceLocation location)
+ : status_(code, ""), loc_(location) {}
+
+StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) {
+ status_ = sb.status_;
+ loc_ = sb.loc_;
+ message_ = sb.message_;
+ return *this;
+}
+
+StatusBuilder::operator Status() const& {
+ return StatusBuilder(*this).CreateStatus();
+}
+StatusBuilder::operator Status() && { return std::move(*this).CreateStatus(); }
+
+bool StatusBuilder::ok() const { return status_.ok(); }
+
+StatusCode StatusBuilder::code() const { return status_.code(); }
+
+SourceLocation StatusBuilder::source_location() const { return loc_; }
+
+Status StatusBuilder::CreateStatus() && {
+ Status result = JoinMessageToStatus(status_, message_);
+
+ // Reset the status after consuming it.
+ status_ = UnknownError("");
+ message_ = "";
+ return result;
+}
+
+Status StatusBuilder::JoinMessageToStatus(Status s, absl::string_view msg) {
+ if (msg.empty()) return s;
+ return Annotate(s, msg);
+}
+
+std::ostream& operator<<(std::ostream& os, const StatusBuilder& builder) {
+ return os << static_cast<Status>(builder);
+}
+
+std::ostream& operator<<(std::ostream& os, StatusBuilder&& builder) {
+ return os << static_cast<Status>(std::move(builder));
+}
+
+StatusBuilder AbortedErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kAborted, location);
+}
+
+StatusBuilder AlreadyExistsErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kAlreadyExists, location);
+}
+
+StatusBuilder CancelledErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kCancelled, location);
+}
+
+StatusBuilder DataLossErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kDataLoss, location);
+}
+
+StatusBuilder DeadlineExceededErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kDeadlineExceeded, location);
+}
+
+StatusBuilder FailedPreconditionErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kFailedPrecondition, location);
+}
+
+StatusBuilder InternalErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kInternal, location);
+}
+
+StatusBuilder InvalidArgumentErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kInvalidArgument, location);
+}
+
+StatusBuilder NotFoundErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kNotFound, location);
+}
+
+StatusBuilder OutOfRangeErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kOutOfRange, location);
+}
+
+StatusBuilder PermissionDeniedErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kPermissionDenied, location);
+}
+
+StatusBuilder UnauthenticatedErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kUnauthenticated, location);
+}
+
+StatusBuilder ResourceExhaustedErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kResourceExhausted, location);
+}
+
+StatusBuilder UnavailableErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kUnavailable, location);
+}
+
+StatusBuilder UnimplementedErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kUnimplemented, location);
+}
+
+StatusBuilder UnknownErrorBuilder(SourceLocation location) {
+ return StatusBuilder(StatusCode::kUnknown, location);
+}
+
+} // namespace iree
diff --git a/base/internal/status_builder.h b/base/internal/status_builder.h
new file mode 100644
index 0000000..60f3221
--- /dev/null
+++ b/base/internal/status_builder.h
@@ -0,0 +1,137 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INTERNAL_STATUS_BUILDER_H_
+#define IREE_BASE_INTERNAL_STATUS_BUILDER_H_
+
+#include "base/internal/status.h"
+#include "base/source_location.h"
+
+namespace iree {
+
+// Creates a status based on an original_status, but enriched with additional
+// information. The builder implicitly converts to Status and StatusOr<T>
+// allowing for it to be returned directly.
+class ABSL_MUST_USE_RESULT StatusBuilder {
+ public:
+ // Creates a `StatusBuilder` based on an original status.
+ explicit StatusBuilder(const Status& original_status,
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+ explicit StatusBuilder(Status&& original_status,
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+
+ // Creates a `StatusBuilder` from a status code.
+ // A typical user will not specify `location`, allowing it to default to the
+ // current location.
+ explicit StatusBuilder(StatusCode code,
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+
+ StatusBuilder(const StatusBuilder& sb);
+ StatusBuilder& operator=(const StatusBuilder& sb);
+ StatusBuilder(StatusBuilder&&) = default;
+ StatusBuilder& operator=(StatusBuilder&&) = default;
+
+ // Appends to the extra message that will be added to the original status.
+ template <typename T>
+ StatusBuilder& operator<<(const T& value) &;
+ template <typename T>
+ StatusBuilder&& operator<<(const T& value) &&;
+
+ // No-op functions that may be added later.
+ StatusBuilder& LogError() & { return *this; }
+ StatusBuilder&& LogError() && { return std::move(LogError()); }
+ StatusBuilder& LogWarning() & { return *this; }
+ StatusBuilder&& LogWarning() && { return std::move(LogWarning()); }
+ StatusBuilder& LogInfo() & { return *this; }
+ StatusBuilder&& LogInfo() && { return std::move(LogInfo()); }
+
+ // Returns true if the Status created by this builder will be ok().
+ bool ok() const;
+
+ // Returns the error code for the Status created by this builder.
+ StatusCode code() const;
+
+ // Returns the source location used to create this builder.
+ SourceLocation source_location() const;
+
+ // Implicit conversion to Status.
+ operator Status() const&;
+ operator Status() &&;
+
+ private:
+ Status CreateStatus() &&;
+
+ static Status JoinMessageToStatus(Status s, absl::string_view msg);
+
+ // The status that the result will be based on.
+ Status status_;
+
+ // The location to record if this status is logged.
+ SourceLocation loc_;
+
+ // The message that will be added to the original status.
+ std::string message_;
+};
+
+template <typename T>
+StatusBuilder& StatusBuilder::operator<<(const T& value) & {
+ return *this;
+}
+template <typename T>
+StatusBuilder&& StatusBuilder::operator<<(const T& value) && {
+ return std::move(operator<<(value));
+}
+
+// Implicitly converts `builder` to `Status` and write it to `os`.
+std::ostream& operator<<(std::ostream& os, const StatusBuilder& builder);
+std::ostream& operator<<(std::ostream& os, StatusBuilder&& builder);
+
+// Each of the functions below creates StatusBuilder with a canonical error.
+// The error code of the StatusBuilder matches the name of the function.
+StatusBuilder AbortedErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder AlreadyExistsErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder CancelledErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder DataLossErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder DeadlineExceededErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder FailedPreconditionErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder InternalErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder InvalidArgumentErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder NotFoundErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder OutOfRangeErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder PermissionDeniedErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder UnauthenticatedErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder ResourceExhaustedErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder UnavailableErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder UnimplementedErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+StatusBuilder UnknownErrorBuilder(
+ SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
+
+} // namespace iree
+
+#endif // IREE_BASE_INTERNAL_STATUS_BUILDER_H_
diff --git a/base/internal/status_errno.cc b/base/internal/status_errno.cc
new file mode 100644
index 0000000..3788bc3
--- /dev/null
+++ b/base/internal/status_errno.cc
@@ -0,0 +1,175 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/internal/status_errno.h"
+
+#include <cerrno>
+
+#include "absl/strings/str_cat.h"
+
+namespace iree {
+
+StatusCode ErrnoToCanonicalCode(int error_number) {
+ switch (error_number) {
+ case 0:
+ return StatusCode::kOk;
+ case EINVAL: // Invalid argument
+ case ENAMETOOLONG: // Filename too long
+ case E2BIG: // Argument list too long
+ case EDESTADDRREQ: // Destination address required
+ case EDOM: // Mathematics argument out of domain of function
+ case EFAULT: // Bad address
+ case EILSEQ: // Illegal byte sequence
+ case ENOPROTOOPT: // Protocol not available
+ case ENOSTR: // Not a STREAM
+ case ENOTSOCK: // Not a socket
+ case ENOTTY: // Inappropriate I/O control operation
+ case EPROTOTYPE: // Protocol wrong type for socket
+ case ESPIPE: // Invalid seek
+ return StatusCode::kInvalidArgument;
+ case ETIMEDOUT: // Connection timed out
+ case ETIME: // Timer expired
+ return StatusCode::kDeadlineExceeded;
+ case ENODEV: // No such device
+ case ENOENT: // No such file or directory
+#ifdef ENOMEDIUM
+ case ENOMEDIUM: // No medium found
+#endif
+ case ENXIO: // No such device or address
+ case ESRCH: // No such process
+ return StatusCode::kNotFound;
+ case EEXIST: // File exists
+ case EADDRNOTAVAIL: // Address not available
+ case EALREADY: // Connection already in progress
+#ifdef ENOTUNIQ
+ case ENOTUNIQ: // Name not unique on network
+#endif
+ return StatusCode::kAlreadyExists;
+ case EPERM: // Operation not permitted
+ case EACCES: // Permission denied
+#ifdef ENOKEY
+ case ENOKEY: // Required key not available
+#endif
+ case EROFS: // Read only file system
+ return StatusCode::kPermissionDenied;
+ case ENOTEMPTY: // Directory not empty
+ case EISDIR: // Is a directory
+ case ENOTDIR: // Not a directory
+ case EADDRINUSE: // Address already in use
+ case EBADF: // Invalid file descriptor
+#ifdef EBADFD
+ case EBADFD: // File descriptor in bad state
+#endif
+ case EBUSY: // Device or resource busy
+ case ECHILD: // No child processes
+ case EISCONN: // Socket is connected
+#ifdef EISNAM
+ case EISNAM: // Is a named type file
+#endif
+#ifdef ENOTBLK
+ case ENOTBLK: // Block device required
+#endif
+ case ENOTCONN: // The socket is not connected
+ case EPIPE: // Broken pipe
+#ifdef ESHUTDOWN
+ case ESHUTDOWN: // Cannot send after transport endpoint shutdown
+#endif
+ case ETXTBSY: // Text file busy
+#ifdef EUNATCH
+ case EUNATCH: // Protocol driver not attached
+#endif
+ return StatusCode::kFailedPrecondition;
+ case ENOSPC: // No space left on device
+#ifdef EDQUOT
+ case EDQUOT: // Disk quota exceeded
+#endif
+ case EMFILE: // Too many open files
+ case EMLINK: // Too many links
+ case ENFILE: // Too many open files in system
+ case ENOBUFS: // No buffer space available
+ case ENODATA: // No message is available on the STREAM read queue
+ case ENOMEM: // Not enough space
+ case ENOSR: // No STREAM resources
+#ifdef EUSERS
+ case EUSERS: // Too many users
+#endif
+ return StatusCode::kResourceExhausted;
+#ifdef ECHRNG
+ case ECHRNG: // Channel number out of range
+#endif
+ case EFBIG: // File too large
+ case EOVERFLOW: // Value too large to be stored in data type
+ case ERANGE: // Result too large
+ return StatusCode::kOutOfRange;
+#ifdef ENOPKG
+ case ENOPKG: // Package not installed
+#endif
+ case ENOSYS: // Function not implemented
+ case ENOTSUP: // Operation not supported
+ case EAFNOSUPPORT: // Address family not supported
+#ifdef EPFNOSUPPORT
+ case EPFNOSUPPORT: // Protocol family not supported
+#endif
+ case EPROTONOSUPPORT: // Protocol not supported
+#ifdef ESOCKTNOSUPPORT
+ case ESOCKTNOSUPPORT: // Socket type not supported
+#endif
+ case EXDEV: // Improper link
+ return StatusCode::kUnimplemented;
+ case EAGAIN: // Resource temporarily unavailable
+#ifdef ECOMM
+ case ECOMM: // Communication error on send
+#endif
+ case ECONNREFUSED: // Connection refused
+ case ECONNABORTED: // Connection aborted
+ case ECONNRESET: // Connection reset
+ case EINTR: // Interrupted function call
+#ifdef EHOSTDOWN
+ case EHOSTDOWN: // Host is down
+#endif
+ case EHOSTUNREACH: // Host is unreachable
+ case ENETDOWN: // Network is down
+ case ENETRESET: // Connection aborted by network
+ case ENETUNREACH: // Network unreachable
+ case ENOLCK: // No locks available
+ case ENOLINK: // Link has been severed
+#ifdef ENONET
+ case ENONET: // Machine is not on the network
+#endif
+ return StatusCode::kUnavailable;
+ case EDEADLK: // Resource deadlock avoided
+#ifdef ESTALE
+ case ESTALE: // Stale file handle
+#endif
+ return StatusCode::kAborted;
+ case ECANCELED: // Operation cancelled
+ return StatusCode::kCancelled;
+ default:
+ return StatusCode::kUnknown;
+ }
+}
+
+Status ErrnoToCanonicalStatus(int error_number, absl::string_view message) {
+ // TODO(scotttodd): convert error number to a string
+ return Status(ErrnoToCanonicalCode(error_number),
+ absl::StrCat(message, ": ", error_number));
+}
+
+StatusBuilder ErrnoToCanonicalStatusBuilder(int error_number,
+ absl::string_view message,
+ SourceLocation location) {
+ return StatusBuilder(ErrnoToCanonicalStatus(error_number, message), location);
+}
+
+} // namespace iree
diff --git a/base/internal/status_errno.h b/base/internal/status_errno.h
new file mode 100644
index 0000000..651e80d
--- /dev/null
+++ b/base/internal/status_errno.h
@@ -0,0 +1,41 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INTERNAL_STATUS_ERRNO_H_
+#define IREE_BASE_INTERNAL_STATUS_ERRNO_H_
+
+#include "absl/strings/string_view.h"
+#include "base/internal/status.h"
+#include "base/internal/statusor.h"
+#include "base/source_location.h"
+
+namespace iree {
+
+// Returns the code for |error_number|, which should be an |errno| value.
+// See https://en.cppreference.com/w/cpp/error/errno_macros and similar refs.
+StatusCode ErrnoToCanonicalCode(int error_number);
+
+// Returns a Status, using a code of `ErrnoToCode(error_number)`, and a
+// |message| with the result of `StrError(error_number)` appended.
+Status ErrnoToCanonicalStatus(int error_number, absl::string_view message);
+
+// Returns a StatusBuilder using a status of
+// `ErrnoToCanonicalStatus(error_number, message)` and |location|.
+StatusBuilder ErrnoToCanonicalStatusBuilder(int error_number,
+ absl::string_view message,
+ SourceLocation location);
+
+} // namespace iree
+
+#endif // IREE_BASE_INTERNAL_STATUS_ERRNO_H_
diff --git a/base/internal/status_errors.cc b/base/internal/status_errors.cc
new file mode 100644
index 0000000..ae695e5
--- /dev/null
+++ b/base/internal/status_errors.cc
@@ -0,0 +1,147 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/internal/status_errors.h"
+
+namespace iree {
+
+Status AbortedError(absl::string_view message) {
+ return Status(StatusCode::kAborted, message);
+}
+
+Status AlreadyExistsError(absl::string_view message) {
+ return Status(StatusCode::kAlreadyExists, message);
+}
+
+Status CancelledError(absl::string_view message) {
+ return Status(StatusCode::kCancelled, message);
+}
+
+Status DataLossError(absl::string_view message) {
+ return Status(StatusCode::kDataLoss, message);
+}
+
+Status DeadlineExceededError(absl::string_view message) {
+ return Status(StatusCode::kDeadlineExceeded, message);
+}
+
+Status FailedPreconditionError(absl::string_view message) {
+ return Status(StatusCode::kFailedPrecondition, message);
+}
+
+Status InternalError(absl::string_view message) {
+ return Status(StatusCode::kInternal, message);
+}
+
+Status InvalidArgumentError(absl::string_view message) {
+ return Status(StatusCode::kInvalidArgument, message);
+}
+
+Status NotFoundError(absl::string_view message) {
+ return Status(StatusCode::kNotFound, message);
+}
+
+Status OutOfRangeError(absl::string_view message) {
+ return Status(StatusCode::kOutOfRange, message);
+}
+
+Status PermissionDeniedError(absl::string_view message) {
+ return Status(StatusCode::kPermissionDenied, message);
+}
+
+Status ResourceExhaustedError(absl::string_view message) {
+ return Status(StatusCode::kResourceExhausted, message);
+}
+
+Status UnauthenticatedError(absl::string_view message) {
+ return Status(StatusCode::kUnauthenticated, message);
+}
+
+Status UnavailableError(absl::string_view message) {
+ return Status(StatusCode::kUnavailable, message);
+}
+
+Status UnimplementedError(absl::string_view message) {
+ return Status(StatusCode::kUnimplemented, message);
+}
+
+Status UnknownError(absl::string_view message) {
+ return Status(StatusCode::kUnknown, message);
+}
+
+bool IsAborted(const Status& status) {
+ return status.code() == StatusCode::kAborted;
+}
+
+bool IsAlreadyExists(const Status& status) {
+ return status.code() == StatusCode::kAlreadyExists;
+}
+
+bool IsCancelled(const Status& status) {
+ return status.code() == StatusCode::kCancelled;
+}
+
+bool IsDataLoss(const Status& status) {
+ return status.code() == StatusCode::kDataLoss;
+}
+
+bool IsDeadlineExceeded(const Status& status) {
+ return status.code() == StatusCode::kDeadlineExceeded;
+}
+
+bool IsFailedPrecondition(const Status& status) {
+ return status.code() == StatusCode::kFailedPrecondition;
+}
+
+bool IsInternal(const Status& status) {
+ return status.code() == StatusCode::kInternal;
+}
+
+bool IsInvalidArgument(const Status& status) {
+ return status.code() == StatusCode::kInvalidArgument;
+}
+
+bool IsNotFound(const Status& status) {
+ return status.code() == StatusCode::kNotFound;
+}
+
+bool IsOutOfRange(const Status& status) {
+ return status.code() == StatusCode::kOutOfRange;
+}
+
+bool IsPermissionDenied(const Status& status) {
+ return status.code() == StatusCode::kPermissionDenied;
+}
+
+bool IsResourceExhausted(const Status& status) {
+ return status.code() == StatusCode::kResourceExhausted;
+}
+
+bool IsUnauthenticated(const Status& status) {
+ return status.code() == StatusCode::kUnauthenticated;
+}
+
+bool IsUnavailable(const Status& status) {
+ return status.code() == StatusCode::kUnavailable;
+}
+
+bool IsUnimplemented(const Status& status) {
+ return status.code() == StatusCode::kUnimplemented;
+}
+
+bool IsUnknown(const Status& status) {
+ return status.code() == StatusCode::kUnknown;
+}
+
+} // namespace iree
diff --git a/base/internal/status_errors.h b/base/internal/status_errors.h
new file mode 100644
index 0000000..fee3f0f
--- /dev/null
+++ b/base/internal/status_errors.h
@@ -0,0 +1,60 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INTERNAL_STATUS_ERRORS_H_
+#define IREE_BASE_INTERNAL_STATUS_ERRORS_H_
+
+#include "absl/base/attributes.h"
+#include "absl/strings/string_view.h"
+#include "base/internal/status.h"
+
+namespace iree {
+
+Status AbortedError(absl::string_view message);
+Status AlreadyExistsError(absl::string_view message);
+Status CancelledError(absl::string_view message);
+Status DataLossError(absl::string_view message);
+Status DeadlineExceededError(absl::string_view message);
+Status FailedPreconditionError(absl::string_view message);
+Status InternalError(absl::string_view message);
+Status InvalidArgumentError(absl::string_view message);
+Status NotFoundError(absl::string_view message);
+Status OutOfRangeError(absl::string_view message);
+Status PermissionDeniedError(absl::string_view message);
+Status ResourceExhaustedError(absl::string_view message);
+Status UnauthenticatedError(absl::string_view message);
+Status UnavailableError(absl::string_view message);
+Status UnimplementedError(absl::string_view message);
+Status UnknownError(absl::string_view message);
+
+ABSL_MUST_USE_RESULT bool IsAborted(const Status& status);
+ABSL_MUST_USE_RESULT bool IsAlreadyExists(const Status& status);
+ABSL_MUST_USE_RESULT bool IsCancelled(const Status& status);
+ABSL_MUST_USE_RESULT bool IsDataLoss(const Status& status);
+ABSL_MUST_USE_RESULT bool IsDeadlineExceeded(const Status& status);
+ABSL_MUST_USE_RESULT bool IsFailedPrecondition(const Status& status);
+ABSL_MUST_USE_RESULT bool IsInternal(const Status& status);
+ABSL_MUST_USE_RESULT bool IsInvalidArgument(const Status& status);
+ABSL_MUST_USE_RESULT bool IsNotFound(const Status& status);
+ABSL_MUST_USE_RESULT bool IsOutOfRange(const Status& status);
+ABSL_MUST_USE_RESULT bool IsPermissionDenied(const Status& status);
+ABSL_MUST_USE_RESULT bool IsResourceExhausted(const Status& status);
+ABSL_MUST_USE_RESULT bool IsUnauthenticated(const Status& status);
+ABSL_MUST_USE_RESULT bool IsUnavailable(const Status& status);
+ABSL_MUST_USE_RESULT bool IsUnimplemented(const Status& status);
+ABSL_MUST_USE_RESULT bool IsUnknown(const Status& status);
+
+} // namespace iree
+
+#endif // IREE_BASE_INTERNAL_STATUS_ERRORS_H_
diff --git a/base/internal/status_macros.h b/base/internal/status_macros.h
new file mode 100644
index 0000000..bcae07f
--- /dev/null
+++ b/base/internal/status_macros.h
@@ -0,0 +1,108 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INTERNAL_STATUS_MACROS_H_
+#define IREE_BASE_INTERNAL_STATUS_MACROS_H_
+
+#include "base/internal/status.h"
+#include "base/internal/status_builder.h"
+#include "base/internal/statusor.h"
+#include "base/source_location.h"
+
+// Evaluates an expression that produces a `iree::Status`. If the status is not
+// ok, returns it from the current function.
+#define RETURN_IF_ERROR(expr) \
+ STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
+ if (iree::status_macro_internal::StatusAdaptorForMacros \
+ status_macro_internal_adaptor = {(expr), IREE_LOC}) { \
+ } else /* NOLINT */ \
+ return status_macro_internal_adaptor.Consume()
+
+// Executes an expression `rexpr` that returns a `iree::StatusOr<T>`. On OK,
+// moves its value into the variable defined by `lhs`, otherwise returns
+// from the current function.
+#define ASSIGN_OR_RETURN(...) \
+ STATUS_MACROS_IMPL_GET_VARIADIC_((__VA_ARGS__, \
+ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \
+ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \
+ (__VA_ARGS__)
+
+// =================================================================
+// == Implementation details, do not rely on anything below here. ==
+// =================================================================
+
+// MSVC incorrectly expands variadic macros, splice together a macro call to
+// work around the bug.
+#define STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME
+#define STATUS_MACROS_IMPL_GET_VARIADIC_(args) \
+ STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args
+
+#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \
+ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, std::move(_))
+#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \
+ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \
+ STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \
+ error_expression)
+#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \
+ error_expression) \
+ auto statusor = (rexpr); \
+ if (ABSL_PREDICT_FALSE(!statusor.ok())) { \
+ iree::StatusBuilder _(std::move(statusor).status(), IREE_LOC); \
+ (void)_; /* error_expression is allowed to not use this variable */ \
+ return (error_expression); \
+ } \
+ lhs = std::move(statusor).ValueOrDie()
+
+// Internal helper for concatenating macro values.
+#define STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
+#define STATUS_MACROS_IMPL_CONCAT_(x, y) STATUS_MACROS_IMPL_CONCAT_INNER_(x, y)
+
+// clang-format off
+#define STATUS_MACROS_IMPL_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT
+// clang-format on
+
+namespace iree {
+namespace status_macro_internal {
+
+// Provides a conversion to bool so that it can be used inside an if statement
+// that declares a variable.
+class StatusAdaptorForMacros {
+ public:
+ StatusAdaptorForMacros(const Status& status, SourceLocation loc)
+ : builder_(status, loc) {}
+
+ StatusAdaptorForMacros(Status&& status, SourceLocation loc)
+ : builder_(std::move(status), loc) {}
+
+ StatusAdaptorForMacros(const StatusBuilder& builder, SourceLocation loc)
+ : builder_(builder) {}
+
+ StatusAdaptorForMacros(StatusBuilder&& builder, SourceLocation loc)
+ : builder_(std::move(builder)) {}
+
+ StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete;
+ StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete;
+
+ explicit operator bool() const { return ABSL_PREDICT_TRUE(builder_.ok()); }
+
+ StatusBuilder&& Consume() { return std::move(builder_); }
+
+ private:
+ StatusBuilder builder_;
+};
+
+} // namespace status_macro_internal
+} // namespace iree
+
+#endif // IREE_BASE_INTERNAL_STATUS_MACROS_H_
diff --git a/base/internal/status_matchers.h b/base/internal/status_matchers.h
new file mode 100644
index 0000000..f951e17
--- /dev/null
+++ b/base/internal/status_matchers.h
@@ -0,0 +1,299 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INTERNAL_STATUS_MATCHERS_H_
+#define IREE_BASE_INTERNAL_STATUS_MATCHERS_H_
+
+#include <memory>
+
+#include "absl/strings/str_cat.h"
+#include "absl/types/optional.h"
+#include "base/status.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#undef EXPECT_OK
+#undef ASSERT_OK
+#undef ASSERT_OK_AND_ASSIGN
+
+namespace iree {
+
+namespace internal {
+
+// Implements a gMock matcher that checks that an iree::StaturOr<T> has an OK
+// status and that the contained T value matches another matcher.
+template <typename T>
+class IsOkAndHoldsMatcher
+ : public ::testing::MatcherInterface<const StatusOr<T> &> {
+ public:
+ template <typename MatcherT>
+ IsOkAndHoldsMatcher(MatcherT &&value_matcher)
+ : value_matcher_(::testing::SafeMatcherCast<const T &>(value_matcher)) {}
+
+ // From testing::MatcherInterface.
+ void DescribeTo(std::ostream *os) const override {
+ *os << "is OK and contains a value that ";
+ value_matcher_.DescribeTo(os);
+ }
+
+ // From testing::MatcherInterface.
+ void DescribeNegationTo(std::ostream *os) const override {
+ *os << "is not OK or contains a value that ";
+ value_matcher_.DescribeNegationTo(os);
+ }
+
+ // From testing::MatcherInterface.
+ bool MatchAndExplain(
+ const StatusOr<T> &status_or,
+ ::testing::MatchResultListener *listener) const override {
+ if (!status_or.ok()) {
+ *listener << "which is not OK";
+ return false;
+ }
+
+ ::testing::StringMatchResultListener value_listener;
+ bool is_a_match =
+ value_matcher_.MatchAndExplain(status_or.ValueOrDie(), &value_listener);
+ std::string value_explanation = value_listener.str();
+ if (!value_explanation.empty()) {
+ *listener << absl::StrCat("which contains a value ", value_explanation);
+ }
+
+ return is_a_match;
+ }
+
+ private:
+ const ::testing::Matcher<const T &> value_matcher_;
+};
+
+// A polymorphic IsOkAndHolds() matcher.
+//
+// IsOkAndHolds() returns a matcher that can be used to process an IsOkAndHolds
+// expectation. However, the value type T is not provided when IsOkAndHolds() is
+// invoked. The value type is only inferable when the gUnit framework invokes
+// the matcher with a value. Consequently, the IsOkAndHolds() function must
+// return an object that is implicitly convertible to a matcher for StatusOr<T>.
+// gUnit refers to such an object as a polymorphic matcher, since it can be used
+// to match with more than one type of value.
+template <typename ValueMatcherT>
+class IsOkAndHoldsGenerator {
+ public:
+ explicit IsOkAndHoldsGenerator(ValueMatcherT value_matcher)
+ : value_matcher_(std::move(value_matcher)) {}
+
+ template <typename T>
+ operator ::testing::Matcher<const StatusOr<T> &>() const {
+ return ::testing::MakeMatcher(new IsOkAndHoldsMatcher<T>(value_matcher_));
+ }
+
+ private:
+ const ValueMatcherT value_matcher_;
+};
+
+// Implements a gMock matcher for checking error-code expectations on
+// iree::Status objects.
+template <typename Enum>
+class StatusMatcher : public ::testing::MatcherInterface<const Status &> {
+ public:
+ StatusMatcher(Enum code, absl::optional<absl::string_view> message)
+ : code_(code), message_(message) {}
+
+ // From testing::MatcherInterface.
+ //
+ // Describes the expected error code.
+ void DescribeTo(std::ostream *os) const override {
+ *os << "error code " << StatusCodeToString(code_);
+ if (message_.has_value()) {
+ *os << "::'" << message_.value() << "'";
+ }
+ }
+
+ // From testing::MatcherInterface.
+ //
+ // Tests whether |status| has an error code that meets this matcher's
+ // expectation. If an error message string is specified in this matcher, it
+ // also tests that |status| has an error message that matches that
+ // expectation.
+ bool MatchAndExplain(
+ const Status &status,
+ ::testing::MatchResultListener *listener) const override {
+ if (status.code() != code_) {
+ *listener << "whose error code is " << StatusCodeToString(status.code());
+ return false;
+ }
+ if (message_.has_value() && status.message() != message_.value()) {
+ *listener << "whose error message is '" << status.message() << "'";
+ return false;
+ }
+ return true;
+ }
+
+ private:
+ // Expected error code.
+ const Enum code_;
+
+ // Expected error message (empty if none expected and verified).
+ const absl::optional<std::string> message_;
+};
+
+// Implements a gMock matcher that checks whether a status container (e.g.
+// iree::Status or iree::StatusOr<T>) has an OK status.
+template <class T>
+class IsOkMatcherImpl : public ::testing::MatcherInterface<T> {
+ public:
+ IsOkMatcherImpl() = default;
+
+ // From testing::MatcherInterface.
+ //
+ // Describes the OK expectation.
+ void DescribeTo(std::ostream *os) const override { *os << "is OK"; }
+
+ // From testing::MatcherInterface.
+ //
+ // Describes the negative OK expectation.
+ void DescribeNegationTo(std::ostream *os) const override {
+ *os << "is not OK";
+ }
+
+ // From testing::MatcherInterface.
+ //
+ // Tests whether |status_container|'s OK value meets this matcher's
+ // expectation.
+ bool MatchAndExplain(
+ const T &status_container,
+ ::testing::MatchResultListener *listener) const override {
+ if (!status_container.ok()) {
+ *listener << "which is not OK";
+ return false;
+ }
+ return true;
+ }
+};
+
+// IsOkMatcherGenerator is an intermediate object returned by iree::IsOk().
+// It implements implicit type-cast operators to supported matcher types:
+// Matcher<const Status &> and Matcher<const StatusOr<T> &>. These typecast
+// operators create gMock matchers that test OK expectations on a status
+// container.
+class IsOkMatcherGenerator {
+ public:
+ // Type-cast operator for Matcher<const iree::Status &>.
+ operator ::testing::Matcher<const Status &>() const {
+ return ::testing::MakeMatcher(
+ new internal::IsOkMatcherImpl<const Status &>());
+ }
+
+ // Type-cast operator for Matcher<const iree::StatusOr<T> &>.
+ template <class T>
+ operator ::testing::Matcher<const StatusOr<T> &>() const {
+ return ::testing::MakeMatcher(
+ new internal::IsOkMatcherImpl<const StatusOr<T> &>());
+ }
+};
+
+} // namespace internal
+
+// Returns a gMock matcher that expects an iree::StatusOr<T> object to have an
+// OK status and for the contained T object to match |value_matcher|.
+//
+// Example:
+//
+// StatusOr<string> raven_speech_result = raven.Speak();
+// EXPECT_THAT(raven_speech_result, IsOkAndHolds(HasSubstr("nevermore")));
+//
+// If foo is an object of type T and foo_result is an object of type
+// StatusOr<T>, you can write:
+//
+// EXPECT_THAT(foo_result, IsOkAndHolds(foo));
+//
+// instead of:
+//
+// EXPECT_THAT(foo_result, IsOkAndHolds(Eq(foo)));
+template <typename ValueMatcherT>
+internal::IsOkAndHoldsGenerator<ValueMatcherT> IsOkAndHolds(
+ ValueMatcherT value_matcher) {
+ return internal::IsOkAndHoldsGenerator<ValueMatcherT>(value_matcher);
+}
+
+// Returns a gMock matcher that expects an iree::Status object to have the
+// given |code|.
+template <typename Enum>
+::testing::Matcher<const Status &> StatusIs(Enum code) {
+ return ::testing::MakeMatcher(
+ new internal::StatusMatcher<Enum>(code, absl::nullopt));
+}
+
+// Returns a gMock matcher that expects an iree::Status object to have the
+// given |code| and |message|.
+template <typename Enum>
+::testing::Matcher<const Status &> StatusIs(Enum code,
+ absl::string_view message) {
+ return ::testing::MakeMatcher(
+ new internal::StatusMatcher<Enum>(code, message));
+}
+
+// Returns an internal::IsOkMatcherGenerator, which may be typecast to a
+// Matcher<iree::Status> or Matcher<iree::StatusOr<T>>. These gMock
+// matchers test that a given status container has an OK status.
+inline internal::IsOkMatcherGenerator IsOk() {
+ return internal::IsOkMatcherGenerator();
+}
+
+// Macros for testing the results of functions that return iree::Status or
+// iree::StatusOr<T> (for any type T).
+#define EXPECT_OK(rexpr) EXPECT_THAT(rexpr, ::iree::IsOk())
+#define ASSERT_OK(rexpr) ASSERT_THAT(rexpr, ::iree::IsOk())
+
+// Executes an expression that returns an iree::StatusOr<T>, and assigns the
+// contained variable to lhs if the error code is OK.
+// If the Status is non-OK, generates a test failure and returns from the
+// current function, which must have a void return type.
+//
+// Example: Assigning to an existing value
+// ASSERT_OK_AND_ASSIGN(ValueType value, MaybeGetValue(arg));
+//
+// The value assignment example might expand into:
+// StatusOr<ValueType> status_or_value = MaybeGetValue(arg);
+// ASSERT_OK(status_or_value.status());
+// ValueType value = status_or_value.ValueOrDie();
+#define ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
+ IREE_ASSERT_OK_AND_ASSIGN_IMPL( \
+ IREE_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
+ rexpr);
+
+#define IREE_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \
+ auto statusor = (rexpr); \
+ ASSERT_OK(statusor.status()) << statusor.status(); \
+ lhs = std::move(statusor.ValueOrDie())
+#define IREE_STATUS_MACROS_CONCAT_NAME(x, y) \
+ IREE_STATUS_MACROS_CONCAT_IMPL(x, y)
+#define IREE_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
+
+// Implements the PrintTo() method for iree::StatusOr<T>. This method is
+// used by gUnit to print iree::StatusOr<T> objects for debugging. The
+// implementation relies on gUnit for printing values of T when a
+// iree::StatusOr<T> object is OK and contains a value.
+template <typename T>
+void PrintTo(const StatusOr<T> &statusor, std::ostream *os) {
+ if (!statusor.ok()) {
+ *os << statusor.status();
+ } else {
+ *os << absl::StrCat("OK: ",
+ ::testing::PrintToString(statusor.ValueOrDie()));
+ }
+}
+
+} // namespace iree
+
+#endif // IREE_BASE_INTERNAL_STATUS_MATCHERS_H_
diff --git a/base/internal/status_win32_errors.cc b/base/internal/status_win32_errors.cc
new file mode 100644
index 0000000..f09009c
--- /dev/null
+++ b/base/internal/status_win32_errors.cc
@@ -0,0 +1,63 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/internal/status_win32_errors.h"
+
+#include "absl/strings/str_cat.h"
+
+#if defined(IREE_PLATFORM_WINDOWS)
+
+#include <windows.h>
+
+namespace iree {
+
+StatusCode Win32ErrorToCanonicalCode(uint32_t error) {
+ switch (error) {
+ case ERROR_SUCCESS:
+ return StatusCode::kOk;
+ case ERROR_FILE_NOT_FOUND:
+ case ERROR_PATH_NOT_FOUND:
+ return StatusCode::kNotFound;
+ case ERROR_TOO_MANY_OPEN_FILES:
+ case ERROR_OUTOFMEMORY:
+ case ERROR_HANDLE_DISK_FULL:
+ case ERROR_HANDLE_EOF:
+ return StatusCode::kResourceExhausted;
+ case ERROR_ACCESS_DENIED:
+ return StatusCode::kPermissionDenied;
+ case ERROR_INVALID_HANDLE:
+ return StatusCode::kInvalidArgument;
+ case ERROR_NOT_READY:
+ case ERROR_READ_FAULT:
+ return StatusCode::kUnavailable;
+ case ERROR_WRITE_FAULT:
+ return StatusCode::kDataLoss;
+ case ERROR_NOT_SUPPORTED:
+ return StatusCode::kUnimplemented;
+ default:
+ return StatusCode::kUnknown;
+ }
+}
+
+StatusBuilder Win32ErrorToCanonicalStatusBuilder(uint32_t error,
+ SourceLocation location) {
+ // TODO(benvanik): use FormatMessage; or defer until required?
+ return StatusBuilder(
+ Status(Win32ErrorToCanonicalCode(error), absl::StrCat("<TBD>", error)),
+ location);
+}
+
+} // namespace iree
+
+#endif // IREE_PLATFORM_WINDOWS
diff --git a/base/internal/status_win32_errors.h b/base/internal/status_win32_errors.h
new file mode 100644
index 0000000..33ce74f
--- /dev/null
+++ b/base/internal/status_win32_errors.h
@@ -0,0 +1,38 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_
+#define IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_
+
+#include "absl/strings/string_view.h"
+#include "base/internal/statusor.h"
+#include "base/source_location.h"
+#include "base/target_platform.h"
+
+#if defined(IREE_PLATFORM_WINDOWS)
+
+namespace iree {
+
+// Returns the code for |error| which should be a Win32 error dword.
+StatusCode Win32ErrorToCanonicalCode(uint32_t error);
+
+// Returns a StatusBuilder with a status describing the |error| and |location|.
+StatusBuilder Win32ErrorToCanonicalStatusBuilder(uint32_t error,
+ SourceLocation location);
+
+} // namespace iree
+
+#endif // IREE_PLATFORM_WINDOWS
+
+#endif // IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_
diff --git a/base/internal/statusor.cc b/base/internal/statusor.cc
new file mode 100644
index 0000000..0550241
--- /dev/null
+++ b/base/internal/statusor.cc
@@ -0,0 +1,39 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/internal/statusor.h"
+
+#include "base/internal/status_errors.h"
+
+namespace iree {
+
+namespace internal_statusor {
+
+void Helper::HandleInvalidStatusCtorArg(Status* status) {
+ const char* kMessage =
+ "An OK status is not a valid constructor argument to StatusOr<T>";
+ LOG(ERROR) << kMessage;
+ *status = InternalError(kMessage);
+ abort();
+}
+
+void Helper::Crash(const Status& status) {
+ LOG(FATAL) << "Attempting to fetch value instead of handling error "
+ << status;
+ abort();
+}
+
+} // namespace internal_statusor
+
+} // namespace iree
diff --git a/base/internal/statusor.h b/base/internal/statusor.h
new file mode 100644
index 0000000..895a8dd
--- /dev/null
+++ b/base/internal/statusor.h
@@ -0,0 +1,699 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_INTERNAL_STATUSOR_H_
+#define IREE_BASE_INTERNAL_STATUSOR_H_
+
+#include "absl/base/attributes.h"
+#include "base/internal/status.h"
+#include "base/internal/status_builder.h"
+
+namespace iree {
+
+template <typename T>
+class ABSL_MUST_USE_RESULT StatusOr;
+
+namespace internal_statusor {
+
+template <typename T, typename U>
+using IsStatusOrConversionAmbiguous =
+ absl::disjunction<std::is_constructible<T, StatusOr<U>&>,
+ std::is_constructible<T, const StatusOr<U>&>,
+ std::is_constructible<T, StatusOr<U>&&>,
+ std::is_constructible<T, const StatusOr<U>&&>,
+ std::is_convertible<StatusOr<U>&, T>,
+ std::is_convertible<const StatusOr<U>&, T>,
+ std::is_convertible<StatusOr<U>&&, T>,
+ std::is_convertible<const StatusOr<U>&&, T>>;
+
+template <typename T, typename U>
+using IsStatusOrConversionAssigmentAmbiguous =
+ absl::disjunction<IsStatusOrConversionAmbiguous<T, U>,
+ std::is_assignable<T&, StatusOr<U>&>,
+ std::is_assignable<T&, const StatusOr<U>&>,
+ std::is_assignable<T&, StatusOr<U>&&>,
+ std::is_assignable<T&, const StatusOr<U>&&>>;
+
+template <typename T, typename U>
+struct IsAmbiguousStatusOrForInitialization
+ : // Strip const-value refs from type and check again, else false_type.
+ public absl::conditional_t<
+ std::is_same<absl::remove_cv_t<absl::remove_reference_t<U>>,
+ U>::value,
+ std::false_type,
+ IsAmbiguousStatusOrForInitialization<
+ T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
+
+template <typename T, typename U>
+struct IsAmbiguousStatusOrForInitialization<T, StatusOr<U>>
+ : public IsStatusOrConversionAmbiguous<T, U> {};
+
+template <typename T, typename U>
+using IsStatusOrDirectInitializationAmbiguous = absl::disjunction<
+ std::is_same<StatusOr<T>, absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<Status, absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<StatusBuilder, absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<absl::in_place_t,
+ absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ IsAmbiguousStatusOrForInitialization<T, U>>;
+
+template <typename T, typename U>
+using IsStatusOrDirectInitializationValid = absl::disjunction<
+ // The is_same allows nested status ors to ignore this check iff same type.
+ std::is_same<T, absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ absl::negation<IsStatusOrDirectInitializationAmbiguous<T, U>>>;
+
+class Helper {
+ public:
+ ABSL_ATTRIBUTE_NORETURN static void HandleInvalidStatusCtorArg(Status*);
+ ABSL_ATTRIBUTE_NORETURN static void Crash(const Status& status);
+};
+
+// Construct an instance of T in `p` through placement new, passing Args... to
+// the constructor.
+// This abstraction is here mostly for the gcc performance fix.
+template <typename T, typename... Args>
+void PlacementNew(void* p, Args&&... args) {
+#if defined(__GNUC__) && !defined(__clang__)
+ // Teach gcc that 'p' cannot be null, fixing code size issues.
+ if (p == nullptr) __builtin_unreachable();
+#endif
+ new (p) T(std::forward<Args>(args)...);
+}
+
+// Helper base class to hold the data and all operations.
+// We move all this to a base class to allow mixing with the appropriate
+// TraitsBase specialization.
+template <typename T>
+class StatusOrData {
+ template <typename U>
+ friend class StatusOrData;
+
+ public:
+ StatusOrData() = delete;
+
+ StatusOrData(const StatusOrData& other) {
+ if (other.ok()) {
+ MakeValue(other.data_);
+ MakeStatus();
+ } else {
+ MakeStatus(other.status_);
+ }
+ }
+
+ StatusOrData(StatusOrData&& other) noexcept {
+ if (other.ok()) {
+ MakeValue(std::move(other.data_));
+ MakeStatus();
+ } else {
+ MakeStatus(std::move(other.status_));
+ }
+ }
+
+ template <typename U>
+ explicit StatusOrData(const StatusOrData<U>& other) {
+ if (other.ok()) {
+ MakeValue(other.data_);
+ MakeStatus();
+ } else {
+ MakeStatus(other.status_);
+ }
+ }
+
+ template <typename U>
+ explicit StatusOrData(StatusOrData<U>&& other) {
+ if (other.ok()) {
+ MakeValue(std::move(other.data_));
+ MakeStatus();
+ } else {
+ MakeStatus(std::move(other.status_));
+ }
+ }
+
+ template <typename... Args>
+ explicit StatusOrData(absl::in_place_t, Args&&... args)
+ : data_(std::forward<Args>(args)...) {
+ MakeStatus();
+ }
+
+ explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); }
+ explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); }
+
+ explicit StatusOrData(const Status& status) : status_(status) {
+ EnsureNotOk();
+ }
+ explicit StatusOrData(Status&& status) : status_(status) { EnsureNotOk(); }
+
+ explicit StatusOrData(const StatusBuilder& builder) : status_(builder) {
+ EnsureNotOk();
+ }
+ explicit StatusOrData(StatusBuilder&& builder) : status_(std::move(builder)) {
+ EnsureNotOk();
+ }
+
+ StatusOrData& operator=(const StatusOrData& other) {
+ if (this == &other) return *this;
+ if (other.ok())
+ Assign(other.data_);
+ else
+ Assign(other.status_);
+ return *this;
+ }
+
+ StatusOrData& operator=(StatusOrData&& other) {
+ if (this == &other) return *this;
+ if (other.ok())
+ Assign(std::move(other.data_));
+ else
+ Assign(std::move(other.status_));
+ return *this;
+ }
+
+ ~StatusOrData() {
+ if (ok()) {
+ status_.~Status();
+ data_.~T();
+ } else {
+ status_.~Status();
+ }
+ }
+
+ void Assign(const T& value) {
+ if (ok()) {
+ data_.~T();
+ MakeValue(value);
+ } else {
+ MakeValue(value);
+ status_ = OkStatus();
+ }
+ }
+
+ void Assign(T&& value) {
+ if (ok()) {
+ data_.~T();
+ MakeValue(std::move(value));
+ } else {
+ MakeValue(std::move(value));
+ status_ = OkStatus();
+ }
+ }
+
+ void Assign(const Status& status) {
+ Clear();
+ status_ = status;
+ EnsureNotOk();
+ }
+
+ void Assign(Status&& status) {
+ Clear();
+ status_ = std::move(status);
+ EnsureNotOk();
+ }
+
+ bool ok() const { return status_.ok(); }
+
+ protected:
+ // status_ will always be active after the constructor.
+ // Union to be able to initialize exactly how we need without waste.
+ // Eg. in the copy constructor we use the default constructor of Status in
+ // the ok() path to avoid an extra Ref call.
+ union {
+ Status status_;
+ };
+
+ // data_ is active iff status_.ok()==true
+ struct Dummy {};
+ union {
+ // When T is const, we need some non-const object we can cast to void* for
+ // the placement new. dummy_ is that object.
+ Dummy dummy_;
+ T data_;
+ };
+
+ void Clear() {
+ if (ok()) data_.~T();
+ }
+
+ void EnsureOk() const {
+ if (!ok()) Helper::Crash(status_);
+ }
+
+ void EnsureNotOk() {
+ if (ok()) Helper::HandleInvalidStatusCtorArg(&status_);
+ }
+
+ // Construct the value (data_) through placement new with the passed arg.
+ template <typename Arg>
+ void MakeValue(Arg&& arg) {
+ internal_statusor::PlacementNew<T>(&dummy_, std::forward<Arg>(arg));
+ }
+
+ // Construct the status (status_) through placement new with the passed arg.
+ template <typename... Args>
+ void MakeStatus(Args&&... args) {
+ internal_statusor::PlacementNew<Status>(&status_,
+ std::forward<Args>(args)...);
+ }
+};
+
+// Helper base class to allow implicitly deleted constructors and assignment
+// operations in StatusOr.
+// TraitsBase will explicitly delete what it can't support and StatusOr will
+// inherit that behavior implicitly.
+template <bool Copy, bool Move>
+struct TraitsBase {
+ TraitsBase() = default;
+ TraitsBase(const TraitsBase&) = default;
+ TraitsBase(TraitsBase&&) = default;
+ TraitsBase& operator=(const TraitsBase&) = default;
+ TraitsBase& operator=(TraitsBase&&) = default;
+};
+
+template <>
+struct TraitsBase<false, true> {
+ TraitsBase() = default;
+ TraitsBase(const TraitsBase&) = delete;
+ TraitsBase(TraitsBase&&) = default;
+ TraitsBase& operator=(const TraitsBase&) = delete;
+ TraitsBase& operator=(TraitsBase&&) = default;
+};
+
+template <>
+struct TraitsBase<false, false> {
+ TraitsBase() = default;
+ TraitsBase(const TraitsBase&) = delete;
+ TraitsBase(TraitsBase&&) = delete;
+ TraitsBase& operator=(const TraitsBase&) = delete;
+ TraitsBase& operator=(TraitsBase&&) = delete;
+};
+
+} // namespace internal_statusor
+
+// StatusOr<T> is the union of a Status object and a T object.
+//
+// A StatusOr object either holds a usable value, or an error Status explaining
+// why such a value is not present.
+template <typename T>
+class StatusOr : private internal_statusor::StatusOrData<T>,
+ private internal_statusor::TraitsBase<
+ std::is_copy_constructible<T>::value,
+ std::is_move_constructible<T>::value> {
+ template <typename U>
+ friend class StatusOr;
+
+ typedef internal_statusor::StatusOrData<T> Base;
+
+ public:
+ typedef T element_type;
+
+ // Constructs a new StatusOr with StatusCode::kUnknown status.
+ explicit StatusOr();
+
+ // StatusOr<T> is copy constructible/assignable if T is copy constructible.
+ StatusOr(const StatusOr&) = default;
+ StatusOr& operator=(const StatusOr&) = default;
+
+ // StatusOr<T> is move constructible/assignable if T is move constructible.
+ StatusOr(StatusOr&&) = default;
+ StatusOr& operator=(StatusOr&&) = default;
+
+ // Converting constructors from StatusOr<U>, when T is constructible from U.
+ // To avoid ambiguity, they are disabled if T is also constructible from
+ // StatusOr<U>. Explicit iff the corresponding construction of T from U is
+ // explicit.
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>,
+ std::is_constructible<T, const U&>,
+ std::is_convertible<const U&, T>,
+ absl::negation<internal_statusor::IsStatusOrConversionAmbiguous<
+ T, U>>>::value,
+ int> = 0>
+ StatusOr(const StatusOr<U>& other) // NOLINT
+ : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>,
+ std::is_constructible<T, const U&>,
+ absl::negation<std::is_convertible<const U&, T>>,
+ absl::negation<internal_statusor::IsStatusOrConversionAmbiguous<
+ T, U>>>::value,
+ int> = 0>
+ explicit StatusOr(const StatusOr<U>& other)
+ : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
+
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
+ std::is_convertible<U&&, T>,
+ absl::negation<internal_statusor::IsStatusOrConversionAmbiguous<
+ T, U>>>::value,
+ int> = 0>
+ StatusOr(StatusOr<U>&& other) // NOLINT
+ : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
+ absl::negation<std::is_convertible<U&&, T>>,
+ absl::negation<internal_statusor::IsStatusOrConversionAmbiguous<
+ T, U>>>::value,
+ int> = 0>
+ explicit StatusOr(StatusOr<U>&& other)
+ : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
+
+ // Conversion copy/move assignment operator, T must be constructible and
+ // assignable from U. Only enable if T cannot be directly assigned from
+ // StatusOr<U>.
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>,
+ std::is_constructible<T, const U&>,
+ std::is_assignable<T, const U&>,
+ absl::negation<
+ internal_statusor::IsStatusOrConversionAssigmentAmbiguous<
+ T, U>>>::value,
+ int> = 0>
+ StatusOr& operator=(const StatusOr<U>& other) {
+ this->Assign(other);
+ return *this;
+ }
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
+ std::is_assignable<T, U&&>,
+ absl::negation<
+ internal_statusor::IsStatusOrConversionAssigmentAmbiguous<
+ T, U>>>::value,
+ int> = 0>
+ StatusOr& operator=(StatusOr<U>&& other) {
+ this->Assign(std::move(other));
+ return *this;
+ }
+
+ // Constructs a new StatusOr with the given value. After calling this
+ // constructor, this->ok() will be true and the contained value may be
+ // retrieved with ValueOrDie(), operator*(), or operator->().
+ StatusOr(const T& value);
+
+ // Constructs a new StatusOr with the given non-ok status. After calling this
+ // constructor, this->ok() will be false and calls to ValueOrDie() will
+ // CHECK-fail.
+ StatusOr(const Status& status);
+ StatusOr& operator=(const Status& status);
+ StatusOr(const StatusBuilder& builder);
+ StatusOr& operator=(const StatusBuilder& builder);
+
+ // Similar to the `const T&` overload.
+ //
+ // REQUIRES: T is move constructible.
+ StatusOr(T&& value);
+
+ // RValue versions of the operations declared above.
+ StatusOr(Status&& status);
+ StatusOr& operator=(Status&& status);
+ StatusOr(StatusBuilder&& builder);
+ StatusOr& operator=(StatusBuilder&& builder);
+
+ // Constructs the inner value T in-place using the provided args, using the
+ // T(args...) constructor.
+ template <typename... Args>
+ explicit StatusOr(absl::in_place_t, Args&&... args);
+ template <typename U, typename... Args>
+ explicit StatusOr(absl::in_place_t, std::initializer_list<U> ilist,
+ Args&&... args);
+
+ // Constructs the inner value T in-place using the provided args, using the
+ // T(U) (direct-initialization) constructor. Only valid if T can be
+ // constructed from a U. Can accept move or copy constructors. Explicit it
+ // U is not convertible to T. To avoid ambiguity, this is disabled if U is
+ // a StatusOr<J>, where J is convertible to T.
+ template <
+ typename U = T,
+ absl::enable_if_t<
+ absl::conjunction<
+ internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>,
+ std::is_constructible<T, U&&>,
+ std::is_convertible<U&&, T>>::value,
+ int> = 0>
+ StatusOr(U&& u) // NOLINT
+ : StatusOr(absl::in_place, std::forward<U>(u)) {}
+
+ template <
+ typename U = T,
+ absl::enable_if_t<
+ absl::conjunction<
+ internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>,
+ std::is_constructible<T, U&&>,
+ absl::negation<std::is_convertible<U&&, T>>>::value,
+ int> = 0>
+ explicit StatusOr(U&& u) // NOLINT
+ : StatusOr(absl::in_place, std::forward<U>(u)) {}
+
+ // Returns this->ok()
+ explicit operator bool() const { return ok(); }
+
+ // Returns this->status().ok()
+ ABSL_MUST_USE_RESULT bool ok() const { return this->status_.ok(); }
+
+ // Returns a reference to our status. If this contains a T, then
+ // returns OkStatus().
+ const Status& status() const&;
+ Status status() &&;
+
+ // Returns a reference to our current value, or CHECK-fails if !this->ok(). If
+ // you have already checked the status using this->ok() or operator bool(),
+ // then you probably want to use operator*() or operator->() to access the
+ // current value instead of ValueOrDie().
+ const T& ValueOrDie() const&;
+ T& ValueOrDie() &;
+ const T&& ValueOrDie() const&&;
+ T&& ValueOrDie() &&;
+
+ // Returns a reference to the current value.
+ //
+ // REQUIRES: this->ok() == true, otherwise the behavior is undefined.
+ const T& operator*() const&;
+ T& operator*() &;
+ const T&& operator*() const&&;
+ T&& operator*() &&;
+
+ // Returns a pointer to the current value.
+ //
+ // REQUIRES: this->ok() == true, otherwise the behavior is undefined.
+ const T* operator->() const;
+ T* operator->();
+
+ // Returns a copy of the current value if this->ok() == true. Otherwise
+ // returns a default value.
+ template <typename U>
+ T value_or(U&& default_value) const&;
+ template <typename U>
+ T value_or(U&& default_value) &&;
+
+ // Ignores any errors. This method does nothing except potentially suppress
+ // complaints from any tools that are checking that errors are not dropped on
+ // the floor.
+ void IgnoreError() const;
+
+ private:
+ using internal_statusor::StatusOrData<T>::Assign;
+ template <typename U>
+ void Assign(const StatusOr<U>& other);
+ template <typename U>
+ void Assign(StatusOr<U>&& other);
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// Implementation details for StatusOr<T>
+
+template <typename T>
+StatusOr<T>::StatusOr() : Base(Status(StatusCode::kUnknown, "")) {}
+
+template <typename T>
+StatusOr<T>::StatusOr(const T& value) : Base(value) {}
+
+template <typename T>
+StatusOr<T>::StatusOr(const Status& status) : Base(status) {}
+
+template <typename T>
+StatusOr<T>::StatusOr(const StatusBuilder& builder) : Base(builder) {}
+
+template <typename T>
+StatusOr<T>& StatusOr<T>::operator=(const Status& status) {
+ this->Assign(status);
+ return *this;
+}
+
+template <typename T>
+StatusOr<T>& StatusOr<T>::operator=(const StatusBuilder& builder) {
+ return *this = static_cast<Status>(builder);
+}
+
+template <typename T>
+StatusOr<T>::StatusOr(T&& value) : Base(std::move(value)) {}
+
+template <typename T>
+StatusOr<T>::StatusOr(Status&& status) : Base(std::move(status)) {}
+
+template <typename T>
+StatusOr<T>::StatusOr(StatusBuilder&& builder) : Base(std::move(builder)) {}
+
+template <typename T>
+StatusOr<T>& StatusOr<T>::operator=(Status&& status) {
+ this->Assign(std::move(status));
+ return *this;
+}
+
+template <typename T>
+StatusOr<T>& StatusOr<T>::operator=(StatusBuilder&& builder) {
+ return *this = static_cast<Status>(std::move(builder));
+}
+
+template <typename T>
+template <typename U>
+inline void StatusOr<T>::Assign(const StatusOr<U>& other) {
+ if (other.ok()) {
+ this->Assign(other.ValueOrDie());
+ } else {
+ this->Assign(other.status());
+ }
+}
+
+template <typename T>
+template <typename U>
+inline void StatusOr<T>::Assign(StatusOr<U>&& other) {
+ if (other.ok()) {
+ this->Assign(std::move(other).ValueOrDie());
+ } else {
+ this->Assign(std::move(other).status());
+ }
+}
+template <typename T>
+template <typename... Args>
+StatusOr<T>::StatusOr(absl::in_place_t, Args&&... args)
+ : Base(absl::in_place, std::forward<Args>(args)...) {}
+
+template <typename T>
+template <typename U, typename... Args>
+StatusOr<T>::StatusOr(absl::in_place_t, std::initializer_list<U> ilist,
+ Args&&... args)
+ : Base(absl::in_place, ilist, std::forward<Args>(args)...) {}
+
+template <typename T>
+const Status& StatusOr<T>::status() const& {
+ return this->status_;
+}
+template <typename T>
+Status StatusOr<T>::status() && {
+ return ok() ? OkStatus() : std::move(this->status_);
+}
+
+template <typename T>
+const T& StatusOr<T>::ValueOrDie() const& {
+ this->EnsureOk();
+ return this->data_;
+}
+
+template <typename T>
+T& StatusOr<T>::ValueOrDie() & {
+ this->EnsureOk();
+ return this->data_;
+}
+
+template <typename T>
+const T&& StatusOr<T>::ValueOrDie() const&& {
+ this->EnsureOk();
+ return std::move(this->data_);
+}
+
+template <typename T>
+T&& StatusOr<T>::ValueOrDie() && {
+ this->EnsureOk();
+ return std::move(this->data_);
+}
+
+template <typename T>
+const T& StatusOr<T>::operator*() const& {
+ this->EnsureOk();
+ return this->data_;
+}
+
+template <typename T>
+T& StatusOr<T>::operator*() & {
+ this->EnsureOk();
+ return this->data_;
+}
+
+template <typename T>
+const T&& StatusOr<T>::operator*() const&& {
+ this->EnsureOk();
+ return std::move(this->data_);
+}
+
+template <typename T>
+T&& StatusOr<T>::operator*() && {
+ this->EnsureOk();
+ return std::move(this->data_);
+}
+
+template <typename T>
+const T* StatusOr<T>::operator->() const {
+ this->EnsureOk();
+ return &this->data_;
+}
+
+template <typename T>
+T* StatusOr<T>::operator->() {
+ this->EnsureOk();
+ return &this->data_;
+}
+
+template <typename T>
+template <typename U>
+T StatusOr<T>::value_or(U&& default_value) const& {
+ if (ok()) {
+ return this->data_;
+ }
+ return std::forward<U>(default_value);
+}
+
+template <typename T>
+template <typename U>
+T StatusOr<T>::value_or(U&& default_value) && {
+ if (ok()) {
+ return std::move(this->data_);
+ }
+ return std::forward<U>(default_value);
+}
+
+template <typename T>
+void StatusOr<T>::IgnoreError() const {
+ // no-op
+}
+
+} // namespace iree
+
+#endif // IREE_BASE_INTERNAL_STATUSOR_H_
diff --git a/base/intrusive_list.h b/base/intrusive_list.h
new file mode 100644
index 0000000..506c37b
--- /dev/null
+++ b/base/intrusive_list.h
@@ -0,0 +1,758 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Doubly linked list using element interior storage.
+// This has the performance of std::list (that means O(1) on insert and remove)
+// but performs no allocations and has better caching behavior.
+//
+// Elements are maintained in lists by way of IntrusiveListLinks, with each link
+// allowing the element to exist in one list simultaneously. In the most simple
+// case subclassing IntrusiveLinkBase will let the type be added to a list with
+// little boilerplate. If an element must be in more than one list
+// simultaneously IntrusiveListLinks can be added as members.
+//
+// Usage (simple):
+// class MySimpleElement : public IntrusiveLinkBase {};
+// IntrusiveList<MySimpleElement> list;
+// list.push_back(new MySimpleElement());
+// for (auto element : list) { ... }
+//
+// Usage (multiple lists):
+// class MultiElement {
+// public:
+// IntrusiveListLink list_link_a;
+// IntrusiveListLink list_link_b;
+// };
+// IntrusiveList<MultiElement, offsetof(MultiElement, list_link_a)> list_a;
+// IntrusiveList<MultiElement, offsetof(MultiElement, list_link_b)> list_b;
+//
+// By default elements in the list are not retained and must be kept alive
+// externally. For automatic memory management there are specializations for
+// std::unique_ptr.
+//
+// Usage (unique_ptr):
+// IntrusiveList<std::unique_ptr<MyElement>> list;
+// list.push_back(absl::make_unique<MyElement>());
+// std::unique_ptr<MyElement> elm = list.take(list.front());
+//
+// This type is thread-unsafe.
+
+#ifndef IREE_BASE_INTRUSIVE_LIST_H_
+#define IREE_BASE_INTRUSIVE_LIST_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <iterator>
+#include <limits>
+
+#include "base/logging.h"
+
+namespace iree {
+
+// Define to enable extensive checks after each mutation of the intrusive list.
+// #define IREE_PARANOID_INTRUSIVE_LIST
+
+// Storage for the doubly-linked list.
+// This is embedded within all elements in an intrusive list.
+struct IntrusiveListLink {
+ IntrusiveListLink* prev = nullptr;
+ IntrusiveListLink* next = nullptr;
+
+ IntrusiveListLink() = default;
+
+ // Prevent copies.
+ IntrusiveListLink(const IntrusiveListLink&) = delete;
+ IntrusiveListLink& operator=(const IntrusiveListLink&) = delete;
+};
+
+template <class T>
+struct IntrusiveLinkBase : public T {
+ public:
+ IntrusiveListLink link;
+};
+
+template <>
+struct IntrusiveLinkBase<void> {
+ public:
+ IntrusiveListLink link;
+};
+
+// Base type for intrusive lists.
+// This is either used directly when the list is on naked pointers or
+// specialized to std::unique_ptr.
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+class IntrusiveListBase {
+ public:
+ using self_type = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>;
+
+ IntrusiveListBase() = default;
+ virtual ~IntrusiveListBase() { clear(); }
+
+ // Prevent copies.
+ IntrusiveListBase(const IntrusiveListBase&) = delete;
+ IntrusiveListBase& operator=(const IntrusiveListBase&) = delete;
+
+ // Returns true if the list is empty.
+ // Performance: O(1)
+ constexpr bool empty() const { return head_ == nullptr; }
+
+ // Returns the total number of items in the list.
+ // Performance: O(1)
+ constexpr size_t size() const { return count_; }
+
+ // Returns true if the given item is contained within the list.
+ // Performance: O(n)
+ bool contains(T* value) const;
+
+ // Appends the contents of the given list to this one.
+ // The |other_list| is cleared.
+ // Performance: O(1)
+ void merge_from(self_type* other_list);
+
+ // Removes all items from the list.
+ // Performance: O(n)
+ void clear();
+
+ IteratorT begin() const { return IteratorT(head_); }
+ IteratorT end() const { return IteratorT(nullptr); }
+ ReverseIteratorT rbegin() const { return ReverseIteratorT(tail_); }
+ ReverseIteratorT rend() const { return ReverseIteratorT(nullptr); }
+
+ // Returns the next item in the list relative to the given item.
+ // |value| must exist in the list.
+ // Performance: O(1)
+ T* next(T* value) const;
+
+ // Returns the previous item in the list relative to the given item.
+ // |value| must exist in the list.
+ // Performance: O(1)
+ T* previous(T* value) const;
+
+ // Returns the item at the front of the list, if any.
+ // Performance: O(1)
+ T* front() const;
+
+ // Inserts an item at the front of the list.
+ // Performance: O(1)
+ void push_front(T* value);
+
+ // Removes the item at the front of the list.
+ // Performance: O(1)
+ void pop_front();
+
+ // Returns the item at the back of the list, if any.
+ // Performance: O(1)
+ T* back() const;
+
+ // Inserts an item at the back of the list.
+ // Performance: O(1)
+ void push_back(T* value);
+
+ // Removes the item at the back of the list.
+ // Performance: O(1)
+ void pop_back();
+
+ // Inserts an item into the list before the given iterator.
+ // Performance: O(1)
+ void insert(const IteratorT& it, T* value) { return insert(*it, value); }
+ void insert(T* position, T* value);
+
+ // Erases the given item from the list.
+ // Returns the item following the erased item, if any.
+ // Performance: O(1)
+ T* erase(T* value);
+
+ // Erases the item from the list at the given iterator.
+ // Performance: O(1)
+ IteratorT erase(const IteratorT& it);
+ ReverseIteratorT erase(const ReverseIteratorT& it);
+
+ // Replaces the item with a new item at the same position.
+ // |new_value| must not be contained in any list.
+ // Performance: O(1)
+ void replace(T* old_value, T* new_value);
+
+ // Sorts the list with the given comparison function.
+ // The sort function is the same as used by std::sort.
+ //
+ // Uses merge sort O(N log N) using the algorithm described here:
+ // http://www.chiark.greenend.org.uk/~sgtatham/algorithms/listsort.html
+ void sort(bool (*compare_fn)(T* a, T* b));
+
+ protected:
+ // Called when an item is added to the list.
+ virtual void OnAdd(T* value) {}
+ // Called when an item is removed from the list.
+ virtual void OnRemove(T* value) {}
+ // Called when an item is removed and deallocated.
+ virtual void OnDeallocate(T* value) {}
+
+ // Performs expensive correctness checks on the list structure. It's too slow
+ // to use in normal builds (even dbg), so it should only be used when there's
+ // a suspected issue with an intrusive list. Define
+ // IREE_PARANOID_INTRUSIVE_LIST to enable.
+ void CheckCorrectness() const;
+
+ IntrusiveListLink* head_ = nullptr;
+ IntrusiveListLink* tail_ = nullptr;
+ size_t count_ = 0;
+};
+
+// Basic iterator for an IntrusiveList.
+template <typename T, size_t kOffset, bool kForward>
+class IntrusiveListIterator
+ : public std::iterator<std::input_iterator_tag, int> {
+ public:
+ using self_type = IntrusiveListIterator<T, kOffset, kForward>;
+
+ explicit IntrusiveListIterator(IntrusiveListLink* current)
+ : current_(current) {}
+ IntrusiveListIterator& operator++();
+ self_type operator++(int);
+ self_type& operator--();
+ self_type operator--(int);
+ bool operator==(const self_type& rhs) const;
+ bool operator!=(const self_type& rhs) const;
+ T* operator*() const;
+
+ protected:
+ IntrusiveListLink* current_;
+};
+
+// Specialized IntrusiveListBase used for unreferenced naked pointers.
+// This very thinly wraps the base type and does no special memory management.
+template <typename T, size_t kOffset>
+class IntrusiveListUnrefBase
+ : public IntrusiveListBase<T, IntrusiveListIterator<T, kOffset, true>,
+ IntrusiveListIterator<T, kOffset, false>,
+ kOffset> {
+ public:
+ using IteratorT = IntrusiveListIterator<T, kOffset, true>;
+ using ReverseIteratorT = IntrusiveListIterator<T, kOffset, false>;
+ using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>;
+
+ using base_list::clear;
+
+ // Removes all items from the list and calls the given deleter function for
+ // each of them. The built-in OnDeallocate will not be used.
+ // Performance: O(n)
+ void clear(const std::function<void(T*)>& deleter);
+
+ private:
+ using base_list::count_;
+ using base_list::head_;
+ using base_list::tail_;
+};
+
+constexpr size_t kUseDefaultLinkOffset = std::numeric_limits<size_t>::max();
+
+// IntrusiveList for raw pointers with a specified offset.
+// Use this if there are multiple links within a type.
+//
+// Usage:
+// struct MyType {
+// IntrusiveListLink link_a;
+// IntrusiveListLink link_b;
+// };
+// IntrusiveList<MyType, offsetof(MyType, link_a)> list_a;
+// IntrusiveList<MyType, offsetof(MyType, link_b)> list_b;
+template <typename T, size_t kOffset = kUseDefaultLinkOffset>
+class IntrusiveList : public IntrusiveListUnrefBase<T, kOffset> {};
+
+// IntrusiveList for raw pointers.
+// Items added to the list will not be owned by the list and must be freed by
+// the caller.
+//
+// Usage:
+// struct MyType : public IntrusiveListBase<void> {};
+// IntrusiveList<MyType> list;
+// auto* p = new MyType();
+// list.push_back(p); // p is not retained and won't be freed!
+// delete p;
+template <typename T>
+class IntrusiveList<T, kUseDefaultLinkOffset>
+ : public IntrusiveListUnrefBase<T, offsetof(T, link)> {};
+
+// -- implementation --
+
+namespace impl {
+
+// Maps an IntrusiveListLink to its containing type T.
+template <typename T, size_t kOffset>
+static inline T* LinkToT(IntrusiveListLink* link) {
+ if (link) {
+ return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(link) - kOffset);
+ } else {
+ return nullptr;
+ }
+}
+
+// Maps a containing type T to its IntrusiveListLink.
+template <typename T, size_t kOffset>
+static inline IntrusiveListLink* TToLink(T* value) {
+ if (value) {
+ return reinterpret_cast<IntrusiveListLink*>(
+ reinterpret_cast<uintptr_t>(value) + kOffset);
+ } else {
+ return nullptr;
+ }
+}
+
+} // namespace impl
+
+template <typename T, size_t kOffset, bool kForward>
+IntrusiveListIterator<T, kOffset, kForward>&
+IntrusiveListIterator<T, kOffset, kForward>::operator++() {
+ if (current_) {
+ current_ = kForward ? current_->next : current_->prev;
+ }
+ return *this;
+}
+
+template <typename T, size_t kOffset, bool kForward>
+IntrusiveListIterator<T, kOffset, kForward>
+IntrusiveListIterator<T, kOffset, kForward>::operator++(int) {
+ self_type tmp(current_);
+ operator++();
+ return tmp;
+}
+
+template <typename T, size_t kOffset, bool kForward>
+IntrusiveListIterator<T, kOffset, kForward>&
+IntrusiveListIterator<T, kOffset, kForward>::operator--() {
+ if (current_) {
+ current_ = kForward ? current_->prev : current_->next;
+ }
+ return *this;
+}
+
+template <typename T, size_t kOffset, bool kForward>
+IntrusiveListIterator<T, kOffset, kForward>
+IntrusiveListIterator<T, kOffset, kForward>::operator--(int) {
+ self_type tmp(current_);
+ operator--();
+ return tmp;
+}
+
+template <typename T, size_t kOffset, bool kForward>
+bool IntrusiveListIterator<T, kOffset, kForward>::operator==(
+ const self_type& rhs) const {
+ return rhs.current_ == current_;
+}
+
+template <typename T, size_t kOffset, bool kForward>
+bool IntrusiveListIterator<T, kOffset, kForward>::operator!=(
+ const self_type& rhs) const {
+ return !operator==(rhs);
+}
+
+template <typename T, size_t kOffset, bool kForward>
+T* IntrusiveListIterator<T, kOffset, kForward>::operator*() const {
+ return impl::LinkToT<T, kOffset>(current_);
+}
+
+template <typename T, size_t kOffset>
+void IntrusiveListUnrefBase<T, kOffset>::clear(
+ const std::function<void(T*)>& deleter) {
+ auto* link = head_;
+ while (link) {
+ auto* next = link->next;
+ link->prev = link->next = nullptr;
+ deleter(impl::LinkToT<T, kOffset>(link));
+ link = next;
+ }
+ head_ = tail_ = nullptr;
+ count_ = 0;
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+void IntrusiveListBase<T, IteratorT, ReverseIteratorT,
+ kOffset>::CheckCorrectness() const {
+#if defined(IREE_PARANOID_INTRUSIVE_LIST)
+ auto* link = head_;
+ IntrusiveListLink* previous = nullptr;
+ size_t actual_count = 0;
+ while (link) {
+ ++actual_count;
+ if (!link->prev) {
+ DCHECK_EQ(link, head_);
+ }
+ if (!link->next) {
+ DCHECK_EQ(link, tail_);
+ }
+ DCHECK_EQ(link->prev, previous);
+ previous = link;
+ link = link->next;
+ }
+ DCHECK_EQ(actual_count, count_);
+#endif // IREE_PARANOID_INTRUSIVE_LIST
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+bool IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::contains(
+ T* value) const {
+ if (!value) return false;
+ // TODO(benvanik): faster way of checking? requires list ptr in link?
+ auto* needle = impl::TToLink<T, kOffset>(value);
+ auto* link = head_;
+ while (link) {
+ if (link == needle) {
+ return true;
+ }
+ link = link->next;
+ }
+ return false;
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::merge_from(
+ self_type* other_list) {
+ if (tail_) {
+ tail_->next = other_list->head_;
+ }
+ if (other_list->head_) {
+ other_list->head_->prev = tail_;
+ }
+ if (!head_) {
+ head_ = other_list->head_;
+ }
+ tail_ = other_list->tail_;
+
+ other_list->head_ = nullptr;
+ other_list->tail_ = nullptr;
+
+ count_ += other_list->count_;
+ other_list->count_ = 0;
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::clear() {
+ auto* link = head_;
+ while (link) {
+ auto* next = link->next;
+ link->prev = link->next = nullptr;
+ OnDeallocate(impl::LinkToT<T, kOffset>(link));
+ link = next;
+ }
+ head_ = tail_ = nullptr;
+ count_ = 0;
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::next(
+ T* value) const {
+ if (!value) {
+ return nullptr;
+ }
+ auto* link = impl::TToLink<T, kOffset>(value);
+ return impl::LinkToT<T, kOffset>(link->next);
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::previous(
+ T* value) const {
+ if (!value) {
+ return nullptr;
+ }
+ auto* link = impl::TToLink<T, kOffset>(value);
+ return impl::LinkToT<T, kOffset>(link->prev);
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::front()
+ const {
+ return impl::LinkToT<T, kOffset>(head_);
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::push_front(
+ T* value) {
+ DCHECK(value);
+ auto* link = impl::TToLink<T, kOffset>(value);
+ DCHECK(!link->next);
+ DCHECK(!link->prev);
+ link->next = head_;
+ link->prev = nullptr;
+ head_ = link;
+ if (link->next) {
+ link->next->prev = link;
+ }
+ if (!tail_) {
+ tail_ = link;
+ }
+ ++count_;
+ OnAdd(value);
+ CheckCorrectness();
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::pop_front() {
+ DCHECK(head_);
+ auto* link = head_;
+ if (link) {
+ head_ = head_->next;
+ link->next = link->prev = nullptr;
+ if (head_) {
+ head_->prev = nullptr;
+ }
+ if (link == tail_) {
+ tail_ = nullptr;
+ }
+ --count_;
+ OnDeallocate(impl::LinkToT<T, kOffset>(link));
+ }
+ CheckCorrectness();
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::back()
+ const {
+ return impl::LinkToT<T, kOffset>(tail_);
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::push_back(
+ T* value) {
+ DCHECK(value);
+ auto* link = impl::TToLink<T, kOffset>(value);
+ DCHECK(!link->next);
+ DCHECK(!link->prev);
+ link->prev = tail_;
+ link->next = nullptr;
+ tail_ = link;
+ if (link->prev) {
+ link->prev->next = link;
+ }
+ if (!head_) {
+ head_ = link;
+ }
+ ++count_;
+ OnAdd(value);
+ CheckCorrectness();
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::pop_back() {
+ DCHECK(tail_);
+ auto* link = tail_;
+ if (link) {
+ tail_ = tail_->prev;
+ link->next = link->prev = nullptr;
+ if (tail_) {
+ tail_->next = nullptr;
+ }
+ if (link == head_) {
+ head_ = nullptr;
+ }
+ --count_;
+ OnDeallocate(impl::LinkToT<T, kOffset>(link));
+ }
+ CheckCorrectness();
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::insert(
+ T* position, T* value) {
+ DCHECK(value);
+ auto* link = impl::TToLink<T, kOffset>(value);
+ auto* position_link = impl::TToLink<T, kOffset>(position);
+ DCHECK(!link->next);
+ DCHECK(!link->prev);
+
+ if (position_link == head_) {
+ push_front(value);
+ } else if (position_link == nullptr) {
+ push_back(value);
+ } else {
+ link->next = position_link;
+ link->prev = position_link->prev;
+ position_link->prev->next = link;
+ position_link->prev = link;
+ ++count_;
+ OnAdd(value);
+ }
+ CheckCorrectness();
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::erase(T* value) {
+ if (!value) {
+ return nullptr;
+ }
+ auto* link = impl::TToLink<T, kOffset>(value);
+ if (link->prev) {
+ DCHECK_NE(link, head_);
+ link->prev->next = link->next;
+ } else {
+ DCHECK_EQ(link, head_);
+ head_ = link->next;
+ }
+ if (link->next) {
+ DCHECK_NE(link, tail_);
+ link->next->prev = link->prev;
+ } else {
+ DCHECK_EQ(link, tail_);
+ tail_ = link->prev;
+ }
+ auto* next = link->next;
+ link->next = link->prev = nullptr;
+ --count_;
+ OnDeallocate(value);
+ CheckCorrectness();
+ return impl::LinkToT<T, kOffset>(next);
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+IteratorT IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::erase(
+ const IteratorT& it) {
+ return IteratorT(impl::TToLink<T, kOffset>(erase(*it)));
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+ReverseIteratorT IntrusiveListBase<T, IteratorT, ReverseIteratorT,
+ kOffset>::erase(const ReverseIteratorT& it) {
+ return ReverseIteratorT(impl::TToLink<T, kOffset>(erase(*it)));
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::replace(
+ T* old_value, T* new_value) {
+ DCHECK(old_value);
+ DCHECK(new_value);
+ DCHECK_NE(old_value, new_value);
+ auto* old_link = impl::TToLink<T, kOffset>(old_value);
+ auto* new_link = impl::TToLink<T, kOffset>(new_value);
+ new_link->next = old_link->next;
+ new_link->prev = old_link->prev;
+ if (new_link->prev) {
+ new_link->prev->next = new_link;
+ } else {
+ head_ = new_link;
+ }
+ if (new_link->next) {
+ new_link->next->prev = new_link;
+ } else {
+ tail_ = new_link;
+ }
+ old_link->next = old_link->prev = nullptr;
+ OnAdd(new_value);
+ OnDeallocate(old_value);
+ CheckCorrectness();
+}
+
+template <typename T, typename IteratorT, typename ReverseIteratorT,
+ size_t kOffset>
+void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::sort(
+ bool (*compare_fn)(T* a, T* b)) {
+ if (empty()) {
+ // Empty list no-op.
+ return;
+ }
+ // Repeatedly run until the list is sorted.
+ int in_size = 1;
+ while (true) {
+ IntrusiveListLink* p = head_;
+ IntrusiveListLink* q = nullptr;
+ IntrusiveListLink* e = nullptr;
+ IntrusiveListLink* tail = nullptr;
+ head_ = nullptr;
+ tail_ = nullptr;
+ // Repeatedly merge sublists.
+ int merge_count = 0;
+ do {
+ ++merge_count;
+ q = p;
+ // Determine the size of the first part and find the second.
+ int p_size = 0;
+ for (int i = 0; i < in_size; ++i) {
+ ++p_size;
+ q = q->next;
+ if (!q) {
+ break;
+ }
+ }
+ // Merge the two lists (if we have two).
+ int q_size = in_size;
+ while (p_size > 0 || (q_size > 0 && q)) {
+ if (p_size == 0) {
+ // p is empty; e must come from q.
+ e = q;
+ q = q->next;
+ --q_size;
+ } else if (q_size == 0 || !q) {
+ // q is empty; e must come from p.
+ e = p;
+ p = p->next;
+ --p_size;
+ } else if (compare_fn(impl::LinkToT<T, kOffset>(p),
+ impl::LinkToT<T, kOffset>(q))) {
+ // p <= q; e must come from p.
+ e = p;
+ p = p->next;
+ --p_size;
+ } else {
+ // q < p; e must come from q.
+ e = q;
+ q = q->next;
+ --q_size;
+ }
+ // Append e to the merged list.
+ if (tail) {
+ tail->next = e;
+ } else {
+ head_ = e;
+ }
+ e->prev = tail;
+ tail = e;
+ }
+ p = q;
+ } while (p);
+ tail->next = nullptr;
+ if (merge_count <= 1) {
+ // List is now sorted; stash and return.
+ tail_ = tail;
+ CheckCorrectness();
+ return;
+ }
+ // Run merge again with larger lists.
+ in_size *= 2;
+ }
+}
+
+} // namespace iree
+
+// Specializations:
+#include "base/intrusive_list_ref_ptr.inc"
+#include "base/intrusive_list_unique_ptr.inc"
+
+#endif // IREE_BASE_INTRUSIVE_LIST_H_
diff --git a/base/intrusive_list_ref_ptr.inc b/base/intrusive_list_ref_ptr.inc
new file mode 100644
index 0000000..962a1fe
--- /dev/null
+++ b/base/intrusive_list_ref_ptr.inc
@@ -0,0 +1,174 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// IWYU pragma: private, include "base/intrusive_list.h"
+
+#ifndef IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_
+#define IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_
+
+#include <cstddef>
+#include <iterator>
+
+#include "base/intrusive_list.h"
+#include "base/ref_ptr.h"
+
+namespace iree {
+
+// Iterator for an IntrusiveList specialized to ref_ptr.
+template <typename T, size_t kOffset, bool kForward>
+class IntrusiveListRefPtrIterator
+ : public std::iterator<std::input_iterator_tag, int> {
+ public:
+ using self_type = IntrusiveListRefPtrIterator<T, kOffset, kForward>;
+
+ explicit IntrusiveListRefPtrIterator(IntrusiveListLink* current)
+ : current_(current) {}
+ self_type& operator++() {
+ if (current_) {
+ current_ = kForward ? current_->next : current_->prev;
+ }
+ return *this;
+ }
+ self_type operator++(int) {
+ self_type tmp(current_);
+ operator++();
+ return tmp;
+ }
+ self_type& operator--() {
+ if (current_) {
+ current_ = kForward ? current_->prev : current_->next;
+ }
+ return *this;
+ }
+ self_type operator--(int) {
+ self_type tmp(current_);
+ operator--();
+ return tmp;
+ }
+ bool operator==(const self_type& rhs) const {
+ return rhs.current_ == current_;
+ }
+ bool operator!=(const self_type& rhs) const { return !operator==(rhs); }
+ ref_ptr<T> operator*() const {
+ return add_ref(impl::LinkToT<T, kOffset>(current_));
+ }
+
+ protected:
+ IntrusiveListLink* current_;
+};
+
+// Specialized IntrusiveListBase for ref_ptr types.
+// This makes the list methods accept/return ref_ptrs and iterate with
+// a ref_ptr iterator.
+template <typename T, size_t kOffset>
+class IntrusiveListRefPtrBase
+ : private IntrusiveListBase<
+ T, IntrusiveListRefPtrIterator<T, kOffset, true>,
+ IntrusiveListRefPtrIterator<T, kOffset, false>, kOffset> {
+ public:
+ using IteratorT = IntrusiveListRefPtrIterator<T, kOffset, true>;
+ using ReverseIteratorT = IntrusiveListRefPtrIterator<T, kOffset, false>;
+ using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>;
+
+ IntrusiveListRefPtrBase() = default;
+
+ using base_list::empty;
+ using base_list::size;
+
+ using base_list::contains;
+ bool contains(const ref_ptr<T>& value) const {
+ return base_list::contains(value.get());
+ }
+
+ using base_list::clear;
+
+ using base_list::begin;
+ using base_list::end;
+ using base_list::rbegin;
+ using base_list::rend;
+
+ inline ref_ptr<T> next(const ref_ptr<T>& value) const {
+ return add_ref(base_list::next(value.get()));
+ }
+ inline ref_ptr<T> next(T* value) const {
+ return add_ref(base_list::next(value));
+ }
+
+ inline ref_ptr<T> previous(const ref_ptr<T>& value) const {
+ return add_ref(base_list::previous(value.get()));
+ }
+ inline ref_ptr<T> previous(T* value) const {
+ return add_ref(base_list::previous(value));
+ }
+
+ // Performance: O(1)
+ inline ref_ptr<T> front() const {
+ return add_ref(impl::LinkToT<T, kOffset>(head_));
+ }
+
+ void push_front(const ref_ptr<T>& value) {
+ base_list::push_front(value.get());
+ }
+
+ using base_list::pop_front;
+
+ // Performance: O(1)
+ inline ref_ptr<T> back() const {
+ return add_ref(impl::LinkToT<T, kOffset>(tail_));
+ }
+
+ void push_back(const ref_ptr<T>& value) { base_list::push_back(value.get()); }
+
+ using base_list::pop_back;
+
+ void insert(const IteratorT& it, const ref_ptr<T>& value) {
+ base_list::insert(it, value.get());
+ }
+
+ using base_list::erase;
+
+ ref_ptr<T> erase(const ref_ptr<T>& value) {
+ return add_ref(base_list::erase(value.get()));
+ }
+
+ void replace(const ref_ptr<T>& old_value, const ref_ptr<T>& new_value) {
+ base_list::replace(old_value.get(), new_value.get());
+ }
+ void replace(T* old_value, const ref_ptr<T>& new_value) {
+ base_list::replace(old_value, new_value.get());
+ }
+
+ using base_list::sort;
+
+ private:
+ void OnAdd(T* value) override { value->AddReference(); }
+ void OnRemove(T* value) override { value->ReleaseReference(); }
+ void OnDeallocate(T* value) override { value->ReleaseReference(); }
+
+ using base_list::count_;
+ using base_list::head_;
+ using base_list::tail_;
+};
+
+template <typename U, size_t kOffset>
+class IntrusiveList<ref_ptr<U>, kOffset>
+ : public IntrusiveListRefPtrBase<U, kOffset> {};
+
+template <typename U>
+class IntrusiveList<ref_ptr<U>, kUseDefaultLinkOffset>
+ : public IntrusiveListRefPtrBase<U, offsetof(U, link)> {};
+
+} // namespace iree
+
+#endif // IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_
diff --git a/base/intrusive_list_ref_ptr_test.cc b/base/intrusive_list_ref_ptr_test.cc
new file mode 100644
index 0000000..81d4296
--- /dev/null
+++ b/base/intrusive_list_ref_ptr_test.cc
@@ -0,0 +1,100 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/intrusive_list.h"
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace {
+
+static int alloc_count = 0;
+struct RefCountedType : public RefObject<RefCountedType> {
+ IntrusiveListLink link;
+ RefCountedType() { ++alloc_count; }
+ ~RefCountedType() { --alloc_count; }
+ static void Deallocate(RefCountedType* value) { delete value; }
+ using RefObject<RefCountedType>::counter_;
+};
+
+TEST(IntrusiveListRefPtrTest, PushAndClear) {
+ alloc_count = 0;
+ IntrusiveList<ref_ptr<RefCountedType>> list;
+ EXPECT_EQ(0, alloc_count);
+ list.push_back(make_ref<RefCountedType>());
+ EXPECT_EQ(1, alloc_count);
+ EXPECT_NE(nullptr, list.front());
+ EXPECT_EQ(2, list.front()->counter_);
+ list.clear();
+ EXPECT_EQ(0, alloc_count);
+}
+
+TEST(IntrusiveListRefPtrTest, PushPop) {
+ alloc_count = 0;
+ IntrusiveList<ref_ptr<RefCountedType>> list;
+ list.push_back(make_ref<RefCountedType>());
+ EXPECT_EQ(1, alloc_count);
+ list.push_back(make_ref<RefCountedType>());
+ EXPECT_EQ(2, alloc_count);
+ EXPECT_NE(list.front(), list.back());
+ list.pop_back();
+ EXPECT_EQ(1, alloc_count);
+ list.pop_front();
+ EXPECT_EQ(0, alloc_count);
+}
+
+TEST(IntrusiveListRefPtrTest, PushErase) {
+ alloc_count = 0;
+ IntrusiveList<ref_ptr<RefCountedType>> list;
+ list.push_back(make_ref<RefCountedType>());
+ EXPECT_EQ(1, alloc_count);
+ EXPECT_NE(nullptr, list.front());
+ EXPECT_EQ(2, list.front()->counter_);
+ auto item = list.front();
+ EXPECT_NE(nullptr, item.get());
+ EXPECT_EQ(3, list.front()->counter_);
+ EXPECT_EQ(1, alloc_count);
+ list.erase(item);
+ EXPECT_EQ(1, alloc_count);
+ item.reset();
+ EXPECT_EQ(0, alloc_count);
+}
+
+TEST(IntrusiveListRefPtrTest, PushReplace) {
+ alloc_count = 0;
+ IntrusiveList<ref_ptr<RefCountedType>> list;
+ list.push_back(make_ref<RefCountedType>());
+ EXPECT_EQ(1, alloc_count);
+ list.replace(list.front(), make_ref<RefCountedType>());
+ EXPECT_EQ(1, alloc_count);
+ list.clear();
+ EXPECT_EQ(0, alloc_count);
+}
+
+TEST(IntrusiveListRefPtrTest, Iteration) {
+ alloc_count = 0;
+ IntrusiveList<ref_ptr<RefCountedType>> list;
+ list.push_back(make_ref<RefCountedType>());
+ list.push_back(make_ref<RefCountedType>());
+ list.push_back(make_ref<RefCountedType>());
+ EXPECT_EQ(3, alloc_count);
+ for (auto item : list) {
+ const ref_ptr<RefCountedType>& item_ref = item;
+ EXPECT_NE(nullptr, item_ref.get());
+ }
+ list.clear();
+ EXPECT_EQ(0, alloc_count);
+}
+
+} // namespace
+} // namespace iree
diff --git a/base/intrusive_list_test.cc b/base/intrusive_list_test.cc
new file mode 100644
index 0000000..e189105
--- /dev/null
+++ b/base/intrusive_list_test.cc
@@ -0,0 +1,523 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/intrusive_list.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace {
+
+using ::testing::ElementsAre;
+
+struct Item {
+ size_t some_data_0;
+ IntrusiveListLink list_a;
+ size_t some_data_1;
+ IntrusiveListLink list_b;
+ size_t some_data_2;
+ int value;
+
+ static const size_t kToken = 0xDEADBEEF;
+ explicit Item(int value)
+ : some_data_0(kToken),
+ some_data_1(kToken),
+ some_data_2(kToken),
+ value(value) {}
+ bool is_valid() {
+ return some_data_0 == kToken && some_data_1 == kToken &&
+ some_data_2 == kToken;
+ }
+};
+
+template <typename T, size_t V>
+std::vector<T*> ExtractItems(const IntrusiveList<T, V>& list) {
+ std::vector<T*> items;
+ for (auto* item : list) {
+ items.push_back(item);
+ }
+ return items;
+}
+
+template <typename T, size_t V>
+std::vector<int> ExtractValues(const IntrusiveList<T, V>& list) {
+ std::vector<int> values;
+ for (auto* item : list) {
+ values.push_back(item->value);
+ }
+ return values;
+}
+
+template <typename T, size_t V>
+std::vector<int> ExtractValuesMutable(const IntrusiveList<T, V>& list) {
+ std::vector<int> values;
+ for (auto* item : list) {
+ values.push_back(item->value);
+ }
+ return values;
+}
+
+TEST(IntrusiveListTest, PushPopItems) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+ Item item4(4);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items;
+ EXPECT_TRUE(items.empty());
+ EXPECT_EQ(items.size(), 0u);
+ EXPECT_EQ(items.front(), nullptr);
+ EXPECT_EQ(items.back(), nullptr);
+ EXPECT_TRUE(items.begin() == items.end());
+ items.push_front(&item1);
+ EXPECT_FALSE(items.empty());
+ EXPECT_EQ(items.size(), 1u);
+ EXPECT_EQ(items.front(), &item1);
+ EXPECT_EQ(items.back(), &item1);
+ EXPECT_FALSE(items.begin() == items.end());
+ items.push_front(&item2);
+ EXPECT_EQ(items.size(), 2u);
+ EXPECT_EQ(items.front(), &item2);
+ EXPECT_EQ(items.back(), &item1);
+ items.push_front(&item3);
+ EXPECT_EQ(items.size(), 3u);
+ EXPECT_EQ(items.front(), &item3);
+ EXPECT_EQ(items.back(), &item1);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2, 1));
+
+ items.push_back(&item4);
+ EXPECT_EQ(items.size(), 4u);
+ EXPECT_EQ(items.front(), &item3);
+ EXPECT_EQ(items.back(), &item4);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2, 1, 4));
+
+ items.pop_front();
+ EXPECT_EQ(items.size(), 3u);
+ EXPECT_EQ(items.front(), &item2);
+ EXPECT_EQ(items.back(), &item4);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(2, 1, 4));
+
+ items.pop_back();
+ EXPECT_EQ(items.size(), 2u);
+ EXPECT_EQ(items.front(), &item2);
+ EXPECT_EQ(items.back(), &item1);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(2, 1));
+
+ items.pop_back();
+ items.pop_front();
+ EXPECT_TRUE(items.empty());
+ EXPECT_EQ(items.size(), 0u);
+ EXPECT_EQ(items.front(), nullptr);
+ EXPECT_EQ(items.back(), nullptr);
+ EXPECT_TRUE(items.begin() == items.end());
+
+ EXPECT_TRUE(item1.is_valid());
+ EXPECT_TRUE(item2.is_valid());
+ EXPECT_TRUE(item3.is_valid());
+ EXPECT_TRUE(item4.is_valid());
+}
+
+TEST(IntrusiveListTest, Contains) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+ Item item4(4);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items;
+ items.push_back(&item1);
+ items.push_back(&item2);
+ items.push_back(&item3);
+ // item4 omitted.
+
+ EXPECT_TRUE(items.contains(&item1));
+ EXPECT_TRUE(items.contains(&item2));
+ EXPECT_TRUE(items.contains(&item3));
+ EXPECT_FALSE(items.contains(&item4));
+
+ EXPECT_FALSE(items.contains(nullptr));
+}
+
+TEST(IntrusiveListTest, MergeFrom) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+ Item item4(4);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items0;
+ items0.push_back(&item1);
+ items0.push_back(&item2);
+ items0.push_back(&item3);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items1;
+ items1.push_back(&item4);
+
+ items0.merge_from(&items1);
+ EXPECT_THAT(ExtractValues(items0), ElementsAre(1, 2, 3, 4));
+ EXPECT_TRUE(items1.empty());
+}
+
+TEST(IntrusiveListTest, MergeFromEmpty) {
+ IntrusiveList<Item, offsetof(Item, list_a)> items0;
+ IntrusiveList<Item, offsetof(Item, list_a)> items1;
+ items0.merge_from(&items1);
+}
+
+TEST(IntrusiveListTest, MergeFromAll) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+ Item item4(4);
+ IntrusiveList<Item, offsetof(Item, list_a)> items0;
+ items0.push_back(&item1);
+ items0.push_back(&item2);
+ items0.push_back(&item3);
+ items0.push_back(&item4);
+ IntrusiveList<Item, offsetof(Item, list_a)> items1;
+
+ // Merge all items from items1 into items0. Shouldn't change anything.
+ items0.merge_from(&items1);
+ EXPECT_THAT(ExtractValues(items0), ElementsAre(1, 2, 3, 4));
+ EXPECT_TRUE(items1.empty());
+
+ // Merge all items from items0 into items1. Should move everything.
+ items1.merge_from(&items0);
+ EXPECT_TRUE(items0.empty());
+ EXPECT_THAT(ExtractValues(items1), ElementsAre(1, 2, 3, 4));
+}
+
+TEST(IntrusiveListTest, Erase) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+ Item item4(4);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items;
+ items.push_back(&item1);
+ items.push_back(&item2);
+ items.push_back(&item3);
+ items.push_back(&item4);
+
+ EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
+ items.erase(&item3);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 4));
+ items.erase(&item1);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(2, 4));
+ items.erase(&item4);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(2));
+ items.erase(&item2);
+ EXPECT_TRUE(items.empty());
+
+ items.push_back(&item1);
+ items.push_back(&item2);
+ items.push_back(&item3);
+ items.push_back(&item4);
+
+ EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
+ auto it = items.begin();
+ items.erase(it);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(2, 3, 4));
+ it = items.end();
+ items.erase(it);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(2, 3, 4));
+ it = items.begin();
+ ++it;
+ items.erase(it);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(2, 4));
+
+ it = items.begin();
+ it = items.erase(it);
+ EXPECT_EQ(4, (*it)->value);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(4));
+ it = items.erase(it);
+ EXPECT_TRUE(items.empty());
+ EXPECT_EQ(items.end(), it);
+}
+
+TEST(IntrusiveListTest, MultipleLists) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+ Item item4(4);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items_a;
+ IntrusiveList<Item, offsetof(Item, list_b)> items_b;
+ items_a.push_back(&item1);
+ items_a.push_back(&item2);
+ items_a.push_back(&item3);
+ items_a.push_back(&item4);
+ items_b.push_front(&item1);
+ items_b.push_front(&item2);
+ items_b.push_front(&item3);
+ items_b.push_front(&item4);
+ EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3, 4));
+ EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 3, 2, 1));
+ items_b.erase(&item3);
+ EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3, 4));
+ EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 2, 1));
+ items_a.pop_back();
+ EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3));
+ EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 2, 1));
+}
+
+TEST(IntrusiveListTest, MutableIterator) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+ Item item4(4);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items;
+ items.push_back(&item4);
+ items.push_front(&item1);
+ items.push_front(&item2);
+ items.push_front(&item3);
+
+ EXPECT_THAT(ExtractValuesMutable(items), ElementsAre(3, 2, 1, 4));
+}
+
+struct BaseType {
+ explicit BaseType(int value) : value(value) {}
+ int value;
+ IntrusiveListLink base_link;
+};
+struct SubType : public BaseType {
+ explicit SubType(int value) : BaseType(value) {}
+ IntrusiveListLink sub_link;
+};
+TEST(IntrusiveListTest, SimpleType) {
+ SubType item1(1);
+ SubType item2(2);
+ SubType item3(3);
+ SubType item4(4);
+
+ IntrusiveList<BaseType, offsetof(BaseType, base_link)> items_a;
+ items_a.push_front(&item1);
+ items_a.push_front(&item2);
+ items_a.push_front(&item3);
+ items_a.push_front(&item4);
+ EXPECT_THAT(ExtractValues(items_a), ElementsAre(4, 3, 2, 1));
+
+ IntrusiveList<SubType, offsetof(SubType, sub_link)> items_b;
+ items_b.push_back(&item1);
+ items_b.push_back(&item2);
+ items_b.push_back(&item3);
+ items_b.push_back(&item4);
+ EXPECT_THAT(ExtractValues(items_b), ElementsAre(1, 2, 3, 4));
+}
+
+struct AbstractType {
+ explicit AbstractType(int value) : value(value) {}
+ virtual ~AbstractType() = default;
+ virtual int DoSomething() = 0;
+ int value;
+ IntrusiveListLink base_link;
+};
+struct ImplType : public AbstractType {
+ explicit ImplType(int value) : AbstractType(value) {}
+ int DoSomething() override { return value; }
+ IntrusiveListLink sub_link;
+};
+
+TEST(IntrusiveListTest, ComplexType) {
+ ImplType item1(1);
+ ImplType item2(2);
+ ImplType item3(3);
+ ImplType item4(4);
+
+ IntrusiveList<AbstractType, offsetof(AbstractType, base_link)> items_a;
+ items_a.push_front(&item1);
+ items_a.push_front(&item2);
+ items_a.push_front(&item3);
+ items_a.push_front(&item4);
+ EXPECT_THAT(ExtractValues(items_a), ElementsAre(4, 3, 2, 1));
+
+ IntrusiveList<ImplType, offsetof(ImplType, sub_link)> items_b;
+ items_b.push_back(&item1);
+ items_b.push_back(&item2);
+ items_b.push_back(&item3);
+ items_b.push_back(&item4);
+ EXPECT_THAT(ExtractValues(items_b), ElementsAre(1, 2, 3, 4));
+}
+
+bool Comparison(Item* a, Item* b) { return a->value < b->value; }
+
+TEST(IntrusiveListTest, Inserting) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+ Item item4(4);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items;
+ items.insert(items.end(), &item3);
+ items.insert(items.begin(), &item1);
+ items.insert(items.end(), &item4);
+
+ auto pos = std::upper_bound(items.begin(), items.end(), &item2, Comparison);
+ items.insert(pos, &item2);
+
+ EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
+}
+
+// TODO(benvanik): test reverse iteration.
+
+TEST(IntrusiveListTest, NextPrevious) {
+ Item item1(1);
+ Item item2(2);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items;
+ EXPECT_EQ(nullptr, items.previous(nullptr));
+ EXPECT_EQ(nullptr, items.next(nullptr));
+
+ items.push_back(&item1);
+ EXPECT_EQ(nullptr, items.previous(&item1));
+ EXPECT_EQ(nullptr, items.next(&item1));
+
+ items.push_back(&item2);
+ EXPECT_EQ(nullptr, items.previous(&item1));
+ EXPECT_EQ(&item2, items.next(&item1));
+ EXPECT_EQ(&item1, items.previous(&item2));
+ EXPECT_EQ(nullptr, items.next(&item2));
+}
+
+TEST(IntrusiveListTest, Clear) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+ Item item4(4);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items;
+
+ // Empty clear.
+ items.clear();
+ EXPECT_TRUE(items.empty());
+
+ // 1 item clear.
+ items.push_back(&item1);
+ items.clear();
+ EXPECT_TRUE(items.empty());
+
+ // Multi-item clear.
+ items.push_back(&item1);
+ items.push_back(&item2);
+ items.push_back(&item3);
+ items.push_back(&item4);
+ items.clear();
+ EXPECT_TRUE(items.empty());
+}
+
+TEST(IntrusiveListTest, ClearDeleter) {
+ Item item1(1);
+ Item item2(2);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items;
+
+ // No-op first.
+ int delete_count = 0;
+ items.clear([&](Item* item) { ++delete_count; });
+ EXPECT_EQ(0, delete_count);
+
+ // Now with items.
+ items.push_back(&item1);
+ items.push_back(&item2);
+ items.clear([&](Item* item) { ++delete_count; });
+ EXPECT_EQ(2, delete_count);
+ EXPECT_TRUE(items.empty());
+}
+
+TEST(IntrusiveListTest, Replace) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items;
+ items.push_back(&item1);
+ items.push_back(&item2);
+
+ items.replace(&item1, &item3);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2));
+ EXPECT_FALSE(items.contains(&item1));
+ items.replace(&item2, &item1);
+ EXPECT_THAT(ExtractValues(items), ElementsAre(3, 1));
+ EXPECT_FALSE(items.contains(&item2));
+}
+
+TEST(IntrusiveListTest, Sort) {
+ Item item1(1);
+ Item item2(2);
+ Item item3(3);
+ Item item4(4);
+
+ IntrusiveList<Item, offsetof(Item, list_a)> items;
+
+ // Empty sort.
+ items.sort([](Item* a, Item* b) { return a->value < b->value; });
+
+ // Single item sort.
+ items.clear();
+ items.push_back(&item1);
+ items.sort([](Item* a, Item* b) { return a->value < b->value; });
+ EXPECT_THAT(ExtractValues(items), ElementsAre(1));
+
+ // Already sorted.
+ items.clear();
+ items.push_back(&item1);
+ items.push_back(&item2);
+ items.push_back(&item3);
+ items.push_back(&item4);
+ items.sort([](Item* a, Item* b) { return a->value < b->value; });
+ EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
+
+ // Reverse.
+ items.clear();
+ items.push_back(&item4);
+ items.push_back(&item3);
+ items.push_back(&item2);
+ items.push_back(&item1);
+ items.sort([](Item* a, Item* b) { return a->value < b->value; });
+ EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
+
+ // Random.
+ items.clear();
+ items.push_back(&item2);
+ items.push_back(&item4);
+ items.push_back(&item1);
+ items.push_back(&item3);
+ items.sort([](Item* a, Item* b) { return a->value < b->value; });
+ EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
+
+ // Stability.
+ Item item1a(1);
+ Item item2a(2);
+ items.clear();
+ items.push_back(&item2);
+ items.push_back(&item4);
+ items.push_back(&item1);
+ items.push_back(&item3);
+ items.push_back(&item1a);
+ items.push_back(&item2a);
+ items.sort([](Item* a, Item* b) { return a->value <= b->value; });
+ EXPECT_THAT(ExtractValues(items), ElementsAre(1, 1, 2, 2, 3, 4));
+ auto items_vector = ExtractItems(items);
+ EXPECT_EQ(&item1, items_vector[0]);
+ EXPECT_EQ(&item1a, items_vector[1]);
+ EXPECT_EQ(&item2, items_vector[2]);
+ EXPECT_EQ(&item2a, items_vector[3]);
+ items.clear();
+}
+
+} // namespace
+} // namespace iree
diff --git a/base/intrusive_list_unique_ptr.inc b/base/intrusive_list_unique_ptr.inc
new file mode 100644
index 0000000..4397265
--- /dev/null
+++ b/base/intrusive_list_unique_ptr.inc
@@ -0,0 +1,140 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// IWYU pragma: private, include "base/intrusive_list.h"
+
+#ifndef IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_
+#define IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_
+
+#include <cstddef>
+#include <memory>
+
+#include "base/intrusive_list.h"
+#include "base/logging.h"
+
+namespace iree {
+
+// Specialized IntrusiveListBase for std::unique_ptr types.
+// This makes the list methods accept std::unique_ptrs and contains a special
+// take() method that takes ownership of a list item.
+template <typename T, size_t kOffset>
+class IntrusiveListUniquePtrBase
+ : private IntrusiveListBase<T, IntrusiveListIterator<T, kOffset, true>,
+ IntrusiveListIterator<T, kOffset, false>,
+ kOffset> {
+ public:
+ using IteratorT = IntrusiveListIterator<T, kOffset, true>;
+ using ReverseIteratorT = IntrusiveListIterator<T, kOffset, false>;
+ using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>;
+
+ IntrusiveListUniquePtrBase() = default;
+
+ using base_list::empty;
+ using base_list::size;
+
+ using base_list::contains;
+
+ using base_list::clear;
+
+ using base_list::begin;
+ using base_list::end;
+ using base_list::rbegin;
+ using base_list::rend;
+
+ using base_list::next;
+
+ using base_list::previous;
+
+ using base_list::front;
+
+ void push_front(std::unique_ptr<T> value) {
+ base_list::push_front(value.release());
+ }
+
+ using base_list::pop_front;
+
+ using base_list::back;
+
+ void push_back(std::unique_ptr<T> value) {
+ base_list::push_back(value.release());
+ }
+
+ using base_list::pop_back;
+
+ void insert(const IteratorT& it, std::unique_ptr<T> value) {
+ base_list::insert(it, value.release());
+ }
+
+ using base_list::erase;
+
+ // Removes an item from the list at the given iterator and transfers ownership
+ // to the caller.
+ // Performance: O(1)
+ std::unique_ptr<T> take(IteratorT& it) { // NOLINT(runtime/references)
+ return take(*it);
+ }
+
+ // Removes an item from the list and transfers ownership to the caller.
+ // Performance: O(1)
+ std::unique_ptr<T> take(T* value) {
+ if (!value) {
+ return {nullptr};
+ }
+ auto* link = impl::TToLink<T, kOffset>(value);
+ if (link->prev) {
+ DCHECK_NE(link, head_);
+ link->prev->next = link->next;
+ } else {
+ DCHECK_EQ(link, head_);
+ head_ = link->next;
+ }
+ if (link->next) {
+ DCHECK_NE(link, tail_);
+ link->next->prev = link->prev;
+ } else {
+ DCHECK_EQ(link, tail_);
+ tail_ = link->prev;
+ }
+ link->next = link->prev = nullptr;
+ --count_;
+ base_list::OnRemove(value);
+ base_list::CheckCorrectness();
+ return std::unique_ptr<T>(value);
+ }
+
+ void replace(T* old_value, std::unique_ptr<T> new_value) {
+ base_list::replace(old_value, new_value.release());
+ }
+
+ using base_list::sort;
+
+ private:
+ void OnDeallocate(T* value) override { delete value; }
+
+ using base_list::count_;
+ using base_list::head_;
+ using base_list::tail_;
+};
+
+template <typename U, size_t kOffset>
+class IntrusiveList<std::unique_ptr<U>, kOffset>
+ : public IntrusiveListUniquePtrBase<U, kOffset> {};
+
+template <typename U>
+class IntrusiveList<std::unique_ptr<U>, kUseDefaultLinkOffset>
+ : public IntrusiveListUniquePtrBase<U, offsetof(U, link)> {};
+
+} // namespace iree
+
+#endif // IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_
diff --git a/base/intrusive_list_unique_ptr_test.cc b/base/intrusive_list_unique_ptr_test.cc
new file mode 100644
index 0000000..c5e9421
--- /dev/null
+++ b/base/intrusive_list_unique_ptr_test.cc
@@ -0,0 +1,84 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "absl/memory/memory.h"
+#include "base/intrusive_list.h"
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace {
+
+struct AllocatedType : public IntrusiveLinkBase<void> {
+ AllocatedType() { ++alloc_count; }
+ ~AllocatedType() { --alloc_count; }
+ static int alloc_count;
+};
+int AllocatedType::alloc_count = 0;
+
+TEST(IntrusiveListUniquePtrTest, UniquePtr) {
+ AllocatedType::alloc_count = 0;
+
+ // Push/clear.
+ IntrusiveList<std::unique_ptr<AllocatedType>> list;
+ EXPECT_EQ(0, AllocatedType::alloc_count);
+ list.push_back(absl::make_unique<AllocatedType>());
+ EXPECT_EQ(1, AllocatedType::alloc_count);
+ EXPECT_NE(nullptr, list.front());
+ list.clear();
+ EXPECT_EQ(0, AllocatedType::alloc_count);
+
+ // Push/pop.
+ list.push_back(absl::make_unique<AllocatedType>());
+ EXPECT_EQ(1, AllocatedType::alloc_count);
+ EXPECT_NE(nullptr, list.front());
+ for (auto item : list) {
+ EXPECT_EQ(item, list.front());
+ }
+ list.pop_back();
+ EXPECT_EQ(0, AllocatedType::alloc_count);
+
+ // Push/take.
+ list.push_back(absl::make_unique<AllocatedType>());
+ EXPECT_EQ(1, AllocatedType::alloc_count);
+ EXPECT_NE(nullptr, list.front());
+ auto item = list.take(list.front());
+ EXPECT_TRUE(list.empty());
+ EXPECT_NE(nullptr, item.get());
+ EXPECT_EQ(1, AllocatedType::alloc_count);
+ item.reset();
+ EXPECT_EQ(0, AllocatedType::alloc_count);
+
+ // Push/replace.
+ list.push_back(absl::make_unique<AllocatedType>());
+ EXPECT_EQ(1, AllocatedType::alloc_count);
+ list.replace(list.front(), absl::make_unique<AllocatedType>());
+ EXPECT_EQ(1, AllocatedType::alloc_count);
+ list.clear();
+ EXPECT_EQ(0, AllocatedType::alloc_count);
+
+ // Iteration.
+ list.push_back(absl::make_unique<AllocatedType>());
+ list.push_back(absl::make_unique<AllocatedType>());
+ list.push_back(absl::make_unique<AllocatedType>());
+ EXPECT_EQ(3, AllocatedType::alloc_count);
+ for (auto item : list) {
+ AllocatedType* item_ptr = item;
+ EXPECT_NE(nullptr, item_ptr);
+ }
+ list.clear();
+ EXPECT_EQ(0, AllocatedType::alloc_count);
+}
+
+} // namespace
+} // namespace iree
diff --git a/base/logging.h b/base/logging.h
new file mode 100644
index 0000000..b256cd6
--- /dev/null
+++ b/base/logging.h
@@ -0,0 +1,63 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_LOGGING_H_
+#define IREE_BASE_LOGGING_H_
+
+// Logging macros live in their own file so that we can use external versions
+// as required.
+//
+// LOG(severity) << ...;
+// Logs a message at the given severity.
+// Severity:
+// INFO Logs information text.
+// WARNING Logs a warning.
+// ERROR Logs an error.
+// FATAL Logs an error and exit(1).
+//
+// VLOG(level) << ...;
+// Logs a verbose message at the given verbosity level.
+//
+// DVLOG(level) << ...;
+// Behaves like `VLOG` in debug mode (i.e. `#ifndef NDEBUG`).
+// Otherwise, it compiles away and does nothing.
+//
+// CHECK(condition) << ...;
+// Runtime asserts that the given condition is true even in release builds.
+// It's recommended that DCHECK is used instead as too many CHECKs
+// can impact performance.
+//
+// CHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...;
+// Runtime assert the specified operation with the given values.
+//
+// DCHECK(condition) << ...;
+// Runtime asserts that the given condition is true only in non-opt builds.
+//
+// DCHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...;
+// Runtime assert the specified operation with the given values in non-opt
+// builds.
+//
+// QCHECK(condition) << ...;
+// QCHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...;
+// These behave like `CHECK` but do not print a full stack trace.
+// They are useful when problems are definitely unrelated to program flow,
+// e.g. when validating user input.
+
+#ifdef IREE_CONFIG_GOOGLE_INTERNAL
+#include "base/google/logging_google.h"
+#else
+#include "base/internal/logging.h"
+#endif // IREE_CONFIG_GOOGLE_INTERNAL
+
+#endif // IREE_BASE_LOGGING_H_
diff --git a/iree/base/math.h b/base/math.h
similarity index 100%
rename from iree/base/math.h
rename to base/math.h
diff --git a/iree/base/memory.h b/base/memory.h
similarity index 100%
rename from iree/base/memory.h
rename to base/memory.h
diff --git a/base/ref_ptr.h b/base/ref_ptr.h
new file mode 100644
index 0000000..996331d
--- /dev/null
+++ b/base/ref_ptr.h
@@ -0,0 +1,364 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_REF_PTR_H_
+#define IREE_BASE_REF_PTR_H_
+
+#include <atomic>
+#include <cstdint>
+#include <type_traits>
+#include <utility>
+
+#include "absl/base/attributes.h"
+#include "base/logging.h"
+
+namespace iree {
+
+// Use this to get really verbose refptr logging:
+// #define IREE_VERBOSE_REF_PTR
+
+template <class T>
+class ref_ptr;
+
+// Allocates a new ref_ptr type.
+// Like make_unique, but for ref_ptr.
+//
+// Usage:
+// ref_ptr<MyType> p = make_ref<MyType>(1, 2, 3);
+template <typename T, typename... Args>
+ref_ptr<T> make_ref(Args&&... args) {
+ return ref_ptr<T>(new T(std::forward<Args>(args)...));
+}
+
+// Assigns a raw pointer to a ref_ptr without adding a reference.
+//
+// Usage:
+// ref_ptr<MyType> p = assign_ref(new MyType());
+template <typename T>
+inline ref_ptr<T> assign_ref(T* value) {
+ return ref_ptr<T>(value);
+}
+
+// Adds a reference to the given raw pointer.
+//
+// Usage:
+// MyType* raw_ptr = AcquirePointerFromSomewhere();
+// ref_ptr<MyType> p = add_ref(raw_ptr);
+template <typename T>
+inline ref_ptr<T> add_ref(T* value) {
+ if (value) ref_ptr_add_ref(value);
+ return ref_ptr<T>(value);
+}
+
+// Adds a reference to the given ref_ptr.
+//
+// Usage:
+// ref_ptr<MyType> a = make_ref<MyType>();
+// ref_ptr<MyType> p = add_ref(a);
+template <typename T>
+inline ref_ptr<T> add_ref(const ref_ptr<T>& value) {
+ if (value.get()) ref_ptr_add_ref(value.get());
+ return ref_ptr<T>(value.get());
+}
+
+// Reference counted pointer container.
+// This is modeled on boost::instrusive_ptr in that it requires no
+// extra storage over the pointer type and should compile to almost
+// no additional code. It also allows us to round-trip object pointers
+// through regular pointers, which is critical when having to round-trip
+// them through JNI/etc where we can't use things like unique_ptr/shared_ptr.
+//
+// ref_ptr<Foo> p1(new Foo()); // ref count 1
+// ref_ptr<Foo> p2(p1); // ref count 2
+// p1.reset(); // ref count 1
+// p2.reset(); // ref count 0, deleted
+//
+// When round-tripping the pointer through external APIs, use release():
+// ref_ptr<Foo> p1(new Foo()); // ref count 1
+// Foo* raw_p = p1.release(); // ref count 1
+// // pass to API
+// ref_ptr<Foo> p2(raw_p); // ref count 1 (don't add ref)
+// p2.reset(); // ref count 0, deleted
+//
+// See the boost intrusive_ptr docs for details of behavior:
+// http://www.boost.org/doc/libs/1_55_0/libs/smart_ptr/intrusive_ptr.html
+//
+// ref_ptr manages the target objects in a thread-safe way, though you'll want
+// to take care with objects that may have pinned threads for deallocation. If
+// you release the last reference to an object on a thread other than what it
+// was expecting you're gonna have a bad time.
+//
+// Compatible only with types that subclass RefObject or implement the following
+// methods:
+// ref_ptr_add_ref
+// ref_ptr_release_ref
+template <class T>
+class ref_ptr {
+ private:
+ typedef ref_ptr this_type;
+ typedef T* this_type::*unspecified_bool_type;
+
+ public:
+ // Initializes with nullptr.
+ ABSL_ATTRIBUTE_ALWAYS_INLINE ref_ptr() noexcept = default;
+
+ // Initializes with nullptr so that there is no way to create an
+ // uninitialized ref_ptr.
+ ABSL_ATTRIBUTE_ALWAYS_INLINE ref_ptr(std::nullptr_t) noexcept {} // NOLINT
+
+ // Initializes the pointer to the given value.
+ // The value will not have its reference count incremented (as it is with
+ // unique_ptr). Use Retain to add to the reference count.
+ ABSL_ATTRIBUTE_ALWAYS_INLINE explicit ref_ptr(T* p) noexcept : px_(p) {}
+
+ // Decrements the reference count of the owned pointer.
+ ABSL_ATTRIBUTE_ALWAYS_INLINE ~ref_ptr() noexcept {
+ if (px_) ref_ptr_release_ref(px_);
+ }
+
+ // No implicit ref_ptr copying allowed; use add_ref instead.
+ ref_ptr(const ref_ptr&) noexcept = delete;
+ ref_ptr& operator=(const ref_ptr&) noexcept = delete;
+
+ // Move support to transfer ownership from one ref_ptr to another.
+ ref_ptr(ref_ptr&& rhs) noexcept : px_(rhs.release()) {}
+ ref_ptr& operator=(ref_ptr&& rhs) noexcept {
+ if (px_ != rhs.px_) {
+ if (px_) ref_ptr_release_ref(px_);
+ px_ = rhs.release();
+ }
+ return *this;
+ }
+
+ // Move support from another compatible type.
+ template <typename U>
+ ref_ptr(ref_ptr<U>&& rhs) noexcept : px_(rhs.release()) {} // NOLINT
+ template <typename U>
+ ref_ptr& operator=(ref_ptr<U>&& rhs) noexcept {
+ if (px_ != rhs.get()) {
+ if (px_) ref_ptr_release_ref(px_);
+ px_ = rhs.release();
+ }
+ return *this;
+ }
+
+ // Resets the object to nullptr and decrements the reference count, possibly
+ // deleting it.
+ void reset() noexcept {
+ if (px_) {
+ ref_ptr_release_ref(px_);
+ px_ = nullptr;
+ }
+ }
+
+ // Releases a pointer.
+ // Returns the current pointer held by this object without having
+ // its reference count decremented and resets the ref_ptr to empty.
+ // Returns nullptr if the ref_ptr holds no value.
+ // To re-wrap in a ref_ptr use either ref_ptr<T>(value) or assign().
+ ABSL_ATTRIBUTE_ALWAYS_INLINE T* release() noexcept {
+ T* p = px_;
+ px_ = nullptr;
+ return p;
+ }
+
+ // Assigns a pointer.
+ // The pointer will be accepted by the ref_ptr and its reference count will
+ // not be incremented.
+ ABSL_ATTRIBUTE_ALWAYS_INLINE void assign(T* value) noexcept {
+ reset();
+ px_ = value;
+ }
+
+ // Gets the pointer referenced by this instance.
+ // operator* and operator-> will assert() if there is no current object.
+ constexpr T* get() const noexcept { return px_; }
+ constexpr T& operator*() const noexcept { return *px_; }
+ constexpr T* operator->() const noexcept { return px_; }
+
+ // Support boolean expression evaluation ala unique_ptr/shared_ptr:
+ // https://en.cppreference.com/w/cpp/memory/shared_ptr/operator_bool
+ constexpr operator unspecified_bool_type() const noexcept {
+ return px_ ? &this_type::px_ : nullptr;
+ }
+ // Supports unary expression evaluation.
+ constexpr bool operator!() const noexcept { return !px_; }
+
+ // Swap support.
+ void swap(ref_ptr& rhs) { std::swap(px_, rhs.px_); }
+
+ private:
+ T* px_ = nullptr;
+};
+
+// Base class for reference counted objects.
+// Reference counted objects should be used with the ref_ptr pointer type.
+// As reference counting can be tricky always prefer to use unique_ptr and
+// avoid this type. Only use this when unique_ptr is not possible, such as
+// when round-tripping objects through marshaling boundaries (v8/Java) or
+// any objects that may have their lifetime tied to a garbage collected
+// object.
+//
+// Subclasses should protect their dtor so that reference counting must
+// be used.
+//
+// This is designed to avoid the need for extra vtable space or for adding
+// methods to the vtable of subclasses. This differs from the boost Pointable
+// version of this object.
+// Inspiration for this comes from Peter Weinert's Dr. Dobb's article:
+// http://www.drdobbs.com/cpp/a-base-class-for-intrusively-reference-c/229218807
+//
+// RefObjects are thread safe and may be used with ref_ptrs from multiple
+// threads.
+//
+// Subclasses may implement a custom Delete operator to handle their
+// deallocation. It should be thread safe as it may be called from any thread.
+//
+// Usage:
+// class MyRefObject : public RefObject<MyRefObject> {
+// public:
+// MyRefObject() = default;
+// // Optional; can be used to return to pool/etc - must be public:
+// static void Delete(MyRefObject* ptr) {
+// ::operator delete(ptr);
+// }
+// };
+template <class T>
+class RefObject {
+ static_assert(!std::is_array<T>::value, "T must not be an array");
+
+ // value is true if a static Delete(T*) function is present.
+ struct has_custom_deleter {
+ template <typename C>
+ static auto Test(C* p) -> decltype(C::Delete(nullptr), std::true_type());
+ template <typename>
+ static std::false_type Test(...);
+ static constexpr bool value =
+ std::is_same<std::true_type, decltype(Test<T>(nullptr))>::value;
+ };
+
+ template <typename V, bool has_custom_deleter>
+ struct delete_thunk {
+ static void Delete(V* p) {
+ auto ref_obj = static_cast<RefObject<V>*>(p);
+ int previous_count = ref_obj->counter_.fetch_sub(1);
+#ifdef IREE_VERBOSE_REF_PTR
+ LOG(INFO) << "ro-- " << typeid(V).name() << " " << p << " now "
+ << previous_count - 1
+ << (previous_count == 1 ? " DEAD (CUSTOM)" : "");
+#endif // IREE_VERBOSE_REF_PTR
+ if (previous_count == 1) {
+ // We delete type T pointer here to avoid the need for a virtual dtor.
+ V::Delete(p);
+ }
+ }
+ };
+
+ template <typename V>
+ struct delete_thunk<V, false> {
+ static void Delete(V* p) {
+ auto ref_obj = static_cast<RefObject<V>*>(p);
+ int previous_count = ref_obj->counter_.fetch_sub(1);
+#ifdef IREE_VERBOSE_REF_PTR
+ LOG(INFO) << "ro-- " << typeid(V).name() << " " << p << " now "
+ << previous_count - 1 << (previous_count == 1 ? " DEAD" : "");
+#endif // IREE_VERBOSE_REF_PTR
+ if (previous_count == 1) {
+ // We delete type T pointer here to avoid the need for a virtual dtor.
+ delete p;
+ }
+ }
+ };
+
+ public:
+ // Adds a reference; used by ref_ptr.
+ friend void ref_ptr_add_ref(T* p) {
+ auto ref_obj = static_cast<RefObject*>(p);
+ ++ref_obj->counter_;
+
+#ifdef IREE_VERBOSE_REF_PTR
+ LOG(INFO) << "ro++ " << typeid(T).name() << " " << p << " now "
+ << ref_obj->counter_;
+#endif // IREE_VERBOSE_REF_PTR
+ }
+
+ // Releases a reference, potentially deleting the object; used by ref_ptr.
+ friend void ref_ptr_release_ref(T* p) {
+ delete_thunk<T, has_custom_deleter::value>::Delete(p);
+ }
+
+ // Adds a reference.
+ // ref_ptr should be used instead of this in most cases. This is required
+ // for when interoperating with marshaling APIs.
+ void AddReference() { ref_ptr_add_ref(static_cast<T*>(this)); }
+
+ // Releases a reference, potentially deleting the object.
+ // ref_ptr should be used instead of this in most cases. This is required
+ // for when interoperating with marshaling APIs.
+ void ReleaseReference() { ref_ptr_release_ref(static_cast<T*>(this)); }
+
+ protected:
+ RefObject() { ref_ptr_add_ref(static_cast<T*>(this)); }
+ RefObject(const RefObject&) = default;
+ RefObject& operator=(const RefObject&) { return *this; }
+
+ std::atomic<intptr_t> counter_{0};
+};
+
+// Various comparison operator overloads.
+
+template <class T, class U>
+inline bool operator==(ref_ptr<T> const& a, ref_ptr<U> const& b) {
+ return a.get() == b.get();
+}
+
+template <class T, class U>
+inline bool operator!=(ref_ptr<T> const& a, ref_ptr<U> const& b) {
+ return a.get() != b.get();
+}
+
+template <class T, class U>
+inline bool operator==(ref_ptr<T> const& a, U* b) {
+ return a.get() == b;
+}
+
+template <class T, class U>
+inline bool operator!=(ref_ptr<T> const& a, U* b) {
+ return a.get() != b;
+}
+
+template <class T, class U>
+inline bool operator==(T* a, ref_ptr<U> const& b) {
+ return a == b.get();
+}
+
+template <class T, class U>
+inline bool operator!=(T* a, ref_ptr<U> const& b) {
+ return a != b.get();
+}
+
+template <class T>
+inline bool operator<(ref_ptr<T> const& a, ref_ptr<T> const& b) {
+ return a.get() < b.get();
+}
+
+// Swaps the pointers of two ref_ptrs.
+template <class T>
+void swap(ref_ptr<T>& lhs, ref_ptr<T>& rhs) {
+ lhs.swap(rhs);
+}
+
+} // namespace iree
+
+#endif // IREE_BASE_REF_PTR_H_
diff --git a/base/ref_ptr_test.cc b/base/ref_ptr_test.cc
new file mode 100644
index 0000000..5642eef
--- /dev/null
+++ b/base/ref_ptr_test.cc
@@ -0,0 +1,330 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/ref_ptr.h"
+
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace {
+
+class MyType : public RefObject<MyType> {
+ public:
+ int x = 5;
+
+ using RefObject<MyType>::counter_; // Expose for testing.
+};
+
+TEST(RefPtrTest, Construction) {
+ // Empty.
+ ref_ptr<MyType> n1;
+ EXPECT_EQ(nullptr, n1.get());
+ ref_ptr<MyType> n2(nullptr);
+ EXPECT_EQ(nullptr, n2.get());
+
+ // Assign a new ptr and add ref.
+ MyType* a_ptr = new MyType();
+ EXPECT_EQ(1, a_ptr->counter_);
+ ref_ptr<MyType> a(a_ptr);
+ EXPECT_EQ(1, a->counter_);
+
+ // Assign existing ptr without adding a ref.
+ ref_ptr<MyType> b(a_ptr);
+ EXPECT_EQ(1, b->counter_);
+
+ // Add a new ref.
+ ref_ptr<MyType> c = add_ref(b);
+ EXPECT_EQ(2, c->counter_);
+
+ b.release();
+}
+
+TEST(RefPtrTest, Assign) {
+ // Ok to assign nothing.
+ ref_ptr<MyType> n1 = assign_ref<MyType>(nullptr);
+ EXPECT_EQ(nullptr, n1.get());
+
+ ref_ptr<MyType> mt = make_ref<MyType>();
+ EXPECT_EQ(1, mt->counter_);
+ ref_ptr<MyType> n2 = assign_ref(mt.get());
+ EXPECT_EQ(1, mt->counter_);
+ mt.release(); // must release, as we assigned to n2.
+ EXPECT_EQ(1, n2->counter_);
+ n2.reset();
+}
+
+TEST(RefPtrTest, Retain) {
+ // Ok to retain nothing.
+ ref_ptr<MyType> n1 = add_ref<MyType>(nullptr);
+ EXPECT_EQ(nullptr, n1.get());
+
+ ref_ptr<MyType> mt = make_ref<MyType>();
+ EXPECT_EQ(1, mt->counter_);
+ ref_ptr<MyType> n2 = add_ref(mt.get());
+ EXPECT_EQ(2, mt->counter_);
+ mt.reset();
+ EXPECT_EQ(1, n2->counter_);
+ n2.reset();
+}
+
+TEST(RefPtrTest, Reset) {
+ ref_ptr<MyType> a(new MyType());
+ ref_ptr<MyType> b(new MyType());
+
+ // Reset to drop reference.
+ ref_ptr<MyType> a_copy = add_ref(a);
+ EXPECT_EQ(2, a_copy->counter_);
+ a.reset();
+ EXPECT_EQ(1, a_copy->counter_);
+
+ // Reset via = operator.
+ a = nullptr;
+ EXPECT_EQ(1, a_copy->counter_);
+ a = add_ref(a_copy);
+ EXPECT_EQ(2, a_copy->counter_);
+
+ // No-op on empty ptrs.
+ ref_ptr<MyType> n;
+ n.reset();
+ n.assign(nullptr);
+}
+
+TEST(RefPtrTest, ReleaseAssign) {
+ ref_ptr<MyType> a(new MyType());
+
+ // Release a's pointer.
+ MyType* a_raw_ptr = a.get();
+ MyType* a_ptr = a.release();
+ EXPECT_EQ(a_raw_ptr, a_ptr);
+ EXPECT_EQ(nullptr, a.get());
+ EXPECT_EQ(1, a_ptr->counter_);
+
+ // Re-wrap in a ref_ptr.
+ a.assign(a_ptr);
+ EXPECT_EQ(1, a->counter_);
+
+ // No-op on empty ptrs.
+ ref_ptr<MyType> n;
+ EXPECT_EQ(nullptr, n.release());
+}
+
+TEST(RefPtrTest, Accessors) {
+ ref_ptr<MyType> a(new MyType());
+ EXPECT_EQ(5, a->x);
+ a->x = 100;
+ EXPECT_EQ(100, a->x);
+
+ MyType& ra = *a;
+ ra.x = 200;
+ EXPECT_EQ(200, ra.x);
+
+ const MyType& cra = *a;
+ EXPECT_EQ(200, cra.x);
+}
+
+TEST(RefPtrTest, BooleanExpressions) {
+ ref_ptr<MyType> a(new MyType());
+ ref_ptr<MyType> n;
+
+ EXPECT_NE(nullptr, a.get());
+ EXPECT_TRUE(a);
+ EXPECT_FALSE(!a);
+ EXPECT_EQ(true, static_cast<bool>(a));
+
+ EXPECT_EQ(nullptr, n.get());
+ EXPECT_FALSE(n);
+ EXPECT_TRUE(!n);
+ EXPECT_EQ(false, static_cast<bool>(n));
+}
+
+TEST(RefPtrTest, Comparisons) {
+ ref_ptr<MyType> a(new MyType());
+ ref_ptr<MyType> b(new MyType());
+ ref_ptr<MyType> n;
+
+ EXPECT_TRUE(a == a);
+ EXPECT_TRUE(a == a.get());
+ EXPECT_TRUE(a.get() == a);
+ EXPECT_FALSE(a != a);
+ EXPECT_FALSE(a != a.get());
+ EXPECT_FALSE(a.get() != a);
+
+ EXPECT_FALSE(a == b);
+ EXPECT_FALSE(a == b.get());
+ EXPECT_FALSE(a.get() == b);
+ EXPECT_TRUE(a != b);
+ EXPECT_TRUE(a != b.get());
+ EXPECT_TRUE(a.get() != b);
+
+ EXPECT_TRUE(n == n);
+ EXPECT_TRUE(n == n.get());
+ EXPECT_TRUE(n.get() == n);
+ EXPECT_FALSE(n != n);
+ EXPECT_FALSE(n != n.get());
+ EXPECT_FALSE(n.get() != n);
+
+ EXPECT_FALSE(a < a);
+ EXPECT_TRUE(n < a);
+}
+
+TEST(RefPtrTest, Swap) {
+ ref_ptr<MyType> a(new MyType());
+ ref_ptr<MyType> b(new MyType());
+ MyType* a_ptr = a.get();
+ MyType* b_ptr = b.get();
+
+ swap(a, a);
+ EXPECT_EQ(a_ptr, a);
+
+ swap(a, b);
+ EXPECT_EQ(a_ptr, b.get());
+ EXPECT_EQ(b_ptr, a.get());
+
+ swap(a, b);
+ EXPECT_EQ(a_ptr, a.get());
+ EXPECT_EQ(b_ptr, b.get());
+
+ ref_ptr<MyType> c;
+ swap(a, c);
+ EXPECT_EQ(a_ptr, c.get());
+ EXPECT_EQ(nullptr, a.get());
+}
+
+TEST(RefPtrTest, Move) {
+ auto a = make_ref<MyType>();
+ auto b = make_ref<MyType>();
+ ref_ptr<MyType> c;
+ EXPECT_EQ(nullptr, c.get());
+
+ c = std::move(a);
+ EXPECT_NE(nullptr, c.get());
+
+ b = std::move(c);
+ EXPECT_NE(nullptr, b.get());
+}
+
+TEST(RefPtrTest, MoveCompatible) {
+ struct MyBaseType : public RefObject<MyBaseType> {
+ int x = 5;
+ using RefObject<MyBaseType>::counter_; // Expose for testing.
+ };
+ struct MyTypeA : public MyBaseType {
+ int a = 6;
+ };
+ struct MyTypeB : public MyBaseType {
+ int b = 7;
+ };
+
+ ref_ptr<MyTypeA> a = make_ref<MyTypeA>();
+ EXPECT_EQ(1, a->counter_);
+ ref_ptr<MyBaseType> base = add_ref(a);
+ EXPECT_EQ(a.get(), base.get());
+ EXPECT_EQ(2, a->counter_);
+
+ base = make_ref<MyTypeB>();
+ EXPECT_EQ(1, a->counter_);
+ EXPECT_EQ(1, base->counter_);
+}
+
+TEST(RefPtrTest, StackAllocation) {
+ static int alloc_count = 0;
+ class StackAllocationType : public RefObject<StackAllocationType> {
+ public:
+ StackAllocationType() { ++alloc_count; }
+ ~StackAllocationType() { --alloc_count; }
+ };
+ {
+ StackAllocationType a;
+ EXPECT_EQ(1, alloc_count);
+ }
+ EXPECT_EQ(0, alloc_count);
+}
+
+TEST(RefPtrTest, DefaultDeleter) {
+ static int alloc_count = 0;
+ class DefaultDeleterType : public RefObject<DefaultDeleterType> {
+ public:
+ DefaultDeleterType() { ++alloc_count; }
+ ~DefaultDeleterType() { --alloc_count; }
+ };
+
+ // Empty is ok.
+ ref_ptr<DefaultDeleterType> n;
+ n.reset();
+
+ // Lifecycle.
+ EXPECT_EQ(0, alloc_count);
+ ref_ptr<DefaultDeleterType> a = make_ref<DefaultDeleterType>();
+ EXPECT_EQ(1, alloc_count);
+ a.reset();
+ EXPECT_EQ(0, alloc_count);
+}
+
+TEST(RefPtrTest, InlineDeallocator) {
+ static int alloc_count = 0;
+ class CustomDeleterType : public RefObject<CustomDeleterType> {
+ public:
+ CustomDeleterType() { ++alloc_count; }
+ static void Delete(CustomDeleterType* ptr) {
+ --alloc_count;
+ ::operator delete(ptr);
+ }
+ };
+
+ // Empty is ok.
+ ref_ptr<CustomDeleterType> n;
+ n.reset();
+
+ // Lifecycle.
+ EXPECT_EQ(0, alloc_count);
+ auto a = make_ref<CustomDeleterType>();
+ EXPECT_EQ(1, alloc_count);
+ a.reset();
+ EXPECT_EQ(0, alloc_count);
+}
+
+class VirtualDtorTypeA : public RefObject<VirtualDtorTypeA> {
+ public:
+ VirtualDtorTypeA() { ++alloc_count_a; }
+ virtual ~VirtualDtorTypeA() { --alloc_count_a; }
+ static int alloc_count_a;
+};
+int VirtualDtorTypeA::alloc_count_a = 0;
+
+class VirtualDtorTypeB : public VirtualDtorTypeA {
+ public:
+ VirtualDtorTypeB() { ++alloc_count_b; }
+ ~VirtualDtorTypeB() override { --alloc_count_b; }
+ static int alloc_count_b;
+};
+int VirtualDtorTypeB::alloc_count_b = 0;
+
+TEST(RefPtrTest, VirtualDestructor) {
+ // Empty is ok.
+ ref_ptr<VirtualDtorTypeB> n;
+ n.reset();
+
+ // Lifecycle.
+ EXPECT_EQ(0, VirtualDtorTypeA::alloc_count_a);
+ EXPECT_EQ(0, VirtualDtorTypeB::alloc_count_b);
+ ref_ptr<VirtualDtorTypeB> a = make_ref<VirtualDtorTypeB>();
+ EXPECT_EQ(1, VirtualDtorTypeA::alloc_count_a);
+ EXPECT_EQ(1, VirtualDtorTypeB::alloc_count_b);
+ a.reset();
+ EXPECT_EQ(0, VirtualDtorTypeA::alloc_count_a);
+ EXPECT_EQ(0, VirtualDtorTypeB::alloc_count_b);
+}
+
+} // namespace
+} // namespace iree
diff --git a/base/shape.cc b/base/shape.cc
new file mode 100644
index 0000000..611d295
--- /dev/null
+++ b/base/shape.cc
@@ -0,0 +1,100 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/shape.h"
+
+#include <cstddef>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "base/source_location.h"
+#include "base/status.h"
+
+namespace iree {
+
+Shape::Shape(const int* values, int size) : rank_(size) {
+ QCHECK_LE(size, kMaxRank)
+ << "Max rank of " << kMaxRank << ", shape has " << size;
+ std::memcpy(value_, values, size * sizeof(int));
+}
+
+std::string Shape::DebugString() const {
+ return absl::StrCat("[", absl::StrJoin(subspan(), ","), "]");
+}
+
+absl::Span<const int> Shape::subspan(size_type pos, size_type len) const {
+ if (len == npos) {
+ len = rank_ - pos;
+ }
+ return absl::MakeConstSpan(&value_[pos], len);
+}
+
+void Shape::push_back(int dim) {
+ DCHECK_LE(rank_ + 1, kMaxRank);
+ value_[rank_++] = dim;
+}
+
+void Shape::insert(iterator pos, int dim) {
+ int axis = static_cast<int>(pos - value_);
+ DCHECK_GE(axis, 0);
+ DCHECK_LE(axis, rank_);
+ DCHECK_LE(rank_ + 1, kMaxRank);
+ ++rank_;
+ for (int i = rank_ - 1; i > axis; --i) {
+ value_[i] = value_[i - 1];
+ }
+ value_[axis] = dim;
+}
+
+void Shape::erase(iterator pos) {
+ int axis = static_cast<int>(pos - value_);
+ DCHECK_GE(axis, 0);
+ DCHECK_LE(axis, rank_);
+ for (int i = axis; i < rank_ - 1; ++i) {
+ value_[i] = value_[i + 1];
+ }
+ --rank_;
+}
+
+int Shape::element_count() const {
+ size_t element_count = 1;
+ for (int i = 0; i < rank_; ++i) {
+ int dim = value_[i];
+ if (dim == -1) {
+ return 0;
+ }
+ element_count *= dim;
+ }
+ return element_count;
+}
+
+StatusOr<int> Shape::ResolveAxis(int axis) const {
+ if (rank_ == 0 && (axis == -1 || axis == 0)) {
+ // Scalar axes resolves to 0.
+ return 0;
+ }
+
+ int new_axis = axis;
+ if (new_axis < 0) {
+ new_axis += rank_;
+ }
+ if (new_axis < 0 || new_axis >= rank_) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Axis " << new_axis << " (orig " << axis
+ << ") out of bounds of rank " << rank_;
+ }
+ return new_axis;
+}
+
+} // namespace iree
diff --git a/base/shape.h b/base/shape.h
new file mode 100644
index 0000000..b0e3ddd
--- /dev/null
+++ b/base/shape.h
@@ -0,0 +1,156 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_SHAPE_H_
+#define IREE_BASE_SHAPE_H_
+
+#include <array>
+#include <cstring>
+#include <initializer_list>
+#include <iterator>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "absl/meta/type_traits.h"
+#include "absl/types/span.h"
+#include "base/logging.h"
+#include "base/status.h"
+
+namespace iree {
+
+// For simplicity we limit our shapes to a max of rank-N (shape.size() == N) as
+// this prevents dynamic allocations and rarely are there greater ranks.
+constexpr int kMaxRank = 5;
+
+// Represent indices and lengths of tensors.
+using Index = std::array<int, kMaxRank>;
+using Length = std::array<int, kMaxRank>;
+
+// Represents the number of elements in multiple dimensions.
+// Can be rank-0 (scalar) to rank-kMaxRank. Tries to match the API of
+// std::vector and can be converted to a Span via subspan().
+//
+// https://www.tensorflow.org/guide/tensors#shape
+class Shape {
+ public:
+ using size_type = int;
+ static constexpr size_type npos = ~(size_type(0)); // NOLINT
+ using iterator = int*;
+ using const_iterator = const int*;
+
+ Shape() = default;
+ Shape(const int* values, int size);
+ Shape(std::initializer_list<int> values)
+ : Shape(values.begin(), values.size()) {}
+ explicit Shape(absl::Span<const int> values)
+ : Shape(values.data(), values.size()) {}
+
+ template <typename Iterator>
+ using EnableIfForwardIterator = absl::enable_if_t<std::is_convertible<
+ typename std::iterator_traits<Iterator>::iterator_category,
+ std::forward_iterator_tag>::value>;
+ template <typename Iterator, EnableIfForwardIterator<Iterator>* = nullptr>
+ Shape(Iterator first, Iterator last) {
+ rank_ = std::distance(first, last);
+ QCHECK_LE(rank_, kMaxRank);
+ for (int i = 0; first != last; ++i, static_cast<void>(++first)) {
+ value_[i] = *first;
+ }
+ }
+
+ // Returns a string representation of the given shape.
+ std::string DebugString() const;
+
+ // Size (aka 'rank') of the shape, counting the number of dimensions.
+ constexpr size_type size() const noexcept { return rank_; }
+
+ // Whether the shape is rank-0 (scalar).
+ constexpr bool empty() const noexcept { return rank_ == 0; }
+
+ // Returns the total elements in the tensor shape.
+ // Returns 0 if the tensor shape is not complete and 1 if the shape is a
+ // scalar value.
+ int element_count() const;
+
+ // Resolves an axis in [-R,R) to the real axis value and verifies the range.
+ StatusOr<int> ResolveAxis(int axis) const;
+
+ // Compares two shapes for equality.
+ inline static bool Equal(const Shape& a, const Shape& b) {
+ return a.rank_ == b.rank_ &&
+ std::memcmp(a.value_, b.value_, a.rank_ * sizeof(value_[0])) == 0;
+ }
+
+ int& operator[](size_type i) noexcept {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, rank_);
+ return value_[i];
+ }
+
+ const int& operator[](size_type i) const noexcept {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, rank_);
+ return value_[i];
+ }
+
+ int front() const noexcept {
+ DCHECK_GE(rank_, 1);
+ return value_[0];
+ }
+
+ int back() const noexcept {
+ DCHECK_GE(rank_, 1);
+ return value_[rank_ - 1];
+ }
+
+ constexpr iterator begin() const noexcept {
+ return const_cast<iterator>(&value_[0]);
+ }
+ constexpr iterator end() const noexcept {
+ return const_cast<iterator>(&value_[rank_]);
+ }
+ constexpr const_iterator cbegin() const noexcept { return &value_[0]; }
+ constexpr const_iterator cend() const noexcept { return &value_[rank_]; }
+
+ absl::Span<const int> subspan(size_type pos = 0, size_type len = npos) const;
+ absl::Span<const int> data() const { return subspan(); }
+
+ void push_back(int dim);
+
+ void insert(iterator pos, int dim);
+
+ void erase(iterator pos);
+
+ void clear() { rank_ = 0; }
+
+ private:
+ size_type rank_ = 0;
+ int value_[kMaxRank];
+};
+
+inline bool operator==(const Shape& a, const Shape& b) {
+ return Shape::Equal(a, b);
+}
+
+inline bool operator!=(const Shape& a, const Shape& b) { return !(a == b); }
+
+inline std::ostream& operator<<(std::ostream& stream, const Shape& shape) {
+ stream << shape.DebugString();
+ return stream;
+}
+
+} // namespace iree
+
+#endif // IREE_BASE_SHAPE_H_
diff --git a/base/shape_test.cc b/base/shape_test.cc
new file mode 100644
index 0000000..ac7b071
--- /dev/null
+++ b/base/shape_test.cc
@@ -0,0 +1,222 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/shape.h"
+
+#include "base/status.h"
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace {
+
+using ::testing::ElementsAre;
+
+// Tests shapes that represent 0-D scalar values.
+TEST(ShapeTest, Scalar) {
+ Shape shape;
+ EXPECT_EQ(0, shape.size());
+ EXPECT_TRUE(shape.empty());
+ EXPECT_EQ(1, shape.element_count());
+ EXPECT_EQ(shape, shape);
+ EXPECT_EQ(0, shape.subspan().size());
+ for (const int dim : shape) {
+ FAIL() << "Should have no dimensions, have: " << dim;
+ }
+ EXPECT_EQ(shape.begin(), shape.end());
+ EXPECT_EQ(shape.cbegin(), shape.cend());
+ shape.clear();
+ EXPECT_EQ(0, shape.size());
+}
+
+// Tests the various ways of constructing a 1+D shape.
+TEST(ShapeTest, NonScalarConstruction) {
+ EXPECT_EQ(0, Shape().size());
+ EXPECT_EQ(0, Shape({}).size());
+ EXPECT_EQ(1, Shape({10}).size());
+ EXPECT_EQ(4, Shape({10, 20, 30, 40}).size());
+
+ std::vector<int> empty_data = {};
+ EXPECT_EQ(0, Shape(empty_data.data(), empty_data.size()).size());
+ EXPECT_EQ(0, Shape(empty_data.begin(), empty_data.end()).size());
+ EXPECT_EQ(0, Shape(absl::MakeConstSpan(empty_data)).size());
+
+ EXPECT_THAT(Shape({}).subspan(), ElementsAre());
+ EXPECT_THAT(Shape({10}).subspan(), ElementsAre(10));
+ EXPECT_THAT(Shape({10, 20, 30, 40}).subspan(), ElementsAre(10, 20, 30, 40));
+
+ std::vector<int> valid_data = {10, 20, 30, 40};
+ EXPECT_THAT(Shape(valid_data.begin(), valid_data.end()).subspan(),
+ ElementsAre(10, 20, 30, 40));
+ EXPECT_THAT(Shape(absl::MakeConstSpan(valid_data)).subspan(),
+ ElementsAre(10, 20, 30, 40));
+}
+
+// Tests shapes that represent 1+D multidimensional values.
+TEST(ShapeTest, NonScalarAccess) {
+ Shape shape = {1, 2, 3, 4};
+ EXPECT_EQ(4, shape.size());
+ EXPECT_FALSE(shape.empty());
+ EXPECT_EQ(1 * 2 * 3 * 4, shape.element_count());
+ EXPECT_EQ(shape, shape);
+ EXPECT_NE(shape, Shape({4, 3, 2, 1}));
+ EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4));
+ std::vector<int> readout;
+ for (const int dim : shape) {
+ readout.push_back(dim);
+ }
+ EXPECT_THAT(readout, ElementsAre(1, 2, 3, 4));
+ EXPECT_EQ(1, shape[0]);
+ EXPECT_EQ(2, shape[1]);
+ EXPECT_EQ(3, shape[2]);
+ EXPECT_EQ(4, shape[3]);
+ EXPECT_EQ(1, shape.front());
+ EXPECT_EQ(4, shape.back());
+}
+
+TEST(ShapeTest, PushBack) {
+ Shape shape;
+ EXPECT_EQ(0, shape.size());
+
+ shape.push_back(10);
+ EXPECT_EQ(1, shape.size());
+ EXPECT_EQ(10, shape.front());
+ EXPECT_EQ(10, shape.back());
+ EXPECT_EQ(10, shape[0]);
+ EXPECT_THAT(shape.subspan(), ElementsAre(10));
+
+ shape.push_back(20);
+ EXPECT_EQ(2, shape.size());
+ EXPECT_EQ(10, shape.front());
+ EXPECT_EQ(20, shape.back());
+ EXPECT_EQ(10, shape[0]);
+ EXPECT_EQ(20, shape[1]);
+ EXPECT_THAT(shape.subspan(), ElementsAre(10, 20));
+}
+
+TEST(ShapeTest, Insert) {
+ Shape shape;
+ EXPECT_EQ(0, shape.size());
+
+ shape.insert(shape.begin(), 20);
+ EXPECT_THAT(shape.subspan(), ElementsAre(20));
+ shape.insert(shape.begin(), 10);
+ EXPECT_THAT(shape.subspan(), ElementsAre(10, 20));
+ shape.insert(shape.end(), 40);
+ EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 40));
+ shape.insert(shape.begin() + 2, 30);
+ EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 30, 40));
+
+ Shape ex_shape{72, 4};
+ ex_shape.insert(ex_shape.begin(), 144);
+ EXPECT_THAT(ex_shape.subspan(), ElementsAre(144, 72, 4));
+}
+
+TEST(ShapeTest, Erase) {
+ Shape shape = {1, 2, 3, 4};
+ EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4));
+ shape.erase(shape.begin());
+ EXPECT_THAT(shape.subspan(), ElementsAre(2, 3, 4));
+ shape.erase(shape.end());
+ EXPECT_THAT(shape.subspan(), ElementsAre(2, 3));
+ shape.erase(shape.begin() + 1);
+ EXPECT_THAT(shape.subspan(), ElementsAre(2));
+ shape.erase(shape.end());
+ EXPECT_THAT(shape.subspan(), ElementsAre());
+}
+
+TEST(ShapeTest, Clear) {
+ Shape shape;
+ EXPECT_EQ(0, shape.size());
+ shape.clear();
+ EXPECT_EQ(0, shape.size());
+
+ shape = Shape({1});
+ shape.clear();
+ EXPECT_EQ(0, shape.size());
+
+ shape = Shape({1, 2, 3, 4});
+ shape.clear();
+ EXPECT_EQ(0, shape.size());
+}
+
+TEST(ShapeTest, DebugString) {
+ EXPECT_EQ("[]", Shape({}).DebugString());
+ EXPECT_EQ("[1]", Shape({1}).DebugString());
+ EXPECT_EQ("[1,2]", Shape({1, 2}).DebugString());
+}
+
+TEST(ShapeTest, ElementCount) {
+ EXPECT_EQ(1, Shape({}).element_count());
+ EXPECT_EQ(0, Shape({0}).element_count());
+ EXPECT_EQ(1, Shape({1}).element_count());
+ EXPECT_EQ(2, Shape({2, 1}).element_count());
+ EXPECT_EQ(10, Shape({2, 5}).element_count());
+ EXPECT_EQ(9216, Shape({72, 1, 128}).element_count());
+ EXPECT_EQ(9216, Shape({1, 72, 128}).element_count());
+
+ // Partial shaping should yield no elements.
+ EXPECT_EQ(0, Shape({1, -1, 2, 3}).element_count());
+}
+
+TEST(ShapeTest, ResolveAxis) {
+ int axis;
+ ASSERT_OK_AND_ASSIGN(axis, Shape({0}).ResolveAxis(0));
+ EXPECT_EQ(0, axis);
+ ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(1));
+ EXPECT_EQ(1, axis);
+ ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(2));
+ EXPECT_EQ(2, axis);
+
+ EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(3).status()));
+}
+
+TEST(ShapeTest, ResolveAxisNegative) {
+ int axis;
+ ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-3));
+ EXPECT_EQ(0, axis);
+ ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-2));
+ EXPECT_EQ(1, axis);
+ ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-1));
+ EXPECT_EQ(2, axis);
+
+ EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(-4).status()));
+}
+
+TEST(ShapeTest, ResolveAxisScalar) {
+ int axis;
+ ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(0));
+ EXPECT_EQ(0, axis);
+ ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(-1));
+ EXPECT_EQ(0, axis);
+
+ EXPECT_TRUE(IsInvalidArgument(Shape({}).ResolveAxis(1).status()));
+}
+
+TEST(ShapeTest, Equality) {
+ EXPECT_EQ(Shape({}), Shape({}));
+ EXPECT_EQ(Shape({0}), Shape({0}));
+ EXPECT_EQ(Shape({1}), Shape({1}));
+ EXPECT_EQ(Shape({1, 2}), Shape({1, 2}));
+
+ EXPECT_NE(Shape({}), Shape({1}));
+ EXPECT_NE(Shape({-1}), Shape({1}));
+ EXPECT_NE(Shape({1}), Shape({}));
+ EXPECT_NE(Shape({1}), Shape({2}));
+ EXPECT_NE(Shape({1, 2}), Shape({3, 4}));
+}
+
+} // namespace
+} // namespace iree
diff --git a/base/source_location.h b/base/source_location.h
new file mode 100644
index 0000000..c93db4a
--- /dev/null
+++ b/base/source_location.h
@@ -0,0 +1,24 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_SOURCE_LOCATION_H_
+#define IREE_BASE_SOURCE_LOCATION_H_
+
+#ifdef IREE_CONFIG_GOOGLE_INTERNAL
+#include "base/google/source_location_google.h"
+#else
+#include "base/internal/source_location.h"
+#endif // IREE_CONFIG_GOOGLE_INTERNAL
+
+#endif // IREE_BASE_SOURCE_LOCATION_H_
diff --git a/base/status.h b/base/status.h
new file mode 100644
index 0000000..feadf75
--- /dev/null
+++ b/base/status.h
@@ -0,0 +1,32 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_STATUS_H_
+#define IREE_BASE_STATUS_H_
+
+#ifdef IREE_CONFIG_GOOGLE_INTERNAL
+#include "base/google/status_google.h"
+#else
+#include "base/internal/status.h"
+#include "base/internal/status_builder.h"
+#include "base/internal/status_errno.h"
+#include "base/internal/status_errors.h"
+#include "base/internal/status_macros.h"
+#include "base/internal/status_win32_errors.h"
+#include "base/internal/statusor.h"
+#endif // IREE_CONFIG_GOOGLE_INTERNAL
+
+#include "base/source_location.h" // IWYU pragma: export
+
+#endif // IREE_BASE_STATUS_H_
diff --git a/base/status_matchers.h b/base/status_matchers.h
new file mode 100644
index 0000000..533f4a4
--- /dev/null
+++ b/base/status_matchers.h
@@ -0,0 +1,28 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_STATUS_MATCHERS_H_
+#define IREE_BASE_STATUS_MATCHERS_H_
+
+#ifdef IREE_CONFIG_GOOGLE_INTERNAL
+
+#include "base/google/status_matchers_google.h" // IWYU pragma: export
+
+#else
+
+#include "base/internal/status_matchers.h" // IWYU pragma: export
+
+#endif // IREE_CONFIG_GOOGLE_INTERNAL
+
+#endif // IREE_BASE_STATUS_MATCHERS_H_
diff --git a/iree/base/target_platform.h b/base/target_platform.h
similarity index 100%
rename from iree/base/target_platform.h
rename to base/target_platform.h
diff --git a/iree/base/time.h b/base/time.h
similarity index 100%
rename from iree/base/time.h
rename to base/time.h
diff --git a/base/tracing.cc b/base/tracing.cc
new file mode 100644
index 0000000..0708532
--- /dev/null
+++ b/base/tracing.cc
@@ -0,0 +1,188 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Force the header to detect WTF_ENABLE so that this library builds
+// (for when building recursively).
+#if !defined(WTF_ENABLE)
+#define WTF_ENABLE
+#endif
+
+#include "base/tracing.h"
+
+#include <thread> // NOLINT: Fiber doesn't work during startup on Android.
+
+#include "absl/base/attributes.h"
+#include "absl/base/const_init.h"
+#include "absl/base/thread_annotations.h"
+#include "absl/flags/flag.h"
+#include "absl/strings/str_cat.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "base/file_io.h"
+#include "base/file_path.h"
+#include "base/init.h"
+#include "base/logging.h"
+#include "base/status.h"
+
+ABSL_FLAG(int32_t, iree_trace_file_period, 5,
+ "Seconds between automatic flushing of WTF trace files. 0 to "
+ "disable auto-flush.");
+ABSL_FLAG(std::string, iree_trace_file, "/dev/null",
+ "wtf-trace file to save if --define=GLOBAL_WTF_ENABLE=1 was used "
+ "when building.");
+
+namespace iree {
+namespace {
+
+// Guards global WTF state (like the flush fiber and IO).
+ABSL_CONST_INIT absl::Mutex global_tracing_mutex(absl::kConstInit);
+
+// True when tracing has been enabled and initialized.
+bool global_tracing_initialized ABSL_GUARDED_BY(global_tracing_mutex) = false;
+
+// If there is an existing file at the given path back it up by moving it aside.
+// Only kMaxBackups will be kept to avoid unbounded growth.
+void RollTraceFiles(const std::string& path) {
+ std::string path_stem = file_path::JoinPaths(file_path::DirectoryName(path),
+ file_path::Stem(path));
+ const int kMaxBackups = 5;
+ for (int i = kMaxBackups; i >= 0; i--) {
+ std::string source_name;
+ if (i > 0) {
+ source_name = absl::StrCat(path_stem, ".", i, ".wtf-trace");
+ } else {
+ source_name = path;
+ }
+ if (!file_io::FileExists(source_name).ok()) {
+ continue;
+ }
+
+ Status status;
+ if (i == kMaxBackups) {
+ status = file_io::DeleteFile(source_name);
+ } else {
+ std::string backup_name =
+ absl::StrCat(path_stem, ".", (i + 1), ".wtf-trace");
+ status = file_io::MoveFile(source_name, backup_name);
+ }
+ if (!status.ok()) {
+ LOG(WARNING) << "Could not remove backup trace file " << source_name
+ << ": " << status;
+ }
+ }
+}
+
+// Flushes all recorded trace data since the last flush.
+void FlushTraceFile() ABSL_EXCLUSIVE_LOCKS_REQUIRED(global_tracing_mutex) {
+ if (!global_tracing_initialized) return;
+
+ const auto& trace_path = absl::GetFlag(FLAGS_iree_trace_file);
+
+ static ::wtf::Runtime::SaveCheckpoint checkpoint;
+ static bool is_first_flush = true;
+
+ if (is_first_flush && trace_path != "/dev/null") {
+ // Backup existing any existing trace files at the specified path.
+ RollTraceFiles(trace_path);
+ }
+
+ auto save_options =
+ ::wtf::Runtime::SaveOptions::ForStreamingFile(&checkpoint);
+ if (is_first_flush) {
+ // On the first time, truncate the file. All subsequent flushes append.
+ save_options.open_mode = std::ios_base::trunc;
+ }
+
+ is_first_flush = false;
+
+ auto* runtime = ::wtf::Runtime::GetInstance();
+ if (!runtime->SaveToFile(trace_path, save_options)) {
+ LOG(ERROR) << "Error saving WTF file: " << trace_path;
+ return;
+ }
+
+ VLOG(1) << "Flushed WTF trace to: " << trace_path;
+}
+
+} // namespace
+
+void InitializeTracing() {
+ if (!::wtf::kMasterEnable) {
+ if (!absl::GetFlag(FLAGS_iree_trace_file).empty()) {
+ LOG(WARNING) << "WTF trace save requested but WTF is not compiled in. "
+ << "Enable by building with --define=GLOBAL_WTF_ENABLE=1.";
+ }
+ return;
+ }
+
+ absl::MutexLock lock(&global_tracing_mutex);
+ if (global_tracing_initialized) return;
+ global_tracing_initialized = true;
+
+ LOG(INFO) << "Tracing enabled and streaming to: "
+ << absl::GetFlag(FLAGS_iree_trace_file);
+
+ // Enable tracing on this thread, which we know is main.
+ IREE_TRACE_THREAD_ENABLE("main");
+
+ // Register atexit callback to stop tracking.
+ atexit(StopTracing);
+
+ // Launch a thread to periodically flush the trace.
+ if (absl::GetFlag(FLAGS_iree_trace_file_period) > 0) {
+ auto flush_thread = std::thread(+[]() {
+ absl::Duration period =
+ absl::Seconds(absl::GetFlag(FLAGS_iree_trace_file_period));
+ while (true) {
+ absl::SleepFor(period);
+ absl::MutexLock lock(&global_tracing_mutex);
+ if (!global_tracing_initialized) {
+ return;
+ }
+ FlushTraceFile();
+ }
+ });
+ flush_thread.detach();
+ }
+}
+
+// Stops tracing if currently initialized.
+void StopTracing() {
+ if (!::wtf::kMasterEnable) return;
+ absl::MutexLock lock(&global_tracing_mutex);
+ if (!global_tracing_initialized) return;
+
+ // Flush any pending trace data.
+ FlushTraceFile();
+
+ // Mark WTF as uninitialized to kill the flush thread.
+ global_tracing_initialized = false;
+
+ LOG(INFO) << "Tracing stopped and flushed to file: "
+ << absl::GetFlag(FLAGS_iree_trace_file);
+}
+
+void FlushTrace() {
+ if (!::wtf::kMasterEnable) return;
+ absl::MutexLock lock(&global_tracing_mutex);
+ if (!global_tracing_initialized) return;
+ FlushTraceFile();
+}
+
+} // namespace iree
+
+IREE_DECLARE_MODULE_INITIALIZER(iree_tracing);
+
+IREE_REGISTER_MODULE_INITIALIZER(iree_tracing, ::iree::InitializeTracing());
diff --git a/iree/base/tracing.h b/base/tracing.h
similarity index 100%
rename from iree/base/tracing.h
rename to base/tracing.h
diff --git a/base/tracing_disabled.cc b/base/tracing_disabled.cc
new file mode 100644
index 0000000..2360d54
--- /dev/null
+++ b/base/tracing_disabled.cc
@@ -0,0 +1,29 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file is linked in only when WTF is not enabled. It allows us to keep the
+// same flags and functions without needing to do a bunch of ifdef hackery or
+// undefok mangling.
+
+#include <cstdint>
+#include <string>
+
+#include "absl/flags/flag.h"
+#include "base/tracing.h"
+
+// TODO(benvanik): remove this when disabled so that we don't dep on flags.
+ABSL_FLAG(int32_t, iree_trace_file_period, 0,
+ "Flag for tracing. Use --define=GLOBAL_WTF_ENABLE=1 to enable WTF.");
+ABSL_FLAG(std::string, iree_trace_file, "",
+ "Flag for tracing. Use --define=GLOBAL_WTF_ENABLE=1 to enable WTF.");
diff --git a/base/wait_handle.cc b/base/wait_handle.cc
new file mode 100644
index 0000000..c8a99ce
--- /dev/null
+++ b/base/wait_handle.cc
@@ -0,0 +1,532 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/wait_handle.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <time.h>
+#include <unistd.h>
+
+#include <type_traits>
+#include <utility>
+
+#include "absl/container/fixed_array.h"
+#include "absl/strings/str_cat.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "base/source_location.h"
+#include "base/status.h"
+
+// TODO(benvanik): organize these macros - they are terrible.
+
+#if !defined(__ANDROID__) && !defined(OS_IOS) && !defined(__EMSCRIPTEN__)
+#define IREE_HAS_PPOLL 1
+#endif // !__ANDROID__ && !__EMSCRIPTEN__
+#define IREE_HAS_POLL 1
+
+#if !defined(OS_IOS) && !defined(OS_MACOSX) && !defined(__EMSCRIPTEN__)
+#define IREE_HAS_EVENTFD 1
+#endif
+#define IREE_HAS_PIPE 1
+// #define IREE_HAS_SYNC_FILE 1
+
+#if defined(IREE_HAS_EVENTFD)
+#include <sys/eventfd.h>
+#endif // IREE_HAS_EVENTFD
+
+namespace iree {
+
+namespace {
+
+constexpr int kInvalidFd = WaitableObject::kInvalidFd;
+constexpr int kSignaledFd = WaitableObject::kSignaledFd;
+
+// Retries a syscall until it succeeds or fails for a real reason.
+template <typename SyscallT, typename... ParamsT>
+StatusOr<typename std::result_of<SyscallT(ParamsT...)>::type> Syscall(
+ SyscallT syscall, ParamsT&&... params) {
+ while (true) {
+ const auto rv = syscall(std::forward<ParamsT>(params)...);
+ if (rv >= 0) return rv;
+ if (errno == EINTR) {
+ // Retry on EINTR.
+ continue;
+ } else {
+ return ErrnoToCanonicalStatus(errno, "");
+ }
+ }
+}
+
+#if defined(IREE_HAS_PPOLL)
+
+// ppoll(), present on Linux.
+// ppoll is preferred as it has a much better timing mechanism; poll can have a
+// large slop on the deadline.
+// Documentation: https://linux.die.net/man/2/poll
+StatusOr<int> SystemPoll(absl::Span<pollfd> poll_fds, absl::Time deadline) {
+ // Convert the deadline into a tmo_p struct for ppoll that controls whether
+ // the call is blocking or non-blocking. Note that we must do this every
+ // iteration of the loop as a previous ppoll may have taken some of the
+ // time.
+ //
+ // See the ppoll docs for more information as to what the expected value is:
+ // http://man7.org/linux/man-pages/man2/poll.2.html
+ timespec timeout_spec;
+ timespec* tmo_p;
+ if (deadline == absl::InfinitePast()) {
+ // 0 for non-blocking.
+ timeout_spec = {0};
+ tmo_p = &timeout_spec;
+ } else if (deadline == absl::InfiniteFuture()) {
+ // nullptr to ppoll() to block forever.
+ tmo_p = nullptr;
+ } else {
+ // Wait only for as much time as we have before the deadline is exceeded.
+ absl::Duration remaining_time = deadline - absl::Now();
+ if (remaining_time < absl::ZeroDuration()) {
+ // Note: we likely have already bailed before getting here with a negative
+ // duration.
+ return DeadlineExceededErrorBuilder(IREE_LOC);
+ }
+ timeout_spec = absl::ToTimespec(remaining_time);
+ tmo_p = &timeout_spec;
+ }
+ return Syscall(::ppoll, poll_fds.data(), poll_fds.size(), tmo_p, nullptr);
+}
+
+#elif defined(IREE_HAS_POLL)
+
+// poll(), present pretty much everywhere.
+// Documentation: https://linux.die.net/man/2/poll
+StatusOr<int> SystemPoll(absl::Span<pollfd> poll_fds, absl::Time deadline) {
+ int timeout;
+ if (deadline == absl::InfinitePast()) {
+ // Don't block.
+ timeout = 0;
+ } else if (deadline == absl::InfiniteFuture()) {
+ // Block forever.
+ timeout = -1;
+ } else {
+ absl::Duration remaining_time = deadline - absl::Now();
+ if (remaining_time < absl::ZeroDuration()) {
+ return DeadlineExceededErrorBuilder(IREE_LOC);
+ }
+ timeout = static_cast<int>(absl::ToInt64Milliseconds(remaining_time));
+ }
+ return Syscall(::poll, poll_fds.data(), poll_fds.size(), timeout);
+}
+
+#else
+#error "No SystemPoll implementation"
+#endif // IREE_HAS_PPOLL / IREE_HAS_POLL / etc
+
+// Builds the list of pollfds to for ppoll wait on and will perform any
+// required wait handle callbacks.
+//
+// The provided deadline will be observed if any of the wait handles needs to
+// block for acquiring an fd.
+StatusOr<absl::FixedArray<pollfd>> AcquireWaitHandles(
+ WaitHandle::WaitHandleSpan wait_handles, absl::Time deadline) {
+ absl::FixedArray<pollfd> poll_fds{wait_handles.size()};
+ for (int i = 0; i < wait_handles.size(); ++i) {
+ poll_fds[i].events = POLLIN | POLLPRI | POLLERR | POLLHUP | POLLNVAL;
+ poll_fds[i].revents = 0;
+ // NOTE: poll will ignore any negative fds and our kInvalidFd == -1 so we
+ // can still put them in the list and it'll just skip them.
+ if (!wait_handles[i] || !wait_handles[i]->object()) {
+ poll_fds[i].fd = kInvalidFd;
+ continue;
+ }
+
+ // Acquire the file descriptor for waiting.
+ // This may block (if |deadline| allows it) if the fd is not yet available.
+ // This is like a pre-wait for the actual poll operation. It can be bad with
+ // WaitAny, though we could handle that better here.
+ ASSIGN_OR_RETURN(auto fd_info,
+ wait_handles[i]->object()->AcquireFdForWait(deadline));
+ poll_fds[i].fd = fd_info.second;
+
+ // Abort if deadline exceeded.
+ if (deadline != absl::InfinitePast() && deadline < absl::Now()) {
+ return DeadlineExceededErrorBuilder(IREE_LOC)
+ << "Deadline exceeded acquiring for fds";
+ }
+ }
+ return poll_fds;
+}
+
+Status ClearFd(WaitableObject::FdType fd_type, int fd) {
+ // Read in a loop until the read would block.
+ // Depending on how the users setup the fd the act of reading may reset the
+ // entire handle (such as with the default eventfd mode) or multiple reads
+ // may be required (such as with semaphores).
+ while (true) {
+#if defined(IREE_HAS_EVENTFD)
+ eventfd_t val = 0;
+ int rv = ::eventfd_read(fd, &val);
+#elif defined(IREE_HAS_PIPE)
+ char buf;
+ int rv = ::read(fd, &buf, 1);
+#else
+ return UnimplementedErrorBuilder(IREE_LOC) << "fd_type cannot be cleared";
+#endif // IREE_HAS_EVENTFD
+ if (rv != -1) {
+ // Success! Keep going.
+ continue;
+ } else {
+ if (errno == EWOULDBLOCK) {
+ // The read would have blocked meaning that we've hit the end and
+ // successfully cleared the fd.
+ return OkStatus();
+ } else if (errno == EINTR) {
+ // Retry.
+ continue;
+ } else {
+ return ErrnoToCanonicalStatus(errno, "ClearFd failed");
+ }
+ }
+ }
+}
+
+// Performs a single poll on multiple fds and returns information about the
+// signaled fds, if any.
+Status MultiPoll(WaitHandle::WaitHandleSpan wait_handles,
+ absl::Span<pollfd> poll_fds, absl::Time deadline,
+ int* out_any_signaled_index, int* out_unsignaled_count) {
+ *out_any_signaled_index = -1;
+ *out_unsignaled_count = 0;
+
+ // poll has a nasty behavior where it allows -1 for fds... except for at [0].
+ // To keep the rest of the code sane we correct for that here as epoll doesn't
+ // have that behavior and we may want to special case this later.
+ bool any_valid_fds = true;
+ int swapped_zero_index = -1;
+ if (poll_fds[0].fd < 0) {
+ // Find a valid handle.
+ for (int i = 1; i < poll_fds.size(); ++i) {
+ if (poll_fds[i].fd > 0) {
+ swapped_zero_index = i;
+ std::swap(poll_fds[0], poll_fds[i]);
+ break;
+ }
+ }
+ if (swapped_zero_index == -1) {
+ // No valid handles found, meaning that all handles are invalid.
+ // We'll skip the wait below so we can share the processing code for any
+ // fds that may be kSignaledFd.
+ any_valid_fds = false;
+ }
+ }
+
+ // Pass handles to ppoll.
+ // http://man7.org/linux/man-pages/man2/poll.2.html
+ if (any_valid_fds) {
+ ASSIGN_OR_RETURN(int rv, SystemPoll(poll_fds, deadline));
+ if (rv == 0) {
+ // Call timed out and no descriptors were ready.
+ // If this was just a poll then that's fine.
+ return DeadlineExceededErrorBuilder(IREE_LOC);
+ }
+ }
+
+ // If we had swapped fds[0] above we need to correct for that now.
+ if (swapped_zero_index != -1) {
+ std::swap(poll_fds[0], poll_fds[swapped_zero_index]);
+ }
+
+ // |rv| denotes the number of fds that were ready. Run through the list and
+ // find the ones that were ready and mark them as completed.
+ for (int i = 0; i < poll_fds.size(); ++i) {
+ if (poll_fds[i].fd == kSignaledFd || poll_fds[i].revents == POLLIN) {
+ // First attempt any resolve actions. If these fail we can't consider the
+ // fd as having been signaled.
+ ASSIGN_OR_RETURN(
+ bool resolved,
+ wait_handles[i]->object()->TryResolveWakeOnFd(poll_fds[i].fd));
+ if (!resolved) {
+ ++(*out_unsignaled_count);
+ continue;
+ }
+
+ // Successful wait. Kill the fd so it is ignored on the next poll.
+ poll_fds[i].fd = kInvalidFd;
+ *out_any_signaled_index = i;
+ } else if (poll_fds[i].revents) {
+ if (poll_fds[i].revents & POLLERR) {
+ return InternalErrorBuilder(IREE_LOC);
+ } else if (poll_fds[i].revents & POLLHUP) {
+ return CancelledErrorBuilder(IREE_LOC);
+ } else if (poll_fds[i].revents & POLLNVAL) {
+ return InvalidArgumentErrorBuilder(IREE_LOC);
+ } else {
+ return UnknownErrorBuilder(IREE_LOC);
+ }
+ } else if (poll_fds[i].fd != kInvalidFd) {
+ ++(*out_unsignaled_count);
+ }
+ }
+
+ return OkStatus();
+}
+
+} // namespace
+
+// static
+std::atomic<uint64_t> WaitHandle::next_unique_id_{1};
+
+// static
+WaitHandle WaitHandle::AlwaysSignaling() {
+ class AlwaysSignalingObject : public WaitableObject {
+ public:
+ std::string DebugString() const override { return "signal"; }
+ StatusOr<std::pair<FdType, int>> AcquireFdForWait(
+ absl::Time deadline) override {
+ return std::make_pair(FdType::kPermanent, kSignaledFd);
+ }
+ StatusOr<bool> TryResolveWakeOnFd(int fd) override { return true; }
+ };
+ static auto* obj = new AlwaysSignalingObject();
+ return WaitHandle(add_ref(obj));
+}
+
+// static
+WaitHandle WaitHandle::AlwaysFailing() {
+ class AlwaysFailingObject : public WaitableObject {
+ public:
+ std::string DebugString() const override { return "fail"; }
+ StatusOr<std::pair<FdType, int>> AcquireFdForWait(
+ absl::Time deadline) override {
+ return InternalErrorBuilder(IREE_LOC) << "AlwaysFailingObject";
+ }
+ StatusOr<bool> TryResolveWakeOnFd(int fd) override {
+ return InternalErrorBuilder(IREE_LOC) << "AlwaysFailingObject";
+ }
+ };
+ static auto* obj = new AlwaysFailingObject();
+ return WaitHandle(add_ref(obj));
+}
+
+// static
+Status WaitHandle::WaitAll(WaitHandleSpan wait_handles, absl::Time deadline) {
+ if (wait_handles.empty()) return OkStatus();
+
+ // Build the list of pollfds to wait on.
+ ASSIGN_OR_RETURN(auto poll_fds, AcquireWaitHandles(wait_handles, deadline));
+
+ // Loop until all handles have been signaled or the deadline is exceeded.
+ int unsignaled_count = 0;
+ do {
+ int any_signaled_index = 0;
+ RETURN_IF_ERROR(MultiPoll(wait_handles, absl::MakeSpan(poll_fds), deadline,
+ &any_signaled_index, &unsignaled_count));
+ } while (unsignaled_count > 0 && absl::Now() < deadline);
+
+ if (unsignaled_count == 0) {
+ // All waits resolved.
+ return OkStatus();
+ } else {
+ // One or more were unsignaled.
+ return DeadlineExceededErrorBuilder(IREE_LOC);
+ }
+}
+
+// static
+StatusOr<bool> WaitHandle::TryWaitAll(WaitHandleSpan wait_handles) {
+ auto status = WaitAll(wait_handles, absl::InfinitePast());
+ if (status.ok()) {
+ return true;
+ } else if (IsDeadlineExceeded(status)) {
+ return false;
+ }
+ return status;
+}
+
+// static
+StatusOr<int> WaitHandle::WaitAny(WaitHandleSpan wait_handles,
+ absl::Time deadline) {
+ if (wait_handles.empty()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "At least one wait handle is required for WaitAny";
+ }
+
+ // Build the list of pollfds to wait on.
+ ASSIGN_OR_RETURN(auto poll_fds, AcquireWaitHandles(wait_handles, deadline));
+
+ // Poll once; this makes a WaitAny just a WaitMulti that doesn't loop.
+ int any_signaled_index = -1;
+ int unsignaled_count = 0;
+ RETURN_IF_ERROR(MultiPoll(wait_handles, absl::MakeSpan(poll_fds), deadline,
+ &any_signaled_index, &unsignaled_count));
+ if (any_signaled_index == -1) {
+ // No wait handles were valid. Pretend 0 was signaled.
+ return 0;
+ }
+ return any_signaled_index;
+}
+
+// static
+StatusOr<int> WaitHandle::TryWaitAny(WaitHandleSpan wait_handles) {
+ auto status_or = WaitAny(wait_handles, absl::InfinitePast());
+ return IsDeadlineExceeded(status_or.status()) ? -1 : status_or;
+}
+
+// Storage for static class variables; these won't be needed when we can use
+// c++17 everywhere.
+constexpr int WaitableObject::kInvalidFd;
+constexpr int WaitableObject::kSignaledFd;
+
+WaitHandle::WaitHandle(ref_ptr<WaitableObject> object)
+ : unique_id_(++next_unique_id_), object_(std::move(object)) {}
+
+WaitHandle::~WaitHandle() { Dispose(); }
+
+void WaitHandle::Dispose() { object_.reset(); }
+
+WaitHandle::WaitHandle(WaitHandle&& other)
+ : unique_id_(other.unique_id_), object_(std::move(other.object_)) {
+ other.unique_id_ = 0;
+}
+
+WaitHandle& WaitHandle::operator=(WaitHandle&& other) {
+ if (this != std::addressof(other)) {
+ // Close current handle.
+ Dispose();
+
+ // Take ownership of handle and resources.
+ object_ = std::move(other.object_);
+
+ other.unique_id_ = ++next_unique_id_;
+ }
+ return *this;
+}
+
+std::string WaitHandle::DebugString() const {
+ return object_ ? object_->DebugString() : absl::StrCat("wh_", unique_id_);
+}
+
+StatusOr<bool> WaitHandle::TryWait() {
+ auto status = WaitAll({this}, absl::InfinitePast());
+ if (status.ok()) {
+ return true;
+ } else if (IsDeadlineExceeded(status)) {
+ return false;
+ }
+ return status;
+}
+
+ManualResetEvent::ManualResetEvent(const char* debug_name)
+ : debug_name_(debug_name) {
+ Initialize();
+}
+
+ManualResetEvent::~ManualResetEvent() { Dispose(); }
+
+void ManualResetEvent::Initialize() {
+#if defined(IREE_HAS_EVENTFD)
+ // Create with an eventfd by default when we support it.
+ // eventfd has lower overhead than pipes (the syscalls are cheap).
+ // This usually will only fail if the system is completely out of handles.
+ //
+ // Docs: http://man7.org/linux/man-pages/man2/eventfd.2.html
+ fd_type_ = FdType::kEventFd;
+ fd_ = Syscall(::eventfd, 0, EFD_CLOEXEC | EFD_NONBLOCK).ValueOrDie();
+#elif defined(IREE_HAS_PIPE)
+ // Android/Linux/iOS-compatible POSIX pipe handle.
+ // Two handles are generated: one for transmitting and one for receiving.
+ //
+ // Docs: http://man7.org/linux/man-pages/man2/pipe.2.html
+ fd_type_ = FdType::kPipe;
+ int pipefd[2];
+ Syscall(::pipe, pipefd).ValueOrDie();
+ Syscall(::fcntl, pipefd[0], F_SETFL, O_NONBLOCK).ValueOrDie();
+ fd_ = pipefd[0];
+ write_fd_ = pipefd[1];
+#else
+// NOTE: sync_file does not use Notifier as they come from the kernel.
+#error "No fd-based sync primitive on this platform"
+#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE / etc
+}
+
+void ManualResetEvent::Dispose() {
+ if (fd_ != kInvalidFd) {
+ // Always signal, as we need to ensure waiters are woken.
+ CHECK_OK(Set());
+ Syscall(::close, fd_).ValueOrDie();
+ fd_ = kInvalidFd;
+ }
+ if (write_fd_ != kInvalidFd) {
+ Syscall(::close, write_fd_).ValueOrDie();
+ write_fd_ = kInvalidFd;
+ }
+}
+
+ManualResetEvent::ManualResetEvent(ManualResetEvent&& other)
+ : fd_type_(other.fd_type_),
+ fd_(other.fd_),
+ write_fd_(other.write_fd_),
+ debug_name_(other.debug_name_) {
+ other.fd_type_ = FdType::kPermanent;
+ other.fd_ = kInvalidFd;
+ other.write_fd_ = kInvalidFd;
+ other.debug_name_ = nullptr;
+}
+
+ManualResetEvent& ManualResetEvent::operator=(ManualResetEvent&& other) {
+ if (this != std::addressof(other)) {
+ Dispose();
+ fd_type_ = other.fd_type_;
+ fd_ = other.fd_;
+ write_fd_ = other.write_fd_;
+ debug_name_ = other.debug_name_;
+ other.fd_type_ = FdType::kPermanent;
+ other.fd_ = kInvalidFd;
+ other.write_fd_ = kInvalidFd;
+ other.debug_name_ = nullptr;
+ other.Initialize();
+ }
+ return *this;
+}
+
+std::string ManualResetEvent::DebugString() const {
+ if (debug_name_) {
+ return debug_name_;
+ }
+#if defined(IREE_HAS_EVENTFD)
+ return absl::StrCat("eventfd_", fd_);
+#elif defined(IREE_HAS_PIPE)
+ return absl::StrCat("pipe_", fd_, "_", write_fd_);
+#else
+ return absl::StrCat("unknown_", fd_, "_", write_fd_);
+#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE
+}
+
+Status ManualResetEvent::Set() {
+#if defined(IREE_HAS_EVENTFD)
+ return Syscall(::eventfd_write, fd_, 1ull).status();
+#elif defined(IREE_HAS_PIPE)
+ char buf = '\n';
+ return Syscall(::write, write_fd_, &buf, 1).status();
+#else
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "No fd-based sync primitive on this platform";
+#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE
+}
+
+Status ManualResetEvent::Reset() { return ClearFd(fd_type_, fd_); }
+
+WaitHandle ManualResetEvent::OnSet() { return WaitHandle(add_ref(this)); }
+
+} // namespace iree
diff --git a/base/wait_handle.h b/base/wait_handle.h
new file mode 100644
index 0000000..8cf8946
--- /dev/null
+++ b/base/wait_handle.h
@@ -0,0 +1,321 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_WAIT_HANDLE_H_
+#define IREE_BASE_WAIT_HANDLE_H_
+
+#include <atomic>
+#include <cstdint>
+#include <string>
+#include <utility>
+
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "base/ref_ptr.h"
+#include "base/status.h"
+#include "base/time.h"
+
+namespace iree {
+
+// Interfaces for waitable objects that can produce WaitHandles.
+// WaitableObjects are much like ::thread::Selectable, only they support both
+// the classic locking style as well as file descriptors for use with select().
+//
+// Usage:
+// class MyWaitableObject : public WaitableObject {
+// public:
+// std::string DebugString() const override { return "something useful"; }
+// WaitHandle OnAsyncTask() {
+// return WaitHandle(retain_ref(this));
+// }
+// private:
+// StatusOr<std::pair<FdType, int>> AcquireFdForWait(
+// absl::Time deadline) override {
+// // If blocking traditionally do so now and then return this:
+// return std::make_pair(FdType::kPermanent, kSignaledFd);
+// // Otherwise, see ManualResetEvent for an example using fds.
+// }
+// StatusOr<bool> TryResolveWakeOnFd(int fd) override {
+// // Return true iff the object is really acquired, such as the semaphore
+// // being decremented.
+// return true;
+// }
+// };
+class WaitableObject : public RefObject<WaitableObject> {
+ public:
+ // Indicates that a file descriptor is invalid. It will not block when waited
+ // upon.
+ constexpr static int kInvalidFd = -1;
+ // Indicates that a file descriptor should be treated as signaled.
+ // Waiting on this fd should return as if it has already been signaled.
+ constexpr static int kSignaledFd = -2;
+
+ // Defines the type of the native handle used for synchronization.
+ enum class FdType : uint16_t {
+ // Event has no handle and should be treated as permanently signaled.
+ kPermanent,
+
+ // Android/Linux/iOS-compatible POSIX pipe handle.
+ // Two handles are generated: one for transmitting and one for receiving.
+ //
+ // More information:
+ // http://man7.org/linux/man-pages/man2/pipe.2.html
+ kPipe,
+
+ // Android/Linux eventfd handle.
+ // These are akin to pipe() but require only a single handle and have
+ // significantly lower overhead (equivalent if not slightly better than
+ // pthreads condvars).
+ //
+ // eventfds support acting as both semaphores and auto reset events.
+ //
+ // More information:
+ // http://man7.org/linux/man-pages/man2/eventfd.2.html
+ kEventFd,
+
+ // Android/Linux sync_file handle (aka 'sync fence').
+ // The handle is allocated indirectly by the device driver via the
+ // <linux/sync_file.h> API. It may be waited upon with poll(), select(), or
+ // epoll() and must be closed with close() when no longer required. If
+ // waiting on multiple sync_files the caller should first merge them
+ // together.
+ //
+ // A sync_file must only be used as fences (one-shot manual reset events).
+ //
+ // More information:
+ // https://www.kernel.org/doc/Documentation/sync_file.txt
+ // https://lwn.net/Articles/702339/
+ // https://source.android.com/devices/graphics/implement-vsync#explicit_synchronization
+ kSyncFile,
+ };
+
+ virtual ~WaitableObject() = default;
+
+ // Returns a string representing the object, either specified as a debug_name
+ // or a unique ID.
+ virtual std::string DebugString() const = 0;
+
+ // Attempts to acquire a file descriptor for the waitable objects by the given
+ // |deadline|. In many cases this will return immediately with a valid fd.
+ //
+ // In cases where the file descriptor may not be available the call may block
+ // until either it is available or the |deadline| has elapsed. Use
+ // absl::InfinitePast() to prevent blocking.
+ //
+ // Returns a valid file descriptor or kInvalidFd as an indication that the
+ // object should not be waited on (already signaled, etc). Can return
+ // kSignaledFd to indicate that it's already known that the handle has been
+ // signaled and the caller should resolve as if it caused a wake normally.
+ virtual StatusOr<std::pair<FdType, int>> AcquireFdForWait(
+ absl::Time deadline) = 0;
+
+ // Tries to resolve the object with the given |fd|.
+ // In many cases this will no-op, however some types may require additional
+ // checks to ensure that the wait operation succeeded (such as semaphores
+ // that may need to query a count). If resolution fails the waitable object
+ // must not be considered signaled. This call will never block.
+ virtual StatusOr<bool> TryResolveWakeOnFd(int fd) = 0;
+};
+
+// Handle to waitable objects.
+// WaitHandles are created by a particular synchronization primitive, such as
+// Fence, as a way for one or more observers to poll or wait for notification.
+//
+// External synchronization primitives can be wrapped in WaitHandles to enable
+// other libraries or languages to be waited on alongside WaitHandles created
+// by the IREE primitives like Fence. See the notes on WaitHandleType for a list
+// of handle types that are supported.
+//
+// Wait handles are thread-safe in that multiple threads may be waiting on them
+// concurrently.
+class WaitHandle {
+ public:
+ // Returns a WaitHandle that when waited on will never block.
+ static WaitHandle AlwaysSignaling();
+
+ // Returns a WaitHandle that when waited on will always fail.
+ static WaitHandle AlwaysFailing();
+
+ using WaitHandleSpan = absl::Span<WaitHandle* const>;
+
+ // Blocks the caller until all passed |wait_handles| are signaled or the
+ // |deadline| elapses.
+ //
+ // Returns success if the wait is successful and all events have been
+ // signaled.
+ //
+ // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all handles
+ // having been signaled. Note that a subset of the |wait_handles| may have
+ // been signaled and each can be queried to see which one.
+ static Status WaitAll(WaitHandleSpan wait_handles, absl::Time deadline);
+ static Status WaitAll(WaitHandleSpan wait_handles, absl::Duration timeout) {
+ return WaitAll(wait_handles, RelativeTimeoutToDeadline(timeout));
+ }
+ static Status WaitAll(WaitHandleSpan wait_handles) {
+ return WaitAll(wait_handles, absl::InfiniteFuture());
+ }
+
+ // Tries waiting on the handles and returns immediately if it would have
+ // blocked. The caller will not be blocked even if a handle has not yet been
+ // signaled.
+ //
+ // Returns true if all handles have been signaled.
+ static StatusOr<bool> TryWaitAll(WaitHandleSpan wait_handles);
+
+ // Blocks the caller until at least one of the |wait_handles| is signaled or
+ // the |deadline| elapses.
+ //
+ // Returns the index into |wait_handles| of a handle that was signaled. Note
+ // that more than one handle may have been signaled and all of the other
+ // |wait_handles| should be queried or waited on again until waits for them
+ // succeed.
+ //
+ // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any handles
+ // having been signaled.
+ static StatusOr<int> WaitAny(WaitHandleSpan wait_handles,
+ absl::Time deadline);
+ static StatusOr<int> WaitAny(WaitHandleSpan wait_handles,
+ absl::Duration timeout) {
+ return WaitAny(wait_handles, RelativeTimeoutToDeadline(timeout));
+ }
+ static StatusOr<int> WaitAny(WaitHandleSpan wait_handles) {
+ return WaitAny(wait_handles, absl::InfiniteFuture());
+ }
+
+ // Tries waiting for at least one handle to complete and returns immediately
+ // if none have been. The caller will not be blocked even if a handle has not
+ // yet been signaled.
+ //
+ // Returns the index into |wait_handles| of a handle that was signaled. Note
+ // that more than one handle may have been signaled and all of the other
+ // |wait_handles| should be queried or waited on again until waits for them
+ // succeed.
+ //
+ // Returns -1 if no handles were signaled.
+ static StatusOr<int> TryWaitAny(WaitHandleSpan wait_handles);
+
+ // Default constructor creates a permanently signaled handle.
+ // Waiting on this handle will never block.
+ WaitHandle() = default;
+
+ // Wraps an existing sync file descriptor.
+ // Ownership of the file descriptor is transferred to the WaitHandle and must
+ // be duplicated by the caller if they want to continue using it.
+ explicit WaitHandle(ref_ptr<WaitableObject> object);
+
+ ~WaitHandle();
+
+ // Copying not supported. Create a new WaitHandle from the source.
+ WaitHandle(const WaitHandle&) = delete;
+ WaitHandle& operator=(const WaitHandle&) = delete;
+
+ // Moving supported; sync primitive ownership is transferred.
+ WaitHandle(WaitHandle&& other);
+ WaitHandle& operator=(WaitHandle&& other);
+
+ // Unique ID for the WaitHandle instance.
+ // Two wait handles, even if waiting on the same underlying primitive, will
+ // have differing unique_ids. This can be used for deduping the handles or
+ // storing handles in a map.
+ uint64_t unique_id() const { return unique_id_; }
+
+ // Returns a unique string representing the handle.
+ std::string DebugString() const;
+
+ // Blocks the caller until the handle is signaled or the |deadline| elapses.
+ //
+ // If waiting on multiple wait handles use WaitAll or WaitAny instead of
+ // multiple calls to Wait as they can significantly reduce overhead.
+ //
+ // Returns success if the wait is successful and the |wait_handle| was
+ // signaled. Returns DEADLINE_EXCEEDED if the timeout elapses without the
+ // handle having been signaled.
+ Status Wait(absl::Time deadline) { return WaitAll({this}, deadline); }
+ Status Wait(absl::Duration timeout) {
+ return WaitAll({this}, RelativeTimeoutToDeadline(timeout));
+ }
+ Status Wait() { return WaitAll({this}, absl::InfiniteFuture()); }
+
+ // Tries waiting on the handle and returns immediately if it would have
+ // waited. The caller will not be blocked even if the handle has not yet been
+ // signaled.
+ //
+ // Returns true if the handle has been signaled.
+ StatusOr<bool> TryWait();
+
+ // These accessors should generally be considered opaque but may be useful to
+ // code trying to interop with other runtimes.
+ const ref_ptr<WaitableObject>& object() const { return object_; }
+
+ private:
+ // Disposes the handle by closing the fd and issuing callbacks.
+ void Dispose();
+
+ static std::atomic<uint64_t> next_unique_id_;
+
+ uint64_t unique_id_ = 0;
+ ref_ptr<WaitableObject> object_;
+};
+
+// A manually-resettable event primitive.
+// Effectively a binary semaphore with a maximum_count of 1 when running in
+// auto-reset mode but also provides a sticky manual reset mode.
+class ManualResetEvent : public WaitableObject {
+ public:
+ explicit ManualResetEvent(const char* debug_name = nullptr);
+
+ ~ManualResetEvent() override;
+
+ // Copying not supported.
+ ManualResetEvent(const ManualResetEvent&) = delete;
+ ManualResetEvent& operator=(const ManualResetEvent&) = delete;
+
+ // Moving supported; sync primitive ownership is transferred.
+ ManualResetEvent(ManualResetEvent&& other);
+ ManualResetEvent& operator=(ManualResetEvent&& other);
+
+ std::string DebugString() const override;
+
+ // Sets the specified event object to the signaled state.
+ // The event stays signaled until Reset is called. Multiple waiters will be
+ // woken.
+ Status Set();
+
+ // Resets the specified event object to the nonsignaled state.
+ // Resetting an event that is already reset has no effect.
+ Status Reset();
+
+ // Returns a WaitHandle that will be signaled when the event is set.
+ WaitHandle OnSet();
+
+ protected:
+ void Initialize();
+ void Dispose();
+
+ StatusOr<std::pair<FdType, int>> AcquireFdForWait(
+ absl::Time deadline) override {
+ return std::make_pair(fd_type_, fd_);
+ }
+ StatusOr<bool> TryResolveWakeOnFd(int fd) override { return true; }
+
+ FdType fd_type_ = FdType::kPermanent;
+ int fd_ = kInvalidFd;
+ int write_fd_ = kInvalidFd; // Used only for fd_type_ == kPipe.
+ const char* debug_name_ = nullptr;
+};
+
+} // namespace iree
+
+#endif // IREE_BASE_WAIT_HANDLE_H_
diff --git a/base/wait_handle_test.cc b/base/wait_handle_test.cc
new file mode 100644
index 0000000..ec1f28e
--- /dev/null
+++ b/base/wait_handle_test.cc
@@ -0,0 +1,555 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "base/wait_handle.h"
+
+#include <unistd.h>
+
+#include <string>
+#include <thread> // NOLINT
+#include <type_traits>
+
+#include "absl/time/time.h"
+#include "base/status.h"
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+// StatusOr<bool> will be true if the status is ok, which is bad.
+#define ASSERT_STATUSOR_TRUE(x) ASSERT_TRUE(x.ValueOrDie())
+#define ASSERT_STATUSOR_FALSE(x) ASSERT_FALSE(x.ValueOrDie())
+
+namespace iree {
+namespace {
+
+using ::testing::_;
+using ::testing::Return;
+
+// Tests the AlwaysSignaling helper.
+TEST(WaitHandleTest, AlwaysSignaling) {
+ ASSERT_OK(WaitHandle::AlwaysSignaling().Wait());
+ EXPECT_FALSE(WaitHandle::AlwaysSignaling().DebugString().empty());
+}
+
+// Tests the AlwaysFailing helper.
+TEST(WaitHandleTest, AlwaysFailing) {
+ ASSERT_FALSE(WaitHandle::AlwaysFailing().Wait().ok());
+ EXPECT_FALSE(WaitHandle::AlwaysFailing().DebugString().empty());
+}
+
+// Tests the basic lifecycle of a permanently signaled wait handle.
+TEST(WaitHandleTest, LifecyclePermanentSignaled) {
+ // Just to be sure it's ok to safely no-op a WaitHandle value.
+ WaitHandle wh_never_used;
+ (void)wh_never_used;
+
+ // Try waiting; should return immediately.
+ WaitHandle wh0;
+ ASSERT_OK(wh0.Wait());
+
+ // Waits on multiple permanent handles should be ok.
+ WaitHandle wh1;
+ ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}));
+}
+
+// Tests moving permanent WaitHandles around.
+TEST(WaitHandleTest, MovePermanent) {
+ WaitHandle wh0;
+ WaitHandle wh1{std::move(wh0)};
+ WaitHandle wh2 = std::move(wh1);
+ wh1 = std::move(wh2);
+}
+
+// Tests moving around real handles (that may require closing).
+TEST(WaitHandleTest, MoveRealHandle) {
+ ManualResetEvent fence0;
+ WaitHandle wh0 = fence0.OnSet();
+ WaitHandle wh1{std::move(wh0)};
+ WaitHandle wh2 = std::move(wh1);
+ wh1 = std::move(wh2);
+
+ // Now overwrite the handle value to force a close.
+ ManualResetEvent fence1;
+ WaitHandle wh3 = fence1.OnSet();
+ wh1 = std::move(wh3);
+ wh1 = WaitHandle(); // Ensure handle dies first.
+}
+
+// Tests the various forms of waiting on a single WaitHandle.
+// Since these just call WaitAll we leave the involved testing to those.
+TEST(WaitHandleTest, SingleWait) {
+ WaitHandle wh;
+ ASSERT_OK(wh.Wait());
+ ASSERT_OK(wh.Wait(absl::Now() + absl::Seconds(1)));
+ ASSERT_OK(wh.Wait(absl::Seconds(1)));
+ ASSERT_STATUSOR_TRUE(wh.TryWait());
+}
+
+// Tests using WaitAll with no valid handles. This should no-op.
+TEST(WaitHandleTest, WaitAllNop) {
+ ASSERT_OK(WaitHandle::WaitAll({}));
+ ASSERT_OK(WaitHandle::WaitAll({nullptr}));
+ ASSERT_OK(WaitHandle::WaitAll({nullptr, nullptr}));
+}
+
+// Tests polling with WaitAll with multiple wait handles.
+TEST(WaitHandleTest, WaitAllPoll) {
+ ManualResetEvent fence0;
+ WaitHandle wh0 = fence0.OnSet();
+ ManualResetEvent fence1;
+ WaitHandle wh1 = fence1.OnSet();
+
+ // Poll; should return immediately with timeout.
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast())));
+
+ // Notify fence1.
+ ASSERT_OK(fence1.Set());
+
+ // Poll; should return immediately with timeout as fence1 is not signaled.
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast())));
+
+ // Notify fence0.
+ ASSERT_OK(fence0.Set());
+
+ // Poll again; should return immediately with success.
+ ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast()));
+}
+
+// Tests waiting when the first file handle is invalid. This is to verify a
+// workaround for bad poll() behavior with fds[0] == -1.
+TEST(WaitHandleTest, WaitAllWithInvalid0) {
+ ManualResetEvent fence;
+ WaitHandle wh = fence.OnSet();
+
+ // Poll; should return immediately with timeout as fence is not signaled.
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAll({nullptr, &wh}, absl::InfinitePast())));
+
+ // Notify fence.
+ ASSERT_OK(fence.Set());
+
+ // Poll again; should return immediately with success.
+ ASSERT_OK(WaitHandle::WaitAll({nullptr, &wh}, absl::InfinitePast()));
+}
+
+// Tests exceeding the timeout deadline with WaitAll.
+TEST(WaitHandleTest, WaitAllTimeout) {
+ ManualResetEvent fence;
+ WaitHandle wh = fence.OnSet();
+
+ // Wait with timeout on the unsignaled fence:
+ // Via polling (should never block):
+ ASSERT_TRUE(
+ IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::InfinitePast())));
+ ASSERT_STATUSOR_FALSE(WaitHandle::TryWaitAll({&wh}));
+ // Via time in the near future (should block):
+ ASSERT_TRUE(
+ IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(250))));
+ // Via time in the past, should exceed deadline.
+ ASSERT_TRUE(
+ IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(-250))));
+
+ // Notify and ensure no more timeouts.
+ ASSERT_OK(fence.Set());
+ ASSERT_OK(WaitHandle::WaitAll({&wh}, absl::InfinitePast()));
+ ASSERT_STATUSOR_TRUE(WaitHandle::TryWaitAll({&wh}));
+ ASSERT_OK(WaitHandle::WaitAll({&wh}, absl::Milliseconds(250)));
+
+ // Via time in the past, should exceed deadline even if signaled.
+ ASSERT_TRUE(
+ IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(-250))));
+}
+
+// Tests using WaitAll to wait on other threads.
+TEST(WaitHandleTest, WaitAllThreaded) {
+ // Spin up two threads.
+ ManualResetEvent fence0;
+ std::thread t0{[&]() {
+ ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250)));
+ ASSERT_OK(fence0.Set());
+ }};
+ ManualResetEvent fence1;
+ std::thread t1{[&]() {
+ ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250)));
+ ASSERT_OK(fence1.Set());
+ }};
+
+ // Wait on both threads to complete.
+ WaitHandle wh0 = fence0.OnSet();
+ WaitHandle wh1 = fence1.OnSet();
+ ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}));
+
+ t0.join();
+ t1.join();
+}
+
+// Tests using WaitAll with multiple wait handles from the same fence.
+TEST(WaitHandleTest, WaitAllSameSource) {
+ ManualResetEvent fence;
+ WaitHandle wh0 = fence.OnSet();
+ WaitHandle wh1 = fence.OnSet();
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast())));
+ ASSERT_OK(fence.Set());
+ ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}));
+}
+
+// Tests using WaitAll with literally the same wait handles.
+TEST(WaitHandleTest, WaitAllSameHandle) {
+ ManualResetEvent fence;
+ WaitHandle wh = fence.OnSet();
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAll({&wh, &wh}, absl::InfinitePast())));
+ ASSERT_OK(fence.Set());
+ ASSERT_OK(WaitHandle::WaitAll({&wh, &wh}));
+}
+
+// Tests WaitAll when a wait handle fails.
+TEST(WaitHandleTest, WaitAllFailure) {
+ WaitHandle good_wh;
+ // Create a purposefully bad handle to induce an error.
+ WaitHandle bad_wh = WaitHandle::AlwaysFailing();
+ // Should fail with some posixy error.
+ ASSERT_FALSE(WaitHandle::WaitAll({&good_wh, &bad_wh}).ok());
+}
+
+// Tests using WaitAny with no valid handles. This should no-op.
+TEST(WaitHandleTest, WaitAnyNop) {
+ ASSERT_TRUE(IsInvalidArgument(WaitHandle::WaitAny({}).status()));
+ ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({nullptr}));
+ ASSERT_EQ(0, index);
+ ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({nullptr, nullptr}));
+ ASSERT_EQ(0, index);
+}
+
+// Tests polling with WaitAny with multiple wait handles.
+TEST(WaitHandleTest, WaitAnyPoll) {
+ ManualResetEvent fence0;
+ WaitHandle wh0 = fence0.OnSet();
+ ManualResetEvent fence1;
+ WaitHandle wh1 = fence1.OnSet();
+
+ // Poll; should return immediately with timeout.
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status()));
+
+ // Notify fence1.
+ ASSERT_OK(fence1.Set());
+
+ // Poll; should return immediately with fence1 signaled.
+ ASSERT_OK_AND_ASSIGN(int index,
+ WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()));
+ EXPECT_EQ(1, index);
+
+ // Notify fence0.
+ ASSERT_OK(fence0.Set());
+
+ // Poll again; should return immediately; which one is signaled is undefined.
+ ASSERT_OK_AND_ASSIGN(index,
+ WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()));
+ ASSERT_TRUE(index == 0 || index == 1);
+}
+
+// Tests exceeding the timeout deadline with WaitAny.
+TEST(WaitHandleTest, WaitAnyTimeout) {
+ ManualResetEvent fence0;
+ WaitHandle wh0 = fence0.OnSet();
+ ManualResetEvent fence1;
+ WaitHandle wh1 = fence1.OnSet();
+
+ // Wait with timeout on the unsignaled fences:
+ // Via polling (should never block):
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status()));
+ ASSERT_OK_AND_ASSIGN(int index, WaitHandle::TryWaitAny({&wh0, &wh1}));
+ ASSERT_EQ(-1, index);
+ // Via time in the near future (should block):
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAny({&wh0, &wh1}, absl::Milliseconds(250)).status()));
+
+ // Notify one of the fences. Should return immediately.
+ ASSERT_OK(fence1.Set());
+ ASSERT_OK_AND_ASSIGN(index,
+ WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()));
+ ASSERT_EQ(1, index);
+ ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0, &wh1}));
+ ASSERT_EQ(1, index);
+ ASSERT_OK_AND_ASSIGN(
+ index, WaitHandle::WaitAny({&wh0, &wh1}, absl::Milliseconds(250)));
+ ASSERT_EQ(1, index);
+
+ // The unnotified fence should still timeout.
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAny({&wh0}, absl::InfinitePast()).status()));
+ ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0}));
+ ASSERT_EQ(-1, index);
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAny({&wh0}, absl::Milliseconds(250)).status()));
+
+ // Notify last fence and ensure complete.
+ ASSERT_OK(fence0.Set());
+ ASSERT_OK_AND_ASSIGN(index,
+ WaitHandle::WaitAny({&wh0}, absl::InfinitePast()));
+ ASSERT_EQ(0, index);
+ ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0}));
+ ASSERT_EQ(0, index);
+ ASSERT_OK_AND_ASSIGN(index,
+ WaitHandle::WaitAny({&wh0}, absl::Milliseconds(250)));
+ ASSERT_EQ(0, index);
+}
+
+// Tests using WaitAny to wait on other threads.
+TEST(WaitHandleTest, WaitAnyThreaded) {
+ // Spin up two threads.
+ // t1 will wait on t0 such that they will act in sequence.
+ ManualResetEvent fence0;
+ std::thread t0{[&]() {
+ ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250)));
+ ASSERT_OK(fence0.Set());
+ }};
+ ManualResetEvent fence1;
+ std::thread t1{[&]() {
+ ASSERT_OK(fence0.OnSet().Wait());
+ ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250)));
+ ASSERT_OK(fence1.Set());
+ }};
+
+ // Wait on both threads. We expect 0 to complete first.
+ WaitHandle wh0 = fence0.OnSet();
+ WaitHandle wh1 = fence1.OnSet();
+ ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh0, &wh1}));
+ ASSERT_EQ(0, index);
+
+ // Now wait for thread 1.
+ ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({&wh1}));
+ ASSERT_EQ(0, index);
+
+ t0.join();
+ t1.join();
+}
+
+// Tests using WaitAny with multiple wait handles from the same fence.
+TEST(WaitHandleTest, WaitAnySameSource) {
+ ManualResetEvent fence;
+ WaitHandle wh0 = fence.OnSet();
+ WaitHandle wh1 = fence.OnSet();
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status()));
+ ASSERT_OK(fence.Set());
+ ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh0, &wh1}));
+ ASSERT_TRUE(index == 0 || index == 1);
+}
+
+// Tests using WaitAny with literally the same wait handles.
+TEST(WaitHandleTest, WaitAnySameHandle) {
+ ManualResetEvent fence;
+ WaitHandle wh = fence.OnSet();
+ ASSERT_TRUE(IsDeadlineExceeded(
+ WaitHandle::WaitAny({&wh, &wh}, absl::InfinitePast()).status()));
+ ASSERT_OK(fence.Set());
+ ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh, &wh}));
+ ASSERT_TRUE(index == 0 || index == 1);
+}
+
+// Tests WaitAny when a wait handle fails.
+TEST(WaitHandleTest, WaitAnyFailure) {
+ WaitHandle good_wh;
+ // Create a purposefully bad handle to induce an error.
+ WaitHandle bad_wh = WaitHandle::AlwaysFailing();
+ // Should fail with some posixy error.
+ ASSERT_FALSE(WaitHandle::WaitAny({&good_wh, &bad_wh}).ok());
+}
+
+// ManualResetEvent with innards exposed. Meh.
+class ExposedManualResetEvent : public ManualResetEvent {
+ public:
+ using ManualResetEvent::AcquireFdForWait;
+ using ManualResetEvent::TryResolveWakeOnFd;
+};
+
+// Mock type for the WaitableObject methods.
+class MockWaitableObject : public ::testing::StrictMock<WaitableObject> {
+ public:
+ MockWaitableObject() : ::testing::StrictMock<WaitableObject>() {}
+
+ MOCK_CONST_METHOD0(DebugString, std::string());
+ MOCK_METHOD1(AcquireFdForWait,
+ StatusOr<std::pair<FdType, int>>(absl::Time deadline));
+ MOCK_METHOD1(TryResolveWakeOnFd, StatusOr<bool>(int fd));
+
+ WaitHandle OnSomething() { return WaitHandle(add_ref(this)); }
+};
+
+// Tests normal AcquireFdForWait + TryResolveWakeOnFd use.
+TEST(WaitableObjectTest, AcquireAndResolve) {
+ MockWaitableObject mwo;
+ WaitHandle wh = mwo.OnSomething();
+
+ // Use a MRE for testing, as we can just use its fd.
+ ExposedManualResetEvent mre;
+
+ // Try waiting; we should see the AcquireFdForWait and then return because
+ // the fd has not been resolved.
+ EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([&](absl::Time deadline) {
+ // Return the valid FD from the MRE.
+ return mre.AcquireFdForWait(deadline);
+ });
+ ASSERT_STATUSOR_FALSE(wh.TryWait());
+
+ // Signal the MRE.
+ ASSERT_OK(mre.Set());
+
+ // Try waiting again; we should get the AcquireFdForWait and then also get
+ // the TryResolveWakeOnFd.
+ EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([&](absl::Time deadline) {
+ // Return the valid (and now signaled) FD from the MRE.
+ return mre.AcquireFdForWait(deadline);
+ });
+ EXPECT_CALL(mwo, TryResolveWakeOnFd(_)).WillOnce(Return(true));
+ ASSERT_STATUSOR_TRUE(wh.TryWait());
+}
+
+// Tests timing out in AcquireFdForWait.
+TEST(WaitableObjectTest, AcquireFdForWaitTimeout) {
+ ManualResetEvent mre;
+ WaitHandle always_wait = mre.OnSet();
+ WaitHandle always_signal = WaitHandle::AlwaysSignaling();
+ MockWaitableObject mwo;
+ WaitHandle wh = mwo.OnSomething();
+
+ // Make the AcquireFdForWait take longer than the timeout. We should hit
+ // deadline exceeded even though always_wait hasn't be signaled.
+ EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([](absl::Time deadline) {
+ ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(10)));
+ return std::make_pair(WaitableObject::FdType::kPermanent,
+ WaitableObject::kInvalidFd);
+ });
+ ASSERT_TRUE(IsDeadlineExceeded(WaitHandle::WaitAll(
+ {&wh, &always_signal}, absl::Now() - absl::Milliseconds(250))));
+}
+
+// Tests TryResolveWakeOnFd when a handle is a permanent kSignaledFd.
+TEST(WaitableObjectTest, SignaledFd) {
+ MockWaitableObject mwo;
+ WaitHandle wh = mwo.OnSomething();
+
+ // Return the kSignaledFd handle and expect that we still get our notify call.
+ // We can do this multiple times.
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_CALL(mwo, AcquireFdForWait(_))
+ .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent,
+ WaitableObject::kSignaledFd)));
+ EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd))
+ .WillOnce(Return(true));
+ ASSERT_STATUSOR_TRUE(wh.TryWait());
+ }
+}
+
+// Tests that waiting will not resolve if TryResolveWakeOnFd returns false.
+TEST(WaitableObjectTest, UnresolvedWake) {
+ MockWaitableObject mwo;
+ WaitHandle wh = mwo.OnSomething();
+
+ // Fail to resolve the first time.
+ // Since we are only trying to wait it should bail.
+ EXPECT_CALL(mwo, AcquireFdForWait(_))
+ .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent,
+ WaitableObject::kSignaledFd)));
+ EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd))
+ .WillOnce(Return(false));
+ ASSERT_STATUSOR_FALSE(wh.TryWait());
+
+ // Resolve on the next try.
+ EXPECT_CALL(mwo, AcquireFdForWait(_))
+ .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent,
+ WaitableObject::kSignaledFd)));
+ EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd))
+ .WillOnce(Return(true));
+ ASSERT_STATUSOR_TRUE(wh.TryWait());
+}
+
+// Tests the normal lifecycle of a ManualResetEvent.
+TEST(ManualResetEventTest, Lifecycle) {
+ ManualResetEvent ev;
+ EXPECT_FALSE(ev.DebugString().empty());
+ WaitHandle wh0 = ev.OnSet();
+ EXPECT_EQ(ev.DebugString(), wh0.DebugString());
+ WaitHandle wh1 = ev.OnSet();
+ EXPECT_EQ(ev.DebugString(), wh1.DebugString());
+ // Should not be set.
+ ASSERT_STATUSOR_FALSE(wh0.TryWait());
+ ASSERT_STATUSOR_FALSE(wh1.TryWait());
+ // Set should be sticky.
+ ASSERT_OK(ev.Set());
+ ASSERT_STATUSOR_TRUE(wh0.TryWait());
+ ASSERT_STATUSOR_TRUE(wh1.TryWait());
+ // Reset should clear.
+ ASSERT_OK(ev.Reset());
+ ASSERT_STATUSOR_FALSE(wh0.TryWait());
+ ASSERT_STATUSOR_FALSE(wh1.TryWait());
+ // Setting again should enable the previous WaitHandles to be signaled.
+ ASSERT_OK(ev.Set());
+ ASSERT_STATUSOR_TRUE(wh0.TryWait());
+ ASSERT_STATUSOR_TRUE(wh1.TryWait());
+}
+
+// Tests moving ManualResetEvents around.
+TEST(ManualResetEventTest, Move) {
+ ManualResetEvent ev0;
+ WaitHandle wh = ev0.OnSet();
+ ManualResetEvent ev1{std::move(ev0)};
+ ManualResetEvent ev2 = std::move(ev1);
+ ev1 = std::move(ev2);
+ ASSERT_OK(ev1.Set());
+ ASSERT_STATUSOR_TRUE(wh.TryWait());
+}
+
+// Tests redundantly setting and resetting ManualResetEvents.
+TEST(ManualResetEventTest, RedundantUse) {
+ ManualResetEvent ev;
+ ASSERT_OK(ev.Reset());
+ ASSERT_OK(ev.Reset());
+ ASSERT_FALSE(ev.OnSet().TryWait().ValueOrDie());
+ ASSERT_OK(ev.Set());
+ ASSERT_OK(ev.Set());
+ ASSERT_TRUE(ev.OnSet().TryWait().ValueOrDie());
+ ASSERT_OK(ev.Reset());
+ ASSERT_FALSE(ev.OnSet().TryWait().ValueOrDie());
+}
+
+// Tests waiting on an initially-set ManualResetEvent;
+TEST(ManualResetEventTest, SetThenWait) {
+ ManualResetEvent ev;
+ ASSERT_OK(ev.Set());
+ ASSERT_TRUE(ev.OnSet().TryWait().ValueOrDie());
+}
+
+// Tests that dangling an event will not wake waiters.
+// This is intentional (for now); we could with a bit of wrangling make it so
+// that WaitableObjects tracked their waiters and ensured they were all cleaned
+// up, but that seems hard. Don't drop your objects.
+TEST(ManualResetEventTest, NeverSet) {
+ ManualResetEvent ev;
+ WaitHandle wh = ev.OnSet();
+ ASSERT_STATUSOR_FALSE(wh.TryWait());
+ // Kill event to unblock waiters.
+ ev = ManualResetEvent();
+ // Waiter should not have woken.
+ ASSERT_STATUSOR_FALSE(wh.TryWait());
+}
+
+} // namespace
+} // namespace iree
diff --git a/bindings/python/BUILD b/bindings/python/BUILD
new file mode 100644
index 0000000..793b93d
--- /dev/null
+++ b/bindings/python/BUILD
@@ -0,0 +1,25 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//:build_defs.google.bzl", "iree_py_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_py_library(
+ name = "pathsetup",
+ imports = ["."],
+)
diff --git a/iree/bindings/python/README.md b/bindings/python/README.md
similarity index 100%
rename from iree/bindings/python/README.md
rename to bindings/python/README.md
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD
new file mode 100644
index 0000000..151f5ee
--- /dev/null
+++ b/bindings/python/pyiree/BUILD
@@ -0,0 +1,103 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//:build_defs.google.bzl", "NUMPY_DEPS", "iree_py_extension")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+COMPILER_DEPS = [
+ "///compiler/Translation/Sequencer",
+ "///compiler/Translation/Interpreter",
+ "///compiler/Translation/SPIRV",
+]
+
+DRIVER_DEPS = [
+ "///hal/interpreter:interpreter_driver_module",
+ "///hal/vulkan:vulkan_driver_module",
+]
+
+iree_py_extension(
+ name = "binding",
+ srcs = [
+ "binding.cc",
+ "binding.h",
+ "compiler.cc",
+ "compiler.h",
+ "hal.cc",
+ "hal.h",
+ "initialize.cc",
+ "initialize.h",
+ "rt.cc",
+ "rt.h",
+ "status_utils.cc",
+ "status_utils.h",
+ "vm.cc",
+ "vm.h",
+ ],
+ copts = [
+ "-fexceptions",
+ ],
+ features = ["-use_header_modules"],
+ deps = [
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "///base:api",
+ "///base:init",
+ "///base:status",
+ "///hal:api",
+ "///rt:api",
+ "///schemas",
+ "///vm:api",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Parser",
+ "@iree_pybind11//:pybind11",
+ ] + COMPILER_DEPS + DRIVER_DEPS,
+)
+
+py_test(
+ name = "compiler_test",
+ srcs = ["compiler_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":binding",
+ "///bindings/python:pathsetup",
+ "@absl_py//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "hal_test",
+ srcs = ["hal_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":binding",
+ "///bindings/python:pathsetup",
+ "@absl_py//absl/testing:absltest",
+ ],
+)
+
+py_test(
+ name = "runtime_test",
+ srcs = ["runtime_test.py"],
+ python_version = "PY3",
+ deps = NUMPY_DEPS + [
+ ":binding",
+ "@absl_py//absl/testing:absltest",
+ ],
+)
diff --git a/bindings/python/pyiree/binding.cc b/bindings/python/pyiree/binding.cc
new file mode 100644
index 0000000..1cbdef7
--- /dev/null
+++ b/bindings/python/pyiree/binding.cc
@@ -0,0 +1,46 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/binding.h"
+
+#include "bindings/python/pyiree/compiler.h"
+#include "bindings/python/pyiree/hal.h"
+#include "bindings/python/pyiree/initialize.h"
+#include "bindings/python/pyiree/rt.h"
+#include "bindings/python/pyiree/status_utils.h"
+#include "bindings/python/pyiree/vm.h"
+
+namespace iree {
+namespace python {
+
+PYBIND11_MODULE(binding, m) {
+ m.doc() = "IREE Binding Backend Helpers";
+ py::class_<OpaqueBlob, std::shared_ptr<OpaqueBlob>>(m, "OpaqueBlob");
+ m.def("initialize_extension", &InitializeExtension);
+
+ auto compiler_m = m.def_submodule("compiler", "IREE compiler support");
+ SetupCompilerBindings(compiler_m);
+
+ auto hal_m = m.def_submodule("hal", "IREE HAL support");
+ SetupHalBindings(hal_m);
+
+ auto rt_m = m.def_submodule("rt", "IREE RT api");
+ SetupRtBindings(rt_m);
+
+ auto vm_m = m.def_submodule("vm", "IREE VM api");
+ SetupVmBindings(vm_m);
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/binding.h b/bindings/python/pyiree/binding.h
new file mode 100644
index 0000000..e470ee6
--- /dev/null
+++ b/bindings/python/pyiree/binding.h
@@ -0,0 +1,147 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
+
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "base/api.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/stl.h"
+
+namespace pybind11 {
+namespace detail {
+#if !defined(ABSL_HAVE_STD_OPTIONAL)
+// Make absl::optional act like the future C++17 optional for pybind11.
+// If ABSL_HAVE_STD_OPTIONAL is defined then absl::optional == std::optional
+// and the default type caster is sufficient.
+template <typename T>
+struct type_caster<absl::optional<T>> : optional_caster<absl::optional<T>> {};
+#endif
+} // namespace detail
+} // namespace pybind11
+
+namespace iree {
+namespace python {
+
+namespace py = pybind11;
+
+// Wrapper around a blob of memory.
+// Used to transport blobs back and forth between C++ and Python.
+class OpaqueBlob {
+ public:
+ OpaqueBlob() : data_(nullptr), size_(0) {}
+ OpaqueBlob(void* data, size_t size) : data_(data), size_(size) {}
+ virtual ~OpaqueBlob() = default;
+
+ void* data() { return data_; }
+ const void* data() const { return data_; }
+ size_t size() const { return size_; }
+
+ // Create a free function from the OpaqueBlob shared pointer.
+ using BufferFreeFn = void (*)(void* self, iree_byte_span_t);
+ static std::pair<BufferFreeFn, void*> CreateFreeFn(
+ std::shared_ptr<OpaqueBlob> blob) {
+ // Note that there are more efficient ways to write this which
+ // don't bounce through an extra heap alloc, but this is not
+ // intended to be a high impact code path.
+ struct Holder {
+ std::shared_ptr<OpaqueBlob> blob;
+ };
+ Holder* holder = new Holder{std::move(blob)};
+ auto free_fn = +([](void* self, iree_byte_span_t) {
+ Holder* self_holder = static_cast<Holder*>(self);
+ delete self_holder;
+ });
+ return {free_fn, holder};
+ }
+
+ protected:
+ void* data_;
+ size_t size_;
+};
+
+// Opaque blob that owns a vector.
+class OpaqueByteVectorBlob : public OpaqueBlob {
+ public:
+ OpaqueByteVectorBlob(std::vector<uint8_t> v)
+ : OpaqueBlob(), v_(std::move(v)) {
+ data_ = v_.data();
+ size_ = v_.size();
+ }
+
+ private:
+ std::vector<uint8_t> v_;
+};
+
+template <typename T>
+struct ApiPtrAdapter {};
+
+template <typename Self, typename T>
+class ApiRefCounted {
+ public:
+ ApiRefCounted() : instance_(nullptr) {}
+ ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) {
+ other.instance_ = nullptr;
+ }
+ void operator=(const ApiRefCounted&) = delete;
+
+ ~ApiRefCounted() { Release(); }
+
+ // Creates an instance of the ref counted wrapper based on an instance
+ // that has already been retained. Ownership is transferred to the
+ // wrapper.
+ static Self CreateRetained(T* retained_inst) {
+ auto self = Self();
+ self.instance_ = retained_inst;
+ return self;
+ }
+
+ // Creates a new instance, retaining the underlying object.
+ static Self RetainAndCreate(T* non_retained_inst) {
+ auto self = Self();
+ self.instance_ = non_retained_inst;
+ if (non_retained_inst) {
+ ApiPtrAdapter<T>::Retain(non_retained_inst);
+ }
+ return self;
+ }
+
+ T* raw_ptr() {
+ if (!instance_) {
+ throw std::invalid_argument("API object is null");
+ }
+ return instance_;
+ }
+ void Retain() {
+ if (instance_) {
+ ApiPtrAdapter<T>::Retain(instance_);
+ }
+ }
+ void Release() {
+ if (instance_) {
+ ApiPtrAdapter<T>::Release(instance_);
+ }
+ }
+
+ private:
+ T* instance_;
+};
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
diff --git a/bindings/python/pyiree/compiler.cc b/bindings/python/pyiree/compiler.cc
new file mode 100644
index 0000000..562f1c9
--- /dev/null
+++ b/bindings/python/pyiree/compiler.cc
@@ -0,0 +1,93 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/compiler.h"
+
+#include <stdexcept>
+
+#include "bindings/python/pyiree/binding.h"
+#include "bindings/python/pyiree/initialize.h"
+#include "bindings/python/pyiree/status_utils.h"
+#include "compiler/Translation/Sequencer/SequencerModuleTranslation.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Parser.h"
+#include "schemas/module_def_generated.h"
+
+namespace py = pybind11;
+
+using namespace mlir;
+using namespace mlir::iree_compiler;
+
+using llvm::MemoryBuffer;
+using llvm::MemoryBufferRef;
+using llvm::StringRef;
+
+namespace iree {
+namespace python {
+
+namespace {
+
+OwningModuleRef parseMLIRModuleFromString(StringRef contents,
+ MLIRContext* context) {
+ std::unique_ptr<MemoryBuffer> contents_buffer;
+ if (contents.back() == 0) {
+ // If it has a nul terminator, just use as-is.
+ contents_buffer = MemoryBuffer::getMemBuffer(contents.drop_back());
+ } else {
+ // Otherwise, make a copy.
+ contents_buffer = MemoryBuffer::getMemBufferCopy(contents, "EMBED");
+ }
+
+ llvm::SourceMgr source_mgr;
+ source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc());
+ OwningModuleRef mlir_module = parseSourceFile(source_mgr, context);
+ return mlir_module;
+}
+
+} // namespace
+
+std::shared_ptr<OpaqueBlob> CompileModuleFromAsm(const std::string& moduleAsm) {
+ InitializeExtension({});
+
+ MLIRContext context;
+
+ // Arrange to get a view that includes a terminating null to avoid additional
+ // copy.
+ const char* moduleAsmChars = moduleAsm.c_str();
+ StringRef moduleAsmSr(moduleAsmChars, moduleAsm.size() + 1);
+
+ // TODO(laurenzo): This error handling is super hoaky. Hook into the MLIR
+ // error reporter and plumb through properly.
+ OwningModuleRef mlirModule = parseMLIRModuleFromString(moduleAsmSr, &context);
+ if (!mlirModule) {
+ throw std::runtime_error("Failed to parse MLIR asm");
+ }
+
+ auto moduleBlob =
+ mlir::iree_compiler::translateMlirToIreeSequencerModule(mlirModule.get());
+ if (moduleBlob.empty()) {
+ throw std::runtime_error("Failed to translate MLIR module");
+ }
+ return std::make_shared<OpaqueByteVectorBlob>(std::move(moduleBlob));
+}
+
+void SetupCompilerBindings(pybind11::module m) {
+ m.def("compile_module_from_asm", CompileModuleFromAsm);
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/compiler.h b/bindings/python/pyiree/compiler.h
new file mode 100644
index 0000000..0bd6624
--- /dev/null
+++ b/bindings/python/pyiree/compiler.h
@@ -0,0 +1,30 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
+
+#include <string>
+
+#include "bindings/python/pyiree/binding.h"
+
+namespace iree {
+namespace python {
+
+void SetupCompilerBindings(pybind11::module m);
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
diff --git a/iree/bindings/python/pyiree/compiler_test.py b/bindings/python/pyiree/compiler_test.py
similarity index 100%
rename from iree/bindings/python/pyiree/compiler_test.py
rename to bindings/python/pyiree/compiler_test.py
diff --git a/bindings/python/pyiree/hal.cc b/bindings/python/pyiree/hal.cc
new file mode 100644
index 0000000..e7a59a2
--- /dev/null
+++ b/bindings/python/pyiree/hal.cc
@@ -0,0 +1,135 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/hal.h"
+
+#include "hal/api.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+class HalMappedMemory {
+ public:
+ HalMappedMemory(iree_hal_mapped_memory_t mapped_memory,
+ iree_hal_buffer_view_t* bv)
+ : mapped_memory_(mapped_memory), bv_(bv) {
+ iree_hal_buffer_view_retain(bv_);
+ }
+ ~HalMappedMemory() {
+ if (bv_) {
+ iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_);
+ CHECK_EQ(iree_hal_buffer_unmap(buffer, &mapped_memory_), IREE_STATUS_OK);
+ iree_hal_buffer_view_release(bv_);
+ }
+ }
+ HalMappedMemory(HalMappedMemory&& other)
+ : mapped_memory_(other.mapped_memory_), bv_(other.bv_) {
+ other.bv_ = nullptr;
+ }
+
+ static HalMappedMemory Create(HalBufferView& bv) {
+ iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv.raw_ptr());
+ iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer);
+ iree_hal_mapped_memory_t mapped_memory;
+ CheckApiStatus(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ,
+ 0 /* element_offset */, byte_length,
+ &mapped_memory),
+ "Could not map memory");
+ return HalMappedMemory(mapped_memory, bv.raw_ptr());
+ }
+
+ py::buffer_info ToBufferInfo() {
+ iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_);
+ iree_shape_t shape;
+ CheckApiStatus(iree_hal_buffer_view_shape(bv_, &shape),
+ "Error getting buffer view shape");
+ int8_t element_size = iree_hal_buffer_view_element_size(bv_);
+ iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer);
+ absl::InlinedVector<ssize_t, IREE_SHAPE_MAX_RANK> dims;
+ dims.resize(shape.rank);
+ for (int i = 0; i < shape.rank; ++i) {
+ dims[i] = shape.dims[i];
+ }
+ absl::InlinedVector<ssize_t, IREE_SHAPE_MAX_RANK> strides;
+ strides.resize(shape.rank);
+ for (int i = 1; i < shape.rank; ++i) {
+ strides[i - 1] = shape.dims[i] * element_size;
+ }
+ if (!strides.empty()) {
+ strides.back() = 1 * element_size;
+ }
+
+ // TODO(laurenzo): We need to figure out how to propagate dtype in the
+ // buffer view.
+ return py::buffer_info(
+ mapped_memory_.contents.data, element_size,
+ py::format_descriptor<float>::format(), // TODO(laurenzo): DTYPE!
+ shape.rank, dims, strides);
+ }
+
+ private:
+ iree_hal_mapped_memory_t mapped_memory_;
+ iree_hal_buffer_view_t* bv_;
+};
+
+} // namespace
+
+void SetupHalBindings(pybind11::module m) {
+ // Enums.
+ py::enum_<iree_hal_memory_type_t>(m, "MemoryType")
+ .value("NONE", IREE_HAL_MEMORY_TYPE_NONE)
+ .value("TRANSIENT", IREE_HAL_MEMORY_TYPE_TRANSIENT)
+ .value("HOST_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)
+ .value("HOST_COHERENT", IREE_HAL_MEMORY_TYPE_HOST_COHERENT)
+ .value("HOST_CACHED", IREE_HAL_MEMORY_TYPE_HOST_CACHED)
+ .value("HOST_LOCAL", IREE_HAL_MEMORY_TYPE_HOST_LOCAL)
+ .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)
+ .value("DEVICE_LOCAL", IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)
+ .export_values();
+ py::enum_<iree_hal_buffer_usage_t>(m, "BufferUsage")
+ .value("NONE", IREE_HAL_BUFFER_USAGE_NONE)
+ .value("CONSTANT", IREE_HAL_BUFFER_USAGE_CONSTANT)
+ .value("TRANSFER", IREE_HAL_BUFFER_USAGE_TRANSFER)
+ .value("MAPPING", IREE_HAL_BUFFER_USAGE_MAPPING)
+ .value("DISPATCH", IREE_HAL_BUFFER_USAGE_DISPATCH)
+ .value("ALL", IREE_HAL_BUFFER_USAGE_ALL)
+ .export_values();
+ py::enum_<iree_hal_memory_access_t>(m, "MemoryAccess")
+ .value("NONE", IREE_HAL_MEMORY_ACCESS_NONE)
+ .value("READ", IREE_HAL_MEMORY_ACCESS_READ)
+ .value("WRITE", IREE_HAL_MEMORY_ACCESS_WRITE)
+ .value("DISCARD", IREE_HAL_MEMORY_ACCESS_DISCARD)
+ .value("DISCARD_WRITE", IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE)
+ .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL)
+ .export_values();
+
+ py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector));
+ py::class_<HalBufferView>(m, "BufferView")
+ .def("map", HalMappedMemory::Create);
+ py::class_<HalMappedMemory>(m, "MappedMemory", py::buffer_protocol())
+ .def_buffer(&HalMappedMemory::ToBufferInfo);
+ py::class_<HalBuffer>(m, "Buffer")
+ .def_static("allocate_heap", &HalBuffer::AllocateHeapBuffer,
+ py::arg("memory_type"), py::arg("usage"),
+ py::arg("allocation_size"))
+ .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
+ py::arg("byte_length"))
+ .def("create_view", &HalBuffer::CreateView, py::arg("shape"),
+ py::arg("element_size"));
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/hal.h b/bindings/python/pyiree/hal.h
new file mode 100644
index 0000000..d26bcf0
--- /dev/null
+++ b/bindings/python/pyiree/hal.h
@@ -0,0 +1,97 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
+
+#include "bindings/python/pyiree/binding.h"
+#include "bindings/python/pyiree/status_utils.h"
+#include "hal/api.h"
+
+namespace iree {
+namespace python {
+
+template <>
+struct ApiPtrAdapter<iree_hal_buffer_t> {
+ static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); }
+ static void Release(iree_hal_buffer_t* b) { iree_hal_buffer_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_buffer_view_t> {
+ static void Retain(iree_hal_buffer_view_t* bv) {
+ iree_hal_buffer_view_retain(bv);
+ }
+ static void Release(iree_hal_buffer_view_t* bv) {
+ iree_hal_buffer_view_release(bv);
+ }
+};
+
+struct HalShape {
+ public:
+ static HalShape FromIntVector(std::vector<int32_t> indices) {
+ if (indices.size() > IREE_SHAPE_MAX_RANK) {
+ throw RaiseValueError("Shape exceeded maximum rank");
+ }
+ HalShape s;
+ s.s.rank = indices.size();
+ for (size_t i = 0, e = indices.size(); i < e; ++i) {
+ s.s.dims[i] = indices[i];
+ }
+ return s;
+ }
+
+ iree_shape_t s;
+};
+
+class HalBufferView
+ : public ApiRefCounted<HalBufferView, iree_hal_buffer_view_t> {
+ public:
+};
+
+class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> {
+ public:
+ static HalBuffer AllocateHeapBuffer(int32_t memory_type, int32_t usage,
+ iree_host_size_t allocation_size) {
+ iree_hal_buffer_t* buffer = nullptr;
+ CheckApiStatus(
+ iree_hal_heap_buffer_allocate(
+ static_cast<iree_hal_memory_type_t>(memory_type),
+ static_cast<iree_hal_buffer_usage_t>(usage), allocation_size,
+ IREE_ALLOCATOR_DEFAULT, IREE_ALLOCATOR_DEFAULT, &buffer),
+ "Error allocating heap buffer");
+ return HalBuffer::CreateRetained(buffer);
+ }
+
+ void FillZero(iree_device_size_t byte_offset,
+ iree_device_size_t byte_length) {
+ CheckApiStatus(iree_hal_buffer_zero(raw_ptr(), byte_offset, byte_length),
+ "Error zero filling buffer");
+ }
+
+ HalBufferView CreateView(HalShape& shape, size_t element_size) {
+ iree_hal_buffer_view_t* bv;
+ CheckApiStatus(iree_hal_buffer_view_create(raw_ptr(), shape.s, element_size,
+ IREE_ALLOCATOR_DEFAULT, &bv),
+ "Error creating buffer view");
+ return HalBufferView::CreateRetained(bv);
+ }
+};
+
+void SetupHalBindings(pybind11::module m);
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
diff --git a/iree/bindings/python/pyiree/hal_test.py b/bindings/python/pyiree/hal_test.py
similarity index 100%
rename from iree/bindings/python/pyiree/hal_test.py
rename to bindings/python/pyiree/hal_test.py
diff --git a/bindings/python/pyiree/initialize.cc b/bindings/python/pyiree/initialize.cc
new file mode 100644
index 0000000..acf5cf0
--- /dev/null
+++ b/bindings/python/pyiree/initialize.cc
@@ -0,0 +1,53 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/initialize.h"
+
+#include <string.h>
+
+#include <mutex> // NOLINT
+
+#include "base/init.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+void InternalInitialize(const std::vector<std::string>& arguments) {
+ int argc = arguments.size() + 1; // plus one for program name.
+ char** argv = static_cast<char**>(
+ malloc(sizeof(char*) * (argc + 1))); // plus one for null terminator.
+ char** orig_argv = argv;
+ argv[0] = strdup("<python_extension>");
+ for (int i = 1; i < argc; ++i) {
+ argv[i] = strdup(arguments[i - 1].c_str());
+ }
+ argv[argc] = nullptr;
+ InitializeEnvironment(&argc, &argv);
+ for (int i = 0; i < argc; ++i) {
+ free(argv[i]);
+ }
+ free(orig_argv);
+}
+
+} // namespace
+
+void InitializeExtension(const std::vector<std::string>& arguments) {
+ static std::once_flag init_once;
+ std::call_once(init_once, InternalInitialize, arguments);
+}
+
+} // namespace python
+} // namespace iree
diff --git a/iree/bindings/python/pyiree/initialize.h b/bindings/python/pyiree/initialize.h
similarity index 100%
rename from iree/bindings/python/pyiree/initialize.h
rename to bindings/python/pyiree/initialize.h
diff --git a/bindings/python/pyiree/rt.cc b/bindings/python/pyiree/rt.cc
new file mode 100644
index 0000000..b683d02
--- /dev/null
+++ b/bindings/python/pyiree/rt.cc
@@ -0,0 +1,150 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/rt.h"
+
+#include "base/api.h"
+#include "bindings/python/pyiree/status_utils.h"
+#include "hal/api.h"
+
+namespace iree {
+namespace python {
+
+HalBufferView RtContext::WrapPyBufferForInput(py::buffer py_buffer) {
+ auto py_buffer_info = py_buffer.request(false /* writable */);
+ if (py_buffer_info.ndim > IREE_SHAPE_MAX_RANK || py_buffer_info.ndim < 0) {
+ RaiseValueError("Unsupported buffer rank");
+ }
+ if (py_buffer_info.size < 0) {
+ RaiseValueError("Illegal buffer size");
+ }
+
+ // For the moment, allocate a device visible buffer of equivalent size and
+ // copy into it.
+ // TODO(laurenzo): Once sequencer is in place, switch to HeapBuffer, wrap
+ // and retain the original buffer.
+ iree_host_size_t byte_size = py_buffer_info.size * py_buffer_info.itemsize;
+ HalBuffer buffer =
+ AllocateDeviceVisible(byte_size, IREE_HAL_BUFFER_USAGE_CONSTANT |
+ IREE_HAL_BUFFER_USAGE_TRANSFER |
+ IREE_HAL_BUFFER_USAGE_DISPATCH);
+ CheckApiStatus(iree_hal_buffer_write_data(buffer.raw_ptr(), 0,
+ py_buffer_info.ptr, byte_size),
+ "Error writing to input buffer");
+
+ // Create the buffer view.
+ // TODO(laurenzo): This does no validation on dtype and only cares if the
+ // elementsize matches. Figure out where to enforce actual dtype.
+ iree_shape_t shape;
+ shape.rank = py_buffer_info.ndim;
+
+ // Verify strides are row-major.
+ // TODO(laurenzo): Test this with rank>1.
+ for (int i = 1; i < shape.rank; ++i) {
+ if ((py_buffer_info.strides[i - 1] * py_buffer_info.itemsize) !=
+ py_buffer_info.shape[i]) {
+ RaiseValueError("Expected row-major layout");
+ }
+ }
+ if (!py_buffer_info.strides.empty()) {
+ if (py_buffer_info.strides.back() != 1) {
+ RaiseValueError("Expected row-major layout");
+ }
+ }
+
+ // Populate shape.
+ for (int i = 0; i < shape.rank; ++i) {
+ ssize_t dim = py_buffer_info.shape[i];
+ if (dim < 0) {
+ RaiseValueError("Unsupported negative dim");
+ }
+ shape.dims[i] = dim;
+ }
+
+ iree_hal_buffer_view_t* bv;
+ CheckApiStatus(iree_hal_buffer_view_create(buffer.raw_ptr(), shape,
+ py_buffer_info.itemsize,
+ IREE_ALLOCATOR_DEFAULT, &bv),
+ "Error allocating buffer view");
+
+ return HalBufferView::CreateRetained(bv);
+}
+
+void SetupRtBindings(pybind11::module m) {
+ // BufferPlacement.
+ py::enum_<BufferPlacement>(m, "BufferPlacement")
+ .value("HEAP", BufferPlacement::kHeap)
+ .value("DEVICE_VISIBLE", BufferPlacement::kDeviceVisible)
+ .value("DEVICE_LOCAL", BufferPlacement::kDeviceLocal)
+ .export_values();
+
+ // RtModule.
+ py::class_<RtModule>(m, "Module")
+ .def_property_readonly("name", &RtModule::name)
+ .def("lookup_function_by_ordinal", &RtModule::lookup_function_by_ordinal)
+ .def("lookup_function_by_name", &RtModule::lookup_function_by_name);
+ // RtFunction.
+ py::class_<RtFunction>(m, "Function")
+ .def_property_readonly("name", &RtFunction::name)
+ .def_property_readonly("signature", &RtFunction::signature);
+ py::class_<iree_rt_function_signature_t>(m, "FunctionSignature")
+ .def_readonly("argument_count",
+ &iree_rt_function_signature_t::argument_count)
+ .def_readonly("result_count",
+ &iree_rt_function_signature_t::result_count);
+
+ // RtPolicy.
+ py::class_<RtPolicy>(m, "Policy").def(py::init(&RtPolicy::Create));
+
+ // RtInstance.
+ py::class_<RtInstance>(m, "Instance")
+ .def(py::init(&RtInstance::Create),
+ py::arg_v("driver_name", absl::optional<std::string>()));
+
+ // RtContext.
+ py::class_<RtContext>(m, "Context")
+ .def(py::init(&RtContext::Create), py::arg("instance"), py::arg("policy"))
+ .def_property_readonly("context_id", &RtContext::context_id)
+ .def("register_modules", &RtContext::RegisterModules, py::arg("modules"))
+ .def("register_module", &RtContext::RegisterModule, py::arg("module"))
+ .def("lookup_module_by_name", &RtContext::LookupModuleByName,
+ py::arg("name"))
+ .def("resolve_function", &RtContext::ResolveFunction,
+ py::arg("full_name"))
+ .def("allocate", &RtContext::Allocate, py::arg("allocation_size"),
+ py::arg("placement") = BufferPlacement::kHeap,
+ py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL)
+ .def("allocate_device_local", &RtContext::AllocateDeviceLocal,
+ py::arg("allocation_size"),
+ py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL)
+ .def("allocate_device_visible", &RtContext::AllocateDeviceVisible,
+ py::arg("allocation_size"),
+ py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL)
+ .def("wrap_for_input", &RtContext::WrapPyBufferForInput, py::arg("v"))
+ .def("invoke", &RtContext::Invoke, py::arg("f"), py::arg("policy"),
+ py::arg("arguments"),
+ py::arg("results") = absl::optional<std::vector<HalBufferView*>>());
+
+ // RtInvocation.
+ py::class_<RtInvocation>(m, "Invocation")
+ .def("query_status", &RtInvocation::QueryStatus)
+ .def("await", &RtInvocation::Await,
+ py::arg("deadline") = IREE_TIME_INFINITE_FUTURE)
+ .def("await_optional", &RtInvocation::AwaitOptional,
+ py::arg("deadline") = IREE_TIME_INFINITE_FUTURE)
+ .def_property_readonly("results", &RtInvocation::ConsumeResults);
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/rt.h b/bindings/python/pyiree/rt.h
new file mode 100644
index 0000000..85e85da
--- /dev/null
+++ b/bindings/python/pyiree/rt.h
@@ -0,0 +1,390 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_RT_H_
+
+#include "absl/container/inlined_vector.h"
+#include "base/api.h"
+#include "bindings/python/pyiree/binding.h"
+#include "bindings/python/pyiree/hal.h"
+#include "bindings/python/pyiree/initialize.h"
+#include "bindings/python/pyiree/status_utils.h"
+#include "hal/api.h"
+#include "rt/api.h"
+
+namespace iree {
+namespace python {
+
+// When creating a buffer via the context, switch between the different
+// allocation entry-points via an enum (these are separate functions in the
+// C API).
+enum class BufferPlacement {
+ kHeap,
+ kDeviceVisible,
+ kDeviceLocal,
+};
+
+// Adapts API pointer access to retain/release API calls.
+template <>
+struct ApiPtrAdapter<iree_rt_module_t> {
+ static void Retain(iree_rt_module_t* m) { iree_rt_module_retain(m); }
+ static void Release(iree_rt_module_t* m) { iree_rt_module_release(m); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_rt_instance_t> {
+ static void Retain(iree_rt_instance_t* inst) {
+ iree_rt_instance_retain(inst);
+ }
+ static void Release(iree_rt_instance_t* inst) {
+ iree_rt_instance_release(inst);
+ }
+};
+
+template <>
+struct ApiPtrAdapter<iree_rt_policy_t> {
+ static void Retain(iree_rt_policy_t* p) { iree_rt_policy_retain(p); }
+ static void Release(iree_rt_policy_t* p) { iree_rt_policy_release(p); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_rt_context_t> {
+ static void Retain(iree_rt_context_t* c) { iree_rt_context_retain(c); }
+ static void Release(iree_rt_context_t* c) { iree_rt_context_release(c); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_rt_invocation_t> {
+ static void Retain(iree_rt_invocation_t* inv) {
+ iree_rt_invocation_retain(inv);
+ }
+ static void Release(iree_rt_invocation_t* inv) {
+ iree_rt_invocation_release(inv);
+ }
+};
+
+// Wrapper classes. These mirror the Python declarations.
+class RtFunction {
+ public:
+ // Note that this will retain the module.
+ RtFunction(iree_rt_function_t function) : function_(function) {
+ iree_rt_module_retain(function_.module);
+ }
+ ~RtFunction() {
+ if (function_.module) iree_rt_module_release(function_.module);
+ }
+ RtFunction(RtFunction&& other) : function_(other.function_) {
+ other.function_.module = nullptr;
+ }
+ void operator=(const RtFunction&) = delete;
+
+ std::string name() {
+ auto sv = iree_rt_function_name(&function_);
+ return std::string(sv.data, sv.size);
+ }
+
+ iree_rt_function_signature_t signature() {
+ iree_rt_function_signature_t sig;
+ CheckApiStatus(iree_rt_function_signature(&function_, &sig),
+ "Error getting function signature");
+ return sig;
+ }
+
+ iree_rt_function_t& raw_function() { return function_; }
+
+ private:
+ iree_rt_function_t function_;
+};
+
+class RtModule : public ApiRefCounted<RtModule, iree_rt_module_t> {
+ public:
+ std::string name() {
+ auto sv = iree_rt_module_name(raw_ptr());
+ return std::string(sv.data, sv.size);
+ }
+
+ absl::optional<RtFunction> lookup_function_by_ordinal(int32_t ordinal) {
+ iree_rt_function_t f;
+ // TODO(laurenzo): Support an optional linkage argument.
+ auto module = raw_ptr();
+ auto status = iree_rt_module_lookup_function_by_ordinal(
+ module, IREE_RT_FUNCTION_LINKAGE_EXPORT, ordinal, &f);
+ if (status == IREE_STATUS_NOT_FOUND) {
+ return absl::optional<RtFunction>();
+ }
+ CheckApiStatus(status, "Error looking up function");
+ return RtFunction(f);
+ }
+
+ absl::optional<RtFunction> lookup_function_by_name(const std::string& name) {
+ iree_rt_function_t f;
+ // TODO(laurenzo): Support an optional linkage argument.
+ auto module = raw_ptr();
+ iree_string_view_t name_sv{name.data(), name.size()};
+ auto status = iree_rt_module_lookup_function_by_name(
+ module, IREE_RT_FUNCTION_LINKAGE_EXPORT, name_sv, &f);
+ if (status == IREE_STATUS_NOT_FOUND) {
+ return absl::optional<RtFunction>();
+ }
+ CheckApiStatus(status, "Error looking up function");
+ return RtFunction(f);
+ }
+};
+
+class RtInstance : public ApiRefCounted<RtInstance, iree_rt_instance_t> {
+ public:
+ // TODO(laurenzo): Support optional allocator argument.
+ static RtInstance Create(absl::optional<std::string> driver_name) {
+ InitializeExtension({});
+ iree_rt_instance_t* raw_inst;
+ CheckApiStatus(iree_rt_instance_create(IREE_ALLOCATOR_DEFAULT, &raw_inst),
+ "Error creating instance");
+ RtInstance inst = RtInstance::CreateRetained(raw_inst);
+
+ if (!driver_name) {
+ driver_name = "interpreter";
+ }
+ CheckApiStatus(iree_rt_instance_register_driver_ex(
+ raw_inst, iree_string_view_t{driver_name->c_str(),
+ driver_name->size()}),
+ "Error registering drivers");
+
+ return inst;
+ }
+};
+
+class RtPolicy : public ApiRefCounted<RtPolicy, iree_rt_policy_t> {
+ public:
+ // TODO(laurenzo): Support optional allocator argument.
+ static RtPolicy Create() {
+ iree_rt_policy_t* policy;
+ CheckApiStatus(iree_rt_policy_create(IREE_ALLOCATOR_DEFAULT, &policy),
+ "Error creating policy");
+ return RtPolicy::CreateRetained(policy);
+ }
+};
+
+class RtInvocation : public ApiRefCounted<RtInvocation, iree_rt_invocation_t> {
+ public:
+ // Returns whether ready.
+ // Raises exception on error.
+ bool QueryStatus() {
+ auto status = iree_rt_invocation_query_status(raw_ptr());
+ if (status == IREE_STATUS_OK) {
+ return true;
+ } else if (status == IREE_STATUS_UNAVAILABLE) {
+ return false;
+ } else {
+ CheckApiStatus(status, "Error in function invocation");
+ return false;
+ }
+ }
+
+ // TODO(laurenzo): Convert to the pybind chrono support.
+ // Returns whether the invocation is ready.
+ bool AwaitOptional(iree_time_t epoch_nanos_deadline) {
+ auto status = iree_rt_invocation_await(raw_ptr(), epoch_nanos_deadline);
+ if (status == IREE_STATUS_OK) {
+ return true;
+ } else if (status == IREE_STATUS_DEADLINE_EXCEEDED) {
+ return false;
+ } else {
+ CheckApiStatus(status, "Error in invocation");
+ return false;
+ }
+ }
+
+ // Similar to AwaitOptional but will raise an error unless if the status
+ // is ready.
+ void Await(iree_time_t epoch_nanos_deadline) {
+ if (!AwaitOptional(epoch_nanos_deadline)) {
+ RaiseValueError("Deadline expired");
+ }
+ }
+
+ std::vector<HalBufferView> ConsumeResults() {
+ static constexpr size_t kInlineSize = 8;
+ iree_host_size_t result_count;
+ absl::InlinedVector<iree_hal_buffer_view_t*, kInlineSize> result_bvs;
+ result_bvs.resize(kInlineSize);
+ auto status = iree_rt_invocation_consume_results(
+ raw_ptr(), kInlineSize, IREE_ALLOCATOR_DEFAULT, &result_bvs[0],
+ &result_count);
+ if (status == IREE_STATUS_OUT_OF_RANGE) {
+ // Resize/retry.
+ result_bvs.resize(result_count);
+ status = iree_rt_invocation_consume_results(
+ raw_ptr(), result_count, IREE_ALLOCATOR_DEFAULT, &result_bvs[0],
+ &result_count);
+ }
+ CheckApiStatus(status, "Error consuming invocation results");
+ result_bvs.resize(result_count);
+ std::vector<HalBufferView> results;
+ for (auto* raw_bv : result_bvs) {
+ results.push_back(HalBufferView::CreateRetained(raw_bv));
+ }
+ return results;
+ }
+};
+
+class RtContext : public ApiRefCounted<RtContext, iree_rt_context_t> {
+ public:
+ static RtContext Create(RtInstance* instance, RtPolicy* policy) {
+ iree_rt_context_t* context;
+ // TODO(laurenzo): Support optional allocator argument.
+ CheckApiStatus(
+ iree_rt_context_create(instance->raw_ptr(), policy->raw_ptr(),
+ IREE_ALLOCATOR_DEFAULT, &context),
+ "Error creating instance");
+ return RtContext::CreateRetained(context);
+ }
+
+ int context_id() { return iree_rt_context_id(raw_ptr()); }
+
+ void RegisterModules(std::vector<RtModule*> modules) {
+ std::vector<iree_rt_module_t*> module_raw_ptrs;
+ module_raw_ptrs.resize(modules.size());
+ for (size_t i = 0, e = modules.size(); i < e; ++i) {
+ auto module_raw_ptr = modules[i]->raw_ptr();
+ module_raw_ptrs[i] = module_raw_ptr;
+ }
+ CheckApiStatus(
+ iree_rt_context_register_modules(raw_ptr(), module_raw_ptrs.data(),
+ module_raw_ptrs.size()),
+ "Error registering modules");
+ }
+
+ void RegisterModule(RtModule* module) {
+ iree_rt_module_t* module_raw_ptr = module->raw_ptr();
+ CheckApiStatus(
+ iree_rt_context_register_modules(raw_ptr(), &module_raw_ptr, 1),
+ "Error registering module");
+ }
+
+ absl::optional<RtModule> LookupModuleByName(const std::string& name) {
+ iree_rt_module_t* module = iree_rt_context_lookup_module_by_name(
+ raw_ptr(), {name.data(), name.size()});
+ if (!module) {
+ return absl::optional<RtModule>();
+ }
+ return RtModule::RetainAndCreate(module);
+ }
+
+ absl::optional<RtFunction> ResolveFunction(const std::string& full_name) {
+ iree_rt_function_t f;
+ auto status = iree_rt_context_resolve_function(
+ raw_ptr(), {full_name.data(), full_name.size()}, &f);
+ if (status == IREE_STATUS_NOT_FOUND) {
+ return absl::optional<RtFunction>();
+ }
+ CheckApiStatus(status, "Error resolving function");
+ return RtFunction(f);
+ }
+
+ // Convenience method to allocate host, device-visible or device-local
+ // buffers.
+ HalBuffer Allocate(iree_host_size_t allocation_size,
+ BufferPlacement placement, int32_t usage) {
+ iree_hal_buffer_t* raw_buffer = nullptr;
+ switch (placement) {
+ case BufferPlacement::kHeap:
+ // Even though allocating a heap buffer does not require the context,
+ // provide it here to make the API easier to navigate.
+ CheckApiStatus(
+ iree_hal_heap_buffer_allocate(
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL,
+ static_cast<iree_hal_buffer_usage_t>(usage), allocation_size,
+ IREE_ALLOCATOR_DEFAULT, IREE_ALLOCATOR_DEFAULT, &raw_buffer),
+ "Error allocating heap buffer");
+ break;
+ case BufferPlacement::kDeviceLocal:
+ CheckApiStatus(
+ iree_rt_context_allocate_device_local_buffer(
+ raw_ptr(), static_cast<iree_hal_buffer_usage_t>(usage),
+ allocation_size, IREE_ALLOCATOR_DEFAULT, &raw_buffer),
+ "Error allocating device local buffer");
+ break;
+ case BufferPlacement::kDeviceVisible:
+ CheckApiStatus(
+ iree_rt_context_allocate_device_visible_buffer(
+ raw_ptr(), static_cast<iree_hal_buffer_usage_t>(usage),
+ allocation_size, IREE_ALLOCATOR_DEFAULT, &raw_buffer),
+ "Error allocating device visible buffer");
+ break;
+ default:
+ throw RaiseValueError("Unknown BufferPlacement");
+ }
+
+ return HalBuffer::CreateRetained(raw_buffer);
+ }
+
+ HalBuffer AllocateHeap(iree_host_size_t allocation_size, int32_t usage) {
+ return Allocate(allocation_size, BufferPlacement::kHeap, usage);
+ }
+
+ HalBuffer AllocateDeviceLocal(iree_host_size_t allocation_size,
+ int32_t usage) {
+ return Allocate(allocation_size, BufferPlacement::kDeviceLocal, usage);
+ }
+
+ HalBuffer AllocateDeviceVisible(iree_host_size_t allocation_size,
+ int32_t usage) {
+ return Allocate(allocation_size, BufferPlacement::kDeviceVisible, usage);
+ }
+
+ // One stop convenience method for wrapping a python buffer protocol buffer
+ // for input to a function. At the runtime's discretion, this may make a copy
+ // or do something smarter, meaning the data in the backing python buffer
+ // will either be accessed immediately or at some future point.
+ HalBufferView WrapPyBufferForInput(py::buffer py_buffer);
+
+ RtInvocation Invoke(RtFunction& f, RtPolicy& policy,
+ std::vector<HalBufferView*> arguments,
+ absl::optional<std::vector<HalBufferView*>> results) {
+ absl::InlinedVector<iree_hal_buffer_view_t*, 8> raw_arguments;
+ raw_arguments.resize(arguments.size());
+ for (size_t i = 0, e = arguments.size(); i < e; ++i) {
+ auto inst = arguments[i];
+ CheckApiNotNull(inst, "Argument buffer view cannot be None");
+ raw_arguments[i] = inst->raw_ptr();
+ }
+ absl::InlinedVector<iree_hal_buffer_view_t*, 8> raw_results;
+ if (results) {
+ raw_results.resize(results->size());
+ for (size_t i = 0, e = results->size(); i < e; ++i) {
+ auto inst = (*results)[i];
+ CheckApiNotNull(inst, "Result buffer view cannot be None");
+ raw_results[i] = inst->raw_ptr();
+ }
+ }
+
+ iree_rt_invocation_t* invocation;
+ CheckApiStatus(iree_rt_invocation_create(
+ raw_ptr(), &f.raw_function(), policy.raw_ptr(),
+ nullptr /* dependencies */, raw_arguments.data(),
+ raw_arguments.size(), raw_results.data(),
+ raw_results.size(), IREE_ALLOCATOR_DEFAULT, &invocation),
+ "Error invoking function");
+
+ return RtInvocation::CreateRetained(invocation);
+ }
+};
+
+void SetupRtBindings(pybind11::module m);
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_H_
diff --git a/iree/bindings/python/pyiree/runtime_test.py b/bindings/python/pyiree/runtime_test.py
similarity index 100%
rename from iree/bindings/python/pyiree/runtime_test.py
rename to bindings/python/pyiree/runtime_test.py
diff --git a/bindings/python/pyiree/status_utils.cc b/bindings/python/pyiree/status_utils.cc
new file mode 100644
index 0000000..63f2131
--- /dev/null
+++ b/bindings/python/pyiree/status_utils.cc
@@ -0,0 +1,72 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/status_utils.h"
+
+#include "absl/strings/str_cat.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+PyObject* StatusToPyExcClass(const Status& status) {
+ switch (status.code()) {
+ case StatusCode::kInvalidArgument:
+ return PyExc_ValueError;
+ case StatusCode::kOutOfRange:
+ return PyExc_IndexError;
+ case StatusCode::kUnimplemented:
+ return PyExc_NotImplementedError;
+ default:
+ return PyExc_RuntimeError;
+ }
+}
+
+PyObject* ApiStatusToPyExcClass(iree_status_t status) {
+ switch (status) {
+ case IREE_STATUS_INVALID_ARGUMENT:
+ return PyExc_ValueError;
+ case IREE_STATUS_OUT_OF_RANGE:
+ return PyExc_IndexError;
+ case IREE_STATUS_UNIMPLEMENTED:
+ return PyExc_NotImplementedError;
+ default:
+ return PyExc_RuntimeError;
+ }
+}
+
+} // namespace
+
+pybind11::error_already_set StatusToPyExc(const Status& status) {
+ assert(!status.ok());
+ PyErr_SetString(StatusToPyExcClass(status), status.error_message().c_str());
+ return pybind11::error_already_set();
+}
+
+pybind11::error_already_set ApiStatusToPyExc(iree_status_t status,
+ const char* message) {
+ assert(status != IREE_STATUS_OK);
+ auto full_message = absl::StrCat(message, ": ", static_cast<int>(status));
+ PyErr_SetString(ApiStatusToPyExcClass(status), full_message.c_str());
+ return pybind11::error_already_set();
+}
+
+pybind11::error_already_set RaiseValueError(const char* message) {
+ PyErr_SetString(PyExc_ValueError, message);
+ return pybind11::error_already_set();
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/status_utils.h b/bindings/python/pyiree/status_utils.h
new file mode 100644
index 0000000..ef89ba8
--- /dev/null
+++ b/bindings/python/pyiree/status_utils.h
@@ -0,0 +1,67 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
+
+#include "base/api.h"
+#include "base/status.h"
+#include "pybind11/pytypes.h"
+
+namespace iree {
+namespace python {
+
+// Converts a failing status to a throwable exception, setting Python
+// error information.
+// Correct usage is something like:
+// if (!status.ok()) {
+// throw StatusToPyExc(status);
+// }
+pybind11::error_already_set StatusToPyExc(const Status& status);
+
+// Raises a value error with the given message.
+// Correct usage:
+// throw RaiseValueError("Foobar'd");
+pybind11::error_already_set RaiseValueError(const char* message);
+
+// Consumes a StatusOr<T>, returning an rvalue reference to the T if the
+// status is ok(). Otherwise, throws an exception.
+template <typename T>
+T&& PyConsumeStatusOr(iree::StatusOr<T>&& sor) {
+ if (sor.ok()) {
+ return std::move(*sor);
+ }
+ throw StatusToPyExc(sor.status());
+}
+
+pybind11::error_already_set ApiStatusToPyExc(iree_status_t status,
+ const char* message);
+
+static void CheckApiStatus(iree_status_t status, const char* message) {
+ if (status == IREE_STATUS_OK) {
+ return;
+ }
+ throw ApiStatusToPyExc(status, message);
+}
+
+static void CheckApiNotNull(const void* p, const char* message) {
+ if (!p) {
+ throw RaiseValueError(message);
+ }
+}
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/vm.cc
new file mode 100644
index 0000000..3d03d97
--- /dev/null
+++ b/bindings/python/pyiree/vm.cc
@@ -0,0 +1,37 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/vm.h"
+
+#include "bindings/python/pyiree/status_utils.h"
+
+namespace iree {
+namespace python {
+
+RtModule CreateModuleFromBlob(std::shared_ptr<OpaqueBlob> blob) {
+ iree_rt_module_t* module;
+ auto free_fn = OpaqueBlob::CreateFreeFn(blob);
+ auto status = iree_vm_bytecode_module_create_from_buffer(
+ {static_cast<const uint8_t*>(blob->data()), blob->size()}, free_fn.first,
+ free_fn.second, IREE_ALLOCATOR_DEFAULT, &module);
+ CheckApiStatus(status, "Error creating vm module from blob");
+ return RtModule::CreateRetained(module);
+}
+
+void SetupVmBindings(pybind11::module m) {
+ m.def("create_module_from_blob", CreateModuleFromBlob);
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/vm.h b/bindings/python/pyiree/vm.h
new file mode 100644
index 0000000..e7338cc
--- /dev/null
+++ b/bindings/python/pyiree/vm.h
@@ -0,0 +1,30 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_VM_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_VM_H_
+
+#include "bindings/python/pyiree/binding.h"
+#include "bindings/python/pyiree/rt.h"
+#include "vm/api.h"
+
+namespace iree {
+namespace python {
+
+void SetupVmBindings(pybind11::module m);
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_VM_H_
diff --git a/build_defs.oss.bzl b/build_defs.oss.bzl
new file mode 100644
index 0000000..92b002a
--- /dev/null
+++ b/build_defs.oss.bzl
@@ -0,0 +1,130 @@
+"""Common Bazel definitions for IREE."""
+
+load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
+load("@iree_native_python//:build_defs.bzl", "py_extension")
+load("@iree_core///build_tools/third_party/glslang:build_defs.bzl", "glsl_vulkan")
+load("@rules_python//python:defs.bzl", "py_library")
+
+NUMPY_DEPS = []
+
+def platform_trampoline_deps(basename):
+ """Produce a list of deps for the given `basename` platform target.
+
+ Example:
+ "file_mapping" -> ["///base/internal/file_mapping_internal"]
+
+ This is used for compatibility with various methods of including the
+ library in foreign source control systems.
+
+ Args:
+ basename: Library name prefix for a library in base/internal.
+ Returns:
+ A list of dependencies for depending on the library in a platform
+ sensitive way.
+ """
+ return [
+ "///base/internal:%s_internal" % basename,
+ ]
+
+# A platform-sensitive list of copts for the Vulkan loader.
+PLATFORM_VULKAN_LOADER_COPTS = select({
+ "///hal/vulkan:native_vk": [],
+ "///hal/vulkan:swiftshader_vk": [],
+ "//conditions:default": [],
+})
+
+# A platform-sensitive list of dependencies for non-test targets using Vulkan.
+PLATFORM_VULKAN_DEPS = select({
+ "///hal/vulkan:native_vk": [],
+ "///hal/vulkan:swiftshader_vk": [],
+ "//conditions:default": [],
+})
+
+# A platform-sensitive list of dependencies for tests using Vulkan.
+PLATFORM_VULKAN_TEST_DEPS = [
+ "@com_google_googletest//:gtest_main",
+]
+
+def iree_py_library(**kwargs):
+ """Compatibility py_library which has bazel compatible args."""
+
+ # This is used when args are needed that are incompatible with upstream.
+ # Presently, this includes:
+ # imports
+ py_library(**kwargs)
+
+def iree_py_extension(deps = [], **kwargs):
+ """Delegates to the real py_extension."""
+ py_extension(
+ deps = ["@iree_native_python//:python_headers"] + deps,
+ **kwargs
+ )
+
+def iree_build_test(name, targets):
+ """Dummy rule to ensure that targets build.
+
+ This is currently undefined in bazel and is preserved for compatibility.
+ """
+ pass
+
+def iree_setup_lit_package(data):
+ """Should be called once per test package that contains globbed lit tests.
+
+ Args:
+ data: Additional, project specific data deps to add.
+ """
+
+ # Bundle together all of the test utilities that are used by tests.
+ native.filegroup(
+ name = "lit_test_utilities",
+ testonly = True,
+ data = data + [
+ "@llvm//:FileCheck",
+ ],
+ )
+
+def iree_glob_lit_tests(
+ data = [":lit_test_utilities"],
+ driver = "///tools:run_lit.sh",
+ test_file_exts = ["mlir"]):
+ """Globs lit test files into tests for a package.
+
+ For most packages, the defaults suffice. Packages that include this must
+ also include a call to iree_setup_lit_package().
+
+ Args:
+ data: Data files to include/build.
+ driver: Test driver.
+ test_file_exts: File extensions to glob.
+ """
+ for test_file_ext in test_file_exts:
+ test_files = native.glob([
+ "*.%s" % (test_file_ext,),
+ "**/*.%s" % (test_file_ext,),
+ ])
+ for test_file in test_files:
+ test_file_location = "$(location %s)" % (test_file,)
+ native.sh_test(
+ name = "%s.test" % (test_file,),
+ size = "small",
+ srcs = [driver],
+ data = data + [test_file],
+ args = [test_file_location],
+ )
+
+# The OSS build currently has issues with generating flatbuffer reflections.
+# It is hard-coded to disabled here (and in iree_flatbuffer_cc_library) until triaged/fixed.
+FLATBUFFER_SUPPORTS_REFLECTIONS = False
+
+def iree_flatbuffer_cc_library(**kwargs):
+ """Wrapper for the flatbuffer_cc_library."""
+
+ # TODO(laurenzo): The bazel rule for reflections seems broken in OSS
+ # builds. Fix it and enable by default.
+ flatbuffer_cc_library(
+ gen_reflections = False,
+ **kwargs
+ )
+
+def iree_glsl_vulkan(**kwargs):
+ glsl_vulkan(**kwargs)
diff --git a/build_tools/embed_data/build_defs.bzl b/build_tools/embed_data/build_defs.bzl
index e48b22a..b4dc8ac 100644
--- a/build_tools/embed_data/build_defs.bzl
+++ b/build_tools/embed_data/build_defs.bzl
@@ -53,7 +53,7 @@
identifier: The identifier to use in generated names (defaults to name).
**kwargs: Args to pass to the cc_library.
"""
- generator = "//build_tools/embed_data:generate_cc_embed_data"
+ generator = "///build_tools/embed_data:generate_cc_embed_data"
generator_location = "$(location %s)" % generator
if identifier == None:
identifier = name
diff --git a/build_tools/embed_data/testembed1_test.cc b/build_tools/embed_data/testembed1_test.cc
index f34d19b..a2b9f01 100644
--- a/build_tools/embed_data/testembed1_test.cc
+++ b/build_tools/embed_data/testembed1_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/build_tools/embed_data/testembed1.h"
+#include "build_tools/embed_data/testembed1.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
diff --git a/build_tools/python/configure.bzl b/build_tools/python/configure.bzl
index 0e5e7b0..d215b4c 100644
--- a/build_tools/python/configure.bzl
+++ b/build_tools/python/configure.bzl
@@ -56,11 +56,11 @@
],
attrs = {
"_generate_script": attr.label(
- default = Label("//build_tools/python:generate_build.py"),
+ default = Label("///build_tools/python:generate_build.py"),
allow_single_file = True,
),
"_build_defs": attr.label(
- default = Label("//build_tools/python:build_defs.bzl"),
+ default = Label("///build_tools/python:build_defs.bzl"),
allow_single_file = True,
),
},
diff --git a/iree/compiler/BUILD b/compiler/BUILD
similarity index 100%
rename from iree/compiler/BUILD
rename to compiler/BUILD
diff --git a/iree/compiler/CMakeLists.txt b/compiler/CMakeLists.txt
similarity index 100%
rename from iree/compiler/CMakeLists.txt
rename to compiler/CMakeLists.txt
diff --git a/iree/compiler/IR/BUILD b/compiler/IR/BUILD
similarity index 100%
rename from iree/compiler/IR/BUILD
rename to compiler/IR/BUILD
diff --git a/iree/compiler/IR/CMakeLists.txt b/compiler/IR/CMakeLists.txt
similarity index 100%
rename from iree/compiler/IR/CMakeLists.txt
rename to compiler/IR/CMakeLists.txt
diff --git a/compiler/IR/ConfigOps.cpp b/compiler/IR/ConfigOps.cpp
new file mode 100644
index 0000000..0837e69
--- /dev/null
+++ b/compiler/IR/ConfigOps.cpp
@@ -0,0 +1,110 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/ConfigOps.h"
+
+#include "compiler/IR/StructureOps.h"
+#include "compiler/IR/Types.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+
+//===----------------------------------------------------------------------===//
+// Generic printers and parsers.
+//===----------------------------------------------------------------------===//
+
+// Parses an op that has no inputs and no outputs.
+static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) {
+ if (failed(parser.parseOptionalAttributeDict(state.attributes))) {
+ return failure();
+ }
+ return success();
+}
+
+// Prints an op that has no inputs and no outputs.
+static void printNoIOOp(Operation *op, OpAsmPrinter &printer) {
+ printer << op->getName();
+ printer.printOptionalAttrDict(op->getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// iree.target_config
+//===----------------------------------------------------------------------===//
+
+void ExecutableTargetConfigOp::build(Builder *builder, OperationState &state,
+ std::string backend) {
+ state.addAttribute("backend", builder->getStringAttr(backend));
+ ensureTerminator(*state.addRegion(), *builder, state.location);
+}
+
+static ParseResult parseExecutableTargetConfigOp(OpAsmParser &parser,
+ OperationState &state) {
+ llvm::SMLoc backendLoc;
+ StringAttr backendAttr;
+ if (failed(parser.parseLParen()) ||
+ failed(parser.getCurrentLocation(&backendLoc)) ||
+ failed(parser.parseAttribute(backendAttr, "backend", state.attributes))) {
+ return failure();
+ }
+
+ Region *body = state.addRegion();
+ if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) {
+ return failure();
+ }
+ if (succeeded(parser.parseOptionalKeyword("attributes"))) {
+ if (failed(parser.parseOptionalAttributeDict(state.attributes))) {
+ return failure();
+ }
+ }
+
+ ExecutableTargetConfigOp::ensureTerminator(*body, parser.getBuilder(),
+ state.location);
+
+ return success();
+}
+
+static void printExecutableTargetConfigOp(OpAsmPrinter &printer,
+ ExecutableTargetConfigOp op) {
+ printer << op.getOperationName() << "(" << op.backend() << ")";
+
+ printer.printRegion(op.body(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+
+ // Print out executable attributes, if present.
+ SmallVector<StringRef, 1> ignoredAttrs = {
+ "backend",
+ };
+ if (op.getAttrs().size() > ignoredAttrs.size()) {
+ printer << "\n attributes ";
+ printer.printOptionalAttrDict(op.getAttrs(), ignoredAttrs);
+ }
+}
+
+#define GET_OP_CLASSES
+#include "compiler/IR/ConfigOps.cpp.inc"
+
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/IR/ConfigOps.h b/compiler/IR/ConfigOps.h
new file mode 100644
index 0000000..de806ae
--- /dev/null
+++ b/compiler/IR/ConfigOps.h
@@ -0,0 +1,38 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_IR_CONFIGOPS_H_
+#define IREE_COMPILER_IR_CONFIGOPS_H_
+
+#include <cstdint>
+
+#include "compiler/IR/Types.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+
+#define GET_OP_CLASSES
+#include "compiler/IR/ConfigOps.h.inc"
+
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_IR_CONFIGOPS_H_
diff --git a/compiler/IR/ConfigOps.td b/compiler/IR/ConfigOps.td
new file mode 100644
index 0000000..12dbce6
--- /dev/null
+++ b/compiler/IR/ConfigOps.td
@@ -0,0 +1,44 @@
+// Ops used to declare configuration used by the IREE compiler.
+// These allow inline config that follows along the IR they are associated with.
+// Multiple config ops are allowed within a single scope to indicate that the
+// parent IR node should be processed for multiple targets.
+
+#ifdef IREE_CONFIG_OPS
+#else
+#define IREE_CONFIG_OPS
+
+include "compiler/IR/OpBase.td"
+
+class IREE_ConfigOp<string mnemonic, list<OpTrait> traits = []> :
+ Op<IREE_Dialect, mnemonic, traits> {
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ print$cppClass(p, *this); }];
+}
+
+//===----------------------------------------------------------------------===//
+// iree.executable configuration
+//===----------------------------------------------------------------------===//
+
+def IREE_ExecutableTargetConfigOp : IREE_ConfigOp<"target_config", [
+ IREE_ExecutableOnly,
+ SingleBlockImplicitTerminator<"ExecutableTargetConfigEndOp">
+]> {
+ let arguments = (ins
+ StrAttr:$backend
+ );
+
+ let regions = (region SizedRegion<1>:$body);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<"Builder *builder, OperationState &state, std::string backend">,
+ ];
+}
+
+def IREE_ExecutableTargetConfigEndOp :
+ IREE_ConfigOp<"_target_config_end", [Terminator, IREE_ExecutableTargetConfigOnly]> {
+ let parser = [{ return parseNoIOOp(parser, result); }];
+ let printer = [{ printNoIOOp(getOperation(), p); }];
+}
+
+#endif // IREE_CONFIG_OPS
diff --git a/compiler/IR/Dialect.cpp b/compiler/IR/Dialect.cpp
new file mode 100644
index 0000000..5968ade
--- /dev/null
+++ b/compiler/IR/Dialect.cpp
@@ -0,0 +1,90 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+
+#include "compiler/IR/ConfigOps.h"
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/IR/Types.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+static DialectRegistration<IREEDialect> iree_dialect;
+
+IREEDialect::IREEDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context) {
+#define IREE_ADD_TYPE(NAME, KIND, TYPE) addTypes<TYPE>();
+ IREE_TYPE_TABLE(IREE_ADD_TYPE);
+
+#define GET_OP_LIST
+ addOperations<
+#include "compiler/IR/Ops.cpp.inc"
+ >();
+#define GET_OP_LIST
+ addOperations<
+#include "compiler/IR/ConfigOps.cpp.inc"
+ >();
+#define GET_OP_LIST
+ addOperations<
+#include "compiler/IR/StructureOps.cpp.inc"
+ >();
+}
+
+//===----------------------------------------------------------------------===//
+// Type Parsing
+//===----------------------------------------------------------------------===//
+
+#define IREE_TYPE_PARSER(NAME, KIND, TYPE) \
+ static Type parse##TYPE(IREEDialect const &dialect, StringRef spec, \
+ Location loc) { \
+ spec.consume_front(NAME); \
+ return TYPE::get(dialect.getContext()); \
+ }
+IREE_TYPE_TABLE(IREE_TYPE_PARSER);
+
+#define IREE_PARSE_TYPE(NAME, KIND, TYPE) \
+ if (spec.startswith(NAME)) { \
+ return parse##TYPE(*this, spec, loc); \
+ }
+Type IREEDialect::parseType(StringRef spec, Location loc) const {
+ IREE_TYPE_TABLE(IREE_PARSE_TYPE);
+ emitError(loc, "unknown IREE type: ") << spec;
+ return Type();
+}
+
+//===----------------------------------------------------------------------===//
+// Type Printing
+//===----------------------------------------------------------------------===//
+
+#define IREE_TYPE_PRINTER(NAME, KIND, TYPE) \
+ static void print##TYPE(TYPE type, llvm::raw_ostream &os) { os << NAME; }
+IREE_TYPE_TABLE(IREE_TYPE_PRINTER);
+
+#define IREE_PRINT_TYPE(NAME, KIND, TYPE) \
+ case KIND: \
+ print##TYPE(type.cast<TYPE>(), os); \
+ return;
+void IREEDialect::printType(Type type, llvm::raw_ostream &os) const {
+ switch (type.getKind()) {
+ IREE_TYPE_TABLE(IREE_PRINT_TYPE);
+ default:
+ llvm_unreachable("unhandled IREE type");
+ }
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/IR/Dialect.h b/compiler/IR/Dialect.h
similarity index 100%
rename from iree/compiler/IR/Dialect.h
rename to compiler/IR/Dialect.h
diff --git a/compiler/IR/Interpreter/BUILD b/compiler/IR/Interpreter/BUILD
new file mode 100644
index 0000000..8dc1cf3
--- /dev/null
+++ b/compiler/IR/Interpreter/BUILD
@@ -0,0 +1,75 @@
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+load("@local_config_mlir//:tblgen.bzl", "gentbl")
+
+filegroup(
+ name = "td_files",
+ srcs = glob(["*.td"]),
+)
+
+cc_library(
+ name = "Interpreter",
+ srcs = [
+ "HLDialect.cpp",
+ "HLOps.cpp",
+ "HLOps.cpp.inc",
+ "LLDialect.cpp",
+ "LLOps.cpp",
+ "LLOps.cpp.inc",
+ "OpWriters.cpp",
+ ],
+ hdrs = [
+ "HLDialect.h",
+ "HLOps.h",
+ "HLOps.h.inc",
+ "LLDialect.h",
+ "LLOps.h",
+ "LLOps.h.inc",
+ "OpWriters.h",
+ ],
+ deps = [
+ ":HLOpsGen",
+ ":LLOpsGen",
+ "///compiler/IR",
+ "///compiler/Serialization",
+ "///compiler/Utils",
+ "///schemas/bytecode:interpreter_bytecode_v0",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:StandardOps",
+ ],
+ alwayslink = 1,
+)
+
+gentbl(
+ name = "HLOpsGen",
+ tbl_outs = [
+ ("-gen-op-decls", "HLOps.h.inc"),
+ ("-gen-op-defs", "HLOps.cpp.inc"),
+ ],
+ tblgen = "@local_config_mlir//:mlir-tblgen",
+ td_file = "HLOps.td",
+ td_srcs = [
+ ":td_files",
+ "@local_config_mlir//:include/mlir/IR/OpBase.td",
+ "///compiler/IR:OpBase.td",
+ ],
+)
+
+gentbl(
+ name = "LLOpsGen",
+ tbl_outs = [
+ ("-gen-op-decls", "LLOps.h.inc"),
+ ("-gen-op-defs", "LLOps.cpp.inc"),
+ ],
+ tblgen = "@local_config_mlir//:mlir-tblgen",
+ td_file = "LLOps.td",
+ td_srcs = [
+ ":td_files",
+ "@local_config_mlir//:include/mlir/IR/OpBase.td",
+ "///compiler/IR:OpBase.td",
+ ],
+)
diff --git a/iree/compiler/IR/Interpreter/CMakeLists.txt b/compiler/IR/Interpreter/CMakeLists.txt
similarity index 100%
rename from iree/compiler/IR/Interpreter/CMakeLists.txt
rename to compiler/IR/Interpreter/CMakeLists.txt
diff --git a/compiler/IR/Interpreter/HLDialect.cpp b/compiler/IR/Interpreter/HLDialect.cpp
new file mode 100644
index 0000000..d34f44b
--- /dev/null
+++ b/compiler/IR/Interpreter/HLDialect.cpp
@@ -0,0 +1,34 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Interpreter/HLDialect.h"
+
+#include "compiler/IR/Interpreter/HLOps.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+IREEHLInterpreterDialect::IREEHLInterpreterDialect(MLIRContext* context)
+ : Dialect(getDialectNamespace(), context) {
+#define GET_OP_LIST
+ addOperations<
+#include "compiler/IR/Interpreter/HLOps.cpp.inc"
+ >();
+}
+
+static DialectRegistration<IREEHLInterpreterDialect> iree_hl_interp_dialect;
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/HLDialect.h b/compiler/IR/Interpreter/HLDialect.h
similarity index 100%
rename from iree/compiler/IR/Interpreter/HLDialect.h
rename to compiler/IR/Interpreter/HLDialect.h
diff --git a/compiler/IR/Interpreter/HLOps.cpp b/compiler/IR/Interpreter/HLOps.cpp
new file mode 100644
index 0000000..6a4b4ff
--- /dev/null
+++ b/compiler/IR/Interpreter/HLOps.cpp
@@ -0,0 +1,246 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Interpreter/HLOps.h"
+
+#include "compiler/IR/Ops.h"
+#include "compiler/Utils/OpCreationUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREEInterp {
+namespace HL {
+
+//===----------------------------------------------------------------------===//
+// iree_hl_interp.call
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) {
+ SymbolRefAttr calleeAttr;
+ FunctionType calleeType;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ auto calleeLoc = parser.getNameLoc();
+ if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttributeDict(state.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.addTypesToList(calleeType.getResults(), state.types) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
+ state.operands)) {
+ return failure();
+ }
+ return success();
+}
+
+static void printCallOp(OpAsmPrinter &p, CallOp op) {
+ p << "iree_hl_interp.call " << op.getAttr("callee") << '(';
+ p.printOperands(op.getOperands());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : ";
+ p.printType(op.getCalleeType());
+}
+
+FunctionType CallOp::getCalleeType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getOperandTypes());
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_interp.call_indirect
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallIndirectOp(OpAsmParser &parser,
+ OperationState &result) {
+ FunctionType calleeType;
+ OpAsmParser::OperandType callee;
+ llvm::SMLoc operandsLoc;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ return failure(
+ parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttributeDict(result.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.resolveOperand(callee, calleeType, result.operands) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
+ result.operands) ||
+ parser.addTypesToList(calleeType.getResults(), result.types));
+}
+
+static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) {
+ p << "iree_hl_interp.call_indirect ";
+ p.printOperand(op.getCallee());
+ p << '(';
+ auto operandRange = op.getOperands();
+ p.printOperands(++operandRange.begin(), operandRange.end());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : " << op.getCallee()->getType();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_interp.return
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) {
+ SmallVector<OpAsmParser::OperandType, 2> opInfo;
+ SmallVector<Type, 2> types;
+ llvm::SMLoc loc = parser.getCurrentLocation();
+ return failure(parser.parseOperandList(opInfo) ||
+ (!opInfo.empty() && parser.parseColonTypeList(types)) ||
+ parser.resolveOperands(opInfo, types, loc, state.operands));
+}
+
+static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
+ p << "iree_hl_interp.return";
+ if (op.getNumOperands() > 0) {
+ p << ' ';
+ p.printOperands(op.operand_begin(), op.operand_end());
+ p << " : ";
+ interleaveComma(op.getOperandTypes(), p);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_interp.br
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
+ Block *dest;
+ SmallVector<Value *, 4> destOperands;
+ if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
+ result.addSuccessor(dest, destOperands);
+ return success();
+}
+
+static void printBranchOp(OpAsmPrinter &p, BranchOp op) {
+ p << "iree_hl_interp.br ";
+ p.printSuccessorAndUseList(op.getOperation(), 0);
+}
+
+Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
+
+void BranchOp::setDest(Block *block) {
+ return getOperation()->setSuccessor(block, 0);
+}
+
+void BranchOp::eraseOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(0, index);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_interp.cond_br
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCondBranchOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Value *, 4> destOperands;
+ Block *dest;
+ OpAsmParser::OperandType condInfo;
+
+ // Parse the condition.
+ Type int1Ty = parser.getBuilder().getI1Type();
+ if (parser.parseOperand(condInfo) || parser.parseComma() ||
+ parser.resolveOperand(condInfo, int1Ty, result.operands)) {
+ return parser.emitError(parser.getNameLoc(),
+ "expected condition type was boolean (i1)");
+ }
+
+ // Parse the true successor.
+ if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
+ result.addSuccessor(dest, destOperands);
+
+ // Parse the false successor.
+ destOperands.clear();
+ if (parser.parseComma() ||
+ parser.parseSuccessorAndUseList(dest, destOperands))
+ return failure();
+ result.addSuccessor(dest, destOperands);
+
+ return success();
+}
+
+static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) {
+ p << "iree_hl_interp.cond_br ";
+ p.printOperand(op.getCondition());
+ p << ", ";
+ p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
+ p << ", ";
+ p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_interp.clone
+//===----------------------------------------------------------------------===//
+
+OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
+ // If this is the only usage, we know the clone is unnecessary.
+ // TODO(b/135053584) More sophisticated analysis.
+ if (src()->hasOneUse()) return src();
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_interp.concat
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ConcatToCopies : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto finalType = concatOp.getResult()->getType().cast<ShapedType>();
+ auto loc = concatOp.getLoc();
+ std::vector<Value *> dimPieces;
+ auto dst =
+ rewriter.create<IREEInterp::HL::AllocHeapOp>(loc, finalType, dimPieces);
+
+ llvm::SmallVector<int64_t, 4> zeroOffset(finalType.getRank(), 0);
+ auto srcIndices = createArrayConstant(rewriter, loc, zeroOffset);
+
+ auto concatDimension = concatOp.dimension().getZExtValue();
+ llvm::SmallVector<int64_t, 4> dstIndices(finalType.getRank(), 0);
+ for (auto *src : concatOp.srcs()) {
+ auto srcShape = src->getType().cast<ShapedType>().getShape();
+ auto lengths = createArrayConstant(rewriter, loc, srcShape);
+ auto dstIndicesOp = createArrayConstant(rewriter, loc, dstIndices);
+ rewriter.create<IREEInterp::HL::CopyOp>(loc, src, srcIndices, dst,
+ dstIndicesOp, lengths);
+ dstIndices[concatDimension] += srcShape[concatDimension];
+ }
+
+ concatOp.replaceAllUsesWith(dst.getResult());
+
+ return matchSuccess();
+ }
+};
+} // namespace
+
+void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ConcatToCopies>(context);
+}
+
+#define GET_OP_CLASSES
+#include "compiler/IR/Interpreter/HLOps.cpp.inc"
+
+} // namespace HL
+} // namespace IREEInterp
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/IR/Interpreter/HLOps.h b/compiler/IR/Interpreter/HLOps.h
new file mode 100644
index 0000000..b02fbed
--- /dev/null
+++ b/compiler/IR/Interpreter/HLOps.h
@@ -0,0 +1,38 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_IR_INTERPRETER_HLOPS_H_
+#define IREE_COMPILER_IR_INTERPRETER_HLOPS_H_
+
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREEInterp {
+namespace HL {
+
+#define GET_OP_CLASSES
+#include "compiler/IR/Interpreter/HLOps.h.inc"
+
+} // namespace HL
+} // namespace IREEInterp
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_IR_INTERPRETER_HLOPS_H_
diff --git a/compiler/IR/Interpreter/HLOps.td b/compiler/IR/Interpreter/HLOps.td
new file mode 100644
index 0000000..63fa5bf
--- /dev/null
+++ b/compiler/IR/Interpreter/HLOps.td
@@ -0,0 +1,658 @@
+// IREE high-level interpreter op definitions.
+// This op set contains pseudo ops, ops that accept non-MemRef types, and ops in
+// normal SSA form.
+//
+// Through lowering these high-level ops are converted to low-level ops in the
+// LLOps.td (iree_ll_interp.*). These map 1:1 with the bytecode,
+// accept only MemRef types, and generally use output parameters instead of
+// return types.
+//
+// The source of truth for bytecode opcodes is:
+// https://github.com/google/iree/tree/master/iree/schemas/bytecode/interpreter_bytecode_v0.h
+
+#ifdef IREE_INTERPRETER_HL_OPS
+#else
+#define IREE_INTERPRETER_HL_OPS
+
+#ifdef IREE_OP_BASE
+#else
+include "compiler/IR/OpBase.td"
+#endif // IREE_OP_BASE
+
+def IREEInterpHL_Dialect : Dialect {
+ let name = "iree_hl_interp";
+ let cppNamespace = "IREEInterp::HL";
+}
+
+//===----------------------------------------------------------------------===//
+// Base op classes
+//===----------------------------------------------------------------------===//
+
+class IREEInterpHL_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<IREEInterpHL_Dialect, mnemonic, traits>;
+
+class IREEInterpHL_PureOp<string mnemonic, list<OpTrait> traits = []> :
+ IREEInterpHL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
+
+//===----------------------------------------------------------------------===//
+// High-level interpreter ops
+//===----------------------------------------------------------------------===//
+
+def IREEInterpHL_CallOp : IREEInterpHL_Op<"call"> {
+ let arguments = (ins SymbolRefAttr:$callee, Variadic<IREEHL_MemRef>);
+ let results = (outs Variadic<IREEHL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, FuncOp callee,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(callee.getType().getResults());
+ }]>, OpBuilder<
+ "Builder *builder, OperationState &result, StringRef callee,"
+ "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(results);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ StringRef getCallee() { return callee(); }
+ FunctionType getCalleeType();
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+ operand_iterator arg_operand_begin() { return operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREEInterpHL_CallIndirectOp : IREEInterpHL_Op<"call_indirect"> {
+ let arguments = (ins FunctionType:$callee, Variadic<IREEHL_MemRef>:$operands);
+ let results = (outs Variadic<IREEHL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Value *callee,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.operands.push_back(callee);
+ result.addOperands(operands);
+ result.addTypes(callee->getType().cast<FunctionType>().getResults());
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ Value *getCallee() { return getOperand(0); }
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+ operand_iterator arg_operand_begin() { return ++operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREEInterpHL_ReturnOp : IREEInterpHL_Op<"return", [Terminator]> {
+ let arguments = (ins Variadic<IREEHL_MemRef>:$operands);
+
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
+ >];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREEInterpHL_BranchOp : IREEInterpHL_Op<"br", [Terminator]> {
+ let arguments = (ins Variadic<IREEHL_MemRef>:$operands);
+
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Block *dest,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.addSuccessor(dest, operands);
+ }]>];
+
+ let extraClassDeclaration = [{
+ Block *getDest();
+ void setDest(Block *block);
+
+ /// Erase the operand at 'index' from the operand list.
+ void eraseOperand(unsigned index);
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREEInterpHL_CondBranchOp : IREEInterpHL_Op<"cond_br", [Terminator]> {
+ let arguments = (ins
+ IREEHL_BoolScalar:$condition,
+ Variadic<IREEHL_MemRef>:$branchOperands
+ );
+
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Value *condition,"
+ "Block *trueDest, ArrayRef<Value *> trueOperands,"
+ "Block *falseDest, ArrayRef<Value *> falseOperands", [{
+ result.addOperands(condition);
+ result.addSuccessor(trueDest, trueOperands);
+ result.addSuccessor(falseDest, falseOperands);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // These are the indices into the dests list.
+ enum { trueIndex = 0, falseIndex = 1 };
+
+ // The condition operand is the first operand in the list.
+ Value *getCondition() { return getOperand(0); }
+
+ /// Return the destination if the condition is true.
+ Block *getTrueDest() {
+ return getOperation()->getSuccessor(trueIndex);
+ }
+
+ /// Return the destination if the condition is false.
+ Block *getFalseDest() {
+ return getOperation()->getSuccessor(falseIndex);
+ }
+
+ // Accessors for operands to the 'true' destination.
+ Value *getTrueOperand(unsigned idx) {
+ assert(idx < getNumTrueOperands());
+ return getOperand(getTrueDestOperandIndex() + idx);
+ }
+
+ void setTrueOperand(unsigned idx, Value *value) {
+ assert(idx < getNumTrueOperands());
+ setOperand(getTrueDestOperandIndex() + idx, value);
+ }
+
+ operand_iterator true_operand_begin() {
+ return operand_begin() + getTrueDestOperandIndex();
+ }
+ operand_iterator true_operand_end() {
+ return true_operand_begin() + getNumTrueOperands();
+ }
+ operand_range getTrueOperands() {
+ return {true_operand_begin(), true_operand_end()};
+ }
+
+ unsigned getNumTrueOperands() {
+ return getOperation()->getNumSuccessorOperands(trueIndex);
+ }
+
+ /// Erase the operand at 'index' from the true operand list.
+ void eraseTrueOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(trueIndex, index);
+ }
+
+ // Accessors for operands to the 'false' destination.
+ Value *getFalseOperand(unsigned idx) {
+ assert(idx < getNumFalseOperands());
+ return getOperand(getFalseDestOperandIndex() + idx);
+ }
+ void setFalseOperand(unsigned idx, Value *value) {
+ assert(idx < getNumFalseOperands());
+ setOperand(getFalseDestOperandIndex() + idx, value);
+ }
+
+ operand_iterator false_operand_begin() { return true_operand_end(); }
+ operand_iterator false_operand_end() {
+ return false_operand_begin() + getNumFalseOperands();
+ }
+ operand_range getFalseOperands() {
+ return {false_operand_begin(), false_operand_end()};
+ }
+
+ unsigned getNumFalseOperands() {
+ return getOperation()->getNumSuccessorOperands(falseIndex);
+ }
+
+ /// Erase the operand at 'index' from the false operand list.
+ void eraseFalseOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(falseIndex, index);
+ }
+
+ private:
+ /// Get the index of the first true destination operand.
+ unsigned getTrueDestOperandIndex() { return 1; }
+
+ /// Get the index of the first false destination operand.
+ unsigned getFalseDestOperandIndex() {
+ return getTrueDestOperandIndex() + getNumTrueOperands();
+ }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREEInterpHL_CmpIOp :
+ IREEInterpHL_PureOp<"cmp_i", [SameOperandsAndResultShape,
+ AllTypesMatch<["lhs", "rhs"]>]> {
+ let arguments = (ins
+ I32Attr:$predicate,
+ IREEHL_IntMemRef:$lhs,
+ IREEHL_IntMemRef:$rhs
+ );
+ let results = (outs IREEHL_BoolMemRef);
+}
+
+def IREEInterpHL_CmpFOp :
+ IREEInterpHL_PureOp<"cmp_f", [SameOperandsAndResultShape,
+ AllTypesMatch<["lhs", "rhs"]>]> {
+ let arguments = (ins
+ I32Attr:$predicate,
+ IREEHL_FloatMemRef:$lhs,
+ IREEHL_FloatMemRef:$rhs
+ );
+ let results = (outs IREEHL_BoolMemRef);
+}
+
+// TODO(b/142012496): Add trait that enables DCE but not CSE.
+def IREEInterpHL_AllocHeapOp : IREEInterpHL_Op<"alloc_heap"> {
+ // TODO(benvanik): attributes and args.
+ let arguments = (ins
+ Variadic<IREEHL_MemRef>:$dim_pieces
+ );
+ let results = (outs
+ IREEHL_MemRef
+ );
+}
+
+def IREEInterpHL_DiscardOp : IREEInterpHL_Op<"discard"> {
+ let arguments = (ins IREEHL_MemRef);
+}
+
+def IREEInterpHL_RankOp : IREEInterpHL_PureOp<"rank"> {
+ let arguments = (ins IREEHL_MemRef);
+ let results = (outs IREEHL_IntScalar);
+}
+
+def IREEInterpHL_DimOp : IREEInterpHL_PureOp<"dim"> {
+ // TODO(benvanik) add dim attr (I32Attr:$dim)
+ let arguments = (ins IREEHL_MemRef);
+ let results = (outs IREEHL_IntScalar);
+}
+
+def IREEInterpHL_ShapeOp : IREEInterpHL_PureOp<"shape"> {
+ let arguments = (ins IREEHL_MemRef);
+ let results = (outs IREEHL_1DIntMemRef);
+}
+
+def IREEInterpHL_LengthOp : IREEInterpHL_PureOp<"length"> {
+ let arguments = (ins IREEHL_MemRef);
+ let results = (outs IREEHL_IndexScalar);
+}
+
+def IREEInterpHL_SliceOp :
+ IREEInterpHL_PureOp<"slice", [AllElementTypesMatch<["src", "result"]>,
+ AllTypesMatch<["srcIndices", "lengths"]>]> {
+ let arguments = (ins
+ IREEHL_MemRef:$src,
+ IREEHL_1DIndexMemRef:$srcIndices,
+ IREEHL_1DIndexMemRef:$lengths
+ );
+ let results = (outs IREEHL_MemRef:$result);
+}
+
+def IREEInterpHL_CopyOp : IREEInterpHL_Op<"copy", [
+ AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>,
+ AllRanksMatch<["src", "dst"]>,
+ // The checks above are redundant with this one, but they give more specific
+ // error messages.
+ AllMatch<[
+ Rank<"src">.result,
+ Rank<"dst">.result,
+ ElementCount<"srcIndices">.result,
+ ElementCount<"dstIndices">.result,
+ ElementCount<"lengths">.result
+ ], "src/dst rank is the same as srcIndices/dstIndices/lengths size">,
+ AllElementTypesMatch<["src", "dst"]>
+]> {
+ let arguments = (ins
+ IREEHL_MemRef:$src,
+ IREEHL_1DIndexMemRef:$srcIndices,
+ IREEHL_MemRef:$dst,
+ IREEHL_1DIndexMemRef:$dstIndices,
+ IREEHL_1DIndexMemRef:$lengths
+ );
+}
+
+def IREEInterpHL_CloneOp :
+ IREEInterpHL_PureOp<"clone", [SameOperandsAndResultType]> {
+ let arguments = (ins IREEHL_MemRef:$src);
+ let results = (outs IREEHL_MemRef);
+
+ let hasFolder = 1;
+}
+
+// A pseudo op provided for convenience. This gets canonicalized to a series of
+// copies.
+def IREEInterpHL_ConcatOp : IREEInterpHL_PureOp<"concat"> {
+ let arguments = (ins
+ Variadic<IREEHL_MemRef>:$srcs,
+ I32Attr:$dimension
+ );
+ let results = (outs IREEHL_MemRef);
+
+ let hasCanonicalizer = 1;
+}
+
+// TODO(benvanik): add split dim/size/etc. Maybe make multiple ops?
+def IREEInterpHL_SplitOp :
+ IREEInterpHL_PureOp<"split", [SameOperandsAndResultElementType]> {
+ let arguments = (ins IREEHL_MemRef:$src);
+ let results = (outs Variadic<IREEHL_MemRef>);
+}
+
+def IREEInterpHL_AssignOp :
+ IREEInterpHL_PureOp<"assign", [SameOperandsAndResultType]> {
+ let arguments = (ins IREEHL_MemRef:$src);
+ let results = (outs IREEHL_MemRef:$result);
+}
+
+def IREEInterpHL_CondAssignOp :
+ IREEInterpHL_PureOp<"cond_assign",
+ [AllTypesMatch<["lhs", "rhs", "result"]>]> {
+ let arguments = (ins
+ IREEHL_BoolScalar:$cond,
+ IREEHL_MemRef:$lhs,
+ IREEHL_MemRef:$rhs
+ );
+ let results = (outs IREEHL_MemRef:$result);
+}
+
+def IREEInterpHL_ReshapeOp : IREEInterpHL_PureOp<"reshape"> {
+ let arguments = (ins IREEHL_MemRef:$src, IREEHL_MemRef:$shape);
+ let results = (outs IREEHL_MemRef);
+}
+
+def IREEInterpHL_SelectOp :
+ IREEInterpHL_PureOp<"select", [AllTypesMatch<["lhs", "rhs", "result"]>]> {
+ let arguments = (ins
+ IREEHL_BoolMemRef:$cond,
+ IREEHL_MemRef:$lhs,
+ IREEHL_MemRef:$rhs
+ );
+ let results = (outs IREEHL_MemRef:$result);
+}
+
+def IREEInterpHL_BroadcastOp :
+ IREEInterpHL_PureOp<"broadcast",
+ [AllElementTypesMatch<["operand", "result"]>]> {
+ let arguments = (ins
+ IREE_ScalarMemRefOf<[AnyType]>:$operand,
+ IREEHL_1DIntMemRef:$shape
+ );
+ let results = (outs IREEHL_MemRef:$result);
+}
+
+def IREEInterpHL_PadOp :
+ IREEInterpHL_PureOp<
+ "pad", [AllElementTypesMatch<["src", "result", "padding_value"]>]> {
+ let arguments = (ins
+ IREEHL_MemRef:$src,
+ IREEHL_AnyScalar:$padding_value,
+ IREEHL_1DIndexMemRef:$edge_padding_low,
+ IREEHL_1DIndexMemRef:$edge_padding_high,
+ IREEHL_1DIndexMemRef:$interior_padding
+ );
+
+ let results = (outs IREEHL_MemRef:$result);
+}
+
+def IREEInterpHL_TileOp :
+ IREEInterpHL_PureOp<"tile", [AllElementTypesMatch<["operand", "result"]>]> {
+ let arguments = (ins
+ IREEHL_MemRef:$operand,
+ IREEHL_1DIntMemRef:$shape
+ );
+ let results = (outs IREEHL_MemRef:$result);
+}
+
+def IREEInterpHL_TransposeOp :
+ IREEInterpHL_PureOp<"transpose", [
+ AllElementTypesMatch<["operand", "result"]>,
+ AllRanksMatch<["operand", "result"]>,
+ AllElementCountsMatch<["operand", "result"]>
+ ]> {
+ let arguments = (ins
+ IREEHL_MemRef:$operand,
+ IREEHL_1DIntMemRef:$permutation
+ );
+ let results = (outs IREEHL_MemRef:$result);
+}
+
+def IREEInterpHL_ReverseOp :
+ IREEInterpHL_PureOp<"reverse", [AllTypesMatch<["operand", "result"]>]> {
+ let arguments = (ins
+ IREEHL_MemRef:$operand,
+ IREEHL_1DIntMemRef:$dims
+ );
+ let results = (outs IREEHL_MemRef:$result);
+}
+
+class IREEInterpHL_UnaryElementwiseOp<string mnemonic, Type type,
+ list<OpTrait> traits = []> :
+ IREEInterpHL_PureOp<mnemonic,
+ !listconcat(traits, [SameOperandsAndResultType])> {
+ let arguments = (ins type);
+ let results = (outs type);
+}
+
+class IREEInterpHL_UnaryElementwiseFloatOp<string mnemonic,
+ list<OpTrait> traits = []> :
+ IREEInterpHL_UnaryElementwiseOp<mnemonic, IREEHL_FloatMemRef, traits>;
+
+class IREEInterpHL_UnaryElementwiseIntOp<string mnemonic,
+ list<OpTrait> traits = []> :
+ IREEInterpHL_UnaryElementwiseOp<mnemonic, IREEHL_IntMemRef, traits>;
+
+class IREEInterpHL_BinaryElementwiseOp<string mnemonic, Type type,
+ list<OpTrait> traits> :
+ IREEInterpHL_PureOp<mnemonic,
+ !listconcat(traits, [SameOperandsAndResultType])> {
+ let arguments = (ins type:$lhs, type:$rhs);
+ let results = (outs type);
+}
+
+class IREEInterpHL_BinaryElementwiseFloatOp<string mnemonic,
+ list<OpTrait> traits = []> :
+ IREEInterpHL_BinaryElementwiseOp<mnemonic, IREEHL_FloatMemRef,
+ traits>;
+
+class IREEInterpHL_BinaryElementwiseIntOp<string mnemonic,
+ list<OpTrait> traits = []> :
+ IREEInterpHL_BinaryElementwiseOp<mnemonic, IREEHL_IntMemRef,
+ traits>;
+
+class IREEInterpHL_TernaryOp<string mnemonic,
+ Type type = IREEHL_MemRef,
+ list<OpTrait> traits = []> :
+ IREEInterpHL_PureOp<mnemonic, traits> {
+ let arguments = (ins type:$a, type:$b, type:$c);
+ let results = (outs type);
+}
+
+// TODO(benvanik): add traits for broadcasting support.
+
+def IREEInterpHL_NotOp : IREEInterpHL_UnaryElementwiseIntOp<"not">;
+def IREEInterpHL_AndOp : IREEInterpHL_BinaryElementwiseIntOp<"and">;
+def IREEInterpHL_OrOp : IREEInterpHL_BinaryElementwiseIntOp<"or">;
+def IREEInterpHL_XorOp : IREEInterpHL_BinaryElementwiseIntOp<"xor">;
+def IREEInterpHL_ShiftLeftOp : IREEInterpHL_BinaryElementwiseIntOp<"sll">;
+def IREEInterpHL_ShiftRightLogicalOp : IREEInterpHL_BinaryElementwiseIntOp<"srl">;
+def IREEInterpHL_ShiftRightArithmeticOp : IREEInterpHL_BinaryElementwiseIntOp<"sra">;
+
+def IREEInterpHL_AddIOp : IREEInterpHL_BinaryElementwiseIntOp<"add_i">;
+def IREEInterpHL_AddFOp : IREEInterpHL_BinaryElementwiseFloatOp<"add_f">;
+def IREEInterpHL_SubIOp : IREEInterpHL_BinaryElementwiseIntOp<"sub_i">;
+def IREEInterpHL_SubFOp : IREEInterpHL_BinaryElementwiseFloatOp<"sub_f">;
+def IREEInterpHL_AbsIOp : IREEInterpHL_UnaryElementwiseIntOp<"abs_i">;
+def IREEInterpHL_AbsFOp : IREEInterpHL_UnaryElementwiseFloatOp<"abs_f">;
+def IREEInterpHL_MulIOp : IREEInterpHL_BinaryElementwiseIntOp<"mul_i">;
+def IREEInterpHL_MulFOp : IREEInterpHL_BinaryElementwiseFloatOp<"mul_f">;
+def IREEInterpHL_DivISOp : IREEInterpHL_BinaryElementwiseIntOp<"div_i_s">;
+def IREEInterpHL_DivIUOp : IREEInterpHL_BinaryElementwiseIntOp<"div_i_u">;
+def IREEInterpHL_DivFOp : IREEInterpHL_BinaryElementwiseFloatOp<"div_f">;
+def IREEInterpHL_MulAddIOp : IREEInterpHL_TernaryOp<"madd_i", IREEHL_IntMemRef>;
+def IREEInterpHL_MulAddFOp : IREEInterpHL_TernaryOp<"madd_f", IREEHL_FloatMemRef>;
+def IREEInterpHL_ExpFOp : IREEInterpHL_UnaryElementwiseFloatOp<"exp_f">;
+def IREEInterpHL_LogFOp : IREEInterpHL_UnaryElementwiseFloatOp<"log_f">;
+def IREEInterpHL_RsqrtFOp : IREEInterpHL_UnaryElementwiseFloatOp<"rsqrt_f">;
+def IREEInterpHL_CosFOp : IREEInterpHL_UnaryElementwiseFloatOp<"cos_f">;
+def IREEInterpHL_SinFOp : IREEInterpHL_UnaryElementwiseFloatOp<"sin_f">;
+def IREEInterpHL_TanhFOp : IREEInterpHL_UnaryElementwiseFloatOp<"tanh_f">;
+def IREEInterpHL_Atan2FOp : IREEInterpHL_UnaryElementwiseFloatOp<"atan2_f">;
+
+def IREEInterpHL_MinISOp : IREEInterpHL_BinaryElementwiseIntOp<"min_i_s">;
+def IREEInterpHL_MinIUOp : IREEInterpHL_BinaryElementwiseIntOp<"min_i_u">;
+def IREEInterpHL_MinFOp : IREEInterpHL_BinaryElementwiseFloatOp<"min_f">;
+def IREEInterpHL_MaxISOp : IREEInterpHL_BinaryElementwiseIntOp<"max_i_s">;
+def IREEInterpHL_MaxIUOp : IREEInterpHL_BinaryElementwiseIntOp<"max_i_u">;
+def IREEInterpHL_MaxFOp : IREEInterpHL_BinaryElementwiseFloatOp<"max_f">;
+def IREEInterpHL_ClampFOp : IREEInterpHL_TernaryOp<"clamp_f", IREEHL_FloatMemRef>;
+def IREEInterpHL_FloorFOp : IREEInterpHL_UnaryElementwiseFloatOp<"floor_f">;
+def IREEInterpHL_CeilFOp : IREEInterpHL_UnaryElementwiseFloatOp<"ceil_f">;
+
+class IREEInterpHL_ConversionOp<string mnemonic, Type inputType,
+ Type outputType> :
+ IREEInterpHL_PureOp<mnemonic, [SameOperandsAndResultShape]> {
+ let arguments = (ins inputType);
+ let results = (outs outputType);
+}
+
+def IREEInterpHL_ConvertSSOp :
+ IREEInterpHL_ConversionOp<"convert_s_s", IREEHL_IntMemRef,
+ IREEHL_IntMemRef>;
+def IREEInterpHL_ConvertSUOp :
+ IREEInterpHL_ConversionOp<"convert_s_u", IREEHL_IntMemRef,
+ IREEHL_IntMemRef>;
+def IREEInterpHL_ConvertSFOp :
+ IREEInterpHL_ConversionOp<"convert_s_f", IREEHL_IntMemRef,
+ IREEHL_FloatMemRef>;
+
+def IREEInterpHL_ConvertUSOp :
+ IREEInterpHL_ConversionOp<"convert_u_s", IREEHL_IntMemRef,
+ IREEHL_IntMemRef>;
+def IREEInterpHL_ConvertUUOp :
+ IREEInterpHL_ConversionOp<"convert_u_u", IREEHL_IntMemRef,
+ IREEHL_IntMemRef>;
+def IREEInterpHL_ConvertUFOp :
+ IREEInterpHL_ConversionOp<"convert_u_f", IREEHL_IntMemRef,
+ IREEHL_FloatMemRef>;
+
+def IREEInterpHL_ConvertFSOp :
+ IREEInterpHL_ConversionOp<"convert_f_s", IREEHL_FloatMemRef,
+ IREEHL_IntMemRef>;
+def IREEInterpHL_ConvertFUOp :
+ IREEInterpHL_ConversionOp<"convert_f_u", IREEHL_FloatMemRef,
+ IREEHL_IntMemRef>;
+def IREEInterpHL_ConvertFFOp :
+ IREEInterpHL_ConversionOp<"convert_f_f", IREEHL_FloatMemRef,
+ IREEHL_FloatMemRef>;
+
+def IREEInterpHL_MatMulIOp :
+ IREEInterpHL_PureOp<"matmul_i",
+ [AllElementTypesMatch<["lhs", "rhs", "result"]>]> {
+ let arguments = (ins
+ IREEHL_IntMemRef:$lhs,
+ IREEHL_IntMemRef:$rhs,
+ IREEHL_IntMemRef:$multiplier_mantissa,
+ IREEHL_IntMemRef:$multiplier_exponent
+ );
+ let results = (outs IREEHL_IntMemRef:$result);
+}
+def IREEInterpHL_MatMulFOp :
+ IREEInterpHL_PureOp<"matmul_f", [SameOperandsAndResultElementType]> {
+ let arguments = (ins
+ IREEHL_FloatMemRef:$lhs,
+ IREEHL_FloatMemRef:$rhs
+ );
+ let results = (outs IREEHL_FloatMemRef);
+}
+
+def IREEInterpHL_ReduceSumIOp :
+ IREEInterpHL_PureOp<"reduce_sum_i",
+ [AllElementTypesMatch<["src", "result", "init"]>]> {
+ let arguments = (ins
+ IREEHL_IntMemRef:$src,
+ IREEHL_IntMemRef:$init,
+ I32Attr:$dimension
+ );
+ let results = (outs IREEHL_IntMemRef:$result);
+}
+def IREEInterpHL_ReduceSumFOp :
+ IREEInterpHL_PureOp<"reduce_sum_f",
+ [AllElementTypesMatch<["src", "result", "init"]>]> {
+ let arguments = (ins
+ IREEHL_FloatMemRef:$src,
+ IREEHL_FloatMemRef:$init,
+ I32Attr:$dimension
+ );
+ let results = (outs IREEHL_FloatMemRef:$result);
+}
+def IREEInterpHL_ReduceMinIOp :
+ IREEInterpHL_PureOp<"reduce_min_i",
+ [AllElementTypesMatch<["src", "result", "init"]>]> {
+ let arguments = (ins
+ IREEHL_IntMemRef:$src,
+ IREEHL_IntMemRef:$init,
+ I32Attr:$dimension
+ );
+ let results = (outs IREEHL_IntMemRef:$result);
+}
+def IREEInterpHL_ReduceMinFOp :
+ IREEInterpHL_PureOp<"reduce_min_f",
+ [AllElementTypesMatch<["src", "result", "init"]>]> {
+ let arguments = (ins
+ IREEHL_FloatMemRef:$src,
+ IREEHL_FloatMemRef:$init,
+ I32Attr:$dimension
+ );
+ let results = (outs IREEHL_FloatMemRef:$result);
+}
+def IREEInterpHL_ReduceMaxIOp :
+ IREEInterpHL_PureOp<"reduce_max_i",
+ [AllElementTypesMatch<["src", "result", "init"]>]> {
+ let arguments = (ins
+ IREEHL_IntMemRef:$src,
+ IREEHL_IntMemRef:$init,
+ I32Attr:$dimension
+ );
+ let results = (outs IREEHL_IntMemRef:$result);
+}
+def IREEInterpHL_ReduceMaxFOp :
+ IREEInterpHL_PureOp<"reduce_max_f",
+ [AllElementTypesMatch<["src", "result", "init"]>]> {
+ let arguments = (ins
+ IREEHL_FloatMemRef:$src,
+ IREEHL_FloatMemRef:$init,
+ I32Attr:$dimension
+ );
+ let results = (outs IREEHL_FloatMemRef:$result);
+}
+
+def IREEInterpHL_TraceOp : IREEInterpHL_Op<"trace"> {
+ let arguments = (ins Variadic<IREEHL_MemRef>:$srcs);
+}
+
+def IREEInterpHL_CondBreakOp : IREEInterpHL_Op<"cond_break"> {
+ let arguments = (ins IREEHL_BoolScalar:$cond);
+}
+
+def IREEInterpHL_BreakOp : IREEInterpHL_Op<"break">;
+
+#endif // IREE_INTERPRETER_HL_OPS
diff --git a/compiler/IR/Interpreter/LLDialect.cpp b/compiler/IR/Interpreter/LLDialect.cpp
new file mode 100644
index 0000000..0b76111
--- /dev/null
+++ b/compiler/IR/Interpreter/LLDialect.cpp
@@ -0,0 +1,34 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Interpreter/LLDialect.h"
+
+#include "compiler/IR/Interpreter/LLOps.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+IREELLInterpreterDialect::IREELLInterpreterDialect(MLIRContext* context)
+ : Dialect(getDialectNamespace(), context) {
+#define GET_OP_LIST
+ addOperations<
+#include "compiler/IR/Interpreter/LLOps.cpp.inc"
+ >();
+}
+
+static DialectRegistration<IREELLInterpreterDialect> iree_ll_interp_dialect;
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/LLDialect.h b/compiler/IR/Interpreter/LLDialect.h
similarity index 100%
rename from iree/compiler/IR/Interpreter/LLDialect.h
rename to compiler/IR/Interpreter/LLDialect.h
diff --git a/compiler/IR/Interpreter/LLOps.cpp b/compiler/IR/Interpreter/LLOps.cpp
new file mode 100644
index 0000000..b5acace
--- /dev/null
+++ b/compiler/IR/Interpreter/LLOps.cpp
@@ -0,0 +1,228 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Interpreter/LLOps.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREEInterp {
+namespace LL {
+
+//===----------------------------------------------------------------------===//
+// iree_ll_interp.call
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) {
+ SymbolRefAttr calleeAttr;
+ FunctionType calleeType;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ auto calleeLoc = parser.getNameLoc();
+ if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttributeDict(state.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.addTypesToList(calleeType.getResults(), state.types) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
+ state.operands)) {
+ return failure();
+ }
+ return success();
+}
+
+static void printCallOp(OpAsmPrinter &p, CallOp op) {
+ p << "iree_ll_interp.call " << op.getAttr("callee") << '(';
+ p.printOperands(op.getOperands());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : ";
+ p.printType(op.getCalleeType());
+}
+
+FunctionType CallOp::getCalleeType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getOperandTypes());
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_interp.call_import
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallImportOp(OpAsmParser &parser,
+ OperationState &state) {
+ SymbolRefAttr calleeAttr;
+ FunctionType calleeType;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ auto calleeLoc = parser.getNameLoc();
+ if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttributeDict(state.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.addTypesToList(calleeType.getResults(), state.types) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
+ state.operands)) {
+ return failure();
+ }
+ return success();
+}
+
+static void printCallImportOp(OpAsmPrinter &p, CallImportOp op) {
+ p << "iree_ll_interp.call_import " << op.getAttr("callee") << '(';
+ p.printOperands(op.getOperands());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : ";
+ p.printType(op.getCalleeType());
+}
+
+FunctionType CallImportOp::getCalleeType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getOperandTypes());
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_interp.call_indirect
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallIndirectOp(OpAsmParser &parser,
+ OperationState &result) {
+ FunctionType calleeType;
+ OpAsmParser::OperandType callee;
+ llvm::SMLoc operandsLoc;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ return failure(
+ parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttributeDict(result.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.resolveOperand(callee, calleeType, result.operands) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
+ result.operands) ||
+ parser.addTypesToList(calleeType.getResults(), result.types));
+}
+
+static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) {
+ p << "iree_ll_interp.call_indirect ";
+ p.printOperand(op.getCallee());
+ p << '(';
+ auto operandRange = op.getOperands();
+ p.printOperands(++operandRange.begin(), operandRange.end());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : " << op.getCallee()->getType();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_interp.return
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) {
+ SmallVector<OpAsmParser::OperandType, 2> opInfo;
+ SmallVector<Type, 2> types;
+ llvm::SMLoc loc = parser.getCurrentLocation();
+ return failure(parser.parseOperandList(opInfo) ||
+ (!opInfo.empty() && parser.parseColonTypeList(types)) ||
+ parser.resolveOperands(opInfo, types, loc, state.operands));
+}
+
+static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
+ p << "iree_ll_interp.return";
+ if (op.getNumOperands() > 0) {
+ p << ' ';
+ p.printOperands(op.operand_begin(), op.operand_end());
+ p << " : ";
+ interleaveComma(op.getOperandTypes(), p);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_interp.br
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
+ Block *dest;
+ SmallVector<Value *, 4> destOperands;
+ if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
+ result.addSuccessor(dest, destOperands);
+ return success();
+}
+
+static void printBranchOp(OpAsmPrinter &p, BranchOp op) {
+ p << "iree_ll_interp.br ";
+ p.printSuccessorAndUseList(op.getOperation(), 0);
+}
+
+Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
+
+void BranchOp::setDest(Block *block) {
+ return getOperation()->setSuccessor(block, 0);
+}
+
+void BranchOp::eraseOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(0, index);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_interp.cond_br
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCondBranchOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Value *, 4> destOperands;
+ Block *dest;
+ OpAsmParser::OperandType condInfo;
+
+ // Parse the condition.
+ Type int1Ty = parser.getBuilder().getI1Type();
+ if (parser.parseOperand(condInfo) || parser.parseComma() ||
+ parser.resolveOperand(condInfo, int1Ty, result.operands)) {
+ return parser.emitError(parser.getNameLoc(),
+ "expected condition type was boolean (i1)");
+ }
+
+ // Parse the true successor.
+ if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
+ result.addSuccessor(dest, destOperands);
+
+ // Parse the false successor.
+ destOperands.clear();
+ if (parser.parseComma() ||
+ parser.parseSuccessorAndUseList(dest, destOperands))
+ return failure();
+ result.addSuccessor(dest, destOperands);
+
+ return success();
+}
+
+static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) {
+ p << "iree_ll_interp.cond_br ";
+ p.printOperand(op.getCondition());
+ p << ", ";
+ p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
+ p << ", ";
+ p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
+}
+
+#define GET_OP_CLASSES
+#include "compiler/IR/Interpreter/LLOps.cpp.inc"
+
+} // namespace LL
+} // namespace IREEInterp
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/IR/Interpreter/LLOps.h b/compiler/IR/Interpreter/LLOps.h
new file mode 100644
index 0000000..5c99ce8
--- /dev/null
+++ b/compiler/IR/Interpreter/LLOps.h
@@ -0,0 +1,38 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_IR_INTERPRETER_LLOPS_H_
+#define IREE_COMPILER_IR_INTERPRETER_LLOPS_H_
+
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREEInterp {
+namespace LL {
+
+#define GET_OP_CLASSES
+#include "compiler/IR/Interpreter/LLOps.h.inc"
+
+} // namespace LL
+} // namespace IREEInterp
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_IR_INTERPRETER_LLOPS_H_
diff --git a/compiler/IR/Interpreter/LLOps.td b/compiler/IR/Interpreter/LLOps.td
new file mode 100644
index 0000000..7453835
--- /dev/null
+++ b/compiler/IR/Interpreter/LLOps.td
@@ -0,0 +1,633 @@
+// IREE low-level interpreter op definitions.
+// These map 1:1 with the bytecode, accept only MemRef types and generally use
+// output parameters instead of return types.
+//
+// The source of truth for bytecode opcodes is:
+// https://github.com/google/iree/tree/master/iree/schemas/bytecode/interpreter_bytecode_v0.h
+
+#ifdef IREE_INTERPRETER_LL_OPS
+#else
+#define IREE_INTERPRETER_LL_OPS
+
+#ifdef IREE_OP_BASE
+#else
+include "compiler/IR/OpBase.td"
+#endif // IREE_OP_BASE
+
+def IREEInterpLL_Dialect : Dialect {
+ let name = "iree_ll_interp";
+ let cppNamespace = "IREEInterp::LL";
+}
+
+//===----------------------------------------------------------------------===//
+// Base op classes
+//===----------------------------------------------------------------------===//
+
+class IREEInterpLL_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<IREEInterpLL_Dialect, mnemonic, traits>;
+
+class IREEInterpLL_PureOp<string mnemonic, list<OpTrait> traits = []> :
+ IREEInterpLL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
+
+class IREEInterpLL_UnaryOp<string mnemonic, Type type = IREELL_MemRef,
+ list<OpTrait> traits = []> : IREEInterpLL_Op<mnemonic, traits> {
+ let arguments = (ins type:$input, type:$dst);
+}
+
+class IREEInterpLL_BinaryOp<string mnemonic, Type type = IREELL_MemRef,
+ list<OpTrait> traits = []> : IREEInterpLL_Op<mnemonic, traits> {
+ let arguments = (ins type:$lhs, type:$rhs, type:$dst);
+}
+
+class IREEInterpLL_TernaryOp<string mnemonic, Type type = IREELL_MemRef,
+ list<OpTrait> traits = []>
+ : IREEInterpLL_Op<mnemonic, traits> {
+ let arguments = (ins type : $a, type : $b, type : $c, type : $dst);
+}
+
+//===----------------------------------------------------------------------===//
+// Low-level interpreter ops
+//===----------------------------------------------------------------------===//
+
+// TODO(benvanik): value attribute.
+def IREEInterpLL_ConstantOp : IREEInterpLL_PureOp<"constant"> {
+ let results = (outs IREELL_MemRef);
+}
+
+def IREEInterpLL_CallOp : IREEInterpLL_Op<"call"> {
+ let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>);
+ let results = (outs Variadic<IREELL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, FuncOp callee,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(callee.getType().getResults());
+ }]>, OpBuilder<
+ "Builder *builder, OperationState &result, StringRef callee,"
+ "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(results);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ StringRef getCallee() { return callee(); }
+ FunctionType getCalleeType();
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+
+ operand_iterator arg_operand_begin() { return operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+// TODO(benvanik): add verifier that target isExternal.
+def IREEInterpLL_CallImportOp : IREEInterpLL_Op<"call_import"> {
+ let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>);
+ let results = (outs Variadic<IREELL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, FuncOp callee,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(callee.getType().getResults());
+ }]>, OpBuilder<
+ "Builder *builder, OperationState &result, StringRef callee,"
+ "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(results);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ StringRef getCallee() { return callee(); }
+ FunctionType getCalleeType();
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+
+ operand_iterator arg_operand_begin() { return operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREEInterpLL_CallIndirectOp : IREEInterpLL_Op<"call_indirect"> {
+ let arguments = (ins FunctionType:$callee, Variadic<IREELL_MemRef>:$operands);
+ let results = (outs Variadic<IREELL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Value *callee,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.operands.push_back(callee);
+ result.addOperands(operands);
+ result.addTypes(callee->getType().cast<FunctionType>().getResults());
+ }]>];
+
+ let extraClassDeclaration = [{
+ Value *getCallee() { return getOperand(0); }
+
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+
+ operand_iterator arg_operand_begin() { return ++operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREEInterpLL_ReturnOp : IREEInterpLL_Op<"return", [Terminator]> {
+ let arguments = (ins Variadic<IREELL_MemRef>:$operands);
+
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
+ >];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREEInterpLL_BranchOp : IREEInterpLL_Op<"br", [Terminator]> {
+ let arguments = (ins Variadic<IREELL_MemRef>:$operands);
+
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Block *dest,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.addSuccessor(dest, operands);
+ }]>];
+
+ let extraClassDeclaration = [{
+ Block *getDest();
+ void setDest(Block *block);
+
+ /// Erase the operand at 'index' from the operand list.
+ void eraseOperand(unsigned index);
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREEInterpLL_CondBranchOp : IREEInterpLL_Op<"cond_br", [Terminator]> {
+ let arguments = (ins
+ IREELL_BoolScalar:$condition,
+ Variadic<IREELL_MemRef>:$branchOperands
+ );
+
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Value *condition,"
+ "Block *trueDest, ArrayRef<Value *> trueOperands,"
+ "Block *falseDest, ArrayRef<Value *> falseOperands", [{
+ result.addOperands(condition);
+ result.addSuccessor(trueDest, trueOperands);
+ result.addSuccessor(falseDest, falseOperands);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // These are the indices into the dests list.
+ enum { trueIndex = 0, falseIndex = 1 };
+
+ // The condition operand is the first operand in the list.
+ Value *getCondition() { return getOperand(0); }
+
+ /// Return the destination if the condition is true.
+ Block *getTrueDest() {
+ return getOperation()->getSuccessor(trueIndex);
+ }
+
+ /// Return the destination if the condition is false.
+ Block *getFalseDest() {
+ return getOperation()->getSuccessor(falseIndex);
+ }
+
+ // Accessors for operands to the 'true' destination.
+ Value *getTrueOperand(unsigned idx) {
+ assert(idx < getNumTrueOperands());
+ return getOperand(getTrueDestOperandIndex() + idx);
+ }
+
+ void setTrueOperand(unsigned idx, Value *value) {
+ assert(idx < getNumTrueOperands());
+ setOperand(getTrueDestOperandIndex() + idx, value);
+ }
+
+ operand_iterator true_operand_begin() {
+ return operand_begin() + getTrueDestOperandIndex();
+ }
+ operand_iterator true_operand_end() {
+ return true_operand_begin() + getNumTrueOperands();
+ }
+ operand_range getTrueOperands() {
+ return {true_operand_begin(), true_operand_end()};
+ }
+
+ unsigned getNumTrueOperands() {
+ return getOperation()->getNumSuccessorOperands(trueIndex);
+ }
+
+ /// Erase the operand at 'index' from the true operand list.
+ void eraseTrueOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(trueIndex, index);
+ }
+
+ // Accessors for operands to the 'false' destination.
+ Value *getFalseOperand(unsigned idx) {
+ assert(idx < getNumFalseOperands());
+ return getOperand(getFalseDestOperandIndex() + idx);
+ }
+ void setFalseOperand(unsigned idx, Value *value) {
+ assert(idx < getNumFalseOperands());
+ setOperand(getFalseDestOperandIndex() + idx, value);
+ }
+
+ operand_iterator false_operand_begin() { return true_operand_end(); }
+ operand_iterator false_operand_end() {
+ return false_operand_begin() + getNumFalseOperands();
+ }
+ operand_range getFalseOperands() {
+ return {false_operand_begin(), false_operand_end()};
+ }
+
+ unsigned getNumFalseOperands() {
+ return getOperation()->getNumSuccessorOperands(falseIndex);
+ }
+
+ /// Erase the operand at 'index' from the false operand list.
+ void eraseFalseOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(falseIndex, index);
+ }
+
+ private:
+ /// Get the index of the first true destination operand.
+ unsigned getTrueDestOperandIndex() { return 1; }
+
+ /// Get the index of the first false destination operand.
+ unsigned getFalseDestOperandIndex() {
+ return getTrueDestOperandIndex() + getNumTrueOperands();
+ }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREEInterpLL_CmpIOp : IREEInterpLL_Op<"cmp_i"> {
+ let arguments = (ins
+ I32Attr:$predicate,
+ IREELL_IntMemRef:$lhs,
+ IREELL_IntMemRef:$rhs,
+ IREELL_BoolMemRef:$dst
+ );
+}
+
+def IREEInterpLL_CmpFOp : IREEInterpLL_Op<"cmp_f"> {
+ let arguments = (ins
+ I32Attr:$predicate,
+ IREELL_FloatMemRef:$lhs,
+ IREELL_FloatMemRef:$rhs,
+ IREELL_BoolMemRef:$dst
+ );
+}
+
+def IREEInterpLL_AllocStaticOp : IREEInterpLL_PureOp<"alloc_static"> {
+ // TODO(benvanik): attributes and args.
+ let results = (outs IREELL_MemRef);
+}
+
+def IREEInterpLL_AllocStackOp : IREEInterpLL_PureOp<"alloc_stack"> {
+ // TODO(benvanik): atributes and args.
+ let arguments = (ins
+ Variadic<IREELL_MemRef>:$dim_pieces
+ );
+ let results = (outs
+ IREELL_MemRef
+ );
+}
+
+def IREEInterpLL_AllocStackInitOp : IREEInterpLL_PureOp<"alloc_stack_init"> {
+ // TODO(benvanik): attributes and args.
+ let arguments = (ins
+ Variadic<IREELL_MemRef>:$dim_pieces
+ );
+ let results = (outs
+ IREELL_MemRef
+ );
+}
+
+// TODO(b/142012496): Add trait that enables DCE but not CSE.
+def IREEInterpLL_AllocHeapOp : IREEInterpLL_Op<"alloc_heap"> {
+ // TODO(benvanik): attributes and args.
+ let arguments = (ins
+ Variadic<IREELL_MemRef>:$dim_pieces
+ );
+ let results = (outs
+ IREELL_MemRef
+ );
+}
+
+def IREEInterpLL_DiscardOp : IREEInterpLL_Op<"discard"> {
+ let arguments = (ins IREELL_MemRef);
+}
+
+def IREEInterpLL_RankOp : IREEInterpLL_Op<"rank"> {
+ let arguments = (ins
+ IREELL_MemRef:$input,
+ IREELL_I32Scalar:$dst
+ );
+}
+
+def IREEInterpLL_DimOp : IREEInterpLL_Op<"dim"> {
+ // TODO(benvanik) add dim attr (I32Attr:$dim)
+ let arguments = (ins
+ IREELL_MemRef:$input,
+ IREELL_I32Scalar:$dst
+ );
+}
+
+def IREEInterpLL_ShapeOp : IREEInterpLL_Op<"shape"> {
+ let arguments = (ins
+ IREELL_MemRef:$input,
+ IREELL_I32MemRef:$dst
+ );
+}
+
+def IREEInterpLL_LengthOp : IREEInterpLL_Op<"length"> {
+ let arguments = (ins
+ IREELL_MemRef:$input,
+ IREELL_I32Scalar:$dst
+ );
+}
+
+
+def IREEInterpLL_DynamicSliceOp : IREEInterpLL_PureOp<"dynamic_slice"> {
+ let arguments = (ins
+ IREELL_MemRef:$src,
+ IREELL_1DIndexMemRef:$srcIndices,
+ IREELL_1DIndexMemRef:$lengths
+ );
+ let results = (outs
+ IREELL_MemRef
+ );
+}
+
+// TODO(benvanik): add attribute requirements/types.
+def IREEInterpLL_StaticSliceOp :
+ IREEInterpLL_PureOp<"static_slice", [SameOperandsAndResultElementType]> {
+ let arguments = (ins IREELL_MemRef:$src);
+ let results = (outs IREELL_MemRef);
+}
+
+def IREEInterpLL_DynamicCopyOp : IREEInterpLL_Op<"dynamic_copy", [
+ AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>,
+]> {
+ let arguments = (ins
+ IREELL_MemRef:$src,
+ IREELL_1DIndexMemRef:$srcIndices,
+ IREELL_MemRef:$dst,
+ IREELL_1DIndexMemRef:$dstIndices,
+ IREELL_1DIndexMemRef:$lengths
+ );
+}
+
+def IREEInterpLL_StaticCopyOp : IREEInterpLL_Op<"static_copy", [
+ AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>,
+]> {
+ let arguments = (ins
+ IREELL_MemRef:$src,
+ I32ElementsAttr:$srcIndices,
+ IREELL_MemRef:$dst,
+ I32ElementsAttr:$dstIndices,
+ I32ElementsAttr:$lengths
+ );
+}
+
+def IREEInterpLL_CloneOp :
+ IREEInterpLL_PureOp<"clone", [SameOperandsAndResultType]> {
+ let arguments = (ins IREELL_MemRef:$src);
+ let results = (outs IREELL_MemRef);
+}
+
+// TODO(benvanik): add split dim/size/etc. Maybe make multiple ops?
+def IREEInterpLL_SplitOp : IREEInterpLL_PureOp<"split"> {
+ let arguments = (ins
+ IREELL_MemRef:$src
+ );
+ let results = (outs
+ Variadic<IREELL_MemRef>
+ );
+}
+
+def IREEInterpLL_AssignOp :
+ IREEInterpLL_Op<"assign", [SameOperandsAndResultType]> {
+ let arguments = (ins IREELL_MemRef:$src);
+ let results = (outs IREELL_MemRef);
+}
+
+def IREEInterpLL_CondAssignOp : IREEInterpLL_Op<"cond_assign"> {
+ let arguments = (ins
+ IREELL_BoolScalar:$cond,
+ IREELL_MemRef:$lhs,
+ IREELL_MemRef:$rhs
+ );
+ let results = (outs
+ IREELL_MemRef
+ );
+}
+
+def IREEInterpLL_ReshapeOp : IREEInterpLL_Op<"reshape"> {
+ let arguments = (ins
+ IREELL_MemRef:$input,
+ IREELL_1DIntMemRef:$shape
+ );
+ let results = (outs
+ IREELL_MemRef
+ );
+}
+
+def IREEInterpLL_SelectOp : IREEInterpLL_Op<"select"> {
+ let arguments = (ins
+ IREELL_MemRef:$cond,
+ IREELL_MemRef:$lhs,
+ IREELL_MemRef:$rhs,
+ IREELL_MemRef:$dst
+ );
+}
+
+def IREEInterpLL_PadOp :
+ IREEInterpLL_Op<
+ "pad", [AllElementTypesMatch<["src", "dst", "padding_value"]>]> {
+ let arguments = (ins
+ IREELL_MemRef:$src,
+ IREELL_ElementScalar:$padding_value,
+ IREELL_1DIndexMemRef:$edge_padding_low,
+ IREELL_1DIndexMemRef:$edge_padding_high,
+ IREELL_1DIndexMemRef:$interior_padding,
+ IREELL_MemRef:$dst
+ );
+}
+
+def IREEInterpLL_TransposeOp : IREEInterpLL_BinaryOp<"transpose">;
+
+def IREEInterPLL_ReverseOp : IREEInterpLL_BinaryOp<"reverse">;
+
+def IREEInterpLL_BroadcastOp : IREEInterpLL_BinaryOp<"broadcast">;
+
+def IREEInterpLL_TileOp : IREEInterpLL_BinaryOp<"tile">;
+
+// TODO(benvanik): add traits for broadcasting support.
+
+def IREEInterpLL_NotOp : IREEInterpLL_UnaryOp<"not">;
+def IREEInterpLL_AndOp : IREEInterpLL_BinaryOp<"and">;
+def IREEInterpLL_OrOp : IREEInterpLL_BinaryOp<"or">;
+def IREEInterpLL_XorOp : IREEInterpLL_BinaryOp<"xor">;
+def IREEInterpLL_ShiftLeftOp : IREEInterpLL_BinaryOp<"sll">;
+def IREEInterpLL_ShiftRightLogicalOp : IREEInterpLL_BinaryOp<"srl">;
+def IREEInterpLL_ShiftRightArithmeticOp : IREEInterpLL_BinaryOp<"sra">;
+
+def IREEInterpLL_AddIOp : IREEInterpLL_BinaryOp<"add_i", IREELL_IntMemRef>;
+def IREEInterpLL_AddFOp : IREEInterpLL_BinaryOp<"add_f", IREELL_FloatMemRef>;
+def IREEInterpLL_SubIOp : IREEInterpLL_BinaryOp<"sub_i", IREELL_IntMemRef>;
+def IREEInterpLL_SubFOp : IREEInterpLL_BinaryOp<"sub_f", IREELL_FloatMemRef>;
+def IREEInterpLL_AbsIOp : IREEInterpLL_UnaryOp<"abs_i", IREELL_IntMemRef>;
+def IREEInterpLL_AbsFOp : IREEInterpLL_UnaryOp<"abs_f", IREELL_FloatMemRef>;
+def IREEInterpLL_MulIOp : IREEInterpLL_BinaryOp<"mul_i", IREELL_IntMemRef>;
+def IREEInterpLL_MulFOp : IREEInterpLL_BinaryOp<"mul_f", IREELL_FloatMemRef>;
+def IREEInterpLL_DivISOp : IREEInterpLL_BinaryOp<"div_i_s", IREELL_IntMemRef>;
+def IREEInterpLL_DivIUOp : IREEInterpLL_BinaryOp<"div_i_u", IREELL_IntMemRef>;
+def IREEInterpLL_DivFOp : IREEInterpLL_BinaryOp<"div_f", IREELL_FloatMemRef>;
+def IREEInterpLL_MulAddIOp : IREEInterpLL_BinaryOp<"madd_i", IREELL_IntMemRef>;
+def IREEInterpLL_MulAddFOp : IREEInterpLL_BinaryOp<"madd_f", IREELL_FloatMemRef>;
+def IREEInterpLL_ExpFOp : IREEInterpLL_UnaryOp<"exp_f", IREELL_FloatMemRef>;
+def IREEInterpLL_LogFOp : IREEInterpLL_UnaryOp<"log_f", IREELL_FloatMemRef>;
+def IREEInterpLL_RsqrtFOp : IREEInterpLL_UnaryOp<"rsqrt_f", IREELL_FloatMemRef>;
+def IREEInterpLL_CosFOp : IREEInterpLL_UnaryOp<"cos_f", IREELL_FloatMemRef>;
+def IREEInterpLL_SinFOp : IREEInterpLL_UnaryOp<"sin_f", IREELL_FloatMemRef>;
+def IREEInterpLL_TanhFOp : IREEInterpLL_UnaryOp<"tanh_f", IREELL_FloatMemRef>;
+def IREEInterpLL_Atan2FOp : IREEInterpLL_UnaryOp<"atan2_f", IREELL_FloatMemRef>;
+
+def IREEInterpLL_MinISOp : IREEInterpLL_BinaryOp<"min_i_s", IREELL_IntMemRef>;
+def IREEInterpLL_MinIUOp : IREEInterpLL_BinaryOp<"min_i_u", IREELL_IntMemRef>;
+def IREEInterpLL_MinFOp : IREEInterpLL_BinaryOp<"min_f", IREELL_FloatMemRef>;
+def IREEInterpLL_MaxISOp : IREEInterpLL_BinaryOp<"max_i_s", IREELL_IntMemRef>;
+def IREEInterpLL_MaxIUOp : IREEInterpLL_BinaryOp<"max_i_u", IREELL_IntMemRef>;
+def IREEInterpLL_MaxFOp : IREEInterpLL_BinaryOp<"max_f", IREELL_FloatMemRef>;
+def IREEInterpLL_ClampFOp : IREEInterpLL_TernaryOp<"clamp_f", IREELL_FloatMemRef>;
+def IREEInterpLL_FloorFOp : IREEInterpLL_UnaryOp<"floor_f", IREELL_FloatMemRef>;
+def IREEInterpLL_CeilFOp : IREEInterpLL_UnaryOp<"ceil_f", IREELL_FloatMemRef>;
+
+def IREEInterpLL_ConvertSSOp : IREEInterpLL_UnaryOp<"convert_s_s", IREELL_MemRef>;
+def IREEInterpLL_ConvertSUOp : IREEInterpLL_UnaryOp<"convert_s_u", IREELL_MemRef>;
+def IREEInterpLL_ConvertSFOp : IREEInterpLL_UnaryOp<"convert_s_f", IREELL_MemRef>;
+
+def IREEInterpLL_ConvertUSOp : IREEInterpLL_UnaryOp<"convert_u_s", IREELL_MemRef>;
+def IREEInterpLL_ConvertUUOp : IREEInterpLL_UnaryOp<"convert_u_u", IREELL_MemRef>;
+def IREEInterpLL_ConvertUFOp : IREEInterpLL_UnaryOp<"convert_u_f", IREELL_MemRef>;
+
+def IREEInterpLL_ConvertFSOp : IREEInterpLL_UnaryOp<"convert_f_s", IREELL_MemRef>;
+def IREEInterpLL_ConvertFUOp : IREEInterpLL_UnaryOp<"convert_f_u", IREELL_MemRef>;
+def IREEInterpLL_ConvertFFOp : IREEInterpLL_UnaryOp<"convert_f_f", IREELL_MemRef>;
+
+def IREEInterpLL_MatMulIOp : IREEInterpLL_Op<"matmul_i"> {
+ let arguments = (ins
+ IREELL_IntMemRef:$lhs,
+ IREELL_IntMemRef:$rhs,
+ IREELL_IntMemRef:$multiplier_mantissa,
+ IREELL_IntMemRef:$multiplier_exponent,
+ IREELL_IntMemRef:$dst
+ );
+}
+def IREEInterpLL_MatMulFOp : IREEInterpLL_Op<"matmul_f"> {
+ let arguments = (ins
+ IREELL_FloatMemRef:$lhs,
+ IREELL_FloatMemRef:$rhs,
+ IREELL_FloatMemRef:$dst
+ );
+}
+
+def IREEInterpLL_ReduceSumIOp : IREEInterpLL_Op<"reduce_sum_i"> {
+ let arguments = (ins
+ IREELL_IntMemRef:$src,
+ IREELL_IntMemRef:$init,
+ I32Attr:$dimension,
+ IREELL_IntMemRef:$dst
+ );
+}
+def IREEInterpLL_ReduceSumFOp : IREEInterpLL_Op<"reduce_sum_f"> {
+ let arguments = (ins
+ IREELL_FloatMemRef:$src,
+ IREELL_FloatMemRef:$init,
+ I32Attr:$dimension,
+ IREELL_FloatMemRef:$dst
+ );
+}
+
+def IREEInterpLL_ReduceMinIOp : IREEInterpLL_Op<"reduce_min_i"> {
+ let arguments = (ins
+ IREELL_IntMemRef:$src,
+ IREELL_IntMemRef:$init,
+ I32Attr:$dimension,
+ IREELL_IntMemRef:$dst
+ );
+}
+def IREEInterpLL_ReduceMinFOp : IREEInterpLL_Op<"reduce_min_f"> {
+ let arguments = (ins
+ IREELL_FloatMemRef:$src,
+ IREELL_FloatMemRef:$init,
+ I32Attr:$dimension,
+ IREELL_FloatMemRef:$dst
+ );
+}
+
+def IREEInterpLL_ReduceMaxIOp : IREEInterpLL_Op<"reduce_max_i"> {
+ let arguments = (ins
+ IREELL_IntMemRef:$src,
+ IREELL_IntMemRef:$init,
+ I32Attr:$dimension,
+ IREELL_IntMemRef:$dst
+ );
+}
+def IREEInterpLL_ReduceMaxFOp : IREEInterpLL_Op<"reduce_max_f"> {
+ let arguments = (ins
+ IREELL_FloatMemRef:$src,
+ IREELL_FloatMemRef:$init,
+ I32Attr:$dimension,
+ IREELL_FloatMemRef:$dst
+ );
+}
+
+def IREEInterpLL_TraceOp : IREEInterpLL_Op<"trace"> {
+ let arguments = (ins
+ Variadic<IREELL_MemRef>:$srcs
+ );
+}
+
+def IREEInterpLL_CondBreakOp : IREEInterpLL_Op<"cond_break"> {
+ let arguments = (ins
+ IREELL_BoolScalar:$cond
+ );
+}
+
+def IREEInterpLL_BreakOp : IREEInterpLL_Op<"break">;
+
+#endif // IREE_INTERPRETER_LL_OPS
diff --git a/compiler/IR/Interpreter/OpWriters.cpp b/compiler/IR/Interpreter/OpWriters.cpp
new file mode 100644
index 0000000..bc2a118
--- /dev/null
+++ b/compiler/IR/Interpreter/OpWriters.cpp
@@ -0,0 +1,261 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Interpreter/OpWriters.h"
+
+#include "compiler/IR/Interpreter/LLOps.h"
+#include "compiler/Serialization/BytecodeWriter.h"
+#include "compiler/Utils/Macros.h"
+#include "schemas/bytecode/interpreter_bytecode_v0.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Sequencer ops
+//===----------------------------------------------------------------------===//
+
+LogicalResult writeOp(IREEInterp::LL::ConstantOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConstant));
+ auto memrefType = op.getType().dyn_cast<MemRefType>();
+ if (!memrefType) {
+ return op.emitOpError()
+ << "Constant has an unsupported type; must be a memref: "
+ << op.getType();
+ }
+ RETURN_IF_FAILURE(writer->WriteConstant(memrefType, op.getAttr("value")));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
+ return success();
+}
+
+LogicalResult writeOp(IREEInterp::LL::CallOp op, BytecodeWriter *writer) {
+ auto module = op.getOperation()->getParentOfType<ModuleOp>();
+ auto callee = module.lookupSymbol<FuncOp>(op.getCallee());
+ // TODO(benvanik): transforms to convert Call->CallImport.
+ // TODO(benvanik): switch with kCallTail if attr exists.
+ if (callee.isExternal()) {
+ RETURN_IF_FAILURE(
+ writer->WriteOpcode(iree::InterpreterOpcode::kCallImport));
+ } else {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCall));
+ }
+ RETURN_IF_FAILURE(writer->WriteFunctionOrdinal(callee));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
+ return success();
+}
+
+LogicalResult writeOp(IREEInterp::LL::CallImportOp op, BytecodeWriter *writer) {
+ auto module = op.getOperation()->getParentOfType<ModuleOp>();
+ auto callee = module.lookupSymbol<FuncOp>(op.getCallee());
+ // TODO(benvanik): switch with kCallTail if attr exists.
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCallImport));
+ RETURN_IF_FAILURE(writer->WriteImportOrdinal(callee));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
+ return success();
+}
+
+LogicalResult writeOp(IREEInterp::LL::CallIndirectOp op,
+ BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(
+ writer->WriteOpcode(iree::InterpreterOpcode::kCallIndirect));
+ RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getCallee()->getType()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getCallee()));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
+ return success();
+}
+
+LogicalResult WriteConvertOperands(Operation *op, BytecodeWriter *writer) {
+ auto *src = op->getOperand(0);
+ RETURN_IF_FAILURE(
+ writer->WriteTypeIndex(getElementTypeOrSelf(src->getType())));
+ RETURN_IF_FAILURE(writer->WriteLocal(src));
+ auto *dst = op->getOperand(1);
+ RETURN_IF_FAILURE(
+ writer->WriteTypeIndex(getElementTypeOrSelf(dst->getType())));
+ RETURN_IF_FAILURE(writer->WriteLocal(dst));
+ return success();
+}
+
+LogicalResult writeOp(IREEInterp::LL::ConvertSSOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertSS));
+ return WriteConvertOperands(op, writer);
+}
+
+LogicalResult writeOp(IREEInterp::LL::ConvertUUOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertUU));
+ return WriteConvertOperands(op, writer);
+}
+
+LogicalResult writeOp(IREEInterp::LL::ConvertSUOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertSU));
+ return WriteConvertOperands(op, writer);
+}
+
+LogicalResult writeOp(IREEInterp::LL::ConvertUSOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertUS));
+ return WriteConvertOperands(op, writer);
+}
+
+LogicalResult writeOp(IREEInterp::LL::BranchOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kBranch));
+ RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getDest()));
+ RETURN_IF_FAILURE(writer->WriteCount(op.getNumOperands()));
+ for (int i = 0; i < op.getNumOperands(); ++i) {
+ // Copy src->dst.
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(i)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getDest()->getArgument(i)));
+ }
+ return success();
+}
+
+LogicalResult writeOp(IREEInterp::LL::CondBranchOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCondBranch));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getCondition()));
+ RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getTrueDest()));
+ RETURN_IF_FAILURE(writer->WriteCount(op.getNumTrueOperands()));
+ for (int i = 0; i < op.getNumTrueOperands(); ++i) {
+ // Copy src->dst.
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueOperand(i)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueDest()->getArgument(i)));
+ }
+ RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getFalseDest()));
+ RETURN_IF_FAILURE(writer->WriteCount(op.getNumFalseOperands()));
+ for (int i = 0; i < op.getNumFalseOperands(); ++i) {
+ // Copy src->dst.
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseOperand(i)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseDest()->getArgument(i)));
+ }
+ return success();
+}
+
+LogicalResult writeOp(IREEInterp::LL::CmpIOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCmpI));
+ RETURN_IF_FAILURE(
+ writer->WriteUint8(static_cast<uint8_t>(op.predicate().getZExtValue())));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(0)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(1)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(2)));
+ return success();
+}
+
+LogicalResult writeOp(IREEInterp::LL::CmpFOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCmpF));
+ RETURN_IF_FAILURE(
+ writer->WriteUint8(static_cast<uint8_t>(op.predicate().getZExtValue())));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(0)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(1)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(2)));
+ return success();
+}
+
+LogicalResult writeOp(IREEInterp::LL::AllocHeapOp op, BytecodeWriter *writer) {
+ auto memrefType = op.getType().cast<MemRefType>();
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kAllocHeap));
+ RETURN_IF_FAILURE(writer->WriteInt32(0));
+ RETURN_IF_FAILURE(writer->WriteTypeIndex(memrefType.getElementType()));
+ RETURN_IF_FAILURE(writer->WriteShapePieces(memrefType));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getOperands()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
+ return success();
+}
+
+LogicalResult writeOp(IREEInterp::LL::StaticCopyOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kStaticCopy));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.src()));
+ RETURN_IF_FAILURE(writer->WriteShapePieces(op.srcIndices()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.dst()));
+ RETURN_IF_FAILURE(writer->WriteShapePieces(op.dstIndices()));
+ RETURN_IF_FAILURE(writer->WriteShapePieces(op.lengths()));
+ return success();
+}
+
+LogicalResult writeReduceOperands(Operation *op, BytecodeWriter *writer,
+ APInt dimension) {
+ RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(0)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(1)));
+ RETURN_IF_FAILURE(writer->WriteInt32(dimension.getZExtValue()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(2)));
+ return success();
+}
+
+LogicalResult writeOp(IREEInterp::LL::ReduceSumIOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceSumI));
+ return writeReduceOperands(op, writer, op.dimension());
+}
+
+LogicalResult writeOp(IREEInterp::LL::ReduceSumFOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceSumF));
+ return writeReduceOperands(op, writer, op.dimension());
+}
+
+LogicalResult writeOp(IREEInterp::LL::ReduceMinIOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMinI));
+ return writeReduceOperands(op, writer, op.dimension());
+}
+
+LogicalResult writeOp(IREEInterp::LL::ReduceMinFOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMinF));
+ return writeReduceOperands(op, writer, op.dimension());
+}
+
+LogicalResult writeOp(IREEInterp::LL::ReduceMaxIOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMaxI));
+ return writeReduceOperands(op, writer, op.dimension());
+}
+
+LogicalResult writeOp(IREEInterp::LL::ReduceMaxFOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMaxF));
+ return writeReduceOperands(op, writer, op.dimension());
+}
+
+} // namespace
+
+void registerInterpreterCustomWriters(VMFunctionBuilder *builder) {
+#define REGISTER_CUSTOM_WRITER_IMPL(op_type) \
+ builder->RegisterCustomWriter( \
+ op_type::getOperationName(), \
+ +[](Operation *op, BytecodeWriter *writer) { \
+ return writeOp(cast<op_type>(op), writer); \
+ });
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConstantOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallImportOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallIndirectOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::BranchOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CondBranchOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertSSOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertUUOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertSUOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertUSOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CmpIOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CmpFOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::AllocHeapOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::StaticCopyOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceSumIOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceSumFOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMinIOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMinFOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMaxIOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMaxFOp);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/IR/Interpreter/OpWriters.h b/compiler/IR/Interpreter/OpWriters.h
new file mode 100644
index 0000000..a02c57f
--- /dev/null
+++ b/compiler/IR/Interpreter/OpWriters.h
@@ -0,0 +1,30 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_
+#define IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_
+
+#include "compiler/Serialization/VMFunctionBuilder.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Registers custom op writers with the builder.
+// Ops not registered will use the generic writer.
+void registerInterpreterCustomWriters(VMFunctionBuilder *builder);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_
diff --git a/compiler/IR/Interpreter/test/BUILD b/compiler/IR/Interpreter/test/BUILD
new file mode 100644
index 0000000..44a5820
--- /dev/null
+++ b/compiler/IR/Interpreter/test/BUILD
@@ -0,0 +1,15 @@
+load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_setup_lit_package(
+ data = [
+ "///tools:iree-opt",
+ "///tools:iree-run-mlir",
+ ],
+)
+
+iree_glob_lit_tests()
diff --git a/iree/compiler/IR/Interpreter/test/concat.mlir b/compiler/IR/Interpreter/test/concat.mlir
similarity index 100%
rename from iree/compiler/IR/Interpreter/test/concat.mlir
rename to compiler/IR/Interpreter/test/concat.mlir
diff --git a/iree/compiler/IR/Interpreter/test/invalid_types_hl.mlir b/compiler/IR/Interpreter/test/invalid_types_hl.mlir
similarity index 100%
rename from iree/compiler/IR/Interpreter/test/invalid_types_hl.mlir
rename to compiler/IR/Interpreter/test/invalid_types_hl.mlir
diff --git a/iree/compiler/IR/Interpreter/test/invalid_types_ll.mlir b/compiler/IR/Interpreter/test/invalid_types_ll.mlir
similarity index 100%
rename from iree/compiler/IR/Interpreter/test/invalid_types_ll.mlir
rename to compiler/IR/Interpreter/test/invalid_types_ll.mlir
diff --git a/iree/compiler/IR/OpBase.td b/compiler/IR/OpBase.td
similarity index 100%
rename from iree/compiler/IR/OpBase.td
rename to compiler/IR/OpBase.td
diff --git a/compiler/IR/Ops.cpp b/compiler/IR/Ops.cpp
new file mode 100644
index 0000000..5179f41
--- /dev/null
+++ b/compiler/IR/Ops.cpp
@@ -0,0 +1,639 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Ops.h"
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/SMLoc.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/STLExtras.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+
+//===----------------------------------------------------------------------===//
+// iree.constant
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseConstantOp(OpAsmParser &parser,
+ OperationState &result) {
+ Attribute valueAttr;
+ Type type;
+ if (parser.parseLSquare() ||
+ parser.parseAttribute(valueAttr, "value", result.attributes) ||
+ parser.parseRSquare() ||
+ parser.parseOptionalAttributeDict(result.attributes) ||
+ parser.parseColonType(type))
+ return failure();
+
+ return parser.addTypeToList(type, result.types);
+}
+
+static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) {
+ p << "iree.constant[";
+ p.printAttribute(op.getValue());
+ p << "] ";
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
+
+ p << " : ";
+ p.printType(op.getType());
+}
+
+namespace {
+
+// TODO(gcmn) this is duplicated from MemRefUtils to avoid a circular
+// dependency. Extract op-dependent parts of memref utils to allow reuse.
+MemRefType convertTypeToMemRef(Type type) {
+ if (type.isIntOrIndexOrFloat()) {
+ return MemRefType::get({}, type, {}, 0);
+ } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ } else if (auto memRefType = type.dyn_cast<MemRefType>()) {
+ return MemRefType::get(memRefType.getShape(), memRefType.getElementType());
+ } else {
+ llvm_unreachable("Unconvertable type");
+ }
+}
+
+} // namespace
+
+void ConstantOp::build(Builder *builder, OperationState &state,
+ ElementsAttr value) {
+ auto type = convertTypeToMemRef(value.getType());
+ return build(builder, state, type, value);
+}
+
+// TODO(b/134575149): enable folder when we store the correct type.
+// OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
+// assert(operands.empty() && "constant has no operands");
+// return getValue();
+// }
+
+//===----------------------------------------------------------------------===//
+// iree.tensor_to_memref
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTensorToMemRefOp(OpAsmParser &parser,
+ OperationState &state) {
+ OpAsmParser::OperandType operand;
+ Type operandType;
+ Type resultType;
+ if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
+ failed(parser.parseColonType(operandType)) ||
+ failed(parser.resolveOperand(operand, operandType, state.operands)) ||
+ failed(parser.parseRParen()) ||
+ failed(parser.parseColonType(resultType)) ||
+ failed(parser.addTypeToList(resultType, state.types))) {
+ return failure();
+ }
+ return success();
+}
+
+static void printTensorToMemRefOp(OpAsmPrinter &p, TensorToMemRefOp &op) {
+ p << "iree.tensor_to_memref(";
+ p.printOperand(op.getOperand());
+ p << " : ";
+ p.printType(op.getOperand()->getType());
+ p << ") : ";
+ p.printType(op.getType());
+}
+
+OpFoldResult TensorToMemRefOp::fold(ArrayRef<Attribute> operands) {
+ if (auto memrefToTensorOp = dyn_cast_or_null<IREE::MemRefToTensorOp>(
+ getOperand()->getDefiningOp())) {
+ return memrefToTensorOp.getOperand();
+ }
+
+ return {};
+}
+
+void TensorToMemRefOp::build(Builder *builder, OperationState &state,
+ Value *arg) {
+ build(builder, state, convertTypeToMemRef(arg->getType()), arg);
+}
+
+//===----------------------------------------------------------------------===//
+// iree.memref_to_tensor
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseMemRefToTensorOp(OpAsmParser &parser,
+ OperationState &state) {
+ OpAsmParser::OperandType operand;
+ Type operandType;
+ Type resultType;
+ if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
+ failed(parser.parseColonType(operandType)) ||
+ failed(parser.resolveOperand(operand, operandType, state.operands)) ||
+ failed(parser.parseRParen()) ||
+ failed(parser.parseColonType(resultType)) ||
+ failed(parser.addTypeToList(resultType, state.types))) {
+ return failure();
+ }
+ return success();
+}
+
+static void printMemRefToTensorOp(OpAsmPrinter &p, MemRefToTensorOp &op) {
+ p << "iree.memref_to_tensor(";
+ p.printOperand(op.getOperand());
+ p << " : ";
+ p.printType(op.getOperand()->getType());
+ p << ") : ";
+ p.printType(op.getType());
+}
+
+OpFoldResult MemRefToTensorOp::fold(ArrayRef<Attribute> operands) {
+ if (auto tensorToMemRefOp = dyn_cast_or_null<IREE::TensorToMemRefOp>(
+ getOperand()->getDefiningOp())) {
+ return tensorToMemRefOp.getOperand();
+ }
+
+ return {};
+}
+
+void MemRefToTensorOp::build(Builder *builder, OperationState &state,
+ Value *arg) {
+ // TODO(gcmn) Use getTensorType from MemRefUtils when circular dependency can
+ // be avoided.
+ auto memRefType = arg->getType().cast<MemRefType>();
+ auto tensorType =
+ RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
+ build(builder, state, tensorType, arg);
+}
+
+//===----------------------------------------------------------------------===//
+// iree.scalar_to_memref
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseScalarToMemRefOp(OpAsmParser &parser,
+ OperationState &state) {
+ OpAsmParser::OperandType operand;
+ Type operandType;
+ Type resultType;
+ if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
+ failed(parser.parseColonType(operandType)) ||
+ failed(parser.resolveOperand(operand, operandType, state.operands)) ||
+ failed(parser.parseRParen()) ||
+ failed(parser.parseColonType(resultType)) ||
+ failed(parser.addTypeToList(resultType, state.types))) {
+ return failure();
+ }
+ return success();
+}
+
+static void printScalarToMemRefOp(OpAsmPrinter &p, ScalarToMemRefOp &op) {
+ p << "iree.scalar_to_memref(";
+ p.printOperand(op.getOperand());
+ p << " : ";
+ p.printType(op.getOperand()->getType());
+ p << ") : ";
+ p.printType(op.getType());
+}
+
+OpFoldResult ScalarToMemRefOp::fold(ArrayRef<Attribute> operands) {
+ if (auto memrefToScalarOp = dyn_cast_or_null<IREE::MemRefToScalarOp>(
+ getOperand()->getDefiningOp())) {
+ return memrefToScalarOp.getOperand();
+ }
+
+ return {};
+}
+
+void ScalarToMemRefOp::build(Builder *builder, OperationState &state,
+ Value *arg) {
+ build(builder, state, convertTypeToMemRef(arg->getType()), arg);
+}
+
+//===----------------------------------------------------------------------===//
+// iree.memref_to_scalar
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseMemRefToScalarOp(OpAsmParser &parser,
+ OperationState &state) {
+ OpAsmParser::OperandType operand;
+ Type operandType;
+ Type resultType;
+ if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
+ failed(parser.parseColonType(operandType)) ||
+ failed(parser.resolveOperand(operand, operandType, state.operands)) ||
+ failed(parser.parseRParen()) ||
+ failed(parser.parseColonType(resultType)) ||
+ failed(parser.addTypeToList(resultType, state.types))) {
+ return failure();
+ }
+ return success();
+}
+
+static void printMemRefToScalarOp(OpAsmPrinter &p, MemRefToScalarOp &op) {
+ p << "iree.memref_to_scalar(";
+ p.printOperand(op.getOperand());
+ p << " : ";
+ p.printType(op.getOperand()->getType());
+ p << ") : ";
+ p.printType(op.getType());
+}
+
+OpFoldResult MemRefToScalarOp::fold(ArrayRef<Attribute> operands) {
+ if (auto scalarToMemRefOp = dyn_cast_or_null<IREE::ScalarToMemRefOp>(
+ getOperand()->getDefiningOp())) {
+ return scalarToMemRefOp.getOperand();
+ }
+
+ return {};
+}
+
+void MemRefToScalarOp::build(Builder *builder, OperationState &state,
+ Value *arg) {
+ build(builder, state, getElementTypeOrSelf(arg), arg);
+}
+
+//===----------------------------------------------------------------------===//
+// iree.dispatch_region
+//===----------------------------------------------------------------------===//
+
+void DispatchRegionOp::build(Builder *builder, OperationState &state,
+ ArrayRef<Type> resultTypes, Value *workload,
+ ArrayRef<Value *> operands,
+ ArrayRef<NamedAttribute> attributes) {
+ state.addTypes(resultTypes);
+ state.addOperands({workload});
+ state.addOperands(operands);
+ state.addAttributes(attributes);
+ state.addRegion();
+ state.setOperandListToResizable();
+}
+
+ParseResult parseDispatchRegionOp(OpAsmParser &parser, OperationState &state) {
+ // Parse required workload.
+ OpAsmParser::OperandType workloadArg;
+ Type workloadArgType;
+ if (failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(workloadArg)) ||
+ failed(parser.parseColonType(workloadArgType)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.resolveOperand(workloadArg, workloadArgType,
+ state.operands))) {
+ return failure();
+ }
+
+ // Parse (optional) args.
+ SmallVector<OpAsmParser::OperandType, 16> regionArgs;
+ SmallVector<Type, 16> regionArgTypes;
+ if (failed(parser.parseLParen())) {
+ return failure();
+ }
+ if (failed(parser.parseOptionalRParen())) {
+ SmallVector<OpAsmParser::OperandType, 16> regionOperands;
+ auto argsLoc = parser.getCurrentLocation();
+ do {
+ // Reserve entries in the lists.
+ regionArgs.emplace_back();
+ regionOperands.emplace_back();
+ regionArgTypes.emplace_back();
+ if (failed(parser.parseRegionArgument(regionArgs.back())) ||
+ failed(parser.parseEqual()) ||
+ failed(parser.parseOperand(regionOperands.back())) ||
+ failed(parser.parseColonType(regionArgTypes.back()))) {
+ return failure();
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+ if (failed(parser.parseRParen()) ||
+ failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc,
+ state.operands))) {
+ return failure();
+ }
+ }
+ state.setOperandListToResizable();
+
+ // Parse (optional) results.
+ if (failed(parser.parseOptionalColonTypeList(state.types))) {
+ return failure();
+ }
+
+ // Parse region body.
+ Region *body = state.addRegion();
+ if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) ||
+ failed(parser.parseOptionalAttributeDict(state.attributes))) {
+ return failure();
+ }
+ return success();
+}
+
+void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) {
+ p << "iree.dispatch_region";
+
+ // Print the workload argument.
+ p << "[";
+ p.printOperand(op.getWorkload());
+ p << " : ";
+ p.printType(op.getWorkload()->getType());
+ p << "]";
+
+ // Print the data argument remapping.
+ p << "(";
+ interleaveComma(
+ llvm::zip(op.getBody().front().getArguments(), op.getArgOperands()), p,
+ [&](std::tuple<BlockArgument *, Value *> it) {
+ p << *std::get<0>(it) << " = " << *std::get<1>(it);
+ p << " : ";
+ p << std::get<1>(it)->getType();
+ });
+ p << ")";
+
+ // Print the result types, if any.
+ if (op.getNumResults() > 0) {
+ p << " : ";
+ interleaveComma(op.getResultTypes(), p);
+ }
+
+ p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
+ p.printOptionalAttrDict(op.getAttrs(),
+ /*elidedAttrs=*/{});
+}
+
+//===----------------------------------------------------------------------===//
+// iree.reduction_region
+//===----------------------------------------------------------------------===//
+
+void ReductionRegionOp::build(Builder *builder, OperationState &state,
+ ArrayRef<Type> resultTypes, Value *workload,
+ ArrayRef<Value *> operands,
+ ArrayRef<Value *> initialValues,
+ ArrayRef<int64_t> dimensions,
+ ArrayRef<NamedAttribute> attributes) {
+ state.addTypes(resultTypes);
+ state.addOperands({workload});
+ state.addOperands(operands);
+ state.addOperands(initialValues);
+ state.addAttribute(
+ "dimensions",
+ DenseIntElementsAttr::get(
+ RankedTensorType::get({static_cast<int64_t>(dimensions.size())},
+ builder->getIntegerType(64)),
+ dimensions));
+ state.addAttributes(attributes);
+ state.addRegion();
+ state.setOperandListToResizable();
+}
+
+void ReductionRegionOp::build(
+ Builder *builder, OperationState &state, ArrayRef<Type> resultTypes,
+ Value *workload, ArrayRef<Value *> operands,
+ ArrayRef<Value *> initialValues, ArrayRef<int64_t> windowDimensions,
+ ArrayRef<int64_t> windowStrides, ArrayRef<int64_t> baseDilations,
+ ArrayRef<int64_t> windowDilations, PaddingMode paddingMode,
+ ArrayRef<NamedAttribute> attributes) {
+ state.addTypes(resultTypes);
+ state.addOperands({workload});
+ state.addOperands(operands);
+ state.addOperands(initialValues);
+ state.addAttribute(
+ "window_dimensions",
+ DenseIntElementsAttr::get(
+ RankedTensorType::get({static_cast<int64_t>(windowDimensions.size())},
+ builder->getIntegerType(64)),
+ windowDimensions));
+ state.addAttribute(
+ "window_strides",
+ DenseIntElementsAttr::get(
+ RankedTensorType::get({static_cast<int64_t>(windowStrides.size())},
+ builder->getIntegerType(64)),
+ windowStrides));
+ state.addAttribute(
+ "base_dilations",
+ DenseIntElementsAttr::get(
+ RankedTensorType::get({static_cast<int64_t>(baseDilations.size())},
+ builder->getIntegerType(64)),
+ baseDilations));
+ state.addAttribute(
+ "window_dilations",
+ DenseIntElementsAttr::get(
+ RankedTensorType::get({static_cast<int64_t>(windowDilations.size())},
+ builder->getIntegerType(64)),
+ windowDilations));
+ state.addAttribute("padding_mode", builder->getI32IntegerAttr(
+ static_cast<int32_t>(paddingMode)));
+ state.addAttributes(attributes);
+ state.addRegion();
+ state.setOperandListToResizable();
+}
+
+ParseResult parseReductionRegionOp(OpAsmParser &parser, OperationState &state) {
+ OpAsmParser::OperandType workloadArg;
+ Type workloadArgType;
+ if (failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(workloadArg)) ||
+ failed(parser.parseColonType(workloadArgType)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.resolveOperand(workloadArg, workloadArgType,
+ state.operands))) {
+ return failure();
+ }
+
+ SmallVector<OpAsmParser::OperandType, 8> reductionOperands;
+ Type reductionType;
+ auto operandsLoc = parser.getCurrentLocation();
+ if (failed(parser.parseLParen()) ||
+ failed(parser.parseOperandList(reductionOperands)) ||
+ failed(parser.parseRParen()) ||
+ failed(parser.parseColonType(reductionType)) ||
+ failed(parser.resolveOperands(
+ reductionOperands, reductionType.cast<FunctionType>().getInputs(),
+ operandsLoc, state.operands))) {
+ return failure();
+ }
+ for (auto type : reductionType.cast<FunctionType>().getResults()) {
+ state.types.push_back(type);
+ }
+ state.setOperandListToResizable();
+
+ SmallVector<OpAsmParser::OperandType, 8> regionArgs;
+ SmallVector<Type, 8> regionArgTypes;
+ if (failed(parser.parseKeyword("invocation")) ||
+ failed(parser.parseLParen())) {
+ return failure();
+ }
+ do {
+ Type argType;
+ SmallVector<OpAsmParser::OperandType, 2> reductionRegionArgs;
+ OpAsmParser::OperandType initialValue;
+ if (failed(parser.parseLParen()) ||
+ failed(parser.parseOperandList(reductionRegionArgs, 2)) ||
+ failed(parser.parseRParen()) || failed(parser.parseEqual()) ||
+ failed(parser.parseOperand(initialValue)) ||
+ failed(parser.parseColonType(argType)) ||
+ failed(parser.resolveOperand(initialValue, argType, state.operands))) {
+ return failure();
+ }
+ regionArgs.push_back(reductionRegionArgs[0]);
+ regionArgTypes.push_back(argType);
+ regionArgs.push_back(reductionRegionArgs[1]);
+ regionArgTypes.push_back(argType);
+ } while (succeeded(parser.parseOptionalComma()));
+ if (failed(parser.parseRParen())) {
+ return failure();
+ }
+
+ // Parse region body.
+ Region *body = state.addRegion();
+ if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) ||
+ failed(parser.parseOptionalAttributeDict(state.attributes))) {
+ return failure();
+ }
+
+ return success();
+}
+
+void printReductionRegionOp(OpAsmPrinter &p, ReductionRegionOp op) {
+ p << "iree.reduction_region";
+
+ // Print the workload argument.
+ p << "[";
+ p.printOperand(op.getWorkload());
+ p << " : ";
+ p.printType(op.getWorkload()->getType());
+ p << "]";
+
+ p << "(";
+ p.printOperands(op.getODSOperands(1));
+ p << ")";
+ if (op.getNumResults() > 0) {
+ p << " : (";
+ interleaveComma(op.getODSOperands(1), p,
+ [&](Value *operand) { p.printType(operand->getType()); });
+ p << ")";
+ p << " -> (";
+ interleaveComma(op.getResultTypes(), p);
+ p << ")";
+ }
+ p << "\n";
+
+ p << " invocation(";
+ auto &entryBlock = op.getBody().getBlocks().front();
+ int regionArgIndex = 0;
+ interleaveComma(op.getODSOperands(2), p, [&](Value *operand) {
+ p << "(";
+ p.printOperand(entryBlock.getArgument(regionArgIndex++));
+ p << ", ";
+ p.printOperand(entryBlock.getArgument(regionArgIndex++));
+ p << ") = ";
+ p.printOperand(operand);
+ p << " : ";
+ p.printType(operand->getType());
+ });
+ p << ") ";
+
+ p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
+ p.printOptionalAttrDict(op.getAttrs(),
+ /*elidedAttrs=*/{});
+}
+
+//===----------------------------------------------------------------------===//
+// iree.return
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) {
+ SmallVector<OpAsmParser::OperandType, 2> opInfo;
+ SmallVector<Type, 2> types;
+ llvm::SMLoc loc = parser.getCurrentLocation();
+ return failure(parser.parseOperandList(opInfo) ||
+ (!opInfo.empty() && parser.parseColonTypeList(types)) ||
+ parser.resolveOperands(opInfo, types, loc, state.operands));
+}
+
+static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
+ p << "iree.return";
+ if (op.getNumOperands() > 0) {
+ p << ' ';
+ p.printOperands(op.operand_begin(), op.operand_end());
+ p << " : ";
+ interleaveComma(op.getOperandTypes(), p);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// iree.load_input
+//===----------------------------------------------------------------------===//
+
+ParseResult parseLoadInputOp(OpAsmParser &parser, OperationState &state) {
+ OpAsmParser::OperandType operand;
+ Type argType;
+ if (parser.parseLParen() || parser.parseOperand(operand) ||
+ parser.parseColonType(argType) || parser.parseRParen() ||
+ parser.resolveOperand(operand, argType, state.operands)) {
+ return failure();
+ }
+ Type outputType;
+ if (parser.parseColonType(outputType) ||
+ parser.addTypeToList(outputType, state.types)) {
+ return failure();
+ }
+ return success();
+}
+
+void printLoadInputOp(OpAsmPrinter &printer, Operation *op) {
+ auto *inputValue = op->getOperand(0);
+ auto *outputValue = op->getResult(0);
+ printer << op->getName() << '(';
+ printer.printOperand(inputValue);
+ printer << " : ";
+ printer.printType(inputValue->getType());
+ printer << ") : ";
+ printer.printType(outputValue->getType());
+}
+
+//===----------------------------------------------------------------------===//
+// iree.store_output
+//===----------------------------------------------------------------------===//
+
+ParseResult parseStoreOutputOp(OpAsmParser &parser, OperationState &state) {
+ OpAsmParser::OperandType op0, op1;
+ Type argType0, argType1;
+ if (parser.parseLParen() || parser.parseOperand(op0) ||
+ parser.parseColonType(argType0) || parser.parseComma() ||
+ parser.resolveOperand(op0, argType0, state.operands) ||
+ parser.parseOperand(op1) || parser.parseColonType(argType1) ||
+ parser.parseRParen() ||
+ parser.resolveOperand(op1, argType1, state.operands)) {
+ return failure();
+ }
+ return success();
+}
+
+void printStoreOutputOp(OpAsmPrinter &printer, Operation *op) {
+ auto *inputValue = op->getOperand(0);
+ auto *outputValue = op->getOperand(1);
+ printer << op->getName() << '(';
+ printer.printOperand(inputValue);
+ printer << " : ";
+ printer.printType(inputValue->getType());
+ printer << ", ";
+ printer.printOperand(outputValue);
+ printer << " : ";
+ printer.printType(outputValue->getType());
+ printer << ")";
+}
+
+#define GET_OP_CLASSES
+#include "compiler/IR/Ops.cpp.inc"
+
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/IR/Ops.h b/compiler/IR/Ops.h
new file mode 100644
index 0000000..f8a1580
--- /dev/null
+++ b/compiler/IR/Ops.h
@@ -0,0 +1,36 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_IR_OPS_H_
+#define IREE_COMPILER_IR_OPS_H_
+
+#include "compiler/IR/Types.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+
+#define GET_OP_CLASSES
+#include "compiler/IR/Ops.h.inc"
+
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_IR_OPS_H_
diff --git a/compiler/IR/Ops.td b/compiler/IR/Ops.td
new file mode 100644
index 0000000..a304532
--- /dev/null
+++ b/compiler/IR/Ops.td
@@ -0,0 +1,200 @@
+// IREE ops for working with buffers and buffer views.
+// These are used by common transforms between the sequencer and interpreter and
+// allow us to share some of the common lowering passes from other dialects.
+
+#ifdef IREE_OPS
+#else
+#define IREE_OPS
+
+#ifdef IREE_OP_BASE
+#else
+include "compiler/IR/OpBase.td"
+#endif // IREE_OP_BASE
+
+class IREE_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<IREE_Dialect, mnemonic, traits> {
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ print$cppClass(p, *this); }];
+}
+
+class IREE_PureOp<string mnemonic, list<OpTrait> traits = []> :
+ IREE_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
+
+// TODO(b/134575149): determine if we want multiple constant op types.
+def IREE_ConstantOp : IREE_PureOp<"constant", [
+ AllShapesMatch<["value", "result"]>,
+ AllElementTypesMatch<["value", "result"]>
+]> {
+ let arguments = (ins ElementsAttr:$value);
+ let results = (outs IREEHL_MemRef:$result);
+
+ // TODO(b/132296600): make tablegen follow the style guide.
+ let extraClassDeclaration = [{
+ Attribute getValue() { return value(); }
+ }];
+
+ let builders = [OpBuilder<"Builder*, OperationState&, ElementsAttr">];
+
+ // TODO(b/134575149): enable folder when we store the correct type.
+ // let hasFolder = 1;
+}
+
+// TODO(b/134671482): remove/move tensor_to_memref/memref_to_tensor.
+def IREE_TensorToMemRefOp : IREE_PureOp<"tensor_to_memref", [
+ SameOperandsAndResultShape, SameOperandsAndResultElementType
+]> {
+ let arguments = (ins AnyTensor);
+ let results = (outs IREEHL_MemRef);
+
+ let builders = [OpBuilder<"Builder*, OperationState&, Value*">];
+
+ let hasFolder = 1;
+}
+
+// TODO(b/134671482): remove/move tensor_to_memref/memref_to_tensor.
+def IREE_MemRefToTensorOp : IREE_PureOp<"memref_to_tensor", [
+ SameOperandsAndResultShape, SameOperandsAndResultElementType
+]> {
+ let arguments = (ins IREEHL_MemRef);
+ let results = (outs AnyTensor);
+ let builders = [OpBuilder<"Builder*, OperationState&, Value*">];
+
+ let hasFolder = 1;
+}
+
+def IREE_ScalarToMemRefOp : IREE_PureOp<"scalar_to_memref", [
+ SameOperandsAndResultElementType
+]> {
+ let arguments = (ins IREEHL_Element);
+ let results = (outs IREEHL_AnyScalar);
+
+ let builders = [OpBuilder<"Builder*, OperationState&, Value*">];
+
+ let hasFolder = 1;
+}
+
+def IREE_MemRefToScalarOp : IREE_PureOp<"memref_to_scalar", [
+ SameOperandsAndResultElementType
+]> {
+ let arguments = (ins IREEHL_AnyScalar);
+ let results = (outs IREEHL_Element);
+
+ let builders = [OpBuilder<"Builder*, OperationState&, Value*">];
+
+ let hasFolder = 1;
+}
+
+def IREE_Workload : TensorOf<[AnyInteger]>;
+
+def IREE_DispatchRegionOp : IREE_PureOp<"dispatch_region"> {
+ let arguments = (ins
+ IREE_Workload:$workload,
+ Variadic<AnyType>:$args
+ );
+ let results = (outs Variadic<AnyType>);
+ let regions = (region AnyRegion:$body);
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ Value *getWorkload() { return workload(); }
+ Region& getBody() { return body(); }
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+ unsigned mapArgOperandToOpOperand(unsigned i) { return i + 1; }
+ unsigned getNumArgOperands() { return getNumOperands() - 1; }
+ Value *getArgOperand(unsigned i) {
+ return getOperand(mapArgOperandToOpOperand(i));
+ }
+ void setArgOperand(unsigned i, Value *arg) {
+ setOperand(mapArgOperandToOpOperand(i), arg);
+ }
+
+ operand_iterator arg_operand_begin() {
+ return operand_begin() + mapArgOperandToOpOperand(0);
+ }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<"Builder *builder, OperationState &state,"
+ "ArrayRef<Type> resultTypes, Value *workload,"
+ "ArrayRef<Value *> args,"
+ "ArrayRef<NamedAttribute> attributes = {}">,
+ ];
+}
+
+def IREE_ReductionRegionOp : IREE_PureOp<"reduction_region", [
+ SameVariadicOperandSize,
+]> {
+ let arguments = (ins
+ IREE_Workload:$workload,
+ Variadic<AnyType>:$operands,
+ Variadic<AnyType>:$initial_values,
+ OptionalAttr<I64ElementsAttr>:$dimensions,
+ OptionalAttr<I64ElementsAttr>:$window_dimensions,
+ OptionalAttr<I64ElementsAttr>:$window_strides,
+ OptionalAttr<I64ElementsAttr>:$base_dilations,
+ OptionalAttr<I64ElementsAttr>:$window_dilations,
+ OptionalAttr<IREE_PaddingModeAttr>:$padding_mode
+ );
+ let results = (outs Variadic<AnyType>);
+ let regions = (region AnyRegion:$body);
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ Value *getWorkload() { return workload(); }
+ Region& getBody() { return body(); }
+
+ bool isWindowed() {
+ return window_dimensions().hasValue();
+ }
+
+ PaddingMode getPaddingMode() {
+ return static_cast<PaddingMode>(padding_mode().getValue());
+ }
+
+ unsigned getNumReductionOperands() { return (getNumOperands() - 1) / 2; }
+ operand_range getReductionOperands() { return getODSOperands(1); }
+ operand_range getInitialValueOperands() { return getODSOperands(2); }
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<"Builder *builder, OperationState &state,"
+ "ArrayRef<Type> resultTypes, Value *workload, ArrayRef<Value *> operands,"
+ "ArrayRef<Value *> initialValues,"
+ "ArrayRef<int64_t> dimensions,"
+ "ArrayRef<NamedAttribute> attributes = {}">,
+ OpBuilder<"Builder *builder, OperationState &state,"
+ "ArrayRef<Type> resultTypes, Value *workload, ArrayRef<Value *> operands,"
+ "ArrayRef<Value *> initialValues,"
+ "ArrayRef<int64_t> windowDimensions, ArrayRef<int64_t> windowStrides,"
+ "ArrayRef<int64_t> baseDilations, ArrayRef<int64_t> windowDilations,"
+ "PaddingMode paddingMode,"
+ "ArrayRef<NamedAttribute> attributes = {}">,
+ ];
+}
+
+def IREE_ReturnOp : IREE_Op<"return", [Terminator]> {
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
+ >];
+}
+
+def IREE_LoadInputOp : IREE_PureOp<"load_input"> {
+ let arguments = (ins IREEHL_MemRef:$src);
+ let results = (outs AnyType);
+}
+
+def IREE_StoreOutputOp : IREE_Op<"store_output"> {
+ let arguments = (ins AnyType:$src, IREEHL_MemRef:$dst);
+}
+
+#endif // IREE_OPS
diff --git a/compiler/IR/Sequencer/BUILD b/compiler/IR/Sequencer/BUILD
new file mode 100644
index 0000000..4ca9ba2
--- /dev/null
+++ b/compiler/IR/Sequencer/BUILD
@@ -0,0 +1,76 @@
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+load("@local_config_mlir//:tblgen.bzl", "gentbl")
+
+filegroup(
+ name = "td_files",
+ srcs = glob(["*.td"]),
+)
+
+cc_library(
+ name = "Sequencer",
+ srcs = [
+ "HLDialect.cpp",
+ "HLOps.cpp",
+ "HLOps.cpp.inc",
+ "LLDialect.cpp",
+ "LLOps.cpp",
+ "LLOps.cpp.inc",
+ "OpWriters.cpp",
+ ],
+ hdrs = [
+ "HLDialect.h",
+ "HLOps.h",
+ "HLOps.h.inc",
+ "LLDialect.h",
+ "LLOps.h",
+ "LLOps.h.inc",
+ "OpWriters.h",
+ ],
+ deps = [
+ ":HLOpsGen",
+ ":LLOpsGen",
+ "///compiler/IR",
+ "///compiler/Serialization",
+ "///compiler/Utils",
+ "///schemas/bytecode:sequencer_bytecode_v0",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:StandardOps",
+ "@local_config_mlir//:Support",
+ ],
+ alwayslink = 1,
+)
+
+gentbl(
+ name = "HLOpsGen",
+ tbl_outs = [
+ ("-gen-op-decls", "HLOps.h.inc"),
+ ("-gen-op-defs", "HLOps.cpp.inc"),
+ ],
+ tblgen = "@local_config_mlir//:mlir-tblgen",
+ td_file = "HLOps.td",
+ td_srcs = [
+ ":td_files",
+ "@local_config_mlir//:include/mlir/IR/OpBase.td",
+ "///compiler/IR:OpBase.td",
+ ],
+)
+
+gentbl(
+ name = "LLOpsGen",
+ tbl_outs = [
+ ("-gen-op-decls", "LLOps.h.inc"),
+ ("-gen-op-defs", "LLOps.cpp.inc"),
+ ],
+ tblgen = "@local_config_mlir//:mlir-tblgen",
+ td_file = "LLOps.td",
+ td_srcs = [
+ ":td_files",
+ "@local_config_mlir//:include/mlir/IR/OpBase.td",
+ "///compiler/IR:OpBase.td",
+ ],
+)
diff --git a/iree/compiler/IR/Sequencer/CMakeLists.txt b/compiler/IR/Sequencer/CMakeLists.txt
similarity index 100%
rename from iree/compiler/IR/Sequencer/CMakeLists.txt
rename to compiler/IR/Sequencer/CMakeLists.txt
diff --git a/compiler/IR/Sequencer/HLDialect.cpp b/compiler/IR/Sequencer/HLDialect.cpp
new file mode 100644
index 0000000..2bc46bc
--- /dev/null
+++ b/compiler/IR/Sequencer/HLDialect.cpp
@@ -0,0 +1,34 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Sequencer/HLDialect.h"
+
+#include "compiler/IR/Sequencer/HLOps.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+IREEHLSequencerDialect::IREEHLSequencerDialect(MLIRContext* context)
+ : Dialect(getDialectNamespace(), context) {
+#define GET_OP_LIST
+ addOperations<
+#include "compiler/IR/Sequencer/HLOps.cpp.inc"
+ >();
+}
+
+static DialectRegistration<IREEHLSequencerDialect> iree_hl_seq_dialect;
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/HLDialect.h b/compiler/IR/Sequencer/HLDialect.h
similarity index 100%
rename from iree/compiler/IR/Sequencer/HLDialect.h
rename to compiler/IR/Sequencer/HLDialect.h
diff --git a/compiler/IR/Sequencer/HLOps.cpp b/compiler/IR/Sequencer/HLOps.cpp
new file mode 100644
index 0000000..6840804
--- /dev/null
+++ b/compiler/IR/Sequencer/HLOps.cpp
@@ -0,0 +1,379 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Sequencer/HLOps.h"
+
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/Types.h"
+#include "compiler/Utils/OpCreationUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREESeq {
+namespace HL {
+
+namespace {
+
+static LogicalResult verifyWorkload(Operation *op, Value *workload) {
+ if (auto workloadType = workload->getType().dyn_cast<MemRefType>()) {
+ if (workloadType.getNumElements() != 3) {
+ return op->emitOpError("workload must be specified as (x,y,z) but has ")
+ << workloadType.getNumElements()
+ << " elements (type=" << workload->getType() << ")";
+ }
+ return success();
+ }
+ return op->emitOpError(
+ "workload must be specified as an (x,y,z) memref but has type ")
+ << workload->getType();
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// iree_hl_seq.call
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) {
+ SymbolRefAttr calleeAttr;
+ FunctionType calleeType;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ auto calleeLoc = parser.getNameLoc();
+ if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttributeDict(state.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.addTypesToList(calleeType.getResults(), state.types) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
+ state.operands)) {
+ return failure();
+ }
+ return success();
+}
+
+static void printCallOp(OpAsmPrinter &p, CallOp op) {
+ p << "iree_hl_seq.call " << op.getAttr("callee") << '(';
+ p.printOperands(op.getOperands());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : ";
+ p.printType(op.getCalleeType());
+}
+
+FunctionType CallOp::getCalleeType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getOperandTypes());
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_seq.call_indirect
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallIndirectOp(OpAsmParser &parser,
+ OperationState &result) {
+ FunctionType calleeType;
+ OpAsmParser::OperandType callee;
+ llvm::SMLoc operandsLoc;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ return failure(
+ parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttributeDict(result.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.resolveOperand(callee, calleeType, result.operands) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
+ result.operands) ||
+ parser.addTypesToList(calleeType.getResults(), result.types));
+}
+
+static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) {
+ p << "iree_hl_seq.call_indirect ";
+ p.printOperand(op.getCallee());
+ p << '(';
+ auto operandRange = op.getOperands();
+ p.printOperands(++operandRange.begin(), operandRange.end());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : " << op.getCallee()->getType();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_seq.return
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) {
+ SmallVector<OpAsmParser::OperandType, 2> opInfo;
+ SmallVector<Type, 2> types;
+ llvm::SMLoc loc = parser.getCurrentLocation();
+ return failure(parser.parseOperandList(opInfo) ||
+ (!opInfo.empty() && parser.parseColonTypeList(types)) ||
+ parser.resolveOperands(opInfo, types, loc, state.operands));
+}
+
+static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
+ p << "iree_hl_seq.return";
+ if (op.getNumOperands() > 0) {
+ p << ' ';
+ p.printOperands(op.operand_begin(), op.operand_end());
+ p << " : ";
+ interleaveComma(op.getOperandTypes(), p);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_seq.br
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
+ Block *dest;
+ SmallVector<Value *, 4> destOperands;
+ if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
+ result.addSuccessor(dest, destOperands);
+ return success();
+}
+
+static void printBranchOp(OpAsmPrinter &p, BranchOp op) {
+ p << "iree_hl_seq.br ";
+ p.printSuccessorAndUseList(op.getOperation(), 0);
+}
+
+Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
+
+void BranchOp::setDest(Block *block) {
+ return getOperation()->setSuccessor(block, 0);
+}
+
+void BranchOp::eraseOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(0, index);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_seq.cond_br
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCondBranchOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Value *, 4> destOperands;
+ Block *dest;
+ OpAsmParser::OperandType condInfo;
+
+ // Parse the condition.
+ Type int1Ty = parser.getBuilder().getI1Type();
+ if (parser.parseOperand(condInfo) || parser.parseComma() ||
+ parser.resolveOperand(condInfo, int1Ty, result.operands)) {
+ return parser.emitError(parser.getNameLoc(),
+ "expected condition type was boolean (i1)");
+ }
+
+ // Parse the true successor.
+ if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
+ result.addSuccessor(dest, destOperands);
+
+ // Parse the false successor.
+ destOperands.clear();
+ if (parser.parseComma() ||
+ parser.parseSuccessorAndUseList(dest, destOperands))
+ return failure();
+ result.addSuccessor(dest, destOperands);
+
+ return success();
+}
+
+static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) {
+ p << "iree_hl_seq.cond_br ";
+ p.printOperand(op.getCondition());
+ p << ", ";
+ p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
+ p << ", ";
+ p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_seq.dispatch
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseDispatchOp(OpAsmParser &parser, OperationState &state) {
+ auto executableLoc = parser.getNameLoc();
+
+ SymbolRefAttr executableAttr;
+ SymbolRefAttr entryPointAttr;
+ FunctionType entryPointType;
+ if (failed(parser.parseAttribute(executableAttr, "executable",
+ state.attributes)) ||
+ failed(parser.parseColon()) || failed(parser.parseColon()) ||
+ failed(parser.parseAttribute(entryPointAttr, "entry_point",
+ state.attributes))) {
+ return failure();
+ }
+
+ OpAsmParser::OperandType workloadArg;
+ Type workloadArgType;
+ if (failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(workloadArg)) ||
+ failed(parser.parseColonType(workloadArgType)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.resolveOperand(workloadArg, workloadArgType,
+ state.operands))) {
+ return failure();
+ }
+
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ if (failed(
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
+ failed(parser.parseOptionalAttributeDict(state.attributes)) ||
+ failed(parser.parseColonType(entryPointType)) ||
+ failed(parser.addTypesToList(entryPointType.getResults(), state.types)) ||
+ failed(parser.resolveOperands(operands, entryPointType.getInputs(),
+ executableLoc, state.operands))) {
+ return failure();
+ }
+ return success();
+}
+
+static void printDispatchOp(OpAsmPrinter &p, DispatchOp op) {
+ p << "iree_hl_seq.dispatch " << op.getExecutable()
+ << "::" << op.getEntryPoint();
+ p << "[";
+ p.printOperand(op.getWorkload());
+ p << " : ";
+ p.printType(op.getWorkload()->getType());
+ p << "](";
+ p.printOperands(op.getArgOperands());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
+ "executable",
+ "entry_point",
+ });
+ p << " : ";
+ p.printType(op.getEntryPointType());
+}
+
+static LogicalResult verifyDispatchOp(DispatchOp op) {
+ if (failed(verifyWorkload(op, op.getWorkload()))) {
+ return failure();
+ }
+ return success();
+}
+
+FunctionType DispatchOp::getEntryPointType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getArgOperandTypes());
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_seq.rank
+//===----------------------------------------------------------------------===//
+
+OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+ Builder builder(getContext());
+ if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) {
+ return builder.getIntegerAttr(builder.getIntegerType(32),
+ op0.getType().getRank());
+ }
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_seq.shape
+//===----------------------------------------------------------------------===//
+
+void ShapeOp::build(Builder *builder, OperationState &state, Value *operand) {
+ state.addOperands(operand);
+ int64_t rank = 0;
+ if (auto shapedType = operand->getType().dyn_cast<ShapedType>()) {
+ rank = shapedType.getRank();
+ }
+ state.addTypes(MemRefType::get({rank}, builder->getIntegerType(32)));
+}
+
+OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
+ Builder builder(getContext());
+ if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) {
+ return DenseIntElementsAttr::get(
+ RankedTensorType::get({op0.getType().getRank()},
+ builder.getIntegerType(32)),
+ op0.getType().getShape());
+ }
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_seq.length
+//===----------------------------------------------------------------------===//
+
+OpFoldResult LengthOp::fold(ArrayRef<Attribute> operands) {
+ Builder builder(getContext());
+ if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) {
+ return builder.getIntegerAttr(builder.getIntegerType(32),
+ op0.getNumElements());
+ }
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// iree_hl_seq.concat
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ConcatToCopies : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto finalType = concatOp.getResult()->getType().cast<ShapedType>();
+ auto loc = concatOp.getLoc();
+ std::vector<Value *> dimPieces;
+ auto dst =
+ rewriter.create<IREESeq::HL::AllocHeapOp>(loc, finalType, dimPieces);
+
+ llvm::SmallVector<int64_t, 4> zeroOffset(finalType.getRank(), 0);
+ auto srcIndices = createArrayConstant(rewriter, loc, zeroOffset);
+
+ auto concatDimension = concatOp.dimension().getZExtValue();
+ llvm::SmallVector<int64_t, 4> dstIndices(finalType.getRank(), 0);
+ for (auto *src : concatOp.srcs()) {
+ auto srcShape = src->getType().cast<ShapedType>().getShape();
+ auto lengths = createArrayConstant(rewriter, loc, srcShape);
+ auto dstIndicesOp = createArrayConstant(rewriter, loc, dstIndices);
+ rewriter.create<IREESeq::HL::CopyOp>(loc, src, srcIndices, dst,
+ dstIndicesOp, lengths);
+ dstIndices[concatDimension] += srcShape[concatDimension];
+ }
+
+ concatOp.replaceAllUsesWith(dst.getResult());
+
+ return matchSuccess();
+ }
+};
+} // namespace
+
+void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ConcatToCopies>(context);
+}
+
+#define GET_OP_CLASSES
+#include "compiler/IR/Sequencer/HLOps.cpp.inc"
+
+} // namespace HL
+} // namespace IREESeq
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/IR/Sequencer/HLOps.h b/compiler/IR/Sequencer/HLOps.h
new file mode 100644
index 0000000..9661021
--- /dev/null
+++ b/compiler/IR/Sequencer/HLOps.h
@@ -0,0 +1,39 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_IR_SEQUENCER_HLOPS_H_
+#define IREE_COMPILER_IR_SEQUENCER_HLOPS_H_
+
+#include "compiler/IR/Types.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREESeq {
+namespace HL {
+
+#define GET_OP_CLASSES
+#include "compiler/IR/Sequencer/HLOps.h.inc"
+
+} // namespace HL
+} // namespace IREESeq
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_IR_SEQUENCER_HLOPS_H_
diff --git a/compiler/IR/Sequencer/HLOps.td b/compiler/IR/Sequencer/HLOps.td
new file mode 100644
index 0000000..6ca114c
--- /dev/null
+++ b/compiler/IR/Sequencer/HLOps.td
@@ -0,0 +1,429 @@
+// IREE high-level sequencer op definitions.
+// This op set contains pseudo ops, ops that accept non-MemRef types, and ops in
+// normal SSA form.
+//
+// Through lowering these high-level ops are converted to low-level ops in the
+// LLOps.td (iree_ll_seq.*). These map 1:1 with the bytecode, accept
+// only MemRef types, and generally use output parameters instead of return
+// types.
+//
+// The source of truth for bytecode opcodes is:
+// https://github.com/google/iree/tree/master/iree/schemas/bytecode/sequencer_bytecode_v0.h
+
+#ifdef IREE_SEQUENCER_HL_OPS
+#else
+#define IREE_SEQUENCER_HL_OPS
+
+#ifdef IREE_OP_BASE
+#else
+include "compiler/IR/OpBase.td"
+#endif // IREE_OP_BASE
+
+def IREESeqHL_Dialect : Dialect {
+ let name = "iree_hl_seq";
+ let cppNamespace = "IREESeq::HL";
+}
+
+//===----------------------------------------------------------------------===//
+// Base op classes
+//===----------------------------------------------------------------------===//
+
+class IREESeqHL_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<IREESeqHL_Dialect, mnemonic, traits>;
+
+class IREESeqHL_PureOp<string mnemonic, list<OpTrait> traits = []> :
+ IREESeqHL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
+
+//===----------------------------------------------------------------------===//
+// High-level sequencer ops
+//===----------------------------------------------------------------------===//
+
+def IREESeqHL_CallOp : IREESeqHL_PureOp<"call"> {
+ let arguments = (ins SymbolRefAttr:$callee, Variadic<IREEHL_MemRef>);
+ let results = (outs Variadic<IREEHL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, FuncOp callee,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(callee.getType().getResults());
+ }]>, OpBuilder<
+ "Builder *builder, OperationState &result, StringRef callee,"
+ "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(results);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ StringRef getCallee() { return callee(); }
+ FunctionType getCalleeType();
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+
+ operand_iterator arg_operand_begin() { return operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREESeqHL_CallIndirectOp : IREESeqHL_Op<"call_indirect"> {
+ let arguments = (ins FunctionType:$callee, Variadic<IREEHL_MemRef>:$operands);
+ let results = (outs Variadic<IREEHL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Value *callee,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.operands.push_back(callee);
+ result.addOperands(operands);
+ result.addTypes(callee->getType().cast<FunctionType>().getResults());
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ Value *getCallee() { return getOperand(0); }
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+ operand_iterator arg_operand_begin() { return ++operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREESeqHL_ReturnOp : IREESeqHL_Op<"return", [Terminator]> {
+ let arguments = (ins Variadic<IREEHL_MemRef>:$operands);
+
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
+ >];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREESeqHL_BranchOp : IREESeqHL_Op<"br", [Terminator]> {
+ let arguments = (ins Variadic<IREEHL_MemRef>:$operands);
+
+ let skipDefaultBuilders = 1;
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Block *dest,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.addSuccessor(dest, operands);
+ }]>];
+
+ let extraClassDeclaration = [{
+ Block *getDest();
+ void setDest(Block *block);
+
+ /// Erase the operand at 'index' from the operand list.
+ void eraseOperand(unsigned index);
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREESeqHL_CondBranchOp : IREESeqHL_Op<"cond_br", [Terminator]> {
+ let arguments = (ins
+ IREEHL_BoolScalar:$condition,
+ Variadic<IREEHL_MemRef>:$branchOperands
+ );
+
+ let skipDefaultBuilders = 1;
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Value *condition, "
+ "Block *trueDest, ArrayRef<Value *> trueOperands, "
+ "Block *falseDest, ArrayRef<Value *> falseOperands", [{
+ result.addOperands(condition);
+ result.addSuccessor(trueDest, trueOperands);
+ result.addSuccessor(falseDest, falseOperands);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // These are the indices into the dests list.
+ enum { trueIndex = 0, falseIndex = 1 };
+
+ // The condition operand is the first operand in the list.
+ Value *getCondition() { return getOperand(0); }
+
+ /// Return the destination if the condition is true.
+ Block *getTrueDest() {
+ return getOperation()->getSuccessor(trueIndex);
+ }
+
+ /// Return the destination if the condition is false.
+ Block *getFalseDest() {
+ return getOperation()->getSuccessor(falseIndex);
+ }
+
+ // Accessors for operands to the 'true' destination.
+ Value *getTrueOperand(unsigned idx) {
+ assert(idx < getNumTrueOperands());
+ return getOperand(getTrueDestOperandIndex() + idx);
+ }
+
+ void setTrueOperand(unsigned idx, Value *value) {
+ assert(idx < getNumTrueOperands());
+ setOperand(getTrueDestOperandIndex() + idx, value);
+ }
+
+ operand_iterator true_operand_begin() {
+ return operand_begin() + getTrueDestOperandIndex();
+ }
+ operand_iterator true_operand_end() {
+ return true_operand_begin() + getNumTrueOperands();
+ }
+ operand_range getTrueOperands() {
+ return {true_operand_begin(), true_operand_end()};
+ }
+
+ unsigned getNumTrueOperands() {
+ return getOperation()->getNumSuccessorOperands(trueIndex);
+ }
+
+ /// Erase the operand at 'index' from the true operand list.
+ void eraseTrueOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(trueIndex, index);
+ }
+
+ // Accessors for operands to the 'false' destination.
+ Value *getFalseOperand(unsigned idx) {
+ assert(idx < getNumFalseOperands());
+ return getOperand(getFalseDestOperandIndex() + idx);
+ }
+ void setFalseOperand(unsigned idx, Value *value) {
+ assert(idx < getNumFalseOperands());
+ setOperand(getFalseDestOperandIndex() + idx, value);
+ }
+
+ operand_iterator false_operand_begin() { return true_operand_end(); }
+ operand_iterator false_operand_end() {
+ return false_operand_begin() + getNumFalseOperands();
+ }
+ operand_range getFalseOperands() {
+ return {false_operand_begin(), false_operand_end()};
+ }
+
+ unsigned getNumFalseOperands() {
+ return getOperation()->getNumSuccessorOperands(falseIndex);
+ }
+
+ /// Erase the operand at 'index' from the false operand list.
+ void eraseFalseOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(falseIndex, index);
+ }
+
+ private:
+ /// Get the index of the first true destination operand.
+ unsigned getTrueDestOperandIndex() { return 1; }
+
+ /// Get the index of the first false destination operand.
+ unsigned getFalseDestOperandIndex() {
+ return getTrueDestOperandIndex() + getNumTrueOperands();
+ }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREESeqHL_DispatchOp : IREESeqHL_Op<"dispatch"> {
+ let arguments = (ins
+ SymbolRefAttr:$executable,
+ SymbolRefAttr:$entry_point,
+ IREEHL_IntMemRef:$workload,
+ Variadic<IREEHL_MemRef>:$operands
+ );
+ let results = (outs Variadic<IREEHL_MemRef>);
+
+ let skipDefaultBuilders = 1;
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, StringRef executable,"
+ "StringRef entry_point, Value *workload,"
+ "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+ result.addOperands({workload});
+ result.addOperands(operands);
+ result.addAttribute("executable", builder->getSymbolRefAttr(executable));
+ result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point));
+ result.addTypes(results);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ StringRef getExecutable() { return executable(); }
+ StringRef getEntryPoint() { return entry_point(); }
+ FunctionType getEntryPointType();
+
+ Value *getWorkload() { return getOperand(0); }
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+ operand_iterator arg_operand_begin() { return operand_begin() + 1; }
+ operand_iterator arg_operand_end() { return operand_end(); }
+
+ operand_type_range getArgOperandTypes() {
+ return {arg_operand_type_begin(), arg_operand_type_end()};
+ }
+ operand_type_iterator arg_operand_type_begin() {
+ return operand_type_iterator(arg_operand_begin());
+ }
+ operand_type_iterator arg_operand_type_end() {
+ return operand_type_iterator(arg_operand_end());
+ }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+ let verifier = [{ return verify$cppClass(*this); }];
+}
+
+// TODO(b/142012496): Add trait that enables DCE but not CSE.
+def IREESeqHL_AllocHeapOp : IREESeqHL_Op<"alloc_heap"> {
+ // TODO(benvanik): attributes and args.
+ let arguments = (ins Variadic<IREEHL_IntMemRef>:$dim_pieces);
+ let results = (outs IREEHL_MemRef);
+}
+
+def IREESeqHL_DiscardOp : IREESeqHL_Op<"discard"> {
+ let arguments = (ins IREEHL_MemRef);
+}
+
+def IREESeqHL_RankOp : IREESeqHL_PureOp<"rank"> {
+ let arguments = (ins IREEHL_MemRef);
+ let results = (outs IREEHL_IntScalar);
+
+ let hasFolder = 1;
+}
+
+def IREESeqHL_DimOp : IREESeqHL_PureOp<"dim"> {
+ // TODO(benvanik) add dim attr (I32Attr:$dim)
+ let arguments = (ins IREEHL_MemRef);
+ let results = (outs IREEHL_IntScalar);
+}
+
+def IREESeqHL_ShapeOp : IREESeqHL_PureOp<"shape"> {
+ let arguments = (ins IREEHL_MemRef);
+ let results = (outs IREEHL_1DIntMemRef);
+
+ let skipDefaultBuilders = 1;
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, Value *operand">];
+
+ let hasFolder = 1;
+}
+
+def IREESeqHL_LengthOp : IREESeqHL_PureOp<"length"> {
+ let arguments = (ins IREEHL_MemRef);
+ let results = (outs IREEHL_IndexScalar);
+
+ let hasFolder = 1;
+}
+
+def IREESeqHL_SliceOp :
+ IREESeqHL_PureOp<"slice", [AllElementTypesMatch<["src", "result"]>,
+ AllTypesMatch<["indices", "lengths"]>]> {
+ let arguments = (ins
+ IREEHL_MemRef:$src,
+ IREEHL_1DIndexMemRef:$indices,
+ IREEHL_1DIndexMemRef:$lengths
+ );
+ let results = (outs IREEHL_MemRef:$result);
+}
+
+def IREESeqHL_CopyOp : IREESeqHL_Op<"copy", [
+ AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>,
+ AllRanksMatch<["src", "dst"]>,
+ // The checks above are redundant with this one, but they give more specific
+ // error messages.
+ AllMatch<[
+ Rank<"src">.result,
+ Rank<"dst">.result,
+ ElementCount<"srcIndices">.result,
+ ElementCount<"dstIndices">.result,
+ ElementCount<"lengths">.result
+ ], "src/dst rank is the same as srcIndices/dstIndices/lengths size">,
+ AllElementTypesMatch<["src", "dst"]>
+]> {
+ let arguments = (ins
+ IREEHL_MemRef:$src,
+ IREEHL_1DIndexMemRef:$srcIndices,
+ IREEHL_MemRef:$dst,
+ IREEHL_1DIndexMemRef:$dstIndices,
+ IREEHL_1DIndexMemRef:$lengths
+ );
+}
+
+def IREESeqHL_FillOp : IREESeqHL_Op<"fill"> {
+ let arguments = (ins
+ IREEHL_I32Scalar:$value,
+ IREEHL_MemRef:$dst,
+ IREEHL_1DIndexMemRef:$dstIndices,
+ IREEHL_1DIndexMemRef:$lengths
+ );
+}
+
+def IREESeqHL_CloneOp : IREESeqHL_PureOp<"clone", [SameOperandsAndResultType]> {
+ let arguments = (ins IREEHL_MemRef:$src);
+ let results = (outs IREEHL_MemRef);
+}
+
+// A pseudo op provided for convenience. This gets canonicalized to a series of
+// copies.
+def IREESeqHL_ConcatOp : IREESeqHL_PureOp<"concat"> {
+ // TODO(b/135032064) Add type constraints when they support variadic
+ let arguments = (ins
+ Variadic<IREEHL_MemRef>:$srcs,
+ I32Attr:$dimension
+ );
+ let results = (outs IREEHL_MemRef);
+
+ let hasCanonicalizer = 1;
+}
+
+def IREESeqHL_AssignOp :
+ IREESeqHL_PureOp<"assign", [SameOperandsAndResultType]> {
+ let arguments = (ins IREEHL_MemRef:$src);
+ let results = (outs IREEHL_MemRef);
+}
+
+def IREESeqHL_CondAssignOp : IREESeqHL_PureOp<"cond_assign"> {
+ let arguments = (ins
+ IREEHL_BoolScalar:$cond,
+ IREEHL_MemRef:$lhs,
+ IREEHL_MemRef:$rhs
+ );
+ let results = (outs IREEHL_MemRef);
+}
+
+def IREESeqHL_ReshapeOp : IREESeqHL_PureOp<"reshape"> {
+ let arguments = (ins IREEHL_MemRef:$src, IREEHL_MemRef:$shape);
+ let results = (outs IREEHL_MemRef);
+}
+
+def IREESeqHL_TraceOp : IREESeqHL_Op<"trace"> {
+ let arguments = (ins Variadic<IREEHL_MemRef>:$srcs);
+}
+
+def IREESeqHL_CondBreakOp : IREESeqHL_Op<"cond_break"> {
+ let arguments = (ins IREEHL_BoolScalar:$cond);
+}
+
+def IREESeqHL_BreakOp : IREESeqHL_Op<"break">;
+
+#endif // IREE_SEQUENCER_HL_OPS
diff --git a/compiler/IR/Sequencer/LLDialect.cpp b/compiler/IR/Sequencer/LLDialect.cpp
new file mode 100644
index 0000000..7fb2368
--- /dev/null
+++ b/compiler/IR/Sequencer/LLDialect.cpp
@@ -0,0 +1,34 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Sequencer/LLDialect.h"
+
+#include "compiler/IR/Sequencer/LLOps.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+IREELLSequencerDialect::IREELLSequencerDialect(MLIRContext* context)
+ : Dialect(getDialectNamespace(), context) {
+#define GET_OP_LIST
+ addOperations<
+#include "compiler/IR/Sequencer/LLOps.cpp.inc"
+ >();
+}
+
+static DialectRegistration<IREELLSequencerDialect> iree_ll_seq_dialect;
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/LLDialect.h b/compiler/IR/Sequencer/LLDialect.h
similarity index 100%
rename from iree/compiler/IR/Sequencer/LLDialect.h
rename to compiler/IR/Sequencer/LLDialect.h
diff --git a/compiler/IR/Sequencer/LLOps.cpp b/compiler/IR/Sequencer/LLOps.cpp
new file mode 100644
index 0000000..a5205f0
--- /dev/null
+++ b/compiler/IR/Sequencer/LLOps.cpp
@@ -0,0 +1,671 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Sequencer/LLOps.h"
+
+#include "compiler/IR/Ops.h"
+#include "compiler/Utils/OpUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/STLExtras.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREESeq {
+namespace LL {
+
+namespace {
+
+static LogicalResult verifyWorkload(Operation *op, Value *workload) {
+ if (auto workloadType = workload->getType().dyn_cast<MemRefType>()) {
+ if (workloadType.getNumElements() != 3) {
+ return op->emitOpError("workload must be specified as (x,y,z) but has ")
+ << workloadType.getNumElements()
+ << " elements (type=" << workload->getType() << ")";
+ }
+ return success();
+ }
+ return op->emitOpError(
+ "workload must be specified as an (x,y,z) memref but has type ")
+ << workload->getType();
+}
+
+static LogicalResult verifyWorkload(Operation *op, ElementsAttr workload) {
+ if (workload.getNumElements() != 3) {
+ return op->emitOpError("workload must be specified as (x,y,z) but has ")
+ << workload.getNumElements() << " elements (value=" << workload
+ << ")";
+ }
+ return success();
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.constant
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
+ return getValue();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.call
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) {
+ SymbolRefAttr calleeAttr;
+ FunctionType calleeType;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ auto calleeLoc = parser.getNameLoc();
+ if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttributeDict(state.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.addTypesToList(calleeType.getResults(), state.types) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
+ state.operands)) {
+ return failure();
+ }
+ return success();
+}
+
+static void printCallOp(OpAsmPrinter &p, CallOp op) {
+ p << "iree_ll_seq.call " << op.getAttr("callee") << '(';
+ p.printOperands(op.getOperands());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : ";
+ p.printType(op.getCalleeType());
+}
+
+FunctionType CallOp::getCalleeType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getOperandTypes());
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.call_import
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallImportOp(OpAsmParser &parser,
+ OperationState &state) {
+ SymbolRefAttr calleeAttr;
+ FunctionType calleeType;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ auto calleeLoc = parser.getNameLoc();
+ if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttributeDict(state.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.addTypesToList(calleeType.getResults(), state.types) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
+ state.operands)) {
+ return failure();
+ }
+ return success();
+}
+
+static void printCallImportOp(OpAsmPrinter &p, CallImportOp op) {
+ p << "iree_ll_seq.call_import " << op.getAttr("callee") << '(';
+ p.printOperands(op.getOperands());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : ";
+ p.printType(op.getCalleeType());
+}
+
+FunctionType CallImportOp::getCalleeType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getOperandTypes());
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.call_indirect
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCallIndirectOp(OpAsmParser &parser,
+ OperationState &result) {
+ FunctionType calleeType;
+ OpAsmParser::OperandType callee;
+ llvm::SMLoc operandsLoc;
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ return failure(
+ parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttributeDict(result.attributes) ||
+ parser.parseColonType(calleeType) ||
+ parser.resolveOperand(callee, calleeType, result.operands) ||
+ parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
+ result.operands) ||
+ parser.addTypesToList(calleeType.getResults(), result.types));
+}
+
+static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) {
+ p << "iree_ll_seq.call_indirect ";
+ p.printOperand(op.getCallee());
+ p << '(';
+ auto operandRange = op.getOperands();
+ p.printOperands(++operandRange.begin(), operandRange.end());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
+ p << " : " << op.getCallee()->getType();
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.return
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) {
+ SmallVector<OpAsmParser::OperandType, 2> opInfo;
+ SmallVector<Type, 2> types;
+ llvm::SMLoc loc = parser.getCurrentLocation();
+ return failure(parser.parseOperandList(opInfo) ||
+ (!opInfo.empty() && parser.parseColonTypeList(types)) ||
+ parser.resolveOperands(opInfo, types, loc, state.operands));
+}
+
+static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
+ p << "iree_ll_seq.return";
+ if (op.getNumOperands() > 0) {
+ p << ' ';
+ p.printOperands(op.operand_begin(), op.operand_end());
+ p << " : ";
+ interleaveComma(op.getOperandTypes(), p);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.br
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
+ Block *dest;
+ SmallVector<Value *, 4> destOperands;
+ if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
+ result.addSuccessor(dest, destOperands);
+ return success();
+}
+
+static void printBranchOp(OpAsmPrinter &p, BranchOp op) {
+ p << "iree_ll_seq.br ";
+ p.printSuccessorAndUseList(op.getOperation(), 0);
+}
+
+Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
+
+void BranchOp::setDest(Block *block) {
+ return getOperation()->setSuccessor(block, 0);
+}
+
+void BranchOp::eraseOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(0, index);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.cond_br
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCondBranchOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Value *, 4> destOperands;
+ Block *dest;
+ OpAsmParser::OperandType condInfo;
+
+ // Parse the condition.
+ Type int1Ty = parser.getBuilder().getI1Type();
+ if (parser.parseOperand(condInfo) || parser.parseComma() ||
+ parser.resolveOperand(condInfo, int1Ty, result.operands)) {
+ return parser.emitError(parser.getNameLoc(),
+ "expected condition type was boolean (i1)");
+ }
+
+ // Parse the true successor.
+ if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
+ result.addSuccessor(dest, destOperands);
+
+ // Parse the false successor.
+ destOperands.clear();
+ if (parser.parseComma() ||
+ parser.parseSuccessorAndUseList(dest, destOperands))
+ return failure();
+ result.addSuccessor(dest, destOperands);
+
+ return success();
+}
+
+static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) {
+ p << "iree_ll_interp.cond_br ";
+ p.printOperand(op.getCondition());
+ p << ", ";
+ p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
+ p << ", ";
+ p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.dynamic_dispatch
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseDynamicDispatchOp(OpAsmParser &parser,
+ OperationState &state) {
+ auto executableLoc = parser.getNameLoc();
+
+ SymbolRefAttr executableAttr;
+ SymbolRefAttr entryPointAttr;
+ FunctionType entryPointType;
+ if (failed(parser.parseAttribute(executableAttr, "executable",
+ state.attributes)) ||
+ failed(parser.parseColon()) || failed(parser.parseColon()) ||
+ failed(parser.parseAttribute(entryPointAttr, "entry_point",
+ state.attributes))) {
+ return failure();
+ }
+
+ OpAsmParser::OperandType workloadArg;
+ Type workloadArgType;
+ if (failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(workloadArg)) ||
+ failed(parser.parseColonType(workloadArgType)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.resolveOperand(workloadArg, workloadArgType,
+ state.operands))) {
+ return failure();
+ }
+
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ if (failed(
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
+ failed(parser.parseOptionalAttributeDict(state.attributes)) ||
+ failed(parser.parseColonType(entryPointType)) ||
+ failed(parser.addTypesToList(entryPointType.getResults(), state.types)) ||
+ failed(parser.resolveOperands(operands, entryPointType.getInputs(),
+ executableLoc, state.operands))) {
+ return failure();
+ }
+ return success();
+}
+
+static void printDynamicDispatchOp(OpAsmPrinter &p, DynamicDispatchOp op) {
+ p << "iree_ll_seq.dynamic_dispatch " << op.getExecutable()
+ << "::" << op.getEntryPoint();
+ p << "[";
+ p.printOperand(op.getWorkload());
+ p << " : ";
+ p.printType(op.getWorkload()->getType());
+ p << "](";
+ p.printOperands(op.getArgOperands());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
+ "executable",
+ "entry_point",
+ });
+ p << " : ";
+ p.printType(op.getEntryPointType());
+}
+
+static LogicalResult verifyDynamicDispatchOp(DynamicDispatchOp op) {
+ if (failed(verifyWorkload(op, op.getWorkload()))) {
+ return failure();
+ }
+ return success();
+}
+
+FunctionType DynamicDispatchOp::getEntryPointType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getArgOperandTypes());
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+namespace {
+struct MakeDynamicDispatchOpStatic
+ : public OpRewritePattern<DynamicDispatchOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(DynamicDispatchOp dynamicDispatchOp,
+ PatternRewriter &rewriter) const override {
+ ElementsAttr workloadAttr;
+ if (!matchPattern(dynamicDispatchOp.getWorkload(),
+ m_Constant(&workloadAttr))) {
+ return matchFailure();
+ }
+
+ SmallVector<Type, 8> resultTypes{dynamicDispatchOp.getResultTypes()};
+ SmallVector<Value *, 8> operands{dynamicDispatchOp.getArgOperands()};
+ rewriter.replaceOpWithNewOp<IREESeq::LL::StaticDispatchOp>(
+ dynamicDispatchOp, dynamicDispatchOp.getExecutable(),
+ dynamicDispatchOp.getEntryPoint(), workloadAttr, resultTypes, operands);
+ return matchSuccess();
+ }
+};
+} // namespace
+
+void DynamicDispatchOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<MakeDynamicDispatchOpStatic>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.static_dispatch
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseStaticDispatchOp(OpAsmParser &parser,
+ OperationState &state) {
+ auto executableLoc = parser.getNameLoc();
+
+ SymbolRefAttr executableAttr;
+ SymbolRefAttr entryPointAttr;
+ FunctionType entryPointType;
+ if (failed(parser.parseAttribute(executableAttr, "executable",
+ state.attributes)) ||
+ failed(parser.parseColon()) || failed(parser.parseColon()) ||
+ failed(parser.parseAttribute(entryPointAttr, "entry_point",
+ state.attributes))) {
+ return failure();
+ }
+
+ ElementsAttr workloadAttr;
+ if (failed(parser.parseLSquare()) ||
+ failed(
+ parser.parseAttribute(workloadAttr, "workload", state.attributes)) ||
+ failed(parser.parseRSquare())) {
+ return failure();
+ }
+
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ if (failed(
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
+ failed(parser.parseOptionalAttributeDict(state.attributes)) ||
+ failed(parser.parseColonType(entryPointType)) ||
+ failed(parser.addTypesToList(entryPointType.getResults(), state.types)) ||
+ failed(parser.resolveOperands(operands, entryPointType.getInputs(),
+ executableLoc, state.operands))) {
+ return failure();
+ }
+ return success();
+}
+
+static void printStaticDispatchOp(OpAsmPrinter &p, StaticDispatchOp op) {
+ p << "iree_ll_seq.static_dispatch " << op.getExecutable()
+ << "::" << op.getEntryPoint();
+ p << "[";
+ p.printAttribute(op.getWorkload());
+ p << "](";
+ p.printOperands(op.getArgOperands());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
+ "executable",
+ "entry_point",
+ "workload",
+ });
+ p << " : ";
+ p.printType(op.getEntryPointType());
+}
+
+static LogicalResult verifyStaticDispatchOp(StaticDispatchOp op) {
+ if (failed(verifyWorkload(op, op.getWorkload()))) {
+ return failure();
+ }
+ return success();
+}
+
+FunctionType StaticDispatchOp::getEntryPointType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(getArgOperandTypes());
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.shape
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldShapeOp : public OpRewritePattern<ShapeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(ShapeOp shapeOp,
+ PatternRewriter &rewriter) const override {
+ auto memRefType = shapeOp.input()->getType().cast<MemRefType>();
+ if (memRefType.hasStaticShape()) {
+ auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>(
+ shapeOp.getLoc(),
+ MemRefType::get({memRefType.getRank()}, rewriter.getIntegerType(64)),
+ DenseIntElementsAttr::get(
+ RankedTensorType::get({memRefType.getRank()},
+ rewriter.getIntegerType(64)),
+ memRefType.getShape()));
+ replaceSubsequentUses(shapeOp, shapeOp.dst(), constantOp.getResult());
+ rewriter.eraseOp(shapeOp);
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+} // namespace
+
+void ShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldShapeOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.length
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldLengthOp : public OpRewritePattern<LengthOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(LengthOp lengthOp,
+ PatternRewriter &rewriter) const override {
+ auto memRefType = lengthOp.input()->getType().cast<MemRefType>();
+ if (memRefType.hasStaticShape()) {
+ auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>(
+ lengthOp.getLoc(), MemRefType::get({}, rewriter.getIntegerType(64)),
+ DenseIntElementsAttr::get(
+ RankedTensorType::get({}, rewriter.getIntegerType(64)),
+ {memRefType.getNumElements()}));
+ replaceSubsequentUses(lengthOp, lengthOp.dst(), constantOp.getResult());
+ rewriter.eraseOp(lengthOp);
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+} // namespace
+
+void LengthOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldLengthOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.compute_offset
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldComputeOffsetOp : public OpRewritePattern<ComputeOffsetOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(ComputeOffsetOp computeOffsetOp,
+ PatternRewriter &rewriter) const override {
+ ElementsAttr shapeAttr;
+ ElementsAttr indicesAttr;
+ if (!matchPattern(computeOffsetOp.shape(), m_Constant(&shapeAttr)) ||
+ !matchPattern(computeOffsetOp.indices(), m_Constant(&indicesAttr))) {
+ return matchFailure();
+ }
+
+ int64_t offset = 0;
+ for (unsigned i = 0; i < indicesAttr.getNumElements(); ++i) {
+ int64_t axisOffset =
+ indicesAttr.getValue({i}).cast<IntegerAttr>().getInt();
+ for (unsigned j = i + 1; j < shapeAttr.getNumElements(); ++j) {
+ axisOffset *= shapeAttr.getValue({j}).cast<IntegerAttr>().getInt();
+ }
+ offset += axisOffset;
+ }
+ offset *= computeOffsetOp.elementSize().getZExtValue();
+
+ auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>(
+ computeOffsetOp.getLoc(),
+ MemRefType::get({}, rewriter.getIntegerType(64)),
+ DenseIntElementsAttr::get(
+ RankedTensorType::get({}, rewriter.getIntegerType(64)), {offset}));
+ replaceSubsequentUses(computeOffsetOp, computeOffsetOp.dst(),
+ constantOp.getResult());
+ rewriter.eraseOp(computeOffsetOp);
+ return matchSuccess();
+ }
+};
+} // namespace
+
+void ComputeOffsetOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<FoldComputeOffsetOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.compute_range
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldComputeRangeOp : public OpRewritePattern<ComputeRangeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(ComputeRangeOp computeRangeOp,
+ PatternRewriter &rewriter) const override {
+ ElementsAttr shapeAttr;
+ ElementsAttr indicesAttr;
+ ElementsAttr lengthsAttr;
+ if (!matchPattern(computeRangeOp.shape(), m_Constant(&shapeAttr)) ||
+ !matchPattern(computeRangeOp.indices(), m_Constant(&indicesAttr)) ||
+ !matchPattern(computeRangeOp.lengths(), m_Constant(&lengthsAttr))) {
+ return matchFailure();
+ }
+
+ int64_t offset = 0;
+ int64_t length = computeRangeOp.elementSize().getZExtValue();
+ for (unsigned i = 0; i < indicesAttr.getNumElements(); ++i) {
+ int64_t axisOffset =
+ indicesAttr.getValue({i}).cast<IntegerAttr>().getInt();
+ for (unsigned j = i + 1; j < shapeAttr.getNumElements(); ++j) {
+ axisOffset *= shapeAttr.getValue({j}).cast<IntegerAttr>().getInt();
+ }
+ offset += axisOffset;
+ length *= lengthsAttr.getValue({i}).cast<IntegerAttr>().getInt();
+ }
+ offset *= computeRangeOp.elementSize().getZExtValue();
+
+ auto offsetConstantOp = rewriter.create<IREESeq::LL::ConstantOp>(
+ computeRangeOp.getLoc(),
+ MemRefType::get({}, rewriter.getIntegerType(64)),
+ DenseIntElementsAttr::get(
+ RankedTensorType::get({}, rewriter.getIntegerType(64)), {offset}));
+ replaceSubsequentUses(computeRangeOp, computeRangeOp.dstOffset(),
+ offsetConstantOp.getResult());
+ auto lengthConstantOp = rewriter.create<IREESeq::LL::ConstantOp>(
+ computeRangeOp.getLoc(),
+ MemRefType::get({}, rewriter.getIntegerType(64)),
+ DenseIntElementsAttr::get(
+ RankedTensorType::get({}, rewriter.getIntegerType(64)), {length}));
+ replaceSubsequentUses(computeRangeOp, computeRangeOp.dstLength(),
+ lengthConstantOp.getResult());
+ rewriter.eraseOp(computeRangeOp);
+ return matchSuccess();
+ }
+};
+} // namespace
+
+void ComputeRangeOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<FoldComputeRangeOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.dynamic_copy
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct MakeDynamicCopyOpStatic : public OpRewritePattern<DynamicCopyOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(DynamicCopyOp dynamicCopyOp,
+ PatternRewriter &rewriter) const override {
+ ElementsAttr srcOffsetAttr;
+ ElementsAttr dstOffsetAttr;
+ ElementsAttr lengthAttr;
+ if (!matchPattern(dynamicCopyOp.srcOffset(), m_Constant(&srcOffsetAttr)) ||
+ !matchPattern(dynamicCopyOp.dstOffset(), m_Constant(&dstOffsetAttr)) ||
+ !matchPattern(dynamicCopyOp.length(), m_Constant(&lengthAttr))) {
+ return matchFailure();
+ }
+
+ rewriter.replaceOpWithNewOp<IREESeq::LL::StaticCopyOp>(
+ dynamicCopyOp, dynamicCopyOp.src(),
+ srcOffsetAttr.getValue({}).cast<IntegerAttr>(), dynamicCopyOp.dst(),
+ dstOffsetAttr.getValue({}).cast<IntegerAttr>(),
+ lengthAttr.getValue({}).cast<IntegerAttr>());
+ return matchSuccess();
+ }
+};
+} // namespace
+
+void DynamicCopyOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<MakeDynamicCopyOpStatic>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_ll_seq.dynamic_fill
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct MakeDynamicFillOpStatic : public OpRewritePattern<DynamicFillOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(DynamicFillOp dynamicFillOp,
+ PatternRewriter &rewriter) const override {
+ ElementsAttr valueAttr;
+ ElementsAttr dstOffsetAttr;
+ ElementsAttr lengthAttr;
+ if (!matchPattern(dynamicFillOp.value(), m_Constant(&valueAttr)) ||
+ !matchPattern(dynamicFillOp.dstOffset(), m_Constant(&dstOffsetAttr)) ||
+ !matchPattern(dynamicFillOp.length(), m_Constant(&lengthAttr))) {
+ return matchFailure();
+ }
+
+ rewriter.replaceOpWithNewOp<IREESeq::LL::StaticFillOp>(
+ dynamicFillOp, valueAttr.getValue({}).cast<IntegerAttr>(),
+ dynamicFillOp.dst(), dstOffsetAttr.getValue({}).cast<IntegerAttr>(),
+ lengthAttr.getValue({}).cast<IntegerAttr>());
+ return matchSuccess();
+ }
+};
+} // namespace
+
+void DynamicFillOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<MakeDynamicFillOpStatic>(context);
+}
+
+#define GET_OP_CLASSES
+#include "compiler/IR/Sequencer/LLOps.cpp.inc"
+
+} // namespace LL
+} // namespace IREESeq
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/IR/Sequencer/LLOps.h b/compiler/IR/Sequencer/LLOps.h
new file mode 100644
index 0000000..c5e83bd
--- /dev/null
+++ b/compiler/IR/Sequencer/LLOps.h
@@ -0,0 +1,38 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_IR_SEQUENCER_LLOPS_H_
+#define IREE_COMPILER_IR_SEQUENCER_LLOPS_H_
+
+#include "compiler/IR/Types.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREESeq {
+namespace LL {
+
+#define GET_OP_CLASSES
+#include "compiler/IR/Sequencer/LLOps.h.inc"
+
+} // namespace LL
+} // namespace IREESeq
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_IR_SEQUENCER_LLOPS_H_
diff --git a/compiler/IR/Sequencer/LLOps.td b/compiler/IR/Sequencer/LLOps.td
new file mode 100644
index 0000000..1c3324d
--- /dev/null
+++ b/compiler/IR/Sequencer/LLOps.td
@@ -0,0 +1,579 @@
+// IREE low-level sequencer op definitions.
+// These map 1:1 with the bytecode, accept only MemRef types and generally use
+// output parameters instead of return types.
+//
+// The source of truth for bytecode opcodes is:
+// https://github.com/google/iree/tree/master/iree/schemas/bytecode/sequencer_bytecode_v0.h
+//
+// Note that in this dialect we cannot use folders: they require that all
+// operands are possible to make constants where we use output arguments that
+// will never be constant. Instead we can use canonicalization patterns to
+// match constant input operands and do the folding by replacing output operands
+// with the new values.
+
+#ifdef IREE_SEQUENCER_LL_OPS
+#else
+#define IREE_SEQUENCER_LL_OPS
+
+#ifdef IREE_OP_BASE
+#else
+include "compiler/IR/OpBase.td"
+#endif // IREE_OP_BASE
+
+def IREESeqLL_Dialect : Dialect {
+ let name = "iree_ll_seq";
+ let cppNamespace = "IREESeq::LL";
+}
+
+//===----------------------------------------------------------------------===//
+// Base op classes
+//===----------------------------------------------------------------------===//
+
+class IREESeqLL_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<IREESeqLL_Dialect, mnemonic, traits> {
+ bit hasCustomSerializer = 0;
+}
+
+class IREESeqLL_PureOp<string mnemonic, list<OpTrait> traits = []> :
+ IREESeqLL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
+
+class IREESeqLL_UnaryOp<string mnemonic, Type type = IREELL_MemRef,
+ list<OpTrait> traits = []> : IREESeqLL_Op<mnemonic, traits> {
+ let arguments = (ins type:$input, type:$dst);
+}
+
+class IREESeqLL_BinaryOp<string mnemonic, Type type = IREELL_MemRef,
+ list<OpTrait> traits = []> : IREESeqLL_Op<mnemonic, traits> {
+ let arguments = (ins type:$lhs, type:$rhs, type:$dst);
+}
+
+class IREESeqLL_TernaryOp<string mnemonic, Type type = IREELL_MemRef,
+ list<OpTrait> traits = []>
+ : IREESeqLL_Op<mnemonic, traits> {
+ let arguments = (ins type : $a, type : $b, type : $c, type : $dst);
+}
+
+//===----------------------------------------------------------------------===//
+// Low-level sequencer ops
+//===----------------------------------------------------------------------===//
+
+def IREESeqLL_ConstantOp : IREESeqLL_PureOp<"constant"> {
+ let arguments = (ins ElementsAttr:$value);
+ let results = (outs IREELL_MemRef);
+
+ // TODO(b/132296600): make tablegen follow the style guide.
+ let extraClassDeclaration = [{
+ Attribute getValue() { return value(); }
+ }];
+
+ let hasFolder = 1;
+}
+
+def IREESeqLL_CallOp : IREESeqLL_Op<"call"> {
+ let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>);
+ let results = (outs Variadic<IREELL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, FuncOp callee,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(callee.getType().getResults());
+ }]>, OpBuilder<
+ "Builder *builder, OperationState &result, StringRef callee,"
+ "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(results);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ StringRef getCallee() { return callee(); }
+ FunctionType getCalleeType();
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+
+ operand_iterator arg_operand_begin() { return operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+// TODO(benvanik): add verifier that target isExternal.
+def IREESeqLL_CallImportOp : IREESeqLL_Op<"call_import"> {
+ let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>);
+ let results = (outs Variadic<IREELL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, FuncOp callee,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(callee.getType().getResults());
+ }]>, OpBuilder<
+ "Builder *builder, OperationState &result, StringRef callee,"
+ "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+ result.addOperands(operands);
+ result.addAttribute("callee", builder->getSymbolRefAttr(callee));
+ result.addTypes(results);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ StringRef getCallee() { return callee(); }
+ FunctionType getCalleeType();
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+
+ operand_iterator arg_operand_begin() { return operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREESeqLL_CallIndirectOp : IREESeqLL_Op<"call_indirect"> {
+ let arguments = (ins FunctionType:$callee, Variadic<IREELL_MemRef>:$operands);
+ let results = (outs Variadic<IREELL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Value *callee,"
+ "ArrayRef<Value *> operands = {}", [{
+ result.operands.push_back(callee);
+ result.addOperands(operands);
+ result.addTypes(callee->getType().cast<FunctionType>().getResults());
+ }]>];
+
+ let extraClassDeclaration = [{
+ Value *getCallee() { return getOperand(0); }
+
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+
+ operand_iterator arg_operand_begin() { return ++operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREESeqLL_ReturnOp : IREESeqLL_Op<"return", [Terminator]> {
+ let arguments = (ins Variadic<IREELL_MemRef>:$operands);
+
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
+ >];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREESeqLL_BranchOp : IREESeqLL_Op<"br", [Terminator]> {
+ let arguments = (ins Variadic<IREELL_MemRef>:$operands);
+
+ let skipDefaultBuilders = 1;
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Block *dest, "
+ "ArrayRef<Value *> operands = {}", [{
+ result.addSuccessor(dest, operands);
+ }]>];
+
+ let extraClassDeclaration = [{
+ Block *getDest();
+ void setDest(Block *block);
+
+ /// Erase the operand at 'index' from the operand list.
+ void eraseOperand(unsigned index);
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREESeqLL_CondBranchOp : IREESeqLL_Op<"cond_br", [Terminator]> {
+ let arguments = (ins
+ IREELL_BoolScalar:$condition,
+ Variadic<IREELL_MemRef>:$branchOperands
+ );
+
+ let skipDefaultBuilders = 1;
+ let builders = [OpBuilder<
+ "Builder *, OperationState &result, Value *condition,"
+ "Block *trueDest, ArrayRef<Value *> trueOperands,"
+ "Block *falseDest, ArrayRef<Value *> falseOperands", [{
+ result.addOperands(condition);
+ result.addSuccessor(trueDest, trueOperands);
+ result.addSuccessor(falseDest, falseOperands);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // These are the indices into the dests list.
+ enum { trueIndex = 0, falseIndex = 1 };
+
+ // The condition operand is the first operand in the list.
+ Value *getCondition() { return getOperand(0); }
+
+ /// Return the destination if the condition is true.
+ Block *getTrueDest() {
+ return getOperation()->getSuccessor(trueIndex);
+ }
+
+ /// Return the destination if the condition is false.
+ Block *getFalseDest() {
+ return getOperation()->getSuccessor(falseIndex);
+ }
+
+ // Accessors for operands to the 'true' destination.
+ Value *getTrueOperand(unsigned idx) {
+ assert(idx < getNumTrueOperands());
+ return getOperand(getTrueDestOperandIndex() + idx);
+ }
+
+ void setTrueOperand(unsigned idx, Value *value) {
+ assert(idx < getNumTrueOperands());
+ setOperand(getTrueDestOperandIndex() + idx, value);
+ }
+
+ operand_iterator true_operand_begin() {
+ return operand_begin() + getTrueDestOperandIndex();
+ }
+ operand_iterator true_operand_end() {
+ return true_operand_begin() + getNumTrueOperands();
+ }
+ operand_range getTrueOperands() {
+ return {true_operand_begin(), true_operand_end()};
+ }
+
+ unsigned getNumTrueOperands() {
+ return getOperation()->getNumSuccessorOperands(trueIndex);
+ }
+
+ /// Erase the operand at 'index' from the true operand list.
+ void eraseTrueOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(trueIndex, index);
+ }
+
+ // Accessors for operands to the 'false' destination.
+ Value *getFalseOperand(unsigned idx) {
+ assert(idx < getNumFalseOperands());
+ return getOperand(getFalseDestOperandIndex() + idx);
+ }
+ void setFalseOperand(unsigned idx, Value *value) {
+ assert(idx < getNumFalseOperands());
+ setOperand(getFalseDestOperandIndex() + idx, value);
+ }
+
+ operand_iterator false_operand_begin() { return true_operand_end(); }
+ operand_iterator false_operand_end() {
+ return false_operand_begin() + getNumFalseOperands();
+ }
+ operand_range getFalseOperands() {
+ return {false_operand_begin(), false_operand_end()};
+ }
+
+ unsigned getNumFalseOperands() {
+ return getOperation()->getNumSuccessorOperands(falseIndex);
+ }
+
+ /// Erase the operand at 'index' from the false operand list.
+ void eraseFalseOperand(unsigned index) {
+ getOperation()->eraseSuccessorOperand(falseIndex, index);
+ }
+
+ private:
+ /// Get the index of the first true destination operand.
+ unsigned getTrueDestOperandIndex() { return 1; }
+
+ /// Get the index of the first false destination operand.
+ unsigned getFalseDestOperandIndex() {
+ return getTrueDestOperandIndex() + getNumTrueOperands();
+ }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+def IREESeqLL_DynamicDispatchOp : IREESeqLL_Op<"dynamic_dispatch"> {
+ let arguments = (ins
+ SymbolRefAttr:$executable,
+ SymbolRefAttr:$entry_point,
+ IREELL_IntMemRef:$workload,
+ Variadic<IREELL_MemRef>:$operands
+ );
+ let results = (outs Variadic<IREELL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, StringRef executable,"
+ "StringRef entry_point, Value *workload,"
+ "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+ result.addOperands({workload});
+ result.addOperands(operands);
+ result.addAttribute("executable", builder->getSymbolRefAttr(executable));
+ result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point));
+ result.addTypes(results);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ StringRef getExecutable() { return executable(); }
+ StringRef getEntryPoint() { return entry_point(); }
+ FunctionType getEntryPointType();
+
+ Value *getWorkload() { return getOperand(0); }
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+ operand_iterator arg_operand_begin() { return operand_begin() + 1; }
+ operand_iterator arg_operand_end() { return operand_end(); }
+
+ operand_type_range getArgOperandTypes() {
+ return {arg_operand_type_begin(), arg_operand_type_end()};
+ }
+ operand_type_iterator arg_operand_type_begin() {
+ return operand_type_iterator(arg_operand_begin());
+ }
+ operand_type_iterator arg_operand_type_end() {
+ return operand_type_iterator(arg_operand_end());
+ }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+ let verifier = [{ return verify$cppClass(*this); }];
+ let hasCanonicalizer = 1;
+}
+
+def IREESeqLL_StaticDispatchOp : IREESeqLL_Op<"static_dispatch"> {
+ let arguments = (ins
+ SymbolRefAttr:$executable,
+ SymbolRefAttr:$entry_point,
+ I32ElementsAttr:$workload,
+ Variadic<IREELL_MemRef>:$operands
+ );
+ let results = (outs Variadic<IREELL_MemRef>);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, StringRef executable,"
+ "StringRef entry_point, ElementsAttr workload,"
+ "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
+ result.addAttribute("workload", workload);
+ result.addOperands(operands);
+ result.addAttribute("executable", builder->getSymbolRefAttr(executable));
+ result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point));
+ result.addTypes(results);
+ }]>];
+
+ let extraClassDeclaration = [{
+ // TODO(b/132296600): make tablegen follow the style guide.
+ StringRef getExecutable() { return executable(); }
+ StringRef getEntryPoint() { return entry_point(); }
+ FunctionType getEntryPointType();
+
+ ElementsAttr getWorkload() { return workload(); }
+
+ // TODO(b/133879130): make tablegen support variadic operand accessors.
+ operand_range getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+ operand_iterator arg_operand_begin() { return operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+
+ operand_type_range getArgOperandTypes() {
+ return {arg_operand_type_begin(), arg_operand_type_end()};
+ }
+ operand_type_iterator arg_operand_type_begin() {
+ return operand_type_iterator(arg_operand_begin());
+ }
+ operand_type_iterator arg_operand_type_end() {
+ return operand_type_iterator(arg_operand_end());
+ }
+ }];
+
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+ let verifier = [{ return verify$cppClass(*this); }];
+}
+
+def IREESeqLL_AllocStaticOp : IREESeqLL_PureOp<"alloc_static"> {
+ // TODO(benvanik): attributes and args.
+ let results = (outs IREELL_MemRef);
+}
+
+def IREESeqLL_AllocStackOp : IREESeqLL_PureOp<"alloc_stack"> {
+ // TODO(benvanik): attributes and args.
+ let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces);
+ let results = (outs IREELL_MemRef);
+}
+
+def IREESeqLL_AllocStackInitOp : IREESeqLL_PureOp<"alloc_stack_init"> {
+ // TODO(benvanik): attributes and args.
+ let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces);
+ let results = (outs IREELL_MemRef);
+}
+
+// TODO(b/142012496): Add trait that enables DCE but not CSE.
+def IREESeqLL_AllocHeapOp : IREESeqLL_Op<"alloc_heap"> {
+ // TODO(benvanik): attributes and args.
+ let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces);
+ let results = (outs IREELL_MemRef);
+}
+
+def IREESeqLL_DiscardOp : IREESeqLL_Op<"discard"> {
+ let arguments = (ins IREELL_MemRef);
+}
+
+def IREESeqLL_ShapeOp : IREESeqLL_Op<"shape"> {
+ let arguments = (ins IREELL_MemRef:$input, IREELL_I32MemRef:$dst);
+
+ let hasCanonicalizer = 1;
+}
+
+def IREESeqLL_LengthOp : IREESeqLL_Op<"length"> {
+ let arguments = (ins IREELL_MemRef:$input, IREELL_I32Scalar:$dst);
+
+ let hasCanonicalizer = 1;
+}
+
+def IREESeqLL_ComputeOffsetOp : IREESeqLL_Op<"compute_offset"> {
+ let arguments = (ins
+ IREELL_1DIntMemRef:$shape,
+ I8Attr:$elementSize,
+ IREELL_1DIntMemRef:$indices,
+ IREELL_I32Scalar:$dst
+ );
+
+ let hasCanonicalizer = 1;
+}
+
+def IREESeqLL_ComputeRangeOp : IREESeqLL_Op<"compute_range"> {
+ let arguments = (ins
+ IREELL_1DIntMemRef:$shape,
+ I8Attr:$elementSize,
+ IREELL_1DIntMemRef:$indices,
+ IREELL_1DIntMemRef:$lengths,
+ IREELL_I32Scalar:$dstOffset,
+ IREELL_I32Scalar:$dstLength
+ );
+
+ let hasCanonicalizer = 1;
+}
+
+def IREESeqLL_DynamicSliceOp : IREESeqLL_PureOp<"dynamic_slice", [
+ AllElementTypesMatch<["src", "result"]>
+]> {
+ let arguments = (ins
+ IREELL_MemRef:$src,
+ IREELL_IntScalar:$offset,
+ IREELL_IntScalar:$length
+ );
+ let results = (outs IREELL_MemRef:$result);
+}
+
+def IREESeqLL_StaticSliceOp : IREESeqLL_PureOp<"static_slice", [
+ AllElementTypesMatch<["src", "result"]>
+]> {
+ let arguments = (ins
+ IREELL_MemRef:$src,
+ I64Attr:$offset,
+ I64Attr:$length
+ );
+ let results = (outs IREELL_MemRef:$result);
+}
+
+def IREESeqLL_DynamicCopyOp : IREESeqLL_Op<"dynamic_copy"> {
+ let arguments = (ins
+ IREELL_MemRef:$src,
+ IREELL_IndexScalar:$srcOffset,
+ IREELL_MemRef:$dst,
+ IREELL_IndexScalar:$dstOffset,
+ IREELL_IndexScalar:$length
+ );
+
+ let hasCanonicalizer = 1;
+}
+
+def IREESeqLL_StaticCopyOp : IREESeqLL_Op<"static_copy"> {
+ let arguments = (ins
+ IREELL_MemRef:$src,
+ I64Attr:$srcOffset,
+ IREELL_MemRef:$dst,
+ I64Attr:$dstOffset,
+ I64Attr:$length
+ );
+}
+
+def IREESeqLL_DynamicFillOp : IREESeqLL_Op<"dynamic_fill"> {
+ let arguments = (ins
+ IREELL_I32Scalar:$value,
+ IREELL_MemRef:$dst,
+ IREELL_IndexScalar:$dstOffset,
+ IREELL_IndexScalar:$length
+ );
+
+ let hasCanonicalizer = 1;
+}
+
+def IREESeqLL_StaticFillOp : IREESeqLL_Op<"static_fill"> {
+ let arguments = (ins
+ I32Attr:$value,
+ IREELL_MemRef:$dst,
+ I64Attr:$dstOffset,
+ I64Attr:$length
+ );
+}
+
+def IREESeqLL_CloneOp :
+ IREESeqLL_PureOp<"clone", [SameOperandsAndResultType]> {
+ let arguments = (ins IREELL_MemRef:$src);
+ let results = (outs IREELL_MemRef);
+}
+
+def IREESeqLL_AssignOp :
+ IREESeqLL_Op<"assign", [SameOperandsAndResultType]> {
+ let arguments = (ins IREELL_MemRef:$src);
+ let results = (outs IREELL_MemRef);
+}
+
+def IREESeqLL_CondAssignOp : IREESeqLL_Op<"cond_assign"> {
+ let arguments = (ins
+ IREELL_BoolScalar:$cond,
+ IREELL_MemRef:$lhs,
+ IREELL_MemRef:$rhs
+ );
+ let results = (outs IREELL_MemRef);
+}
+
+def IREESeqLL_ReshapeOp : IREESeqLL_Op<"reshape"> {
+ let arguments = (ins IREELL_MemRef:$input, IREELL_1DIntMemRef:$shape);
+ let results = (outs IREELL_MemRef);
+}
+
+def IREESeqLL_TraceOp : IREESeqLL_Op<"trace"> {
+ let arguments = (ins Variadic<IREELL_MemRef>:$srcs);
+}
+
+def IREESeqLL_CondBreakOp : IREESeqLL_Op<"cond_break"> {
+ let arguments = (ins IREELL_BoolScalar:$cond);
+}
+
+def IREESeqLL_BreakOp : IREESeqLL_Op<"break">;
+
+#endif // IREE_SEQUENCER_LL_OPS
diff --git a/compiler/IR/Sequencer/OpWriters.cpp b/compiler/IR/Sequencer/OpWriters.cpp
new file mode 100644
index 0000000..9d2be02
--- /dev/null
+++ b/compiler/IR/Sequencer/OpWriters.cpp
@@ -0,0 +1,266 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Sequencer/OpWriters.h"
+
+#include "compiler/IR/Sequencer/LLOps.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Serialization/BytecodeWriter.h"
+#include "compiler/Utils/Macros.h"
+#include "schemas/bytecode/sequencer_bytecode_v0.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Sequencer ops
+//===----------------------------------------------------------------------===//
+
+LogicalResult writeOp(IREESeq::LL::ConstantOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kConstant));
+ auto memRefType = op.getType().dyn_cast<MemRefType>();
+ if (!memRefType) {
+ return op.emitError()
+ << "Constant has an unsupported type; must be a memref: "
+ << op.getType();
+ }
+ RETURN_IF_FAILURE(writer->WriteConstant(memRefType, op.getAttr("value")));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::CallOp op, BytecodeWriter *writer) {
+ auto module = op.getOperation()->getParentOfType<ModuleOp>();
+ auto callee = module.lookupSymbol<FuncOp>(op.getCallee());
+ // TODO(benvanik): switch with kCallTail if attr exists.
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCall));
+ RETURN_IF_FAILURE(writer->WriteFunctionOrdinal(callee));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::CallImportOp op, BytecodeWriter *writer) {
+ auto module = op.getOperation()->getParentOfType<ModuleOp>();
+ auto callee = module.lookupSymbol<FuncOp>(op.getCallee());
+ // TODO(benvanik): transforms to convert Call->CallImport.
+ // TODO(benvanik): switch with kCallTail if attr exists.
+ if (callee.isExternal()) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCallImport));
+ } else {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCall));
+ }
+ RETURN_IF_FAILURE(writer->WriteImportOrdinal(callee));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::CallIndirectOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCallIndirect));
+ RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getCallee()->getType()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getCallee()));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::BranchOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kBranch));
+ RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getDest()));
+ RETURN_IF_FAILURE(writer->WriteCount(op.getNumOperands()));
+ for (int i = 0; i < op.getNumOperands(); ++i) {
+ // Copy src->dst.
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(i)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getDest()->getArgument(i)));
+ }
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::CondBranchOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCondBranch));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getCondition()));
+ RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getTrueDest()));
+ RETURN_IF_FAILURE(writer->WriteCount(op.getNumTrueOperands()));
+ for (int i = 0; i < op.getNumTrueOperands(); ++i) {
+ // Copy src->dst.
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueOperand(i)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueDest()->getArgument(i)));
+ }
+ RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getFalseDest()));
+ RETURN_IF_FAILURE(writer->WriteCount(op.getNumFalseOperands()));
+ for (int i = 0; i < op.getNumFalseOperands(); ++i) {
+ // Copy src->dst.
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseOperand(i)));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseDest()->getArgument(i)));
+ }
+ return success();
+}
+
+LogicalResult writeDispatchOpExecutableRef(Operation *op, StringRef executable,
+ StringRef entryPoint,
+ BytecodeWriter *writer) {
+ auto module = op->getParentOfType<ModuleOp>();
+ auto multiArchExecutableOp =
+ module.lookupSymbol<IREE::MultiArchExecutableOp>(executable);
+ if (!multiArchExecutableOp) {
+ return op->emitError() << "Executable @" << executable.str()
+ << " not found in module";
+ }
+
+ auto executableOrdinalAttr = multiArchExecutableOp.getAttr("iree.ordinal")
+ .dyn_cast_or_null<IntegerAttr>();
+ if (!executableOrdinalAttr) {
+ return op->emitError() << "No ordinal assigned to executable";
+ }
+ int executableOrdinal = executableOrdinalAttr.getInt();
+
+ // TODO(benvanik): move an export table to the MAE to make this cleaner.
+ auto executableOp =
+ cast<IREE::ExecutableOp>(multiArchExecutableOp.getBlock().front());
+ auto entryPointOp =
+ executableOp.getInnerModule().lookupSymbol<FuncOp>(entryPoint);
+ if (!entryPointOp) {
+ return op->emitError() << "Entry point @" << entryPoint.str()
+ << " not found in executable @" << executable.str();
+ }
+ if (!entryPointOp.getAttr("iree.ordinal")) {
+ return op->emitError() << "No ordinal assigned to entry point";
+ }
+ int entryPointOrdinal =
+ entryPointOp.getAttr("iree.ordinal").cast<IntegerAttr>().getInt();
+
+ RETURN_IF_FAILURE(writer->WriteUint32(executableOrdinal));
+ RETURN_IF_FAILURE(writer->WriteUint16(entryPointOrdinal));
+
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::DynamicDispatchOp op,
+ BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(
+ writer->WriteOpcode(iree::SequencerOpcode::kDynamicDispatch));
+ RETURN_IF_FAILURE(writeDispatchOpExecutableRef(op, op.getExecutable(),
+ op.getEntryPoint(), writer));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getWorkload()));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
+ // TODO(benvanik): support output arg group (or change to tags).
+ RETURN_IF_FAILURE(writer->WriteCount(/*output_arg_count*/ 0));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::StaticDispatchOp op,
+ BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(
+ writer->WriteOpcode(iree::SequencerOpcode::kStaticDispatch));
+ RETURN_IF_FAILURE(writeDispatchOpExecutableRef(op, op.getExecutable(),
+ op.getEntryPoint(), writer));
+ auto workloadAttr = op.getWorkload();
+ RETURN_IF_FAILURE(
+ writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({0}).getInt()));
+ RETURN_IF_FAILURE(
+ writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({1}).getInt()));
+ RETURN_IF_FAILURE(
+ writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({2}).getInt()));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
+ // TODO(benvanik): support output arg group (or change to tags).
+ RETURN_IF_FAILURE(writer->WriteCount(/*output_arg_count*/ 0));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::AllocHeapOp op, BytecodeWriter *writer) {
+ auto memRefType = op.getType().cast<MemRefType>();
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kAllocHeap));
+ RETURN_IF_FAILURE(writer->WriteInt32(0));
+ RETURN_IF_FAILURE(writer->WriteTypeIndex(memRefType.getElementType()));
+ RETURN_IF_FAILURE(writer->WriteShapePieces(memRefType));
+ RETURN_IF_FAILURE(writer->WriteLocals(op.getOperands()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::ComputeRangeOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kComputeRange));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.shape()));
+ RETURN_IF_FAILURE(writer->WriteUint8(op.elementSize().getZExtValue()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.indices()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.lengths()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.dstOffset()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.dstLength()));
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::StaticSliceOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticSlice));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.src()));
+ RETURN_IF_FAILURE(writer->WriteInt32(op.offset().getZExtValue()));
+ RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue()));
+ RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getResult()->getType()));
+ RETURN_IF_FAILURE(
+ writer->WriteShapePieces(op.getResult()->getType().cast<ShapedType>()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::StaticCopyOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticCopy));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.src()));
+ RETURN_IF_FAILURE(writer->WriteInt32(op.srcOffset().getZExtValue()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.dst()));
+ RETURN_IF_FAILURE(writer->WriteInt32(op.dstOffset().getZExtValue()));
+ RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue()));
+ return success();
+}
+
+LogicalResult writeOp(IREESeq::LL::StaticFillOp op, BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticFill));
+ RETURN_IF_FAILURE(writer->WriteInt32(op.value().getZExtValue()));
+ RETURN_IF_FAILURE(writer->WriteLocal(op.dst()));
+ RETURN_IF_FAILURE(writer->WriteInt32(op.dstOffset().getZExtValue()));
+ RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue()));
+ return success();
+}
+
+} // namespace
+
+void registerSequencerCustomWriters(VMFunctionBuilder *builder) {
+#define REGISTER_CUSTOM_WRITER_IMPL(op_type) \
+ builder->RegisterCustomWriter( \
+ op_type::getOperationName(), \
+ +[](Operation *op, BytecodeWriter *writer) { \
+ return writeOp(cast<op_type>(op), writer); \
+ });
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::ConstantOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallImportOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallIndirectOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::BranchOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CondBranchOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::DynamicDispatchOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticDispatchOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::AllocHeapOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::ComputeRangeOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticSliceOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticCopyOp);
+ REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticFillOp);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/IR/Sequencer/OpWriters.h b/compiler/IR/Sequencer/OpWriters.h
new file mode 100644
index 0000000..d9b638f
--- /dev/null
+++ b/compiler/IR/Sequencer/OpWriters.h
@@ -0,0 +1,30 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_
+#define IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_
+
+#include "compiler/Serialization/VMFunctionBuilder.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Registers custom op writers with the builder.
+// Ops not registered will use the generic writer.
+void registerSequencerCustomWriters(VMFunctionBuilder *builder);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_
diff --git a/compiler/IR/Sequencer/test/BUILD b/compiler/IR/Sequencer/test/BUILD
new file mode 100644
index 0000000..44a5820
--- /dev/null
+++ b/compiler/IR/Sequencer/test/BUILD
@@ -0,0 +1,15 @@
+load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_setup_lit_package(
+ data = [
+ "///tools:iree-opt",
+ "///tools:iree-run-mlir",
+ ],
+)
+
+iree_glob_lit_tests()
diff --git a/iree/compiler/IR/Sequencer/test/concat.mlir b/compiler/IR/Sequencer/test/concat.mlir
similarity index 100%
rename from iree/compiler/IR/Sequencer/test/concat.mlir
rename to compiler/IR/Sequencer/test/concat.mlir
diff --git a/compiler/IR/StructureOps.cpp b/compiler/IR/StructureOps.cpp
new file mode 100644
index 0000000..0778c76
--- /dev/null
+++ b/compiler/IR/StructureOps.cpp
@@ -0,0 +1,250 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/StructureOps.h"
+
+#include "compiler/IR/Types.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+
+//===----------------------------------------------------------------------===//
+// Generic printers and parsers.
+//===----------------------------------------------------------------------===//
+
+// Parses an op that has no inputs and no outputs.
+static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) {
+ if (failed(parser.parseOptionalAttributeDict(state.attributes))) {
+ return failure();
+ }
+ return success();
+}
+
+// Prints an op that has no inputs and no outputs.
+static void printNoIOOp(Operation *op, OpAsmPrinter &printer) {
+ printer << op->getName();
+ printer.printOptionalAttrDict(op->getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// iree.module
+//===----------------------------------------------------------------------===//
+
+void ModuleOp::build(Builder *builder, OperationState &state) {
+ ensureTerminator(*state.addRegion(), *builder, state.location);
+}
+
+static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
+ Region *body = state.addRegion();
+ if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) {
+ return failure();
+ }
+ if (parser.parseOptionalAttributeDict(state.attributes)) {
+ return failure();
+ }
+ ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
+ return success();
+}
+
+static void printModuleOp(OpAsmPrinter &printer, Operation *op) {
+ printer << op->getName();
+ printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+ printer.printOptionalAttrDict(op->getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// iree.multi_arch_executable
+//===----------------------------------------------------------------------===//
+
+void MultiArchExecutableOp::build(Builder *builder, OperationState &state,
+ StringRef name) {
+ state.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder->getStringAttr(name));
+ ensureTerminator(*state.addRegion(), *builder, state.location);
+}
+
+static ParseResult parseMultiArchExecutableOp(OpAsmParser &parser,
+ OperationState &state) {
+ auto &builder = parser.getBuilder();
+
+ // Parse the name as a symbol reference attr and then convert to a string.
+ SymbolRefAttr nameAttr;
+ if (failed(parser.parseAttribute(nameAttr, SymbolTable::getSymbolAttrName(),
+ state.attributes))) {
+ return failure();
+ }
+ state.attributes.back().second = builder.getStringAttr(nameAttr.getValue());
+
+ if (succeeded(parser.parseOptionalLSquare())) {
+ IntegerAttr ordinalAttr;
+ if (failed(parser.parseAttribute(ordinalAttr, builder.getIntegerType(32),
+ "iree.ordinal", state.attributes)) ||
+ failed(parser.parseRSquare())) {
+ return failure();
+ }
+ }
+
+ if (failed(parser.parseLParen()) || failed(parser.parseRParen())) {
+ return failure();
+ }
+
+ Region *body = state.addRegion();
+ if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) {
+ return failure();
+ }
+ if (succeeded(parser.parseOptionalKeyword("attributes"))) {
+ if (failed(parser.parseOptionalAttributeDict(state.attributes))) {
+ return failure();
+ }
+ }
+
+ MultiArchExecutableOp::ensureTerminator(*body, builder, state.location);
+
+ return success();
+}
+
+static void printMultiArchExecutableOp(OpAsmPrinter &printer,
+ MultiArchExecutableOp op) {
+ printer << op.getOperationName() << " @" << op.sym_name();
+ if (auto ordinalAttr =
+ op.getAttr("iree.ordinal").dyn_cast_or_null<IntegerAttr>()) {
+ printer << "[" << ordinalAttr.getInt() << "]";
+ }
+ printer << "()";
+
+ printer.printRegion(op.body(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+
+ // Print out executable attributes, if present.
+ SmallVector<StringRef, 2> ignoredAttrs = {
+ SymbolTable::getSymbolAttrName(),
+ "iree.ordinal",
+ };
+ SmallVector<NamedAttribute, 4> attrs(
+ llvm::make_filter_range(op.getAttrs(), [&](const NamedAttribute &attr) {
+ return llvm::count(ignoredAttrs, attr.first) == 0;
+ }));
+ if (!attrs.empty()) {
+ printer << "\n attributes ";
+ printer.printOptionalAttrDict(attrs);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// iree.executable
+//===----------------------------------------------------------------------===//
+
+void ExecutableOp::build(Builder *builder, OperationState &state,
+ IREE::ExecutableFormat format) {
+ state.addAttribute("format",
+ builder->getI32IntegerAttr(static_cast<uint32_t>(format)));
+ ensureTerminator(*state.addRegion(), *builder, state.location);
+}
+
+static ParseResult parseExecutableOp(OpAsmParser &parser,
+ OperationState &state) {
+ auto &builder = parser.getBuilder();
+
+ if (succeeded(parser.parseOptionalLSquare())) {
+ IntegerAttr ordinalAttr;
+ if (failed(parser.parseAttribute(ordinalAttr, builder.getIntegerType(32),
+ "iree.ordinal", state.attributes)) ||
+ failed(parser.parseRSquare())) {
+ return failure();
+ }
+ }
+
+ IntegerAttr executableOrdinalAttr;
+ StringAttr formatAttr;
+ llvm::SMLoc formatLoc;
+ if (failed(parser.parseLParen()) ||
+ failed(parser.getCurrentLocation(&formatLoc)) ||
+ failed(parser.parseAttribute(formatAttr, "format", state.attributes))) {
+ return failure();
+ }
+ auto format = symbolizeExecutableFormat(formatAttr.getValue());
+ if (!format.hasValue()) {
+ return parser.emitError(formatLoc)
+ << "Unknown executable format " << formatAttr.getValue();
+ }
+ state.attributes.back().second =
+ builder.getI32IntegerAttr(static_cast<int32_t>(format.getValue()));
+
+ Region *body = state.addRegion();
+ if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) {
+ return failure();
+ }
+ if (succeeded(parser.parseOptionalKeyword("attributes"))) {
+ if (failed(parser.parseOptionalAttributeDict(state.attributes))) {
+ return failure();
+ }
+ }
+
+ ExecutableOp::ensureTerminator(*body, parser.getBuilder(), state.location);
+
+ return success();
+}
+
+static void printExecutableOp(OpAsmPrinter &printer, ExecutableOp op) {
+ printer << op.getOperationName();
+ if (auto ordinalAttr =
+ op.getAttr("iree.ordinal").dyn_cast_or_null<IntegerAttr>()) {
+ printer << "[" << ordinalAttr.getInt() << "]";
+ }
+ printer << "(";
+ auto format = symbolizeExecutableFormat(op.format());
+ if (format.hasValue()) {
+ printer << stringifyExecutableFormat(format.getValue());
+ } else {
+ printer << "INVALID FORMAT";
+ }
+ printer << ")";
+
+ printer.printRegion(op.body(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+
+ // Print out executable attributes, if present.
+ SmallVector<StringRef, 2> ignoredAttrs = {
+ "iree.ordinal",
+ "format",
+ };
+ SmallVector<NamedAttribute, 4> attrs(
+ llvm::make_filter_range(op.getAttrs(), [&](const NamedAttribute &attr) {
+ return llvm::count(ignoredAttrs, attr.first) == 0;
+ }));
+ if (!attrs.empty()) {
+ printer << "\n attributes ";
+ printer.printOptionalAttrDict(attrs);
+ }
+}
+
+#define GET_OP_CLASSES
+#include "compiler/IR/StructureOps.cpp.inc"
+
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/IR/StructureOps.h b/compiler/IR/StructureOps.h
new file mode 100644
index 0000000..f187b71
--- /dev/null
+++ b/compiler/IR/StructureOps.h
@@ -0,0 +1,40 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_IR_STRUCTUREOPS_H_
+#define IREE_COMPILER_IR_STRUCTUREOPS_H_
+
+#include <cstdint>
+
+#include "compiler/IR/Types.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/FunctionSupport.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+
+#define GET_OP_CLASSES
+#include "compiler/IR/StructureOps.h.inc"
+
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_IR_STRUCTUREOPS_H_
diff --git a/compiler/IR/StructureOps.td b/compiler/IR/StructureOps.td
new file mode 100644
index 0000000..afed1a0
--- /dev/null
+++ b/compiler/IR/StructureOps.td
@@ -0,0 +1,122 @@
+// Structural ops such as 'module' and 'executable'.
+// These are used to organize IREE IR into regions representing ops that act at
+// the sequencer level (coarse control flow/scheduling) and ops that perform
+// actual work (math/etc) on runtime execution backends.
+
+#ifdef IREE_STRUCTURE_OPS
+#else
+#define IREE_STRUCTURE_OPS
+
+#ifdef IREE_OP_BASE
+#else
+include "compiler/IR/OpBase.td"
+#endif // IREE_OP_BASE
+
+class IREE_StructureOp<string mnemonic, list<OpTrait> traits = []> :
+ Op<IREE_Dialect, mnemonic, traits> {
+ let parser = [{ return parse$cppClass(parser, result); }];
+ let printer = [{ print$cppClass(p, *this); }];
+}
+
+def IREE_ModuleOp :
+ IREE_StructureOp<"module", [
+ SingleBlockImplicitTerminator<"ModuleEndOp">,
+ NativeOpTrait<"SymbolTable">
+ ]> {
+ let regions = (region SizedRegion<1>:$body);
+ let extraClassDeclaration = [{
+ Block& getBlock() {
+ return this->getOperation()->getRegion(0).front();
+ }
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [OpBuilder<"Builder *, OperationState &state">];
+}
+
+def IREE_ModuleEndOp :
+ IREE_StructureOp<"_module_end", [
+ IREE_ModuleOnly,
+ Terminator
+ ]> {
+ let parser = [{ return parseNoIOOp(parser, result); }];
+ let printer = [{ printNoIOOp(getOperation(), p); }];
+}
+
+def IREE_MultiArchExecutableOp :
+ IREE_StructureOp<"multi_arch_executable", [
+ // TODO(benvanik): make iree.module work and make this IREE_ModuleOnly.
+ SingleBlockImplicitTerminator<"MultiArchExecutableEndOp">
+ ]> {
+ let arguments = (ins
+ StrAttr:$sym_name,
+ OptionalAttr<I32Attr>:$ordinal
+ );
+
+ let regions = (region SizedRegion<1>:$body);
+ let extraClassDeclaration = [{
+ StringRef getName() {
+ return this->getOperation()->template getAttrOfType<StringAttr>(
+ ::mlir::SymbolTable::getSymbolAttrName()).getValue();
+ }
+
+ Region& getBody() {
+ return this->getOperation()->getRegion(0);
+ }
+ Block& getBlock() {
+ return this->getOperation()->getRegion(0).front();
+ }
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<"Builder *builder, OperationState &state, StringRef name">,
+ ];
+}
+
+def IREE_MultiArchExecutableEndOp :
+ IREE_StructureOp<"_multi_arch_executable_end", [
+ IREE_MultiArchExecutableOnly,
+ Terminator
+ ]> {
+ let parser = [{ return parseNoIOOp(parser, result); }];
+ let printer = [{ printNoIOOp(getOperation(), p); }];
+}
+
+def IREE_ExecutableOp :
+ IREE_StructureOp<"executable", [
+ SingleBlockImplicitTerminator<"ExecutableEndOp">,
+ NativeOpTrait<"SymbolTable">
+ ]> {
+ let arguments = (ins
+ IREE_ExecutableFormatAttr:$format,
+ OptionalAttr<I32Attr>:$ordinal
+ );
+
+ let regions = (region SizedRegion<1>:$body);
+ let extraClassDeclaration = [{
+ Region& getBody() {
+ return this->getOperation()->getRegion(0);
+ }
+ Block& getBlock() {
+ return this->getOperation()->getRegion(0).front();
+ }
+ ::mlir::ModuleOp getInnerModule() {
+ return *getBlock().getOps<::mlir::ModuleOp>().begin();
+ }
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<[{Builder *builder, OperationState &state,
+ ExecutableFormat executable_format}]>,
+ ];
+}
+
+def IREE_ExecutableEndOp :
+ IREE_StructureOp<"_executable_end", [Terminator, IREE_ExecutableOnly]> {
+ let parser = [{ return parseNoIOOp(parser, result); }];
+ let printer = [{ printNoIOOp(getOperation(), p); }];
+}
+
+#endif // IREE_STRUCTURE_OPS
diff --git a/compiler/IR/Traits.cpp b/compiler/IR/Traits.cpp
new file mode 100644
index 0000000..bc8a3fb
--- /dev/null
+++ b/compiler/IR/Traits.cpp
@@ -0,0 +1,23 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Traits.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// TODO(benvanik): traits.
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/IR/Traits.h b/compiler/IR/Traits.h
similarity index 100%
rename from iree/compiler/IR/Traits.h
rename to compiler/IR/Traits.h
diff --git a/compiler/IR/Types.cpp b/compiler/IR/Types.cpp
new file mode 100644
index 0000000..3e65bb5
--- /dev/null
+++ b/compiler/IR/Types.cpp
@@ -0,0 +1,53 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Types.h"
+
+#include "compiler/IR/Enums.cpp.inc"
+
+namespace mlir {
+namespace iree_compiler {
+
+// static
+DeviceType DeviceType::get(MLIRContext *context) {
+ return Base::get(context, TypeKind::Device);
+}
+
+// static
+DeviceGroupType DeviceGroupType::get(MLIRContext *context) {
+ return Base::get(context, TypeKind::DeviceGroup);
+}
+
+// static
+CommandBufferType CommandBufferType::get(MLIRContext *context) {
+ return Base::get(context, TypeKind::CommandBuffer);
+}
+
+// static
+EventType EventType::get(MLIRContext *context) {
+ return Base::get(context, TypeKind::Event);
+}
+
+// static
+SemaphoreType SemaphoreType::get(MLIRContext *context) {
+ return Base::get(context, TypeKind::Semaphore);
+}
+
+// static
+FenceType FenceType::get(MLIRContext *context) {
+ return Base::get(context, TypeKind::Fence);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/IR/Types.h b/compiler/IR/Types.h
new file mode 100644
index 0000000..9379877
--- /dev/null
+++ b/compiler/IR/Types.h
@@ -0,0 +1,115 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_IR_TYPES_H_
+#define IREE_COMPILER_IR_TYPES_H_
+
+#include <cstdint>
+
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+
+// Order matters.
+#include "compiler/IR/Enums.h.inc"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace TypeKind {
+enum Kind {
+ Device = Type::FIRST_IREE_TYPE,
+ DeviceGroup,
+ CommandBuffer,
+ Event,
+ Semaphore,
+ Fence,
+};
+} // namespace TypeKind
+
+// clang-format off
+#define IREE_TYPE_TABLE(map) \
+ map("device", TypeKind::Device, DeviceType) \
+ map("device_group", TypeKind::DeviceGroup, DeviceGroupType) \
+ map("command_buffer", TypeKind::CommandBuffer, CommandBufferType) \
+ map("event", TypeKind::Event, EventType) \
+ map("semaphore", TypeKind::Semaphore, SemaphoreType) \
+ map("fence", TypeKind::Fence, FenceType)
+// clang-format on
+
+// iree.device mapping to a runtime-resolved device type.
+class DeviceType : public Type::TypeBase<DeviceType, Type> {
+ public:
+ using Base::Base;
+
+ static bool kindof(unsigned kind) { return kind == TypeKind::Device; }
+
+ static DeviceType get(MLIRContext *context);
+};
+
+// iree.device_group relating multiple iree.device requirements with each other.
+class DeviceGroupType : public Type::TypeBase<DeviceGroupType, Type> {
+ public:
+ using Base::Base;
+
+ static bool kindof(unsigned kind) { return kind == TypeKind::DeviceGroup; }
+
+ static DeviceGroupType get(MLIRContext *context);
+};
+
+// iree.command_buffer mapping to an iree::hal::CommandBuffer.
+class CommandBufferType : public Type::TypeBase<CommandBufferType, Type> {
+ public:
+ using Base::Base;
+
+ static bool kindof(unsigned kind) { return kind == TypeKind::CommandBuffer; }
+
+ static CommandBufferType get(MLIRContext *context);
+};
+
+// iree.event mapping to an iree::hal::Event.
+class EventType : public Type::TypeBase<EventType, Type> {
+ public:
+ using Base::Base;
+
+ static bool kindof(unsigned kind) { return kind == TypeKind::Event; }
+
+ static EventType get(MLIRContext *context);
+};
+
+// iree.semaphore mapping to an iree::hal::Semaphore.
+class SemaphoreType : public Type::TypeBase<SemaphoreType, Type> {
+ public:
+ using Base::Base;
+
+ static bool kindof(unsigned kind) { return kind == TypeKind::Semaphore; }
+
+ static SemaphoreType get(MLIRContext *context);
+};
+
+// iree.fence mapping to an iree::hal::Fence.
+class FenceType : public Type::TypeBase<FenceType, Type> {
+ public:
+ using Base::Base;
+
+ static bool kindof(unsigned kind) { return kind == TypeKind::Fence; }
+
+ static FenceType get(MLIRContext *context);
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_IR_TYPES_H_
diff --git a/compiler/IR/test/BUILD b/compiler/IR/test/BUILD
new file mode 100644
index 0000000..44a5820
--- /dev/null
+++ b/compiler/IR/test/BUILD
@@ -0,0 +1,15 @@
+load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_setup_lit_package(
+ data = [
+ "///tools:iree-opt",
+ "///tools:iree-run-mlir",
+ ],
+)
+
+iree_glob_lit_tests()
diff --git a/iree/compiler/IR/test/bindings.mlir b/compiler/IR/test/bindings.mlir
similarity index 100%
rename from iree/compiler/IR/test/bindings.mlir
rename to compiler/IR/test/bindings.mlir
diff --git a/iree/compiler/IR/test/constant.mlir b/compiler/IR/test/constant.mlir
similarity index 100%
rename from iree/compiler/IR/test/constant.mlir
rename to compiler/IR/test/constant.mlir
diff --git a/iree/compiler/IR/test/dispatch_regions.mlir b/compiler/IR/test/dispatch_regions.mlir
similarity index 100%
rename from iree/compiler/IR/test/dispatch_regions.mlir
rename to compiler/IR/test/dispatch_regions.mlir
diff --git a/iree/compiler/IR/test/reduction_regions.mlir b/compiler/IR/test/reduction_regions.mlir
similarity index 100%
rename from iree/compiler/IR/test/reduction_regions.mlir
rename to compiler/IR/test/reduction_regions.mlir
diff --git a/iree/compiler/IR/test/scalar_memref.mlir b/compiler/IR/test/scalar_memref.mlir
similarity index 100%
rename from iree/compiler/IR/test/scalar_memref.mlir
rename to compiler/IR/test/scalar_memref.mlir
diff --git a/iree/compiler/IR/test/tensor_memref.mlir b/compiler/IR/test/tensor_memref.mlir
similarity index 100%
rename from iree/compiler/IR/test/tensor_memref.mlir
rename to compiler/IR/test/tensor_memref.mlir
diff --git a/compiler/Serialization/BUILD b/compiler/Serialization/BUILD
new file mode 100644
index 0000000..523684e
--- /dev/null
+++ b/compiler/Serialization/BUILD
@@ -0,0 +1,43 @@
+# Serialization for the VM bytecode.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "Serialization",
+ srcs = [
+ "BytecodeTables.cpp",
+ "BytecodeWriter.cpp",
+ "VMDeviceTableBuilder.cpp",
+ "VMExecutableTableBuilder.cpp",
+ "VMFunctionBuilder.cpp",
+ "VMFunctionTableBuilder.cpp",
+ "VMModuleBuilder.cpp",
+ "VMSourceMapBuilder.cpp",
+ ],
+ hdrs = [
+ "BytecodeTables.h",
+ "BytecodeWriter.h",
+ "VMDeviceTableBuilder.h",
+ "VMExecutableTableBuilder.h",
+ "VMFunctionBuilder.h",
+ "VMFunctionTableBuilder.h",
+ "VMModuleBuilder.h",
+ "VMSourceMapBuilder.h",
+ ],
+ deps = [
+ "///compiler/IR",
+ "///compiler/Utils",
+ "///schemas",
+ "///schemas/bytecode:bytecode_v0",
+ "///schemas/bytecode:interpreter_bytecode_v0",
+ "///schemas/bytecode:sequencer_bytecode_v0",
+ "@com_github_google_flatbuffers//:flatbuffers",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:StandardOps",
+ "@local_config_mlir//:Support",
+ ],
+)
diff --git a/compiler/Serialization/BytecodeTables.cpp b/compiler/Serialization/BytecodeTables.cpp
new file mode 100644
index 0000000..80f63c9
--- /dev/null
+++ b/compiler/Serialization/BytecodeTables.cpp
@@ -0,0 +1,73 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Serialization/BytecodeTables.h"
+
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Info tables mapping 1:1 with bytecode ops.
+//
+// Note that we ensure the table is 256 elements long exactly to make sure
+// that unused opcodes are handled gracefully.
+#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \
+ { \
+ name, \
+ flags, \
+ {operand_encodings}, \
+ },
+
+static const OpcodeInfo kInterpreterInfoTable[256] = {
+ IREE_INTERPRETER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)};
+
+static const OpcodeInfo kSequencerInfoTable[256] = {
+ IREE_SEQUENCER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)};
+
+#undef DECLARE_INFO
+
+} // namespace
+
+llvm::Optional<iree::InterpreterOpcode> GetInterpreterOpcodeByName(
+ StringRef name) {
+ for (int i = 0; i < llvm::array_lengthof(kInterpreterInfoTable); ++i) {
+ if (name == kInterpreterInfoTable[i].mnemonic) {
+ return static_cast<iree::InterpreterOpcode>(i);
+ }
+ }
+ return llvm::None;
+}
+
+const OpcodeInfo& GetInterpreterOpcodeInfo(iree::InterpreterOpcode opcode) {
+ return kInterpreterInfoTable[static_cast<uint8_t>(opcode)];
+}
+
+llvm::Optional<iree::SequencerOpcode> GetSequencerOpcodeByName(StringRef name) {
+ for (int i = 0; i < llvm::array_lengthof(kSequencerInfoTable); ++i) {
+ if (name == kSequencerInfoTable[i].mnemonic) {
+ return static_cast<iree::SequencerOpcode>(i);
+ }
+ }
+ return llvm::None;
+}
+
+const OpcodeInfo& GetSequencerOpcodeInfo(iree::SequencerOpcode opcode) {
+ return kSequencerInfoTable[static_cast<uint8_t>(opcode)];
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Serialization/BytecodeTables.h b/compiler/Serialization/BytecodeTables.h
new file mode 100644
index 0000000..a649c97
--- /dev/null
+++ b/compiler/Serialization/BytecodeTables.h
@@ -0,0 +1,52 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_
+#define IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
+#include "mlir/Support/LLVM.h"
+#include "schemas/bytecode/interpreter_bytecode_v0.h"
+#include "schemas/bytecode/sequencer_bytecode_v0.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+struct OpcodeInfo {
+ const char* mnemonic = nullptr;
+ iree::OpcodeFlagBitfield flags = iree::OpcodeFlagBitfield::kDefault;
+ union {
+ const char operands_value[8] = {0};
+ const iree::OperandEncoding operands[8];
+ };
+};
+
+// Returns an opcode - if found - for the given interpreter op.
+llvm::Optional<iree::InterpreterOpcode> GetInterpreterOpcodeByName(
+ StringRef name);
+
+// Returns the info for the given interpreter opcode.
+const OpcodeInfo& GetInterpreterOpcodeInfo(iree::InterpreterOpcode opcode);
+
+// Returns an opcode - if found - for the given sequencer op.
+llvm::Optional<iree::SequencerOpcode> GetSequencerOpcodeByName(StringRef name);
+
+// Returns the info for the given sequencer opcode.
+const OpcodeInfo& GetSequencerOpcodeInfo(iree::SequencerOpcode opcode);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_
diff --git a/compiler/Serialization/BytecodeWriter.cpp b/compiler/Serialization/BytecodeWriter.cpp
new file mode 100644
index 0000000..beb1e1e
--- /dev/null
+++ b/compiler/Serialization/BytecodeWriter.cpp
@@ -0,0 +1,334 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Serialization/BytecodeWriter.h"
+
+#include <algorithm>
+
+#include "compiler/Utils/Macros.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+LogicalResult BytecodeWriter::WriteCount(int count) {
+ if (count > UINT8_MAX) {
+ // TODO(benvanik): varints?
+ llvm::errs() << "Too many items: " << count
+ << "; only 0-UINT8_MAX are supported";
+ return failure();
+ }
+ return WriteUint8(static_cast<uint8_t>(count));
+}
+
+LogicalResult BytecodeWriter::WriteTypeIndex(Type type) {
+ iree::BuiltinType type_index;
+ if (type.isInteger(8)) {
+ type_index = iree::BuiltinType::kI8;
+ } else if (type.isInteger(16)) {
+ type_index = iree::BuiltinType::kI16;
+ } else if (type.isInteger(32)) {
+ type_index = iree::BuiltinType::kI32;
+ } else if (type.isInteger(64)) {
+ type_index = iree::BuiltinType::kI64;
+ } else if (type.isF16()) {
+ type_index = iree::BuiltinType::kF16;
+ } else if (type.isF32()) {
+ type_index = iree::BuiltinType::kF32;
+ } else if (type.isF64()) {
+ type_index = iree::BuiltinType::kF64;
+ } else {
+ // TODO(benvanik): support unknown types as BuiltinType::kOpaque?
+ return emitError(UnknownLoc::get(type.getContext()))
+ << "Type " << type << " cannot be represented by a builtin type";
+ }
+ return WriteUint8(static_cast<uint8_t>(type_index));
+}
+
+LogicalResult BytecodeWriter::WriteFunctionOrdinal(FuncOp function) {
+ auto functionOrdinal = function.getAttrOfType<IntegerAttr>("iree.ordinal");
+ if (!functionOrdinal) {
+ return function.emitError() << "Ordinal not assigned to function";
+ }
+ RETURN_IF_FAILURE(WriteUint32(functionOrdinal.getInt()));
+ return success();
+}
+
+LogicalResult BytecodeWriter::WriteImportOrdinal(FuncOp function) {
+ // For now this is the same as internal function ordinals, though we could
+ // probably shrink it.
+ return WriteFunctionOrdinal(function);
+}
+
+LogicalResult BytecodeWriter::WriteConstant(MemRefType memRefType,
+ Attribute baseAttr) {
+ // All types are memrefs, so we only need the element type.
+ RETURN_IF_FAILURE(WriteTypeIndex(memRefType.getElementType()));
+
+ // Write shape (we could optimize this for cases of scalars and such).
+ RETURN_IF_FAILURE(WriteCount(memRefType.getRank()));
+ for (int i = 0; i < memRefType.getRank(); ++i) {
+ RETURN_IF_FAILURE(WriteInt32(memRefType.getDimSize(i)));
+ }
+
+ if (auto attr = baseAttr.dyn_cast<SplatElementsAttr>()) {
+ RETURN_IF_FAILURE(
+ WriteUint8(static_cast<uint8_t>(iree::ConstantEncoding::kSplat)));
+ return WriteAttributeData(attr.getSplatValue());
+ }
+ RETURN_IF_FAILURE(
+ WriteUint8(static_cast<uint8_t>(iree::ConstantEncoding::kDense)));
+ return WriteAttributeData(baseAttr);
+}
+
+LogicalResult BytecodeWriter::WriteAttributeData(Attribute baseAttr) {
+ if (auto attr = baseAttr.dyn_cast<BoolAttr>()) {
+ return WriteUint8(attr.getValue() ? 1 : 0);
+ } else if (auto attr = baseAttr.dyn_cast<IntegerAttr>()) {
+ if (attr.getType().isIndex()) {
+ int32_t value = static_cast<int32_t>(attr.getInt());
+ return WriteBytes(&value, 4);
+ } else {
+ int bitWidth = attr.getValue().getBitWidth();
+ switch (bitWidth) {
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ return WriteBytes(attr.getValue().getRawData(), bitWidth / 8);
+ default:
+ return emitError(UnknownLoc::get(baseAttr.getContext()))
+ << "Bit width for integers must be one of 8,16,32,64; others "
+ "not implemented: "
+ << bitWidth;
+ }
+ }
+ } else if (auto attr = baseAttr.dyn_cast<FloatAttr>()) {
+ int bitWidth = attr.getType().getIntOrFloatBitWidth();
+ auto bitcastValue = attr.getValue().bitcastToAPInt();
+ switch (bitWidth) {
+ case 16:
+ case 32:
+ case 64:
+ return WriteBytes(bitcastValue.getRawData(), bitWidth / 8);
+ default:
+ return emitError(UnknownLoc::get(baseAttr.getContext()))
+ << "Bit width for floats must be one of 16,32,64; others "
+ "not implemented: "
+ << bitWidth;
+ }
+ } else if (auto attr = baseAttr.dyn_cast<StringAttr>()) {
+ // TODO(benvanik): other attribute encodings.
+ } else if (auto attr = baseAttr.dyn_cast<ArrayAttr>()) {
+ // TODO(benvanik): other attribute encodings.
+ } else if (auto attr = baseAttr.dyn_cast<AffineMapAttr>()) {
+ // TODO(benvanik): other attribute encodings.
+ } else if (auto attr = baseAttr.dyn_cast<IntegerSetAttr>()) {
+ // TODO(benvanik): other attribute encodings.
+ } else if (auto attr = baseAttr.dyn_cast<TypeAttr>()) {
+ // TODO(benvanik): other attribute encodings.
+ } else if (auto attr = baseAttr.dyn_cast<SymbolRefAttr>()) {
+ // TODO(benvanik): other attribute encodings.
+ } else if (auto attr = baseAttr.dyn_cast<SplatElementsAttr>()) {
+ return WriteAttributeData(attr.getSplatValue());
+ } else if (auto attr = baseAttr.dyn_cast<DenseIntElementsAttr>()) {
+ int elementCount = attr.getType().getNumElements();
+ if (elementCount == 0) {
+ return success();
+ }
+ int bitWidth = attr.getType().getElementTypeBitWidth();
+ int byteWidth = bitWidth / 8;
+ auto dst = ReserveBytes(elementCount * byteWidth);
+ if (dst.empty()) return failure();
+ uint8_t *dstPtr = dst.data();
+ for (auto element : attr) {
+ assert(element.getBitWidth() == bitWidth);
+ std::memcpy(dstPtr, element.getRawData(), byteWidth);
+ dstPtr += byteWidth;
+ }
+ return success();
+ } else if (auto attr = baseAttr.dyn_cast<DenseFPElementsAttr>()) {
+ int elementCount = attr.getType().getNumElements();
+ if (elementCount == 0) {
+ return success();
+ }
+ int bitWidth = attr.getType().getElementTypeBitWidth();
+ auto dst = ReserveBytes(elementCount * bitWidth / 8);
+ if (dst.empty()) return failure();
+ uint8_t *dstPtr = dst.data();
+ for (auto element : attr) {
+ auto bitcastValue = element.bitcastToAPInt();
+ std::memcpy(dstPtr, bitcastValue.getRawData(),
+ bitcastValue.getBitWidth() / 8);
+ dstPtr += bitWidth / 8;
+ }
+ return success();
+ } else if (auto attr = baseAttr.dyn_cast<DenseElementsAttr>()) {
+ // TODO(benvanik): other attribute encodings.
+ } else if (auto attr = baseAttr.dyn_cast<OpaqueElementsAttr>()) {
+ // TODO(benvanik): other attribute encodings.
+ } else if (auto attr = baseAttr.dyn_cast<SparseElementsAttr>()) {
+ // TODO(benvanik): other attribute encodings.
+ }
+ return emitError(UnknownLoc::get(baseAttr.getContext()))
+ << "Serializer for attribute kind "
+ << static_cast<int>(baseAttr.getKind()) << " not implemented";
+}
+
+Optional<int> BytecodeWriter::LookupLocalOrdinal(Value *value) {
+ int ordinal;
+ auto it = localMap_.find(value);
+ if (it != localMap_.end()) {
+ ordinal = it->second;
+ } else {
+ ordinal = localMap_.size();
+ localMap_.insert({value, ordinal});
+ }
+ if (ordinal > UINT16_MAX) {
+ // TODO(benvanik): varints?
+ emitError(UnknownLoc::get(value->getContext()))
+ << "Too many ordinals: " << ordinal
+ << "; only 0-UINT16_MAX are supported";
+ return llvm::None;
+ }
+ return ordinal;
+}
+
+LogicalResult BytecodeWriter::PrepareLocal(Value *value) {
+ if (!LookupLocalOrdinal(value).hasValue()) return failure();
+ return success();
+}
+
+LogicalResult BytecodeWriter::WriteLocal(Value *value) {
+ auto ordinal = LookupLocalOrdinal(value);
+ if (!ordinal.hasValue()) {
+ return failure();
+ }
+ if (ordinal.getValue() > UINT16_MAX) {
+ // TODO(benvanik): varints?
+ return emitError(UnknownLoc::get(value->getContext()))
+ << "Too many locals: " << ordinal.getValue()
+ << "; only 0-UINT16_MAX are supported";
+ }
+ return WriteUint16(static_cast<uint16_t>(ordinal.getValue()));
+}
+
+LogicalResult BytecodeWriter::WriteLocals(
+ llvm::iterator_range<Operation::operand_iterator> values) {
+ int count = std::distance(values.begin(), values.end());
+ RETURN_IF_FAILURE(WriteCount(count));
+ for (auto *value : values) {
+ RETURN_IF_FAILURE(WriteLocal(value));
+ }
+ return success();
+}
+
+LogicalResult BytecodeWriter::WriteLocals(
+ llvm::iterator_range<Operation::result_iterator> values) {
+ int count = std::distance(values.begin(), values.end());
+ RETURN_IF_FAILURE(WriteCount(count));
+ for (auto *value : values) {
+ RETURN_IF_FAILURE(WriteLocal(value));
+ }
+ return success();
+}
+
+MutableArrayRef<uint8_t> BytecodeWriter::ReserveBytes(size_t dataLength) {
+ int offset = bytecode_.size();
+ bytecode_.resize(offset + dataLength);
+ return MutableArrayRef<uint8_t>(
+ reinterpret_cast<uint8_t *>(bytecode_.data()) + offset, dataLength);
+}
+
+LogicalResult BytecodeWriter::WriteBytes(const void *data, size_t dataLength) {
+ auto dst = ReserveBytes(dataLength);
+ if (dataLength != dst.size()) {
+ return failure();
+ }
+ std::memcpy(dst.data(), data, dst.size());
+ return success();
+}
+
+LogicalResult BytecodeWriter::WriteUint8(uint8_t value) {
+ return WriteBytes(&value, sizeof(value));
+}
+
+LogicalResult BytecodeWriter::WriteUint16(uint16_t value) {
+ return WriteBytes(&value, sizeof(value));
+}
+
+LogicalResult BytecodeWriter::WriteInt32(int32_t value) {
+ return WriteBytes(&value, sizeof(value));
+}
+
+LogicalResult BytecodeWriter::WriteUint32(uint32_t value) {
+ return WriteBytes(&value, sizeof(value));
+}
+
+LogicalResult BytecodeWriter::WriteElementsAttrInt32(ElementsAttr attr) {
+ int elementCount = attr.getType().getNumElements();
+ RETURN_IF_FAILURE(WriteCount(elementCount));
+ for (auto value : attr.getValues<int32_t>()) {
+ RETURN_IF_FAILURE(WriteInt32(value));
+ }
+ return success();
+}
+
+LogicalResult BytecodeWriter::WriteShapePieces(const ShapedType &type) {
+ RETURN_IF_FAILURE(WriteCount(type.getRank()));
+ for (int64_t dim : type.getShape()) {
+ RETURN_IF_FAILURE(WriteInt32(dim));
+ }
+ return success();
+}
+
+LogicalResult BytecodeWriter::WriteShapePieces(ElementsAttr pieces) {
+ return WriteElementsAttrInt32(pieces);
+}
+
+LogicalResult BytecodeWriter::MarkBlockOffset(Block *block) {
+ blockOffsets_[block] = bytecode_.size();
+ return success();
+}
+
+LogicalResult BytecodeWriter::WriteBlockOffset(Block *targetBlock) {
+ // Reserve space for the offset and stash for later fixup.
+ blockOffsetFixups_.push_back({targetBlock, bytecode_.size()});
+ bytecode_.resize(bytecode_.size() + sizeof(int32_t));
+ return success();
+}
+
+LogicalResult BytecodeWriter::FixupOffsets() {
+ for (const auto &fixup : blockOffsetFixups_) {
+ auto it = blockOffsets_.find(fixup.first);
+ if (it == blockOffsets_.end()) {
+ llvm::errs() << "Block offset not found: " << fixup.first;
+ return failure();
+ }
+ std::memcpy(bytecode_.data() + fixup.second, &it->second, sizeof(int32_t));
+ }
+ blockOffsetFixups_.clear();
+ return success();
+}
+
+std::vector<uint8_t> BytecodeWriter::Finish() {
+ localMap_.clear();
+ return std::move(bytecode_);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Serialization/BytecodeWriter.h b/compiler/Serialization/BytecodeWriter.h
new file mode 100644
index 0000000..c46e196
--- /dev/null
+++ b/compiler/Serialization/BytecodeWriter.h
@@ -0,0 +1,96 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_
+#define IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_
+
+#include <cstddef>
+#include <utility>
+#include <vector>
+
+#include "compiler/IR/StructureOps.h"
+#include "llvm/ADT/Optional.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "schemas/bytecode/bytecode_v0.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+class BytecodeWriter {
+ public:
+ int offset() const { return bytecode_.size(); }
+
+ int local_count() const { return localMap_.size(); }
+
+ template <typename T>
+ LogicalResult WriteOpcode(T value) {
+ static_assert(sizeof(T) == sizeof(uint8_t), "Opcode enum size mismatch");
+ return WriteUint8(static_cast<uint8_t>(value));
+ }
+
+ LogicalResult WriteCount(int count);
+
+ LogicalResult WriteTypeIndex(Type type);
+
+ LogicalResult WriteFunctionOrdinal(FuncOp function);
+ LogicalResult WriteImportOrdinal(FuncOp function);
+
+ LogicalResult WriteConstant(MemRefType memRefType, Attribute baseAttr);
+ LogicalResult WriteAttributeData(Attribute baseAttr);
+
+ llvm::Optional<int> LookupLocalOrdinal(Value *value);
+ LogicalResult PrepareLocal(Value *value);
+ LogicalResult WriteLocal(Value *value);
+ LogicalResult WriteLocals(
+ llvm::iterator_range<Operation::operand_iterator> values);
+ LogicalResult WriteLocals(
+ llvm::iterator_range<Operation::result_iterator> values);
+
+ LogicalResult WriteBytes(const void *data, size_t dataLength);
+ MutableArrayRef<uint8_t> ReserveBytes(size_t dataLength);
+ LogicalResult WriteUint8(uint8_t value);
+ LogicalResult WriteUint16(uint16_t value);
+ LogicalResult WriteInt32(int32_t value);
+ LogicalResult WriteUint32(uint32_t value);
+
+ LogicalResult WriteElementsAttrInt32(ElementsAttr attr);
+
+ LogicalResult WriteShapePieces(const ShapedType &type);
+ LogicalResult WriteShapePieces(ElementsAttr pieces);
+
+ LogicalResult MarkBlockOffset(Block *block);
+ LogicalResult WriteBlockOffset(Block *targetBlock);
+ LogicalResult FixupOffsets();
+
+ std::vector<uint8_t> Finish();
+
+ private:
+ std::vector<uint8_t> bytecode_;
+
+ llvm::DenseMap<Value *, int> localMap_;
+
+ llvm::DenseMap<Block *, size_t> blockOffsets_;
+ std::vector<std::pair<Block *, size_t>> blockOffsetFixups_;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_
diff --git a/iree/compiler/Serialization/CMakeLists.txt b/compiler/Serialization/CMakeLists.txt
similarity index 100%
rename from iree/compiler/Serialization/CMakeLists.txt
rename to compiler/Serialization/CMakeLists.txt
diff --git a/compiler/Serialization/VMDeviceTableBuilder.cpp b/compiler/Serialization/VMDeviceTableBuilder.cpp
new file mode 100644
index 0000000..4c95ab9
--- /dev/null
+++ b/compiler/Serialization/VMDeviceTableBuilder.cpp
@@ -0,0 +1,46 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Serialization/VMDeviceTableBuilder.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+VMDeviceTableBuilder::VMDeviceTableBuilder(
+ ::flatbuffers::FlatBufferBuilder *fbb)
+ : fbb_(fbb) {}
+
+LogicalResult VMDeviceTableBuilder::AddDevice(
+ ::flatbuffers::Offset<iree::DeviceDef> deviceDef) {
+ deviceDefs_.push_back(deviceDef);
+ return success();
+}
+
+LogicalResult VMDeviceTableBuilder::AddDeviceGroup(
+ ::flatbuffers::Offset<iree::DeviceGroupDef> deviceGroupDef) {
+ deviceGroupDefs_.push_back(deviceGroupDef);
+ return success();
+}
+
+::flatbuffers::Offset<iree::DeviceTableDef> VMDeviceTableBuilder::Finish() {
+ auto devicesOffset = fbb_->CreateVector(deviceDefs_);
+ auto deviceGroupsOffset = fbb_->CreateVector(deviceGroupDefs_);
+ iree::DeviceTableDefBuilder dtdb(*fbb_);
+ dtdb.add_devices(devicesOffset);
+ dtdb.add_device_groups(deviceGroupsOffset);
+ return dtdb.Finish();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Serialization/VMDeviceTableBuilder.h b/compiler/Serialization/VMDeviceTableBuilder.h
new file mode 100644
index 0000000..1cd9147
--- /dev/null
+++ b/compiler/Serialization/VMDeviceTableBuilder.h
@@ -0,0 +1,45 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_
+#define IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_
+
+#include "flatbuffers/flatbuffers.h"
+#include "mlir/Support/LogicalResult.h"
+#include "schemas/device_table_def_generated.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+class VMDeviceTableBuilder {
+ public:
+ explicit VMDeviceTableBuilder(::flatbuffers::FlatBufferBuilder *fbb);
+
+ LogicalResult AddDevice(::flatbuffers::Offset<iree::DeviceDef> deviceDef);
+
+ LogicalResult AddDeviceGroup(
+ ::flatbuffers::Offset<iree::DeviceGroupDef> deviceGroupDef);
+
+ ::flatbuffers::Offset<iree::DeviceTableDef> Finish();
+
+ private:
+ ::flatbuffers::FlatBufferBuilder *fbb_;
+ std::vector<::flatbuffers::Offset<iree::DeviceDef>> deviceDefs_;
+ std::vector<::flatbuffers::Offset<iree::DeviceGroupDef>> deviceGroupDefs_;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_
diff --git a/compiler/Serialization/VMExecutableTableBuilder.cpp b/compiler/Serialization/VMExecutableTableBuilder.cpp
new file mode 100644
index 0000000..ad71958
--- /dev/null
+++ b/compiler/Serialization/VMExecutableTableBuilder.cpp
@@ -0,0 +1,41 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Serialization/VMExecutableTableBuilder.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+VMExecutableTableBuilder::VMExecutableTableBuilder(
+ ::flatbuffers::FlatBufferBuilder *fbb)
+ : fbb_(fbb) {}
+
+LogicalResult VMExecutableTableBuilder::AddMultiArchExecutable(
+ ::flatbuffers::Offset<iree::MultiArchExecutableDef>
+ multiArchExecutableDef) {
+ multiArchExecutableDefs_.push_back(multiArchExecutableDef);
+ return success();
+}
+
+::flatbuffers::Offset<iree::ExecutableTableDef>
+VMExecutableTableBuilder::Finish() {
+ auto multiArchExecutablesOffset =
+ fbb_->CreateVector(multiArchExecutableDefs_);
+ iree::ExecutableTableDefBuilder etdb(*fbb_);
+ etdb.add_multi_arch_executables(multiArchExecutablesOffset);
+ return etdb.Finish();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Serialization/VMExecutableTableBuilder.h b/compiler/Serialization/VMExecutableTableBuilder.h
new file mode 100644
index 0000000..0106c09
--- /dev/null
+++ b/compiler/Serialization/VMExecutableTableBuilder.h
@@ -0,0 +1,44 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_
+#define IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_
+
+#include "flatbuffers/flatbuffers.h"
+#include "mlir/Support/LogicalResult.h"
+#include "schemas/executable_table_def_generated.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+class VMExecutableTableBuilder {
+ public:
+ explicit VMExecutableTableBuilder(::flatbuffers::FlatBufferBuilder *fbb);
+
+ LogicalResult AddMultiArchExecutable(
+ ::flatbuffers::Offset<iree::MultiArchExecutableDef>
+ multiArchExecutableDef);
+
+ ::flatbuffers::Offset<iree::ExecutableTableDef> Finish();
+
+ private:
+ ::flatbuffers::FlatBufferBuilder *fbb_;
+ std::vector<::flatbuffers::Offset<iree::MultiArchExecutableDef>>
+ multiArchExecutableDefs_;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_
diff --git a/compiler/Serialization/VMFunctionBuilder.cpp b/compiler/Serialization/VMFunctionBuilder.cpp
new file mode 100644
index 0000000..23eda88
--- /dev/null
+++ b/compiler/Serialization/VMFunctionBuilder.cpp
@@ -0,0 +1,359 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Serialization/VMFunctionBuilder.h"
+
+#include "flatbuffers/flatbuffers.h"
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Types.h"
+#include "compiler/Serialization/BytecodeTables.h"
+#include "compiler/Utils/Macros.h"
+#include "schemas/bytecode/bytecode_v0.h"
+#include "schemas/type_def_generated.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Module.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+LogicalResult WriteGenericIreeOp(Block *block, Operation *op,
+ BytecodeWriter *writer) {
+ // Strip the dialect name from the op name and lookup the opcode.
+ // TODO(benvanik): adjust for supporting sequencer opcodes.
+
+ auto opName = op->getName().getStringRef();
+ auto dialect = op->getDialect();
+ if (!dialect) {
+ return op->emitOpError() << "Op does not belong to a registered dialect";
+ }
+
+ auto dialectNamespace = dialect->getNamespace();
+ std::unique_ptr<OpcodeInfo> operandInfo;
+ auto strippedOpName = opName.substr(opName.find('.') + 1).str();
+ if (dialectNamespace == "iree_ll_seq") {
+ auto opcode = GetSequencerOpcodeByName(strippedOpName);
+ if (!opcode.hasValue()) {
+ return op->emitOpError()
+ << "No sequencer opcode found for op; is it a pseudo op?";
+ }
+ RETURN_IF_FAILURE(writer->WriteOpcode(opcode.getValue()));
+ operandInfo =
+ std::make_unique<OpcodeInfo>(GetSequencerOpcodeInfo(opcode.getValue()));
+ } else if (dialectNamespace == "iree_ll_interp" ||
+ // TODO(gcmn) remove special case for IREE dialect?
+ dialectNamespace == IREEDialect::getDialectNamespace()) {
+ auto opcode = GetInterpreterOpcodeByName(strippedOpName);
+ if (!opcode.hasValue()) {
+ return op->emitOpError()
+ << "No interpreter opcode found for op; is it a pseudo op?";
+ }
+ RETURN_IF_FAILURE(writer->WriteOpcode(opcode.getValue()));
+ operandInfo = std::make_unique<OpcodeInfo>(
+ GetInterpreterOpcodeInfo(opcode.getValue()));
+ } else {
+ return op->emitOpError()
+ << "Op belongs to unknown dialect " << dialectNamespace.str();
+ }
+ // Write inputs and outputs based on the bytecode encoding.
+ int operandIndex = 0;
+ int resultIndex = 0;
+ for (int i = 0; i < llvm::array_lengthof(operandInfo->operands); ++i) {
+ auto op_encoding = operandInfo->operands[i];
+ if (op_encoding == iree::OperandEncoding::kNone) break;
+ switch (op_encoding) {
+ case iree::OperandEncoding::kInputSlot:
+ case iree::OperandEncoding::kOutputSlot: {
+ auto *value = op->getOperand(operandIndex++);
+ RETURN_IF_FAILURE(writer->WriteLocal(value));
+ break;
+ }
+ case iree::OperandEncoding::kVariadicInputSlots:
+ case iree::OperandEncoding::kVariadicOutputSlots: {
+ int count = op->getNumOperands() - operandIndex;
+ RETURN_IF_FAILURE(writer->WriteCount(count));
+ for (; count; --count) {
+ auto *value = op->getOperand(operandIndex++);
+ RETURN_IF_FAILURE(writer->WriteLocal(value));
+ }
+ break;
+ }
+ case iree::OperandEncoding::kResultSlot: {
+ auto *value = op->getResult(resultIndex++);
+ RETURN_IF_FAILURE(writer->WriteLocal(value));
+ break;
+ }
+ case iree::OperandEncoding::kVariadicResultSlots: {
+ int count = op->getNumResults() - resultIndex;
+ RETURN_IF_FAILURE(writer->WriteCount(count));
+ for (; count; --count) {
+ auto *value = op->getResult(resultIndex++);
+ RETURN_IF_FAILURE(writer->WriteLocal(value));
+ }
+ break;
+ }
+ case iree::OperandEncoding::kConstant:
+ case iree::OperandEncoding::kFunctionOrdinal:
+ case iree::OperandEncoding::kBlockOffset:
+ case iree::OperandEncoding::kTypeIndex:
+ case iree::OperandEncoding::kIndex:
+ case iree::OperandEncoding::kIndexList:
+ case iree::OperandEncoding::kCmpIPredicate:
+ case iree::OperandEncoding::kCmpFPredicate:
+ return op->emitOpError()
+ << "Operand encoding " << static_cast<char>(op_encoding)
+ << " not supported by generic writer for " << opName.str();
+ return failure();
+ default:
+ return op->emitOpError()
+ << "Operand encoding " << static_cast<char>(op_encoding) << " ("
+ << static_cast<int>(op_encoding) << ") not recognized (typo?)";
+ }
+ }
+
+ return success();
+}
+
+} // namespace
+
+VMFunctionBuilder::VMFunctionBuilder(FuncOp function,
+ VMFunctionTableBuilder *functionTable,
+ ::flatbuffers::FlatBufferBuilder *fbb)
+ : context_(function.getContext()),
+ function_(function),
+ functionTable_(functionTable),
+ fbb_(fbb) {}
+
+void VMFunctionBuilder::RegisterCustomWriter(StringRef operationName,
+ CustomWriterFn writerFn) {
+ customWriters_.insert({operationName, writerFn});
+}
+
+LogicalResult VMFunctionBuilder::ConvertBytecode() {
+ BytecodeWriter writer;
+ sourceMap_ = {};
+
+ RETURN_IF_FAILURE(BeginFunction(function_, &writer));
+ for (auto &block : function_.getBlocks()) {
+ RETURN_IF_FAILURE(BeginBlock(&block, &writer));
+ for (auto &op : block.getOperations()) {
+ if (failed(WriteOperation(&block, &op, &writer))) {
+ op.emitError() << "Unable to serialize operation";
+ return failure();
+ }
+ }
+ RETURN_IF_FAILURE(EndBlock(&block, block.getTerminator(), &writer));
+ }
+ RETURN_IF_FAILURE(EndFunction(function_, &writer));
+
+ int localCount = writer.local_count();
+ auto bodyBytes = writer.Finish();
+ auto bodyOffset = fbb_->CreateVector(
+ reinterpret_cast<const int8_t *>(bodyBytes.data()), bodyBytes.size());
+ iree::BytecodeDefBuilder bdb(*fbb_);
+ bdb.add_local_count(localCount);
+ bdb.add_contents(bodyOffset);
+ bytecodeDef_ = bdb.Finish();
+
+ return success();
+}
+
+::flatbuffers::Offset<iree::FunctionDef> VMFunctionBuilder::Finish() {
+ using TypeDefVector =
+ ::flatbuffers::Vector<::flatbuffers::Offset<iree::TypeDef>>;
+
+ const auto &functionType = function_.getType();
+ std::vector<::flatbuffers::Offset<iree::TypeDef>> inputs;
+ for (const auto &type : functionType.getInputs()) {
+ auto typeOffset = SerializeType(type, fbb_);
+ if (typeOffset.IsNull()) return {};
+ inputs.push_back(typeOffset);
+ }
+ ::flatbuffers::Offset<TypeDefVector> inputsOffset;
+ if (!inputs.empty()) {
+ inputsOffset = fbb_->CreateVector(inputs);
+ }
+
+ std::vector<::flatbuffers::Offset<iree::TypeDef>> results;
+ for (const auto &type : functionType.getResults()) {
+ auto typeOffset = SerializeType(type, fbb_);
+ if (typeOffset.IsNull()) return {};
+ results.push_back(typeOffset);
+ }
+ ::flatbuffers::Offset<TypeDefVector> resultsOffset;
+ if (!results.empty()) {
+ resultsOffset = fbb_->CreateVector(results);
+ }
+ iree::FunctionTypeDefBuilder ftb(*fbb_);
+ ftb.add_inputs(inputsOffset);
+ ftb.add_results(resultsOffset);
+ auto functionTypeOffset = ftb.Finish();
+
+ // TODO(benvanik): strip names of internal functions.
+ auto nameOffset = fbb_->CreateString(function_.getName().str());
+ iree::FunctionDefBuilder fdb(*fbb_);
+ fdb.add_name(nameOffset);
+ fdb.add_type(functionTypeOffset);
+ fdb.add_bytecode(bytecodeDef_);
+ return fdb.Finish();
+}
+
+LogicalResult VMFunctionBuilder::BeginFunction(FuncOp function,
+ BytecodeWriter *writer) {
+ // Assign value slots for all arguments and results.
+ // Keeping them at the front will make it easier to find during debugging
+ // and makes spans easier to compute at runtime.
+ for (auto argument : function.getArguments()) {
+ RETURN_IF_FAILURE(writer->PrepareLocal(argument));
+ }
+ return success();
+}
+
+LogicalResult VMFunctionBuilder::EndFunction(FuncOp function,
+ BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->FixupOffsets());
+ return success();
+}
+
+LogicalResult VMFunctionBuilder::BeginBlock(Block *block,
+ BytecodeWriter *writer) {
+ RETURN_IF_FAILURE(writer->MarkBlockOffset(block));
+ return success();
+}
+
+LogicalResult VMFunctionBuilder::EndBlock(Block *block, Operation *op,
+ BytecodeWriter *writer) {
+ return success();
+}
+
+LogicalResult VMFunctionBuilder::WriteOperation(Block *block, Operation *baseOp,
+ BytecodeWriter *writer) {
+ if (!baseOp->getLoc().isa<UnknownLoc>()) {
+ sourceMap_.locations.push_back({writer->offset(), baseOp->getLoc()});
+ }
+
+ // Check registered writers first to allow overrides.
+ auto writerIt = customWriters_.find(baseOp->getName().getStringRef());
+ if (writerIt != customWriters_.end()) {
+ return writerIt->second(baseOp, writer);
+ }
+
+ // Fallback to using the generic writer.
+ if (baseOp->getAbstractOperation()->dialect.getNamespace().startswith(
+ "iree")) {
+ RETURN_IF_FAILURE(WriteGenericIreeOp(block, baseOp, writer));
+ } else {
+ return baseOp->emitError()
+ << "Unsupported op " << baseOp->getName().getStringRef().str()
+ << "; incorrectly outlined or not yet implemented";
+ }
+ return success();
+}
+
+::flatbuffers::Offset<iree::TypeDef> VMFunctionBuilder::SerializeType(
+ Type type, ::flatbuffers::FlatBufferBuilder *fbb) {
+ ::flatbuffers::Offset<void> typeDefUnion;
+ iree::TypeDefUnion typeUnionType;
+ if (auto memRefType = type.dyn_cast<MemRefType>()) {
+ auto memRefTypeOffset = SerializeMemRefType(memRefType, fbb_);
+ if (memRefTypeOffset.IsNull()) return {};
+ typeDefUnion = memRefTypeOffset.Union();
+ typeUnionType = iree::TypeDefUnion::MemRefTypeDef;
+ } else if (auto deviceType = type.dyn_cast<DeviceType>()) {
+ typeDefUnion = iree::CreateDeviceTypeDef(*fbb).Union();
+ typeUnionType = iree::TypeDefUnion::DeviceTypeDef;
+ } else if (auto commandBufferType = type.dyn_cast<CommandBufferType>()) {
+ typeDefUnion = iree::CreateCommandBufferTypeDef(*fbb).Union();
+ typeUnionType = iree::TypeDefUnion::CommandBufferTypeDef;
+ } else if (auto eventType = type.dyn_cast<EventType>()) {
+ typeDefUnion = iree::CreateEventTypeDef(*fbb).Union();
+ typeUnionType = iree::TypeDefUnion::EventTypeDef;
+ } else if (auto semaphoreType = type.dyn_cast<SemaphoreType>()) {
+ typeDefUnion = iree::CreateSemaphoreTypeDef(*fbb).Union();
+ typeUnionType = iree::TypeDefUnion::SemaphoreTypeDef;
+ } else if (auto fenceType = type.dyn_cast<FenceType>()) {
+ typeDefUnion = iree::CreateFenceTypeDef(*fbb).Union();
+ typeUnionType = iree::TypeDefUnion::FenceTypeDef;
+ } else {
+ function_.emitError() << "Function " << function_.getName().str()
+ << " has unsupported I/O with type " << type;
+ return {};
+ }
+
+ iree::TypeDefBuilder tdb(*fbb);
+ tdb.add_type_union_type(typeUnionType);
+ tdb.add_type_union(typeDefUnion);
+ return tdb.Finish();
+}
+
+::flatbuffers::Offset<iree::MemRefTypeDef>
+VMFunctionBuilder::SerializeMemRefType(const MemRefType &type,
+ ::flatbuffers::FlatBufferBuilder *fbb) {
+ auto elementTypeOffset = SerializeElementType(type.getElementType(), fbb);
+ if (elementTypeOffset.IsNull()) return {};
+ std::vector<int> shape;
+ for (int dim : type.getShape()) {
+ shape.push_back(dim);
+ }
+ auto shapeOffset = fbb->CreateVector(shape);
+ iree::MemRefTypeDefBuilder tb(*fbb);
+ tb.add_element_type(elementTypeOffset);
+ tb.add_shape(shapeOffset);
+ tb.add_memory_space(type.getMemorySpace());
+ return tb.Finish();
+}
+
+::flatbuffers::Offset<iree::ElementTypeDef>
+VMFunctionBuilder::SerializeElementType(const Type &genericType,
+ ::flatbuffers::FlatBufferBuilder *fbb) {
+ ::flatbuffers::Offset<void> typeDefUnion;
+ iree::ElementTypeDefUnion typeUnionType;
+ if (auto type = genericType.dyn_cast<FloatType>()) {
+ iree::FloatTypeDefBuilder tb(*fbb);
+ tb.add_width(type.getWidth());
+ typeDefUnion = tb.Finish().Union();
+ typeUnionType = iree::ElementTypeDefUnion::FloatTypeDef;
+ } else if (auto type = genericType.dyn_cast<IntegerType>()) {
+ iree::IntegerTypeDefBuilder tb(*fbb);
+ tb.add_width(type.getWidth());
+ typeDefUnion = tb.Finish().Union();
+ typeUnionType = iree::ElementTypeDefUnion::IntegerTypeDef;
+ } else if (auto type = genericType.dyn_cast<OpaqueType>()) {
+ auto dialectOffset = fbb->CreateString(type.getDialectNamespace().c_str());
+ auto typeDataOffset = fbb->CreateString(type.getTypeData().data());
+ iree::UnknownTypeDefBuilder tb(*fbb);
+ tb.add_dialect(dialectOffset);
+ tb.add_type_data(typeDataOffset);
+ typeDefUnion = tb.Finish().Union();
+ typeUnionType = iree::ElementTypeDefUnion::UnknownTypeDef;
+ } else {
+ function_.emitError()
+ << "Unimplemented type encoding: " << genericType
+ << "; ensure IREE lowering passes are converting types to the IREE "
+ "set";
+ return {};
+ }
+
+ iree::ElementTypeDefBuilder tdb(*fbb);
+ tdb.add_type_union_type(typeUnionType);
+ tdb.add_type_union(typeDefUnion);
+ return tdb.Finish();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Serialization/VMFunctionBuilder.h b/compiler/Serialization/VMFunctionBuilder.h
new file mode 100644
index 0000000..fcfefde
--- /dev/null
+++ b/compiler/Serialization/VMFunctionBuilder.h
@@ -0,0 +1,77 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_
+#define IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_
+
+#include "compiler/Serialization/BytecodeWriter.h"
+#include "compiler/Serialization/VMFunctionTableBuilder.h"
+#include "compiler/Serialization/VMSourceMapBuilder.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "schemas/bytecode_def_generated.h"
+#include "schemas/function_def_generated.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+class VMFunctionBuilder {
+ public:
+ using CustomWriterFn =
+ std::function<LogicalResult(Operation *, BytecodeWriter *writer)>;
+
+ VMFunctionBuilder(FuncOp function, VMFunctionTableBuilder *functionTable,
+ ::flatbuffers::FlatBufferBuilder *fbb);
+ ~VMFunctionBuilder() = default;
+
+ void RegisterCustomWriter(StringRef operationName, CustomWriterFn writerFn);
+
+ const VMFunctionSourceMap &source_map() const { return sourceMap_; }
+
+ LogicalResult ConvertBytecode();
+
+ ::flatbuffers::Offset<iree::FunctionDef> Finish();
+
+ ::flatbuffers::Offset<iree::TypeDef> SerializeType(
+ Type type, ::flatbuffers::FlatBufferBuilder *fbb);
+ ::flatbuffers::Offset<iree::MemRefTypeDef> SerializeMemRefType(
+ const MemRefType &genericType, ::flatbuffers::FlatBufferBuilder *fbb);
+ ::flatbuffers::Offset<iree::ElementTypeDef> SerializeElementType(
+ const Type &genericType, ::flatbuffers::FlatBufferBuilder *fbb);
+
+ private:
+ LogicalResult BeginFunction(FuncOp function, BytecodeWriter *writer);
+ LogicalResult EndFunction(FuncOp function, BytecodeWriter *writer);
+ LogicalResult BeginBlock(Block *block, BytecodeWriter *writer);
+ LogicalResult EndBlock(Block *block, Operation *op, BytecodeWriter *writer);
+
+ LogicalResult WriteOperation(Block *block, Operation *baseOp,
+ BytecodeWriter *writer);
+
+ llvm::StringMap<CustomWriterFn> customWriters_;
+
+ MLIRContext *context_;
+ FuncOp function_;
+ VMFunctionTableBuilder *functionTable_;
+ ::flatbuffers::FlatBufferBuilder *fbb_;
+ ::flatbuffers::Offset<iree::BytecodeDef> bytecodeDef_;
+ VMFunctionSourceMap sourceMap_;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_
diff --git a/compiler/Serialization/VMFunctionTableBuilder.cpp b/compiler/Serialization/VMFunctionTableBuilder.cpp
new file mode 100644
index 0000000..f013b07
--- /dev/null
+++ b/compiler/Serialization/VMFunctionTableBuilder.cpp
@@ -0,0 +1,87 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Serialization/VMFunctionTableBuilder.h"
+
+#include "compiler/Serialization/VMSourceMapBuilder.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+VMFunctionTableBuilder::VMFunctionTableBuilder(
+ ::flatbuffers::FlatBufferBuilder *fbb)
+ : fbb_(fbb) {}
+
+bool VMFunctionTableBuilder::IsFunctionDeclared(FuncOp funcOp) {
+ return functionSet_.count(funcOp.getName()) != 0;
+}
+
+LogicalResult VMFunctionTableBuilder::DeclareFunction(FuncOp funcOp,
+ LinkageType linkageType) {
+ if (functionSet_.count(funcOp.getName())) {
+ return funcOp.emitError() << "Function has already been declared/defined";
+ }
+ auto functionOrdinal = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal");
+ if (!functionOrdinal) {
+ return funcOp.emitError() << "Ordinal not assigned to function";
+ }
+ int ordinal = functionOrdinal.getInt();
+ functionDefs_.resize(
+ std::max(functionDefs_.size(), static_cast<size_t>(ordinal) + 1u));
+ functionSourceMaps_.resize(
+ std::max(functionDefs_.size(), static_cast<size_t>(ordinal) + 1u));
+ functionSet_.insert({funcOp.getName()});
+ switch (linkageType) {
+ case LinkageType::kInternal:
+ break;
+ case LinkageType::kImport:
+ importIndices_.push_back(ordinal);
+ break;
+ case LinkageType::kExport:
+ exportIndices_.push_back(ordinal);
+ break;
+ }
+ return success();
+}
+
+LogicalResult VMFunctionTableBuilder::DefineFunction(
+ FuncOp funcOp, ::flatbuffers::Offset<iree::FunctionDef> functionDef,
+ VMFunctionSourceMap functionSourceMap) {
+ auto functionOrdinal = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal");
+ if (!functionOrdinal) {
+ return funcOp.emitError() << "Ordinal not assigned to function";
+ }
+ int ordinal = functionOrdinal.getInt();
+ if (!functionDefs_[ordinal].IsNull()) {
+ return funcOp.emitOpError() << "Function has already been defined";
+ }
+ functionDefs_[ordinal] = functionDef;
+ functionSourceMaps_[ordinal] = std::move(functionSourceMap);
+ return success();
+}
+
+::flatbuffers::Offset<iree::FunctionTableDef> VMFunctionTableBuilder::Finish() {
+ auto functionsOffset = fbb_->CreateVector(functionDefs_);
+ auto importsOffset = fbb_->CreateVector(importIndices_);
+ auto exportsOffset = fbb_->CreateVector(exportIndices_);
+ iree::FunctionTableDefBuilder ftdb(*fbb_);
+ ftdb.add_functions(functionsOffset);
+ ftdb.add_imports(importsOffset);
+ ftdb.add_exports(exportsOffset);
+ return ftdb.Finish();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Serialization/VMFunctionTableBuilder.h b/compiler/Serialization/VMFunctionTableBuilder.h
new file mode 100644
index 0000000..ee3427e
--- /dev/null
+++ b/compiler/Serialization/VMFunctionTableBuilder.h
@@ -0,0 +1,75 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_
+#define IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_
+
+#include <string>
+#include <vector>
+
+#include "compiler/Serialization/VMSourceMapBuilder.h"
+#include "flatbuffers/flatbuffers.h"
+#include "llvm/ADT/StringSet.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
+#include "schemas/function_def_generated.h"
+#include "schemas/function_table_def_generated.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+enum class LinkageType {
+ kInternal,
+ kImport,
+ kExport,
+};
+
+class VMFunctionTableBuilder {
+ public:
+ explicit VMFunctionTableBuilder(::flatbuffers::FlatBufferBuilder *fbb);
+
+ int max_function_ordinal() const { return functionDefs_.size(); }
+
+ ArrayRef<VMFunctionSourceMap> function_source_maps() {
+ return llvm::makeArrayRef(functionSourceMaps_);
+ }
+
+ // Returns true if |funcOp| has already been declared in the table.
+ bool IsFunctionDeclared(FuncOp funcOp);
+
+ // Declares |funcOp| with the given |linkageType|.
+ // Fails if the function has already been declared or defined.
+ LogicalResult DeclareFunction(FuncOp funcOp, LinkageType linkageType);
+
+ // Defines |funcOp| using the given |functionDef|.
+ LogicalResult DefineFunction(
+ FuncOp funcOp, ::flatbuffers::Offset<iree::FunctionDef> functionDef,
+ VMFunctionSourceMap functionSourceMap);
+
+ ::flatbuffers::Offset<iree::FunctionTableDef> Finish();
+
+ private:
+ ::flatbuffers::FlatBufferBuilder *fbb_;
+ llvm::StringSet<> functionSet_;
+ std::vector<::flatbuffers::Offset<iree::FunctionDef>> functionDefs_;
+ std::vector<VMFunctionSourceMap> functionSourceMaps_;
+ std::vector<int> importIndices_;
+ std::vector<int> exportIndices_;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_
diff --git a/compiler/Serialization/VMModuleBuilder.cpp b/compiler/Serialization/VMModuleBuilder.cpp
new file mode 100644
index 0000000..0aa7d60
--- /dev/null
+++ b/compiler/Serialization/VMModuleBuilder.cpp
@@ -0,0 +1,70 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Serialization/VMModuleBuilder.h"
+
+#include "schemas/executable_table_def_generated.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+VMModuleBuilder::VMModuleBuilder(::flatbuffers::FlatBufferBuilder *fbb)
+ : fbb_(fbb),
+ deviceTable_(fbb),
+ functionTable_(fbb),
+ executableTable_(fbb),
+ sourceMap_(fbb) {}
+
+::flatbuffers::Offset<iree::ModuleDef> VMModuleBuilder::Finish() {
+ auto nameOffset = fbb_->CreateString("module");
+ auto deviceTableOffset = deviceTable_.Finish();
+ if (deviceTableOffset.IsNull()) return {};
+ auto functionTableOffset = functionTable_.Finish();
+ if (functionTableOffset.IsNull()) return {};
+ auto executableTableOffset = executableTable_.Finish();
+ if (executableTableOffset.IsNull()) return {};
+
+ for (int function_ordinal = 0;
+ function_ordinal < functionTable_.function_source_maps().size();
+ ++function_ordinal) {
+ if (failed(sourceMap_.AddFunction(
+ function_ordinal,
+ functionTable_.function_source_maps()[function_ordinal]))) {
+ return {};
+ }
+ }
+ auto sourceMapOffset =
+ sourceMap_.Finish(functionTable_.max_function_ordinal());
+ if (sourceMapOffset.IsNull()) return {};
+
+ iree::ModuleDefBuilder mdb(*fbb_);
+ mdb.add_name(nameOffset);
+ mdb.add_device_table(deviceTableOffset);
+ mdb.add_function_table(functionTableOffset);
+ mdb.add_executable_table(executableTableOffset);
+ mdb.add_source_map(sourceMapOffset);
+ return mdb.Finish();
+}
+
+std::vector<uint8_t> VMModuleBuilder::Serialize(
+ ::flatbuffers::Offset<iree::ModuleDef> module_def) {
+ FinishModuleDefBuffer(*fbb_, module_def);
+ std::vector<uint8_t> bytes;
+ bytes.resize(fbb_->GetSize());
+ std::memcpy(bytes.data(), fbb_->GetBufferPointer(), bytes.size());
+ return bytes;
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Serialization/VMModuleBuilder.h b/compiler/Serialization/VMModuleBuilder.h
new file mode 100644
index 0000000..e9a9516
--- /dev/null
+++ b/compiler/Serialization/VMModuleBuilder.h
@@ -0,0 +1,57 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_
+#define IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_
+
+#include <vector>
+
+#include "compiler/Serialization/VMDeviceTableBuilder.h"
+#include "compiler/Serialization/VMExecutableTableBuilder.h"
+#include "compiler/Serialization/VMFunctionTableBuilder.h"
+#include "compiler/Serialization/VMSourceMapBuilder.h"
+#include "flatbuffers/flatbuffers.h"
+#include "schemas/module_def_generated.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+class VMModuleBuilder {
+ public:
+ explicit VMModuleBuilder(::flatbuffers::FlatBufferBuilder *fbb);
+
+ ::flatbuffers::FlatBufferBuilder *fbb() const { return fbb_; }
+ VMDeviceTableBuilder *device_table() { return &deviceTable_; }
+ VMFunctionTableBuilder *function_table() { return &functionTable_; }
+ VMExecutableTableBuilder *executable_table() { return &executableTable_; }
+ VMSourceMapBuilder *source_map() { return &sourceMap_; }
+
+ ::flatbuffers::Offset<iree::ModuleDef> Finish();
+
+ std::vector<uint8_t> Serialize(
+ ::flatbuffers::Offset<iree::ModuleDef> module_def);
+
+ private:
+ ::flatbuffers::FlatBufferBuilder *fbb_;
+
+ VMDeviceTableBuilder deviceTable_;
+ VMFunctionTableBuilder functionTable_;
+ VMExecutableTableBuilder executableTable_;
+ VMSourceMapBuilder sourceMap_;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_
diff --git a/compiler/Serialization/VMSourceMapBuilder.cpp b/compiler/Serialization/VMSourceMapBuilder.cpp
new file mode 100644
index 0000000..a534540
--- /dev/null
+++ b/compiler/Serialization/VMSourceMapBuilder.cpp
@@ -0,0 +1,164 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Serialization/VMSourceMapBuilder.h"
+
+#include "flatbuffers/flatbuffers.h"
+#include "schemas/source_map_def_generated.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Location.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+VMSourceMapBuilder::VMSourceMapBuilder(::flatbuffers::FlatBufferBuilder *fbb)
+ : fbb_(fbb) {}
+
+int VMSourceMapBuilder::GetUniqueString(std::string value) {
+ auto it = stringTableMap_.find(value);
+ if (it != stringTableMap_.end()) {
+ return it->second;
+ }
+ int stringIndex = stringTable_.size();
+ stringTableMap_.insert({value, stringIndex});
+ stringTable_.push_back(std::move(value));
+ return stringIndex;
+}
+
+LogicalResult VMSourceMapBuilder::AddFunction(
+ int functionOrdinal, VMFunctionSourceMap functionSourceMap) {
+ if (functionMaps_.size() <= functionOrdinal) {
+ functionMaps_.resize(functionOrdinal + 1);
+ }
+ functionMaps_[functionOrdinal] = std::move(functionSourceMap);
+ return success();
+}
+
+::flatbuffers::Offset<iree::SourceMapDef> VMSourceMapBuilder::Finish(
+ int maxFunctionOrdinal) {
+ // NOTE: we always ensure the source map table is the same size as the
+ // function table so that lookups at runtime can be validated once at load
+ // time (ensuring the tables match up) instead of on each lookup.
+ if (maxFunctionOrdinal < functionMaps_.size()) {
+ llvm::errs() << "Max function ordinal defined as " << maxFunctionOrdinal
+ << " but there are " << functionMaps_.size()
+ << " function source maps present";
+ return {};
+ }
+ functionMaps_.resize(maxFunctionOrdinal);
+
+ std::vector<::flatbuffers::Offset<iree::FunctionSourceMapDef>> functionDefs;
+ functionDefs.resize(maxFunctionOrdinal);
+ for (int i = 0; i < functionMaps_.size(); ++i) {
+ const auto &functionMap = functionMaps_[i];
+ functionDefs[i] = SerializeVMFunctionSourceMap(functionMap);
+ if (functionDefs[i].IsNull()) return {};
+ }
+
+ auto functionTableOffset = fbb_->CreateVector(functionDefs);
+ auto stringTableOffset = fbb_->CreateVectorOfStrings(stringTable_);
+ iree::SourceMapDefBuilder smdb(*fbb_);
+ smdb.add_function_table(functionTableOffset);
+ smdb.add_string_table(stringTableOffset);
+ return smdb.Finish();
+}
+
+::flatbuffers::Offset<iree::FunctionSourceMapDef>
+VMSourceMapBuilder::SerializeVMFunctionSourceMap(
+ const VMFunctionSourceMap &functionMap) {
+ if (functionMap.locations.empty()) {
+ // Empty table. This ensures that we still have a non-null value in the
+ // function table but doesn't waste much space.
+ iree::FunctionSourceMapDefBuilder fsmdb(*fbb_);
+ return fsmdb.Finish();
+ }
+
+ LocationOffsetTable locationOffsetTable;
+ std::vector<iree::BytecodeSourceLocation> bytecodeMap;
+ for (const auto &offset_location : functionMap.locations) {
+ int locationIndex =
+ SerializeLocation(offset_location.second, &locationOffsetTable);
+ bytecodeMap.push_back({offset_location.first, locationIndex});
+ }
+ auto locationTableOffset =
+ fbb_->CreateVector(locationOffsetTable.locationDefs);
+ auto bytecodeMapOffset = fbb_->CreateVectorOfStructs(bytecodeMap);
+
+ iree::FunctionSourceMapDefBuilder fsmdb(*fbb_);
+ fsmdb.add_location_table(locationTableOffset);
+ fsmdb.add_bytecode_map(bytecodeMapOffset);
+ return fsmdb.Finish();
+}
+
+int VMSourceMapBuilder::SerializeLocation(
+ const Location &location, LocationOffsetTable *locationOffsetTable) {
+ auto existingIt = locationOffsetTable->locationMap.find(location);
+ if (existingIt != locationOffsetTable->locationMap.end()) {
+ return existingIt->getSecond();
+ }
+
+ iree::LocationDefUnion locationUnionType;
+ ::flatbuffers::Offset<void> locationUnionOffset;
+ if (auto fileLoc = location.dyn_cast<FileLineColLoc>()) {
+ locationUnionType = iree::LocationDefUnion::FileLocationDef;
+ int filenameIndex = GetUniqueString(fileLoc.getFilename().str());
+ iree::FileLocationDefBuilder lb(*fbb_);
+ lb.add_filename(filenameIndex);
+ lb.add_line(fileLoc.getLine());
+ lb.add_column(fileLoc.getColumn());
+ locationUnionOffset = lb.Finish().Union();
+ } else if (auto nameLoc = location.dyn_cast<NameLoc>()) {
+ locationUnionType = iree::LocationDefUnion::NameLocationDef;
+ int nameIndex = GetUniqueString(nameLoc.getName().str());
+ iree::NameLocationDefBuilder lb(*fbb_);
+ lb.add_name(nameIndex);
+ locationUnionOffset = lb.Finish().Union();
+ } else if (auto callSiteLoc = location.dyn_cast<CallSiteLoc>()) {
+ locationUnionType = iree::LocationDefUnion::CallSiteLocationDef;
+ int calleeIndex =
+ SerializeLocation(callSiteLoc.getCallee(), locationOffsetTable);
+ int callerIndex =
+ SerializeLocation(callSiteLoc.getCaller(), locationOffsetTable);
+ iree::CallSiteLocationDefBuilder lb(*fbb_);
+ lb.add_callee_location(calleeIndex);
+ lb.add_caller_location(callerIndex);
+ locationUnionOffset = lb.Finish().Union();
+ } else if (auto fusedLoc = location.dyn_cast<FusedLoc>()) {
+ locationUnionType = iree::LocationDefUnion::FusedLocationDef;
+ std::vector<int> locationIndices;
+ locationIndices.reserve(fusedLoc.getLocations().size());
+ for (const auto &child_loc : fusedLoc.getLocations()) {
+ int child_index = SerializeLocation(child_loc, locationOffsetTable);
+ locationIndices.push_back(child_index);
+ }
+ auto locationIndicesOffset = fbb_->CreateVector(locationIndices);
+ iree::FusedLocationDefBuilder lb(*fbb_);
+ lb.add_locations(locationIndicesOffset);
+ locationUnionOffset = lb.Finish().Union();
+ } else {
+ llvm_unreachable("Unimplemented location kind");
+ }
+
+ iree::LocationDefBuilder ldb(*fbb_);
+ ldb.add_location_union_type(locationUnionType);
+ ldb.add_location_union(locationUnionOffset);
+ int locationIndex = locationOffsetTable->locationDefs.size();
+ locationOffsetTable->locationDefs.push_back(ldb.Finish());
+ locationOffsetTable->locationMap.insert({location, locationIndex});
+ return locationIndex;
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Serialization/VMSourceMapBuilder.h b/compiler/Serialization/VMSourceMapBuilder.h
new file mode 100644
index 0000000..13faf71
--- /dev/null
+++ b/compiler/Serialization/VMSourceMapBuilder.h
@@ -0,0 +1,64 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_
+#define IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_
+
+#include <vector>
+
+#include "flatbuffers/flatbuffers.h"
+#include "llvm/ADT/StringMap.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "schemas/source_map_def_generated.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+struct VMFunctionSourceMap {
+ std::vector<std::pair<int, Location>> locations;
+};
+
+class VMSourceMapBuilder {
+ public:
+ explicit VMSourceMapBuilder(::flatbuffers::FlatBufferBuilder *fbb);
+
+ LogicalResult AddFunction(int functionOrdinal,
+ VMFunctionSourceMap functionSourceMap);
+
+ ::flatbuffers::Offset<iree::SourceMapDef> Finish(int maxFunctionOrdinal);
+
+ private:
+ struct LocationOffsetTable {
+ std::vector<::flatbuffers::Offset<iree::LocationDef>> locationDefs;
+ llvm::DenseMap<Location, int> locationMap;
+ };
+
+ int GetUniqueString(std::string value);
+
+ ::flatbuffers::Offset<iree::FunctionSourceMapDef>
+ SerializeVMFunctionSourceMap(const VMFunctionSourceMap &functionMap);
+ int SerializeLocation(const Location &location,
+ LocationOffsetTable *locationOffsetTable);
+
+ ::flatbuffers::FlatBufferBuilder *fbb_;
+ std::vector<std::string> stringTable_;
+ llvm::StringMap<int> stringTableMap_;
+ std::vector<VMFunctionSourceMap> functionMaps_;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_
diff --git a/compiler/Transforms/AggressiveOpElimination.cpp b/compiler/Transforms/AggressiveOpElimination.cpp
new file mode 100644
index 0000000..9dd1781
--- /dev/null
+++ b/compiler/Transforms/AggressiveOpElimination.cpp
@@ -0,0 +1,77 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <deque>
+#include <memory>
+
+#include "compiler/IR/Interpreter/HLOps.h"
+#include "compiler/IR/Interpreter/LLOps.h"
+#include "compiler/IR/Sequencer/HLOps.h"
+#include "compiler/IR/Sequencer/LLOps.h"
+#include "compiler/IR/StructureOps.h"
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+template <typename T>
+struct EraseUnused : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ if (op.use_empty()) {
+ rewriter.eraseOp(op);
+ return this->matchSuccess();
+ }
+ return this->matchFailure();
+ }
+};
+
+void populateAggressiveOpEliminationPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ patterns.insert<EraseUnused<LoadOp>, EraseUnused<AllocOp>,
+ EraseUnused<IREESeq::HL::AllocHeapOp>,
+ EraseUnused<IREESeq::LL::AllocHeapOp>,
+ EraseUnused<IREEInterp::HL::AllocHeapOp>,
+ EraseUnused<IREEInterp::LL::AllocHeapOp>>(ctx);
+}
+
+} // namespace
+
+// TODO(b/142012496) Make these be handled by normal DCE.
+class AggressiveOpEliminationPass
+ : public FunctionPass<AggressiveOpEliminationPass> {
+ public:
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ populateAggressiveOpEliminationPatterns(patterns, &getContext());
+
+ applyPatternsGreedily(getFunction(), patterns);
+ }
+};
+
+std::unique_ptr<OpPassBase<FuncOp>> createAggressiveOpEliminationPass() {
+ return std::make_unique<AggressiveOpEliminationPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Transforms/AssignFunctionOrdinals.cpp b/compiler/Transforms/AssignFunctionOrdinals.cpp
similarity index 100%
rename from iree/compiler/Transforms/AssignFunctionOrdinals.cpp
rename to compiler/Transforms/AssignFunctionOrdinals.cpp
diff --git a/compiler/Transforms/BUILD b/compiler/Transforms/BUILD
new file mode 100644
index 0000000..9590416
--- /dev/null
+++ b/compiler/Transforms/BUILD
@@ -0,0 +1,41 @@
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "Transforms",
+ srcs = [
+ "AggressiveOpElimination.cpp",
+ "AssignFunctionOrdinals.cpp",
+ "ConvertFromTupleCallingConvention.cpp",
+ "ConvertToMemRefCallingConvention.cpp",
+ "DropUnreachableFunctions.cpp",
+ "DropUnusedExecutables.cpp",
+ "LegalizeTypeStorage.cpp",
+ "LowerStdToIreeDialect.cpp",
+ "LowerXLAToIreeDialect.cpp",
+ ],
+ hdrs = [
+ "ConversionUtils.h",
+ "Passes.h",
+ "Rewrites.h",
+ ],
+ deps = [
+ "///compiler/IR",
+ "///compiler/IR/Interpreter",
+ "///compiler/IR/Sequencer",
+ "///compiler/Utils",
+ "@llvm//:support",
+ "@local_config_mlir//:Analysis",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Pass",
+ "@local_config_mlir//:StandardDialectRegistration",
+ "@local_config_mlir//:StandardOps",
+ "@local_config_mlir//:Support",
+ "@local_config_mlir//:TransformUtils",
+ "@local_config_mlir//:Transforms",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
+ ],
+ alwayslink = 1,
+)
diff --git a/iree/compiler/Transforms/CMakeLists.txt b/compiler/Transforms/CMakeLists.txt
similarity index 100%
rename from iree/compiler/Transforms/CMakeLists.txt
rename to compiler/Transforms/CMakeLists.txt
diff --git a/compiler/Transforms/ConversionUtils.h b/compiler/Transforms/ConversionUtils.h
new file mode 100644
index 0000000..01f484c
--- /dev/null
+++ b/compiler/Transforms/ConversionUtils.h
@@ -0,0 +1,101 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_
+#define IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_
+
+#include "compiler/Utils/MemRefUtils.h"
+#include "compiler/Utils/TypeConversionUtils.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+template <typename SrcOp, typename DstOp>
+struct UnaryOpLowering : public OpConversionPattern<SrcOp> {
+ using OpConversionPattern<SrcOp>::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ SrcOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *value = loadAccessValue(op.getLoc(), operands[0], rewriter);
+ value = wrapAsMemRef(value, op, rewriter);
+
+ auto dstType = convertTypeToMemRef(op.getResult());
+ auto dstOp = rewriter.create<DstOp>(op.getLoc(), dstType, value);
+ auto result = dstOp.getResult();
+ result = wrapAsTensor(result, op, rewriter);
+
+ rewriter.replaceOp(
+ op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)});
+ return this->matchSuccess();
+ }
+};
+
+template <typename SrcOp, typename DstOp>
+struct BinaryOpLowering : public OpConversionPattern<SrcOp> {
+ using OpConversionPattern<SrcOp>::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ SrcOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *lhsValue = loadAccessValue(op.getLoc(), operands[0], rewriter);
+ auto *rhsValue = loadAccessValue(op.getLoc(), operands[1], rewriter);
+ auto dstType = convertTypeToMemRef(op.getResult());
+
+ lhsValue = wrapAsMemRef(lhsValue, op, rewriter);
+ rhsValue = wrapAsMemRef(rhsValue, op, rewriter);
+
+ auto midOp =
+ rewriter.create<DstOp>(op.getLoc(), dstType, lhsValue, rhsValue);
+ auto result = midOp.getResult();
+ result = wrapAsTensor(result, op, rewriter);
+
+ rewriter.replaceOp(
+ op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)});
+ return this->matchSuccess();
+ }
+};
+
+template <typename SrcOp, typename DstOp>
+struct TernaryOpLowering : public OpConversionPattern<SrcOp> {
+ using OpConversionPattern<SrcOp>::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ SrcOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *aValue = loadAccessValue(op.getLoc(), operands[0], rewriter);
+ auto *bValue = loadAccessValue(op.getLoc(), operands[1], rewriter);
+ auto *cValue = loadAccessValue(op.getLoc(), operands[2], rewriter);
+
+ aValue = wrapAsMemRef(aValue, op, rewriter);
+ bValue = wrapAsMemRef(bValue, op, rewriter);
+ cValue = wrapAsMemRef(cValue, op, rewriter);
+
+ auto dstType = convertTypeToMemRef(op.getResult());
+ auto dstOp =
+ rewriter.create<DstOp>(op.getLoc(), dstType, aValue, bValue, cValue);
+ auto result = dstOp.getResult();
+ result = wrapAsTensor(result, op, rewriter);
+
+ rewriter.replaceOp(
+ op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)});
+ return this->matchSuccess();
+ }
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_
diff --git a/iree/compiler/Transforms/ConvertFromTupleCallingConvention.cpp b/compiler/Transforms/ConvertFromTupleCallingConvention.cpp
similarity index 100%
rename from iree/compiler/Transforms/ConvertFromTupleCallingConvention.cpp
rename to compiler/Transforms/ConvertFromTupleCallingConvention.cpp
diff --git a/compiler/Transforms/ConvertToMemRefCallingConvention.cpp b/compiler/Transforms/ConvertToMemRefCallingConvention.cpp
new file mode 100644
index 0000000..1b65ede
--- /dev/null
+++ b/compiler/Transforms/ConvertToMemRefCallingConvention.cpp
@@ -0,0 +1,398 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Utils/MemRefUtils.h"
+#include "compiler/Utils/TypeConversionUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+void copyOperationAttrs(Operation *oldOp, Operation *newOp) {
+ for (const auto &oldAttr : oldOp->getAttrs()) {
+ newOp->setAttr(oldAttr.first, oldAttr.second);
+ }
+}
+
+FunctionType getMemRefFunctionType(FunctionType type) {
+ Builder builder(type.getContext());
+ llvm::SmallVector<Type, 8> replacementInputs;
+ for (auto type : type.getInputs()) {
+ auto memRefType = convertTypeToMemRef(type);
+ if (!memRefType) {
+ return nullptr;
+ }
+ replacementInputs.push_back(memRefType);
+ }
+ llvm::SmallVector<Type, 8> replacementResults;
+ for (auto type : type.getResults()) {
+ auto memRefType = convertTypeToMemRef(type);
+ if (!memRefType) {
+ return nullptr;
+ }
+ replacementResults.push_back(memRefType);
+ }
+ return builder.getFunctionType(replacementInputs, replacementResults);
+}
+
+bool insertLoad(BlockArgument *oldArg, BlockArgument *newArg,
+ OpBuilder &builder, BlockAndValueMapping *mapping) {
+ auto loc = oldArg->getOwner()->getParent()->getLoc();
+
+ // If old arg was a memref we don't need to change anything. We still need
+ // to remap so that the use lists match through conversion, though.
+ if (oldArg->getType().isa<MemRefType>()) {
+ mapping->map(oldArg, newArg);
+ return false;
+ } else if (oldArg->getType().isa<TensorType>()) {
+ auto castOp = builder.create<IREE::MemRefToTensorOp>(loc, newArg);
+ mapping->map(oldArg, castOp.getResult());
+ return false;
+ }
+
+ // Insert the load we'll use to unbox the value.
+ auto loadedValue = builder.create<LoadOp>(loc, newArg, ArrayRef<Value *>{});
+ mapping->map(oldArg, loadedValue);
+
+ return false;
+}
+
+bool insertLoad(Operation *oldOp, Value *oldValue, Value *newValue,
+ OpBuilder &builder, BlockAndValueMapping *mapping) {
+ // If old value was a memref we don't need to change anything.
+ if (oldValue->getType().isa<MemRefType>()) {
+ mapping->map(oldValue, newValue);
+ return false;
+ } else if (oldValue->getType().isa<TensorType>()) {
+ auto castOp =
+ builder.create<IREE::MemRefToTensorOp>(oldOp->getLoc(), newValue);
+ mapping->map(oldValue, castOp.getResult());
+ return false;
+ }
+
+ assert(newValue->getType().isa<MemRefType>());
+
+ // Insert the load we'll use to unbox the value.
+ auto loadedValue =
+ builder.create<LoadOp>(oldOp->getLoc(), newValue, ArrayRef<Value *>{});
+ mapping->map(oldValue, loadedValue);
+
+ return false;
+}
+
+Value *insertStore(Operation *oldOp, Value *oldValue, OpBuilder &builder,
+ BlockAndValueMapping *mapping) {
+ auto *newValue = mapping->lookupOrNull(oldValue);
+ if (!newValue) {
+ return nullptr;
+ }
+
+ // If the previous value was already a memref we don't need to change
+ // anything.
+ // TODO(benvanik): ensure indices make sense.
+ if (oldValue->getType().isa<MemRefType>()) {
+ return newValue;
+ } else if (oldValue->getType().isa<TensorType>()) {
+ auto castOp =
+ builder.create<IREE::TensorToMemRefOp>(oldOp->getLoc(), newValue);
+ return castOp.getResult();
+ }
+
+ // Look back up and see if we can find the memref the value was loaded from.
+ if (auto *sourceMemRef = resolveValueToSourceMemRef(oldValue, oldOp)) {
+ return mapping->lookupOrNull(sourceMemRef);
+ }
+
+ // Allocate the memref to store the value.
+ auto newStorage = builder.create<AllocOp>(
+ oldOp->getLoc(), convertTypeToMemRef(oldValue->getType()));
+
+ // Insert the store we'll use to box the value.
+ builder.create<StoreOp>(oldOp->getLoc(), newValue, newStorage,
+ ArrayRef<Value *>{});
+
+ return newStorage;
+}
+
+bool convertCallOp(CallOp *oldOp, OpBuilder &builder,
+ BlockAndValueMapping *mapping) {
+ llvm::SmallVector<Value *, 4> newArgs;
+ for (auto *oldArg : oldOp->getOperands()) {
+ auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
+ if (!newArg) {
+ return true;
+ }
+ newArgs.push_back(newArg);
+ }
+
+ SmallVector<Type, 4> resultTypes;
+ for (auto oldType : oldOp->getOperation()->getResultTypes()) {
+ resultTypes.push_back(convertTypeToMemRef(oldType));
+ }
+ auto newOp = builder.create<CallOp>(oldOp->getLoc(), oldOp->getCallee(),
+ resultTypes, newArgs);
+ copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
+
+ for (int i = 0; i < newOp.getNumResults(); ++i) {
+ auto *oldResult = oldOp->getResult(i);
+ auto *newResult = newOp.getResult(i);
+ if (insertLoad(oldOp->getOperation(), oldResult, newResult, builder,
+ mapping)) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+bool convertCallIndirectOp(CallIndirectOp *oldOp, OpBuilder &builder,
+ BlockAndValueMapping *mapping) {
+ // TODO(benvanik): support wrapping callee values.
+ oldOp->emitError("CallIndirectOp not yet supported");
+ return true;
+#if 0
+ llvm::SmallVector<Value *, 4> newArgs;
+ for (auto *oldArg : oldOp->getArgOperands()) {
+ auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
+ if (!newArg) {
+ return true;
+ }
+ newArgs.push_back(newArg);
+ }
+
+ auto newOp = builder.create<CallIndirectOp>(oldOp->getLoc(),
+ oldOp->getCallee(), newArgs);
+ copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
+
+ for (int i = 0; i < newOp.getNumResults(); ++i) {
+ auto *oldResult = oldOp->getResult(i);
+ auto *newResult = newOp.getResult(i);
+ if (insertLoad(oldOp->getOperation(), oldResult, newResult, builder,
+ mapping)) {
+ return true;
+ }
+ }
+
+ return false;
+#endif // 0
+}
+
+bool convertReturnOp(Operation *oldOp, OpBuilder &builder,
+ BlockAndValueMapping *mapping) {
+ BlockAndValueMapping returnMapping;
+ for (auto *oldArg : oldOp->getOperands()) {
+ auto *newArg = insertStore(oldOp, oldArg, builder, mapping);
+ if (!newArg) {
+ return true;
+ }
+ returnMapping.map(oldArg, newArg);
+ }
+
+ builder.clone(*oldOp, returnMapping);
+ return false;
+}
+
+bool convertBranchOp(BranchOp *oldOp, OpBuilder &builder,
+ BlockAndValueMapping *mapping) {
+ llvm::SmallVector<Value *, 4> newArgs;
+ for (auto *oldArg : oldOp->getOperands()) {
+ auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
+ if (!newArg) {
+ return true;
+ }
+ newArgs.push_back(newArg);
+ }
+
+ auto *dest = mapping->lookupOrNull(oldOp->getDest());
+ if (!dest) {
+ oldOp->emitError("Destination block mapping not found");
+ return true;
+ }
+
+ auto newOp = builder.create<BranchOp>(oldOp->getLoc(), dest, newArgs);
+ copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
+
+ return false;
+}
+
+bool convertCondBranchOp(CondBranchOp *oldOp, OpBuilder &builder,
+ BlockAndValueMapping *mapping) {
+ llvm::SmallVector<Value *, 4> trueArgs;
+ for (auto *oldArg : oldOp->getTrueOperands()) {
+ auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
+ if (!newArg) {
+ return true;
+ }
+ trueArgs.push_back(newArg);
+ }
+ llvm::SmallVector<Value *, 4> falseArgs;
+ for (auto *oldArg : oldOp->getFalseOperands()) {
+ auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
+ if (!newArg) {
+ return true;
+ }
+ falseArgs.push_back(newArg);
+ }
+
+ auto *trueDest = mapping->lookupOrNull(oldOp->getTrueDest());
+ if (!trueDest) {
+ oldOp->emitError("True destination block mapping not found");
+ return true;
+ }
+ auto *falseDest = mapping->lookupOrNull(oldOp->getFalseDest());
+ if (!falseDest) {
+ oldOp->emitError("False destination block mapping not found");
+ return true;
+ }
+
+ // Lowering will take care of the condition store.
+ auto *newCondition = mapping->lookupOrNull(oldOp->getCondition());
+ if (!newCondition) {
+ oldOp->emitError("Condition value mapping not found");
+ return false;
+ }
+
+ auto newOp = builder.create<CondBranchOp>(
+ oldOp->getLoc(), newCondition, trueDest, trueArgs, falseDest, falseArgs);
+ copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
+
+ return false;
+}
+
+bool convertOperation(Operation *oldOp, OpBuilder &builder,
+ BlockAndValueMapping *mapping) {
+ if (isa<ConstantOp>(oldOp)) {
+ builder.clone(*oldOp, *mapping);
+ return false;
+ } else if (auto callOp = dyn_cast<CallOp>(oldOp)) {
+ return convertCallOp(&callOp, builder, mapping);
+ } else if (auto callIndirectOp = dyn_cast<CallIndirectOp>(oldOp)) {
+ return convertCallIndirectOp(&callIndirectOp, builder, mapping);
+ } else if (isa<ReturnOp>(oldOp) || isa<IREE::ReturnOp>(oldOp)) {
+ return convertReturnOp(oldOp, builder, mapping);
+ } else if (auto branchOp = dyn_cast<BranchOp>(oldOp)) {
+ return convertBranchOp(&branchOp, builder, mapping);
+ } else if (auto condBranchOp = dyn_cast<CondBranchOp>(oldOp)) {
+ return convertCondBranchOp(&condBranchOp, builder, mapping);
+ } else {
+ builder.clone(*oldOp, *mapping);
+ return false;
+ }
+}
+
+bool convertFunction(FuncOp oldFunc, FuncOp newFunc) {
+ OpBuilder builder(newFunc.getBody());
+ BlockAndValueMapping mapping;
+
+ // Create new blocks matching the expected arguments of the old ones.
+ // This sets up the block mappings to enable us to reference blocks forward
+ // during conversion.
+ newFunc.getBlocks().clear();
+ for (auto &oldBlock : oldFunc.getBlocks()) {
+ auto *newBlock = builder.createBlock(&newFunc.getBody());
+ for (auto *oldArg : oldBlock.getArguments()) {
+ // Replace the block args with memrefs.
+ auto memRefType = convertTypeToMemRef(oldArg->getType());
+ if (!memRefType) return true;
+ auto *newArg = newBlock->addArgument(memRefType);
+
+ // Insert loads to preserve type, if needed.
+ // This will replace all uses of the oldArg with the loaded value from
+ // newArg so that the block contents are still using unwrapped values.
+ if (insertLoad(oldArg, newArg, builder, &mapping)) {
+ return true;
+ }
+ }
+ mapping.map(&oldBlock, newBlock);
+ }
+
+ // Convert all ops in the blocks.
+ for (auto &oldBlock : oldFunc.getBlocks()) {
+ builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock));
+ for (auto &oldOp : oldBlock.getOperations()) {
+ if (convertOperation(&oldOp, builder, &mapping)) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+} // namespace
+
+class ConvertToMemRefCallingConventionPass
+ : public ModulePass<ConvertToMemRefCallingConventionPass> {
+ public:
+ void runOnModule() override {
+ auto module = getModule();
+
+ // Build a list of (oldFunc, newFunc) for all functions we need to
+ // replace. This will ensure that when we go to convert function bodies we
+ // have only new functions defined.
+ std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions;
+
+ for (auto oldFunc : module.getOps<FuncOp>()) {
+ // Create the replacement function, ensuring that we copy attributes.
+ auto functionType = getMemRefFunctionType(oldFunc.getType());
+ if (!functionType) {
+ return signalPassFailure();
+ }
+
+ auto newFunc = FuncOp::create(oldFunc.getLoc(), oldFunc.getName(),
+ functionType, oldFunc.getDialectAttrs());
+ convertedFunctions.push_back({oldFunc, newFunc});
+
+ // Perform the actual body conversion now.
+ if (convertFunction(oldFunc, newFunc)) {
+ return signalPassFailure();
+ }
+ }
+
+ // Replace functions in the module.
+ for (auto &pair : convertedFunctions) {
+ pair.first.erase();
+ module.push_back(pair.second);
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>>
+createConvertToMemRefCallingConventionPass() {
+ return std::make_unique<ConvertToMemRefCallingConventionPass>();
+}
+
+static PassRegistration<ConvertToMemRefCallingConventionPass> pass(
+ "convert-to-memref-calling-convention",
+ "Convert functions to use a memref-based calling convention.");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/DropUnreachableFunctions.cpp b/compiler/Transforms/DropUnreachableFunctions.cpp
new file mode 100644
index 0000000..f334dba
--- /dev/null
+++ b/compiler/Transforms/DropUnreachableFunctions.cpp
@@ -0,0 +1,67 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Utils/ModuleUtils.h"
+#include "llvm/ADT/SetVector.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Drops all functions in a module that are not reachable by functions with the
+// "iree.module.export" attribute.
+class DropUnreachableModuleFunctionsPass
+ : public ModulePass<DropUnreachableModuleFunctionsPass> {
+ public:
+ void runOnModule() override {
+ dropUnusedFunctions(getModule(), {"iree.module.export"});
+ }
+};
+
+// Drops all functions in a module that are not reachable by functions with the
+// "iree.executable.export" attribute.
+class DropUnreachableExecutableFunctionsPass
+ : public ModulePass<DropUnreachableExecutableFunctionsPass> {
+ public:
+ void runOnModule() override {
+ dropUnusedFunctions(getModule(), {"iree.executable.export"});
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>>
+createDropUnreachableModuleFunctionsPass() {
+ return std::make_unique<DropUnreachableModuleFunctionsPass>();
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>>
+createDropUnreachableExecutableFunctionsPass() {
+ return std::make_unique<DropUnreachableExecutableFunctionsPass>();
+}
+
+static PassRegistration<DropUnreachableModuleFunctionsPass> moduleFunctionsPass(
+ "iree-drop-unreachable-module-functions",
+ "Drop all functions not reachable from an exported function");
+
+static PassRegistration<DropUnreachableExecutableFunctionsPass>
+ executableFunctionsPass(
+ "iree-drop-unreachable-executable-functions",
+ "Drop all functions not reachable from an exported function");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/DropUnusedExecutables.cpp b/compiler/Transforms/DropUnusedExecutables.cpp
new file mode 100644
index 0000000..73fa365
--- /dev/null
+++ b/compiler/Transforms/DropUnusedExecutables.cpp
@@ -0,0 +1,62 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Sequencer/HLOps.h"
+#include "compiler/IR/StructureOps.h"
+#include "llvm/ADT/SetVector.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Drops all executables in a module that are not used by any dispatch
+// sequencer op.
+class DropUnusedExecutablesPass : public ModulePass<DropUnusedExecutablesPass> {
+ public:
+ void runOnModule() override {
+ DenseSet<StringRef> usedExecutableNames;
+ for (auto funcOp : getModule().getOps<FuncOp>()) {
+ funcOp.walk([&](IREESeq::HL::DispatchOp op) {
+ usedExecutableNames.insert(op.getExecutable());
+ });
+ }
+ DenseSet<Operation *> deadExecutables;
+ for (auto executableOp :
+ getModule().getOps<IREE::MultiArchExecutableOp>()) {
+ if (usedExecutableNames.count(executableOp.getName()) == 0) {
+ deadExecutables.insert(executableOp);
+ }
+ }
+ for (auto executableOp : deadExecutables) {
+ executableOp->erase();
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createDropUnusedExecutablesPass() {
+ return std::make_unique<DropUnusedExecutablesPass>(); // NOLINT
+}
+
+static PassRegistration<DropUnusedExecutablesPass> executableFunctionsPass(
+ "iree-drop-unused-executables",
+ "Drop all executables not reachable from a dispatch/reduce op.");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/BUILD b/compiler/Transforms/Interpreter/BUILD
new file mode 100644
index 0000000..9405f06
--- /dev/null
+++ b/compiler/Transforms/Interpreter/BUILD
@@ -0,0 +1,41 @@
+# Transforms specific to the IREE interpreter.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "Interpreter",
+ srcs = [
+ "ExpandReductionsToOps.cpp",
+ "LowerInterpreterDialect.cpp",
+ "LowerStdToInterpreterDialect.cpp",
+ "LowerToInterpreterDialect.cpp",
+ "LowerXLAToInterpreterDialect.cpp",
+ "MakeExecutableABI.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ "Rewrites.h",
+ ],
+ deps = [
+ "///compiler/IR",
+ "///compiler/IR/Interpreter",
+ "///compiler/Serialization",
+ "///compiler/Transforms",
+ "///compiler/Utils",
+ "///schemas/bytecode:interpreter_bytecode_v0",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Pass",
+ "@local_config_mlir//:StandardOps",
+ "@local_config_mlir//:Support",
+ "@local_config_mlir//:TransformUtils",
+ "@local_config_mlir//:Transforms",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_lower_general_dot",
+ ],
+ alwayslink = 1,
+)
diff --git a/iree/compiler/Transforms/Interpreter/CMakeLists.txt b/compiler/Transforms/Interpreter/CMakeLists.txt
similarity index 100%
rename from iree/compiler/Transforms/Interpreter/CMakeLists.txt
rename to compiler/Transforms/Interpreter/CMakeLists.txt
diff --git a/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp b/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp
new file mode 100644
index 0000000..c8f80e6
--- /dev/null
+++ b/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp
@@ -0,0 +1,216 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Interpreter/HLDialect.h"
+#include "compiler/IR/Interpreter/HLOps.h"
+#include "compiler/IR/Ops.h"
+#include "compiler/Transforms/ConversionUtils.h"
+#include "compiler/Utils/MemRefUtils.h"
+#include "compiler/Utils/OpCreationUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+LogicalResult convertReductionOp(FuncOp entryPoint, FuncOp applyFunc,
+ Operation *elementOp, OpBuilder &builder) {
+ // Ensure that this op is pass-through and does not interact with any other
+ // ops within the function.
+ // TODO(b/139313439): support fused reductions.
+ for (auto *operand : elementOp->getOperands()) {
+ if (operand->getDefiningOp() != nullptr) {
+ return elementOp->emitOpError()
+ << "Fused reductions are not supported (operand not sourced from "
+ "block args)";
+ }
+ }
+ for (auto *result : elementOp->getResults()) {
+ for (auto *user : result->getUsers()) {
+ if (!user->isKnownTerminator()) {
+ return elementOp->emitOpError() << "Fused reductions are not supported "
+ "(result used by non-terminator)";
+ }
+ }
+ }
+
+ // Determine the index of the args we care about. We'll use these to match up
+ // the operands of the entry point with our application.
+ // Our arguments are expanded tuples like <lhs0, rhs0>, <lhs1, rhs1>, so this
+ // index gets the offset * 2.
+ auto &applyEntryBlock = applyFunc.getBlocks().front();
+ int setIndex = std::distance(applyEntryBlock.args_begin(),
+ llvm::find(applyEntryBlock.getArguments(),
+ elementOp->getOperand(0))) /
+ 2;
+
+ // Map to the args from the entry point.
+ auto &entryPointEntryBlock = entryPoint.getBlocks().front();
+ Value *srcArg = entryPointEntryBlock.getArgument(setIndex);
+ Value *initArg = entryPointEntryBlock.getArgument(
+ applyFunc.getNumArguments() / 2 + setIndex);
+ Value *dstArg =
+ entryPointEntryBlock.getArgument(applyFunc.getNumArguments() + setIndex);
+ auto dstType = dstArg->getType().cast<ShapedType>();
+ Type elementType = dstType.getElementType();
+ auto loc = elementOp->getLoc();
+ auto dimensionAttr = entryPoint.getAttrOfType<IntegerAttr>(
+ "iree.executable.reduction.dimension");
+
+ Operation *expandedOp = nullptr;
+ if (isa<IREEInterp::HL::AddFOp>(elementOp) ||
+ isa<IREEInterp::HL::AddIOp>(elementOp)) {
+ if (elementType.isa<FloatType>()) {
+ expandedOp = builder.create<IREEInterp::HL::ReduceSumFOp>(
+ loc, dstType, srcArg, initArg, dimensionAttr);
+ } else {
+ expandedOp = builder.create<IREEInterp::HL::ReduceSumIOp>(
+ loc, dstType, srcArg, initArg, dimensionAttr);
+ }
+ } else if (isa<IREEInterp::HL::MinFOp>(elementOp) ||
+ isa<IREEInterp::HL::MinISOp>(elementOp) ||
+ isa<IREEInterp::HL::MinIUOp>(elementOp)) {
+ if (elementType.isa<FloatType>()) {
+ expandedOp = builder.create<IREEInterp::HL::ReduceMinFOp>(
+ loc, dstType, srcArg, initArg, dimensionAttr);
+ } else {
+ expandedOp = builder.create<IREEInterp::HL::ReduceMinIOp>(
+ loc, dstType, srcArg, initArg, dimensionAttr);
+ }
+ } else if (isa<IREEInterp::HL::MaxFOp>(elementOp) ||
+ isa<IREEInterp::HL::MaxISOp>(elementOp) ||
+ isa<IREEInterp::HL::MaxIUOp>(elementOp)) {
+ if (elementType.isa<FloatType>()) {
+ expandedOp = builder.create<IREEInterp::HL::ReduceMaxFOp>(
+ loc, dstType, srcArg, initArg, dimensionAttr);
+ } else {
+ expandedOp = builder.create<IREEInterp::HL::ReduceMaxIOp>(
+ loc, dstType, srcArg, initArg, dimensionAttr);
+ }
+ }
+ if (!expandedOp) {
+ return elementOp->emitOpError()
+ << "No matching expanded reduction op for elemental op";
+ }
+ llvm::SmallVector<int64_t, 4> zeroOffset(dstType.getRank(), 0);
+ auto zeroIndices = createArrayConstant(builder, loc, zeroOffset);
+ auto lengths = createArrayConstant(builder, loc, dstType.getShape());
+ builder.create<IREEInterp::HL::CopyOp>(
+ loc, expandedOp->getResult(0), zeroIndices, dstArg, zeroIndices, lengths);
+
+ return success();
+}
+
+// Replaces the given elemental |funcOp| with a widened reduction.
+LogicalResult expandReductionFunction(FuncOp entryFunc) {
+ if (!entryFunc.empty()) {
+ return entryFunc.emitError()
+ << "Function has already been expanded or has existing contents";
+ } else if (!entryFunc.getAttr("iree.executable.reduction.dimension")) {
+ return entryFunc.emitError() << "Windowed reductions are not yet supported";
+ }
+ auto applySym =
+ entryFunc.getAttrOfType<SymbolRefAttr>("iree.executable.reduction.apply");
+ if (!applySym) {
+ return entryFunc.emitError() << "No reduction application function defined";
+ }
+ auto applyFunc = entryFunc.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
+ applySym.getValue());
+ if (!applyFunc) {
+ return entryFunc.emitError()
+ << "Unable to find apply function " << applySym;
+ }
+
+ auto *entryBlock = entryFunc.addEntryBlock();
+ OpBuilder builder(entryBlock);
+
+ if (applyFunc.getBlocks()
+ .front()
+ .walk([&](Operation *op) {
+ if (!op->isKnownTerminator()) {
+ if (failed(
+ convertReductionOp(entryFunc, applyFunc, op, builder))) {
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ })
+ .wasInterrupted()) {
+ return applyFunc.emitError() << "Unable to convert apply func";
+ }
+
+ builder.create<IREE::ReturnOp>(builder.getUnknownLoc());
+
+ // Remove the apply function as we have inlined it.
+ applyFunc.erase();
+ entryFunc.removeAttr("iree.executable.reduction.apply");
+ entryFunc.removeAttr("iree.executable.reduction.dimension");
+
+ return success();
+}
+
+// Limited lowering of reductions to fat reduce_* ops.
+//
+// The specific subset this supports is:
+// * 'min', 'max', and 'add' computations, with function names matching the
+// computation
+// * one op per reduction (no fusions yet).
+// Note: computations and shapes are not validated.
+//
+// TODO(b/139410773): Implement more generally, supporting custom computations.
+class ExpandReductionsToOpsPass : public ModulePass<ExpandReductionsToOpsPass> {
+ public:
+ void runOnModule() override {
+ auto module = getModule();
+ SmallVector<FuncOp, 4> reductionFuncs;
+ for (auto funcOp : module.getOps<FuncOp>()) {
+ if (funcOp.getAttr("iree.executable.reduction.apply")) {
+ reductionFuncs.push_back(funcOp);
+ }
+ }
+ for (auto funcOp : reductionFuncs) {
+ if (failed(expandReductionFunction(funcOp))) {
+ return signalPassFailure();
+ }
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OpPassBase<ModuleOp>> createExpandReductionsToOpsPass() {
+ return std::make_unique<ExpandReductionsToOpsPass>();
+}
+
+static PassRegistration<ExpandReductionsToOpsPass> pass(
+ "iree-expand-reductions-to-ops",
+ "Expands IREE reduction functions to their interpreter ops");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp b/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp
new file mode 100644
index 0000000..8238b48
--- /dev/null
+++ b/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp
@@ -0,0 +1,253 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Interpreter/HLDialect.h"
+#include "compiler/IR/Interpreter/HLOps.h"
+#include "compiler/IR/Interpreter/LLDialect.h"
+#include "compiler/IR/Interpreter/LLOps.h"
+#include "compiler/IR/Ops.h"
+#include "compiler/Serialization/BytecodeTables.h"
+#include "schemas/bytecode/interpreter_bytecode_v0.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct LowerBranchOpPattern
+ : public OpRewritePattern<IREEInterp::HL::BranchOp> {
+ using OpRewritePattern<IREEInterp::HL::BranchOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(IREEInterp::HL::BranchOp op,
+ PatternRewriter &rewriter) const {
+ SmallVector<Value *, 8> operands{op.getOperation()->getOperands()};
+
+ rewriter.replaceOpWithNewOp<IREEInterp::LL::BranchOp>(op, op.getDest(),
+ operands);
+ return matchSuccess();
+ }
+};
+
+struct LowerCondCondBranchOpPattern
+ : public OpRewritePattern<IREEInterp::HL::CondBranchOp> {
+ using OpRewritePattern<IREEInterp::HL::CondBranchOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(IREEInterp::HL::CondBranchOp op,
+ PatternRewriter &rewriter) const {
+ SmallVector<Value *, 8> trueOperands{op.getTrueOperands()};
+ SmallVector<Value *, 8> falseOperands{op.getFalseOperands()};
+
+ rewriter.replaceOpWithNewOp<IREEInterp::LL::CondBranchOp>(
+ op, op.getCondition(), op.getTrueDest(), trueOperands,
+ op.getFalseDest(), falseOperands);
+ return matchSuccess();
+ }
+};
+
+// Returns true if the op defined by |opName| (like 'iree_ll_interp.reshape')
+// uses output operands for results (like iree_ll_interp.add_i) or returns real
+// results.
+bool opTakesOutputOperands(llvm::StringRef opName) {
+ if (!opName.consume_front("iree_ll_interp.")) {
+ assert(false && "op not part of IREE LL Interpreter dialect");
+ return false;
+ }
+ auto opcode = GetInterpreterOpcodeByName(opName.str());
+ assert(opcode.hasValue() && "op has no corresponding opcode");
+ const auto &info = GetInterpreterOpcodeInfo(opcode.getValue());
+ for (auto &operand : info.operands) {
+ if (operand == iree::OperandEncoding::kOutputSlot ||
+ operand == iree::OperandEncoding::kVariadicOutputSlots) {
+ return true;
+ }
+ }
+ return false;
+}
+
+template <typename SrcOp, typename DstOp>
+class SimpleOpLowering : public OpRewritePattern<SrcOp> {
+ using OpRewritePattern<SrcOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(SrcOp op,
+ PatternRewriter &rewriter) const {
+ SmallVector<Value *, 8> operands{op.getOperation()->getOperands()};
+
+ // Most ops take results as output operands to populate during execution.
+ // Certain ops, like reshape, return references to existing memrefs and
+ // should still retain their results.
+ if (!opTakesOutputOperands(DstOp::getOperationName())) {
+ SmallVector<Type, 8> resultTypes{op.getOperation()->getResultTypes()};
+
+ rewriter.replaceOpWithNewOp<DstOp>(op, resultTypes, operands,
+ op.getAttrs());
+ return this->matchSuccess();
+ }
+
+ SmallVector<Value *, 4> replacementValues;
+ for (Value *result : op.getOperation()->getResults()) {
+ auto memRefType = result->getType().cast<MemRefType>();
+ if (!memRefType.hasStaticShape()) {
+ // TODO(benvanik): real thing here - dynamic shaping required.
+ // This should emit a shape calculation based on the operation. Most
+ // are likely simple and by running DCE after this we can clean up
+ // parts that are static or unused.
+ op.emitOpError() << "uses unsupported dynamic shapes";
+ return this->matchFailure();
+ }
+ ArrayRef<Value *> dim_pieces;
+ auto allocOp = rewriter.create<IREEInterp::LL::AllocHeapOp>(
+ op.getLoc(), memRefType, dim_pieces);
+ operands.push_back(allocOp);
+ replacementValues.push_back(allocOp);
+ }
+ ArrayRef<Type> resultTypes;
+ rewriter.create<DstOp>(op.getLoc(), resultTypes, operands, op.getAttrs());
+ rewriter.replaceOp(op, replacementValues);
+ return this->matchSuccess();
+ }
+};
+
+} // namespace
+
+void populateInterpreterLoweringPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ patterns.insert<LowerBranchOpPattern, LowerCondCondBranchOpPattern>(ctx);
+ patterns.insert<
+ SimpleOpLowering<IREE::ConstantOp, IREEInterp::LL::ConstantOp>,
+ SimpleOpLowering<IREEInterp::HL::CopyOp, IREEInterp::LL::DynamicCopyOp>,
+ SimpleOpLowering<IREEInterp::HL::SliceOp,
+ IREEInterp::LL::DynamicSliceOp>>(ctx);
+#define SAME_NAME_SIMPLE_PATTERN(op_name) \
+ SimpleOpLowering<IREEInterp::HL::op_name, IREEInterp::LL::op_name>
+ // clang-format off
+ patterns.insert<
+ SAME_NAME_SIMPLE_PATTERN(AssignOp),
+ SAME_NAME_SIMPLE_PATTERN(AbsFOp),
+ SAME_NAME_SIMPLE_PATTERN(AbsIOp),
+ SAME_NAME_SIMPLE_PATTERN(AddFOp),
+ SAME_NAME_SIMPLE_PATTERN(AddIOp),
+ SAME_NAME_SIMPLE_PATTERN(AllocHeapOp),
+ SAME_NAME_SIMPLE_PATTERN(AndOp),
+ SAME_NAME_SIMPLE_PATTERN(Atan2FOp),
+ SAME_NAME_SIMPLE_PATTERN(BreakOp),
+ SAME_NAME_SIMPLE_PATTERN(BroadcastOp),
+ SAME_NAME_SIMPLE_PATTERN(CallOp),
+ SAME_NAME_SIMPLE_PATTERN(CallIndirectOp),
+ SAME_NAME_SIMPLE_PATTERN(CeilFOp),
+ SAME_NAME_SIMPLE_PATTERN(ClampFOp),
+ SAME_NAME_SIMPLE_PATTERN(CloneOp),
+ SAME_NAME_SIMPLE_PATTERN(CmpFOp),
+ SAME_NAME_SIMPLE_PATTERN(CmpIOp),
+ SAME_NAME_SIMPLE_PATTERN(CondAssignOp),
+ SAME_NAME_SIMPLE_PATTERN(ConvertSSOp),
+ SAME_NAME_SIMPLE_PATTERN(ConvertUUOp),
+ SAME_NAME_SIMPLE_PATTERN(ConvertSUOp),
+ SAME_NAME_SIMPLE_PATTERN(ConvertUSOp),
+ SAME_NAME_SIMPLE_PATTERN(CondBreakOp),
+ SAME_NAME_SIMPLE_PATTERN(CosFOp),
+ SAME_NAME_SIMPLE_PATTERN(DimOp),
+ SAME_NAME_SIMPLE_PATTERN(DivFOp),
+ SAME_NAME_SIMPLE_PATTERN(DivISOp),
+ SAME_NAME_SIMPLE_PATTERN(DivIUOp),
+ SAME_NAME_SIMPLE_PATTERN(ExpFOp),
+ SAME_NAME_SIMPLE_PATTERN(LogFOp),
+ SAME_NAME_SIMPLE_PATTERN(RsqrtFOp),
+ SAME_NAME_SIMPLE_PATTERN(FloorFOp),
+ SAME_NAME_SIMPLE_PATTERN(LengthOp),
+ SAME_NAME_SIMPLE_PATTERN(MatMulFOp),
+ SAME_NAME_SIMPLE_PATTERN(MatMulIOp),
+ SAME_NAME_SIMPLE_PATTERN(MaxFOp),
+ SAME_NAME_SIMPLE_PATTERN(MaxISOp),
+ SAME_NAME_SIMPLE_PATTERN(MaxIUOp),
+ SAME_NAME_SIMPLE_PATTERN(MinFOp),
+ SAME_NAME_SIMPLE_PATTERN(MinISOp),
+ SAME_NAME_SIMPLE_PATTERN(MinIUOp),
+ SAME_NAME_SIMPLE_PATTERN(MulAddFOp),
+ SAME_NAME_SIMPLE_PATTERN(MulAddIOp),
+ SAME_NAME_SIMPLE_PATTERN(MulFOp),
+ SAME_NAME_SIMPLE_PATTERN(MulIOp),
+ SAME_NAME_SIMPLE_PATTERN(NotOp),
+ SAME_NAME_SIMPLE_PATTERN(OrOp),
+ SAME_NAME_SIMPLE_PATTERN(PadOp),
+ SAME_NAME_SIMPLE_PATTERN(RankOp),
+ SAME_NAME_SIMPLE_PATTERN(ReduceSumIOp),
+ SAME_NAME_SIMPLE_PATTERN(ReduceSumFOp),
+ SAME_NAME_SIMPLE_PATTERN(ReduceMinIOp),
+ SAME_NAME_SIMPLE_PATTERN(ReduceMinFOp),
+ SAME_NAME_SIMPLE_PATTERN(ReduceMaxIOp),
+ SAME_NAME_SIMPLE_PATTERN(ReduceMaxFOp),
+ SAME_NAME_SIMPLE_PATTERN(ReshapeOp),
+ SAME_NAME_SIMPLE_PATTERN(ReturnOp),
+ SAME_NAME_SIMPLE_PATTERN(SelectOp),
+ SAME_NAME_SIMPLE_PATTERN(ShapeOp),
+ SAME_NAME_SIMPLE_PATTERN(ShiftLeftOp),
+ SAME_NAME_SIMPLE_PATTERN(ShiftRightArithmeticOp),
+ SAME_NAME_SIMPLE_PATTERN(ShiftRightLogicalOp),
+ SAME_NAME_SIMPLE_PATTERN(SinFOp),
+ SAME_NAME_SIMPLE_PATTERN(SplitOp),
+ SAME_NAME_SIMPLE_PATTERN(SubFOp),
+ SAME_NAME_SIMPLE_PATTERN(SubIOp),
+ SAME_NAME_SIMPLE_PATTERN(TanhFOp),
+ SAME_NAME_SIMPLE_PATTERN(TileOp),
+ SAME_NAME_SIMPLE_PATTERN(TraceOp),
+ SAME_NAME_SIMPLE_PATTERN(TransposeOp),
+ SAME_NAME_SIMPLE_PATTERN(ReverseOp),
+ SAME_NAME_SIMPLE_PATTERN(XorOp)>(ctx);
+ // clang-format on
+#undef SAME_NAME_SIMPLE_PATTERN
+}
+
+namespace {
+class LowerInterpreterDialectPass
+ : public FunctionPass<LowerInterpreterDialectPass> {
+ public:
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ populateInterpreterLoweringPatterns(patterns, &getContext());
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<IREELLInterpreterDialect>();
+ target.addLegalOp<FuncOp, IREE::ReturnOp>();
+ if (failed(applyFullConversion(getFunction(), target, patterns))) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> createLowerInterpreterDialectPass() {
+ return std::make_unique<LowerInterpreterDialectPass>();
+}
+
+static PassRegistration<LowerInterpreterDialectPass> pass(
+ "lower-iree-interpreter-hl-to-ll", "Lowers IREE HL ops to IREE LL ops");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp b/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp
new file mode 100644
index 0000000..c174669
--- /dev/null
+++ b/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp
@@ -0,0 +1,303 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Interpreter/HLDialect.h"
+#include "compiler/IR/Interpreter/HLOps.h"
+#include "compiler/IR/Interpreter/LLDialect.h"
+#include "compiler/IR/Ops.h"
+#include "compiler/Transforms/ConversionUtils.h"
+#include "compiler/Utils/MemRefUtils.h"
+#include "compiler/Utils/OpCreationUtils.h"
+#include "compiler/Utils/TypeConversionUtils.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/Allocator.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct CallOpLowering : public OpConversionPattern<CallOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ CallOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto callOp = cast<CallOp>(op);
+ auto calleeType = callOp.getCalleeType();
+ rewriter.replaceOpWithNewOp<IREEInterp::HL::CallOp>(
+ op, callOp.getCallee(), calleeType.getResults(), operands);
+ return matchSuccess();
+ }
+};
+
+struct CallIndirectOpLowering : public OpConversionPattern<CallIndirectOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ CallIndirectOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto callOp = cast<CallIndirectOp>(op);
+ rewriter.replaceOpWithNewOp<IREEInterp::HL::CallIndirectOp>(
+ op, callOp.getCallee(), operands);
+ return matchSuccess();
+ }
+};
+
+struct ReturnOpLowering : public OpConversionPattern<ReturnOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ ReturnOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<IREEInterp::HL::ReturnOp>(op, operands);
+ return matchSuccess();
+ }
+};
+
+struct BranchOpLowering : public OpConversionPattern<BranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ BranchOp op, ArrayRef<Value *> properOperands,
+ ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<IREEInterp::HL::BranchOp>(op, destinations[0],
+ operands[0]);
+ return this->matchSuccess();
+ }
+};
+
+struct CondBranchOpLowering : public OpConversionPattern<CondBranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ CondBranchOp op, ArrayRef<Value *> properOperands,
+ ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *condValue = loadAccessValue(op.getLoc(), properOperands[0], rewriter);
+ rewriter.replaceOpWithNewOp<IREEInterp::HL::CondBranchOp>(
+ op, condValue, destinations[IREEInterp::HL::CondBranchOp::trueIndex],
+ operands[IREEInterp::HL::CondBranchOp::trueIndex],
+ destinations[IREEInterp::HL::CondBranchOp::falseIndex],
+ operands[IREEInterp::HL::CondBranchOp::falseIndex]);
+ return this->matchSuccess();
+ }
+};
+
+template <typename SrcOp, typename DstOp>
+struct CompareOpLowering : public OpConversionPattern<SrcOp> {
+ using OpConversionPattern<SrcOp>::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ SrcOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto lhValue = loadAccessValue(op.getLoc(), operands[0], rewriter);
+ auto rhValue = loadAccessValue(op.getLoc(), operands[1], rewriter);
+
+ lhValue = wrapAsMemRef(lhValue, op, rewriter);
+ rhValue = wrapAsMemRef(rhValue, op, rewriter);
+
+ // TODO(benvanik): map predicate to stable value.
+ auto predicate =
+ rewriter.getI32IntegerAttr(static_cast<int32_t>(op.getPredicate()));
+
+ auto dstType = convertTypeToMemRef(op.getResult());
+ auto midOp = rewriter.create<DstOp>(op.getLoc(), dstType, predicate,
+ lhValue, rhValue);
+
+ auto result = wrapAsTensor(midOp.getResult(), op, rewriter);
+ rewriter.replaceOp(
+ op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)});
+ return this->matchSuccess();
+ }
+};
+
+struct CmpIOpLowering
+ : public CompareOpLowering<CmpIOp, IREEInterp::HL::CmpIOp> {
+ using CompareOpLowering::CompareOpLowering;
+};
+
+struct CmpFOpLowering
+ : public CompareOpLowering<CmpFOp, IREEInterp::HL::CmpFOp> {
+ using CompareOpLowering::CompareOpLowering;
+};
+
+struct AllocOpLowering : public OpConversionPattern<AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ AllocOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO(benvanik): replace with length computation.
+ rewriter.replaceOpWithNewOp<IREEInterp::HL::AllocHeapOp>(op, op.getType(),
+ operands);
+ return matchSuccess();
+ }
+};
+
+struct DeallocOpLowering : public OpConversionPattern<DeallocOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ DeallocOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<IREEInterp::HL::DiscardOp>(op, operands[0]);
+ return matchSuccess();
+ }
+};
+
+struct LoadOpLowering : public OpRewritePattern<LoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(LoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ if (loadOp.getMemRefType().getRank() != 0) {
+ loadOp.emitError() << "Cannot lower load of non-scalar";
+ return matchFailure();
+ }
+ ArrayRef<Value *> dimPieces;
+ auto dst =
+ rewriter
+ .create<AllocOp>(loadOp.getLoc(), loadOp.getMemRefType(), dimPieces)
+ .getResult();
+ auto emptyArrayMemref = createArrayConstant(rewriter, loadOp.getLoc(), {});
+ rewriter.create<IREEInterp::HL::CopyOp>(
+ loadOp.getLoc(), loadOp.getMemRef(),
+ /*srcIndices=*/emptyArrayMemref, dst,
+ /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref);
+
+ rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(loadOp, dst);
+
+ return matchSuccess();
+ }
+};
+
+struct StoreOpLowering : public OpRewritePattern<StoreOp> {
+ using OpRewritePattern::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(StoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ if (storeOp.getMemRefType().getRank() != 0) {
+ storeOp.emitError() << "Cannot lower store of non-scalar";
+ return matchFailure();
+ }
+
+ auto src = rewriter.create<IREE::ScalarToMemRefOp>(
+ storeOp.getLoc(), storeOp.getValueToStore());
+
+ auto emptyArrayMemref = createArrayConstant(rewriter, storeOp.getLoc(), {});
+ rewriter.replaceOpWithNewOp<IREEInterp::HL::CopyOp>(
+ storeOp, src, /*srcIndices=*/emptyArrayMemref, storeOp.getMemRef(),
+ /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref);
+
+ return matchSuccess();
+ }
+};
+
+#define UNARY_OP_LOWERING(StdOpType, IREEOpType) \
+ struct StdOpType##Lowering : public UnaryOpLowering<StdOpType, IREEOpType> { \
+ using UnaryOpLowering::UnaryOpLowering; \
+ };
+
+#define BINARY_OP_LOWERING(StdOpType, IREEOpType) \
+ struct StdOpType##Lowering \
+ : public BinaryOpLowering<StdOpType, IREEOpType> { \
+ using BinaryOpLowering::BinaryOpLowering; \
+ };
+
+#define TERNARY_OP_LOWERING(StdOpType, IREEOpType) \
+ struct StdOpType##Lowering \
+ : public TernaryOpLowering<StdOpType, IREEOpType> { \
+ using TernaryOpLowering::TernaryOpLowering; \
+ };
+
+// UNARY_OP_LOWERING(RankOp, IREEInterp::HL::RankOp);
+UNARY_OP_LOWERING(DimOp, IREEInterp::HL::DimOp);
+// UNARY_OP_LOWERING(ShapeOp, IREEInterp::HL::ShapeOp);
+// UNARY_OP_LOWERING(LengthOp, IREEInterp::HL::LengthOp);
+
+// UNARY_OP_LOWERING(NotOp, IREEInterp::HL::NotOp);
+BINARY_OP_LOWERING(AndOp, IREEInterp::HL::AndOp);
+BINARY_OP_LOWERING(OrOp, IREEInterp::HL::OrOp);
+// BINARY_OP_LOWERING(XorOp, IREEInterp::HL::XorOp);
+// BINARY_OP_LOWERING(ShiftLeftOp, IREEInterp::HL::ShiftLeftOp);
+// BINARY_OP_LOWERING(ShiftRightLogicalOp, IREEInterp::HL::ShiftRightLogicalOp);
+// BINARY_OP_LOWERING(ShiftRightArithmeticOp,
+// IREEInterp::HL::ShiftRightArithmeticOp);
+
+BINARY_OP_LOWERING(AddIOp, IREEInterp::HL::AddIOp);
+BINARY_OP_LOWERING(AddFOp, IREEInterp::HL::AddFOp);
+BINARY_OP_LOWERING(SubIOp, IREEInterp::HL::SubIOp);
+BINARY_OP_LOWERING(SubFOp, IREEInterp::HL::SubFOp);
+// UNARY_OP_LOWERING(AbsIOp, IREEInterp::HL::AbsIOp);
+// UNARY_OP_LOWERING(AbsFOp, IREEInterp::HL::AbsFOp);
+BINARY_OP_LOWERING(MulIOp, IREEInterp::HL::MulIOp);
+BINARY_OP_LOWERING(MulFOp, IREEInterp::HL::MulFOp);
+BINARY_OP_LOWERING(DivISOp, IREEInterp::HL::DivISOp);
+BINARY_OP_LOWERING(DivIUOp, IREEInterp::HL::DivIUOp);
+BINARY_OP_LOWERING(DivFOp, IREEInterp::HL::DivFOp);
+// BINARY_OP_LOWERING(MulAddIOp, IREEInterp::HL::MulAddIOp);
+// BINARY_OP_LOWERING(MulAddFOp, IREEInterp::HL::MulAddFOp);
+// UNARY_OP_LOWERING(ExpFOp, IREEInterp::HL::ExpFOp);
+// UNARY_OP_LOWERING(LogFOp, IREEInterp::HL::LogFOp);
+// UNARY_OP_LOWERING(RsqrtFOp, IREEInterp::HL::RsqrtFOp);
+// UNARY_OP_LOWERING(CosFOp, IREEInterp::HL::CosFOp);
+// UNARY_OP_LOWERING(SinFOp, IREEInterp::HL::SinFOp);
+// UNARY_OP_LOWERING(TanhFOp, IREEInterp::HL::TanhFOp);
+// UNARY_OP_LOWERING(Atan2FOp, IREEInterp::HL::Atan2FOp);
+
+// BINARY_OP_LOWERING(MinISOp, IREEInterp::HL::MinISOp);
+// BINARY_OP_LOWERING(MinIUOp, IREEInterp::HL::MinIUOp);
+// BINARY_OP_LOWERING(MinFOp, IREEInterp::HL::MinFOp);
+// BINARY_OP_LOWERING(MaxISOp, IREEInterp::HL::MaxISOp);
+// BINARY_OP_LOWERING(MaxIUOp, IREEInterp::HL::MaxIUOp);
+// BINARY_OP_LOWERING(MaxFOp, IREEInterp::HL::MaxFOp);
+// TERNARY_OP_LOWERING(ClampFOp, IREEInterp::HL::ClampFOp);
+// UNARY_OP_LOWERING(FloorFOp, IREEInterp::HL::FloorFOp);
+// UNARY_OP_LOWERING(CeilFOp, IREEInterp::HL::CeilFOp);
+
+} // namespace
+
+void populateLowerStdToInterpreterPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ patterns.insert<
+ // Control flow.
+ CallOpLowering, CallIndirectOpLowering, ReturnOpLowering,
+ BranchOpLowering, CondBranchOpLowering, CmpIOpLowering, CmpFOpLowering,
+ // Memory management.
+ AllocOpLowering, DeallocOpLowering, LoadOpLowering, StoreOpLowering,
+ // Shape operations.
+ DimOpLowering,
+ // Logical ops.
+ AndOpLowering, OrOpLowering,
+ // Arithmetic ops.
+ AddIOpLowering, AddFOpLowering, SubIOpLowering, SubFOpLowering,
+ MulIOpLowering, MulFOpLowering, DivISOpLowering, DivIUOpLowering,
+ DivFOpLowering>(ctx);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp b/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp
new file mode 100644
index 0000000..a1334de
--- /dev/null
+++ b/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp
@@ -0,0 +1,63 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Interpreter/HLDialect.h"
+#include "compiler/IR/Interpreter/LLDialect.h"
+#include "compiler/Transforms/Interpreter/Rewrites.h"
+#include "compiler/Transforms/Rewrites.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+class LowerToInterpreterDialectPass
+ : public FunctionPass<LowerToInterpreterDialectPass> {
+ public:
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ auto* ctx = &getContext();
+ xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, ctx);
+ xla_hlo::PopulateXlaToStdPatterns(&patterns, ctx);
+ populateLowerStdToIreePatterns(patterns, ctx);
+ populateLowerStdToInterpreterPatterns(patterns, ctx);
+ populateLowerXlaToIreePatterns(patterns, ctx);
+ populateLowerXlaToInterpreterPatterns(patterns, ctx);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<IREEHLInterpreterDialect, IREEDialect>();
+ target.addLegalOp<FuncOp, ReturnOp>();
+ if (failed(applyFullConversion(getFunction(), target, patterns))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> createLowerToInterpreterDialectPass() {
+ return std::make_unique<LowerToInterpreterDialectPass>();
+}
+
+static PassRegistration<LowerToInterpreterDialectPass> pass(
+ "lower-to-iree-interpreter",
+ "Convert all ops to the IREE interpreter dialect");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp b/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp
new file mode 100644
index 0000000..00668ae
--- /dev/null
+++ b/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp
@@ -0,0 +1,565 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Interpreter/HLDialect.h"
+#include "compiler/IR/Interpreter/HLOps.h"
+#include "compiler/IR/Ops.h"
+#include "compiler/Transforms/ConversionUtils.h"
+#include "compiler/Utils/MemRefUtils.h"
+#include "compiler/Utils/OpCreationUtils.h"
+#include "compiler/Utils/TypeConversionUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// TODO(suderman): tablegen this? or something a bit more flexible.
+
+#define UNARY_OP_LOWERING(XlaOpType, IREEOpType) \
+ struct XlaOpType##Lowering \
+ : public UnaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \
+ using UnaryOpLowering::UnaryOpLowering; \
+ };
+
+#define TERNARY_OP_LOWERING(XlaOpType, IREEOpType) \
+ struct XlaOpType##Lowering \
+ : public TernaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \
+ using TernaryOpLowering::TernaryOpLowering; \
+ };
+
+UNARY_OP_LOWERING(CopyOp, IREEInterp::HL::CloneOp);
+UNARY_OP_LOWERING(ExpOp, IREEInterp::HL::ExpFOp);
+UNARY_OP_LOWERING(LogOp, IREEInterp::HL::LogFOp);
+UNARY_OP_LOWERING(FloorOp, IREEInterp::HL::FloorFOp);
+UNARY_OP_LOWERING(RsqrtOp, IREEInterp::HL::RsqrtFOp);
+UNARY_OP_LOWERING(TanhOp, IREEInterp::HL::TanhFOp);
+TERNARY_OP_LOWERING(SelectOp, IREEInterp::HL::SelectOp);
+
+#undef UNARY_OP_LOWERING
+#undef TERNARY_OP_LOWERING
+
+template <typename T>
+static Operation *createShapeTargetingOp(ConversionPatternRewriter &rewriter,
+ Location loc, Value *input,
+ MemRefType targetType) {
+ auto shapeOp = createArrayConstant(rewriter, loc, targetType.getShape());
+ return rewriter.create<T>(loc, targetType, input, shapeOp);
+}
+
+static Value *inputAsMemref(ConversionPatternRewriter &rewriter, Operation *op,
+ Value *tensor) {
+ return wrapAsMemRef(loadAccessValue(op->getLoc(), tensor, rewriter), op,
+ rewriter);
+}
+
+template <typename SrcOp>
+class XlaOpLowering : public OpConversionPattern<SrcOp> {
+ using OpConversionPattern<SrcOp>::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ SrcOp srcOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value *, 4> memrefOperands;
+ for (auto operand : operands) {
+ memrefOperands.push_back(inputAsMemref(rewriter, srcOp, operand));
+ }
+
+ if (auto dstOp = rewriteInternal(&srcOp, memrefOperands, rewriter)) {
+ rewriter.replaceOp(srcOp,
+ wrapAsTensor(dstOp->getResult(0), srcOp, rewriter));
+ return this->matchSuccess();
+ }
+ return this->matchFailure();
+ }
+
+ protected:
+ virtual Operation *rewriteInternal(
+ SrcOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const {
+ llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?");
+ }
+};
+
+struct BroadcastInDimOpLowering
+ : public XlaOpLowering<xla_hlo::BroadcastInDimOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::BroadcastInDimOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *inputValue = operands[0];
+ auto inputType = inputValue->getType().cast<MemRefType>();
+ auto finalType = convertTypeToMemRef(*op);
+
+ // Reshape to scalar and broadcast.
+ auto createFinal = createShapeTargetingOp<IREEInterp::HL::BroadcastOp>;
+ llvm::SmallVector<int64_t, 6> intermediateShape{};
+
+ // Or reshape to final rank and tile.
+ if (inputType.getNumElements() != 1) {
+ createFinal = createShapeTargetingOp<IREEInterp::HL::TileOp>;
+
+ intermediateShape = llvm::SmallVector<int64_t, 6>(finalType.getRank(), 1);
+ auto inputShape = inputType.getShape();
+ auto dimensions = op->broadcast_dimensions();
+ for (size_t i = 0; i < inputType.getRank(); ++i) {
+ auto index = dimensions->getValue(i).cast<IntegerAttr>().getInt();
+ intermediateShape[index] = inputShape[i];
+ }
+ }
+
+ auto intermediateType =
+ MemRefType::get(intermediateShape, inputType.getElementType());
+ auto reshapeOp = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>(
+ rewriter, op->getLoc(), inputValue, intermediateType);
+ return createFinal(rewriter, op->getLoc(), reshapeOp->getResult(0),
+ finalType);
+ }
+};
+
+struct ConcatOpLowering : public XlaOpLowering<xla_hlo::ConcatenateOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::ConcatenateOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto finalType = convertTypeToMemRef(*op);
+
+ return rewriter.create<IREEInterp::HL::ConcatOp>(
+ op->getLoc(), finalType, operands,
+ rewriter.getI32IntegerAttr(op->dimension().getZExtValue()));
+ }
+};
+
+struct DotOpLowering : public XlaOpLowering<xla_hlo::DotOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::DotOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *lhsValue = operands[0];
+ auto *rhsValue = operands[1];
+
+ auto finalType = convertTypeToMemRef(*op);
+ auto elementType = finalType.getElementType();
+ if (!elementType.isa<FloatType>()) {
+ op->emitOpError("xla_hlo.dot only supports floating point values");
+ }
+
+ Operation *matMulOp = rewriter
+ .create<IREEInterp::HL::MatMulFOp>(
+ op->getLoc(), finalType, lhsValue, rhsValue)
+ .getOperation();
+ return matMulOp;
+ }
+};
+
+struct DynamicUpdateSliceOpLowering
+ : public XlaOpLowering<xla_hlo::DynamicUpdateSliceOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::DynamicUpdateSliceOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto operand = operands[0];
+ auto update = operands[1];
+
+ auto updateType = update->getType().cast<ShapedType>();
+ Value *lengthConstant =
+ createArrayConstant(rewriter, op->getLoc(), updateType.getShape());
+
+ auto startIndices = makeArrayRef(operands).drop_front(2);
+ const int rank = startIndices.size();
+ llvm::SmallVector<Value *, 4> valuesToConcat;
+ valuesToConcat.reserve(startIndices.size());
+ auto type = getElementTypeOrSelf(startIndices.front());
+
+ // To generate the offset matrix we need to convert the variadic tensors
+ // into a reshaped and concated value.
+ for (auto index : startIndices) {
+ auto reshapedIndex = rewriter.create<IREEInterp::HL::ReshapeOp>(
+ op->getLoc(), MemRefType::get({1}, type), index,
+ createArrayConstant(rewriter, op->getLoc(), {1}));
+ valuesToConcat.push_back(reshapedIndex);
+ }
+
+ auto dstOffset = rewriter
+ .create<IREEInterp::HL::ConcatOp>(
+ op->getLoc(), MemRefType::get({rank}, type),
+ valuesToConcat, rewriter.getI32IntegerAttr(0))
+ .getResult();
+
+ llvm::SmallVector<int64_t, 4> zero_offset;
+ zero_offset.resize(updateType.getRank(), 0);
+ auto srcOffset = createArrayConstant(rewriter, op->getLoc(), zero_offset);
+
+ auto copiedOperand = rewriter.create<IREEInterp::HL::CloneOp>(
+ op->getLoc(), operand->getType(), operand);
+
+ rewriter
+ .create<IREEInterp::HL::CopyOp>(op->getLoc(), update, srcOffset,
+ copiedOperand, dstOffset,
+ lengthConstant)
+ .getOperation();
+
+ return copiedOperand;
+ }
+};
+
+template <typename XlaOpType, typename IreeFloatOpType, typename IreeIntOpType>
+struct BinaryFloatIntOpLowering : public XlaOpLowering<XlaOpType> {
+ using XlaOpLowering<XlaOpType>::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ XlaOpType *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *lhs = operands[0];
+ auto *rhs = operands[1];
+ auto inputType = lhs->getType().cast<MemRefType>();
+ auto elementType = inputType.getElementType();
+
+ if (elementType.isa<FloatType>()) {
+ return rewriter.create<IreeFloatOpType>(op->getLoc(), inputType, lhs,
+ rhs);
+ }
+
+ return rewriter.create<IreeIntOpType>(op->getLoc(), inputType, lhs, rhs);
+ }
+};
+
+struct MaxOpLowering
+ : public BinaryFloatIntOpLowering<xla_hlo::MaxOp, IREEInterp::HL::MaxFOp,
+ IREEInterp::HL::MaxISOp> {
+ using BinaryFloatIntOpLowering::BinaryFloatIntOpLowering;
+};
+
+struct MinOpLowering
+ : public BinaryFloatIntOpLowering<xla_hlo::MinOp, IREEInterp::HL::MinFOp,
+ IREEInterp::HL::MinISOp> {
+ using BinaryFloatIntOpLowering::BinaryFloatIntOpLowering;
+};
+
+struct ConvertLowering : public XlaOpLowering<xla_hlo::ConvertOp> {
+ using XlaOpLowering<xla_hlo::ConvertOp>::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::ConvertOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *operand = operands[0];
+ auto *result = op->getResult();
+
+ auto operandType = operand->getType().cast<MemRefType>().getElementType();
+ auto resultType = result->getType().cast<ShapedType>().getElementType();
+
+ auto newResultType = convertTypeToMemRef(result);
+
+#define ConvertCase(InType, OutType, NewOp) \
+ { \
+ if (operandType.isa<InType>() && resultType.isa<OutType>()) { \
+ return rewriter.create<NewOp>(op->getLoc(), newResultType, operand); \
+ } \
+ }
+ ConvertCase(IntegerType, IntegerType, IREEInterp::HL::ConvertSSOp);
+ ConvertCase(IntegerType, FloatType, IREEInterp::HL::ConvertSFOp);
+ ConvertCase(FloatType, IntegerType, IREEInterp::HL::ConvertFSOp);
+ ConvertCase(FloatType, FloatType, IREEInterp::HL::ConvertFFOp);
+#undef ConvertCase
+
+ return nullptr;
+ }
+};
+
+// Lowers a subset of gathers along axis 0 that are really just a slice and
+// reshape.
+struct GatherOpLowering : public OpConversionPattern<xla_hlo::GatherOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ // TODO(gcmn): This only handles a minimal number of cases. When XLA
+ // redefines gather to be simpler, lower it properly.
+ PatternMatchResult matchAndRewrite(
+ xla_hlo::GatherOp gatherOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (gatherOp.index_vector_dim() != 0) {
+ gatherOp.emitRemark()
+ << "Couldn't lower gather with index_vector_dim != 0";
+ return matchFailure();
+ }
+ if (gatherOp.start_index_map().getType().getRank() != 1 ||
+ gatherOp.start_index_map().getValue(0).cast<IntegerAttr>().getValue() !=
+ 0) {
+ gatherOp.emitRemark()
+ << "Couldn't lower gather with start_index_map != [0]";
+ return matchFailure();
+ }
+ if (gatherOp.collapsed_slice_dims().getType().getRank() != 1 ||
+ gatherOp.collapsed_slice_dims()
+ .getValue(0)
+ .cast<IntegerAttr>()
+ .getValue() != 0) {
+ gatherOp.emitRemark()
+ << "Couldn't lower gather with collapsed_dims != [0]";
+ return matchFailure();
+ }
+
+ auto resultType = gatherOp.getResult()->getType().cast<RankedTensorType>();
+ if (gatherOp.offset_dims().getType().getNumElements() !=
+ resultType.getRank()) {
+ gatherOp.emitRemark() << "Couldn't lower gather with offset_dims != "
+ "[0,...,rank of output]";
+ return matchFailure();
+ }
+ for (auto it : llvm::enumerate(gatherOp.offset_dims())) {
+ if (it.index() != it.value()) {
+ gatherOp.emitRemark() << "Couldn't lower gather with offset_dims != "
+ "[0,...,rank of output]";
+ return matchFailure();
+ }
+ }
+
+ for (auto it : llvm::enumerate(resultType.getShape())) {
+ if (gatherOp.slice_sizes()
+ .getValue(it.index() + 1)
+ .cast<IntegerAttr>()
+ .getValue() != it.value()) {
+ gatherOp.emitRemark()
+ << "Couldn't lower gather with slice_sizes not [1] + final shape";
+ return matchFailure();
+ }
+ }
+
+ auto inputType = gatherOp.operand()->getType().cast<RankedTensorType>();
+
+ auto startIndices =
+ inputAsMemref(rewriter, gatherOp, gatherOp.start_indices());
+ auto startIndicesType = startIndices->getType().cast<MemRefType>();
+ if (startIndicesType.getNumElements() != inputType.getRank()) {
+ auto extraDims = inputType.getRank() - startIndicesType.getNumElements();
+ auto elementType = startIndicesType.getElementType();
+
+ if (startIndicesType.getRank() != 1) {
+ startIndices = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>(
+ rewriter, gatherOp.getLoc(), startIndices,
+ MemRefType::get({1}, elementType))
+ ->getResult(0);
+ }
+
+ llvm::SmallVector<int64_t, 4> zeroes;
+ zeroes.resize(extraDims, 0);
+
+ auto elementsAttr = DenseIntElementsAttr::get(
+ RankedTensorType::get(zeroes.size(), elementType),
+ llvm::makeArrayRef(zeroes));
+
+ auto extraStartIndices =
+ rewriter.create<IREE::ConstantOp>(gatherOp.getLoc(), elementsAttr);
+
+ auto memrefOutputType =
+ MemRefType::get({inputType.getRank()}, elementType);
+
+ SmallVector<Value *, 2> valuesToConcat = {startIndices,
+ extraStartIndices};
+ startIndices = rewriter.create<IREEInterp::HL::ConcatOp>(
+ gatherOp.getLoc(), memrefOutputType, valuesToConcat,
+ rewriter.getI32IntegerAttr(0));
+ }
+
+ auto sliceSizeValues = gatherOp.slice_sizes().getValues<int64_t>();
+ std::vector<int64_t> sliceSizes = {sliceSizeValues.begin(),
+ sliceSizeValues.end()};
+ auto dstType = MemRefType::get(sliceSizes, inputType.getElementType());
+
+ auto src = inputAsMemref(rewriter, gatherOp, gatherOp.operand());
+ std::vector<Value *> dim_pieces;
+ auto dst = rewriter.create<IREEInterp::HL::AllocHeapOp>(
+ gatherOp.getLoc(), dstType, dim_pieces);
+ auto lengths = rewriter.create<IREE::ConstantOp>(gatherOp.getLoc(),
+ gatherOp.slice_sizes());
+ llvm::SmallVector<int64_t, 4> zero_offset;
+ zero_offset.resize(dstType.getRank(), 0);
+ auto dstIndices =
+ createArrayConstant(rewriter, gatherOp.getLoc(), zero_offset);
+
+ rewriter.create<IREEInterp::HL::CopyOp>(
+ gatherOp.getLoc(), src, startIndices, dst, dstIndices, lengths);
+
+ auto reshaped = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>(
+ rewriter, gatherOp.getLoc(), dst, convertTypeToMemRef(gatherOp));
+ rewriter.replaceOp(
+ gatherOp, wrapAsTensor(reshaped->getResult(0), gatherOp, rewriter));
+
+ return matchSuccess();
+ }
+};
+
+struct SliceOpLowering : public XlaOpLowering<xla_hlo::SliceOp> {
+ using XlaOpLowering<xla_hlo::SliceOp>::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::SliceOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // XLA slice has value semantics, whereas the IREE slice creates a view. We
+ // lower it to a copy if all strides are one which may be transformed to a
+ // slice by later optimizations.
+ auto isNotOne = [](APInt stride) { return stride != 1; };
+ if (llvm::any_of(op->strides(), isNotOne)) {
+ op->emitRemark() << "Could not lower slice op with non-singular strides";
+ return nullptr;
+ }
+
+ auto finalType = convertTypeToMemRef(*op);
+ auto src = operands[0];
+ std::vector<Value *> dim_pieces;
+ auto dst = rewriter.create<IREEInterp::HL::AllocHeapOp>(
+ op->getLoc(), finalType, dim_pieces);
+ auto srcIndices =
+ rewriter.create<IREE::ConstantOp>(op->getLoc(), op->start_indices());
+ auto lengths =
+ createArrayConstant(rewriter, op->getLoc(), finalType.getShape());
+
+ llvm::SmallVector<int64_t, 4> zero_offset;
+ zero_offset.resize(finalType.getRank(), 0);
+ auto dstIndices = createArrayConstant(rewriter, op->getLoc(), zero_offset);
+
+ rewriter.create<IREEInterp::HL::CopyOp>(op->getLoc(), src, srcIndices, dst,
+ dstIndices, lengths);
+ return dst;
+ }
+};
+
+struct PadOpLowering : public XlaOpLowering<xla_hlo::PadOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::PadOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *src = operands[0];
+ auto *paddingValue = operands[1];
+
+ // TODO(b/140836672) Support negative padding
+ for (int i = 0; i < op->edge_padding_high().getNumElements(); ++i) {
+ if (op->edge_padding_high().getValue<IntegerAttr>(i).getInt() < 0 ||
+ op->edge_padding_low().getValue<IntegerAttr>(i).getInt() < 0) {
+ op->emitRemark() << "Could not lower pad op with negative padding";
+ return nullptr;
+ }
+ }
+
+ auto edgePaddingLowOp =
+ rewriter.create<IREE::ConstantOp>(op->getLoc(), op->edge_padding_low());
+ auto edgePaddingHighOp = rewriter.create<IREE::ConstantOp>(
+ op->getLoc(), op->edge_padding_high());
+ auto interiorPaddingOp =
+ rewriter.create<IREE::ConstantOp>(op->getLoc(), op->interior_padding());
+
+ return rewriter.create<IREEInterp::HL::PadOp>(
+ op->getLoc(), convertTypeToMemRef(*op), src, paddingValue,
+ edgePaddingLowOp, edgePaddingHighOp, interiorPaddingOp);
+ }
+};
+
+struct ReshapeOpLowering : public XlaOpLowering<xla_hlo::ReshapeOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::ReshapeOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ return createShapeTargetingOp<IREEInterp::HL::ReshapeOp>(
+ rewriter, op->getLoc(), operands[0], convertTypeToMemRef(*op));
+ }
+};
+
+struct TransposeOpLowering : public XlaOpLowering<xla_hlo::TransposeOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::TransposeOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto permutationOp =
+ rewriter.create<IREE::ConstantOp>(op->getLoc(), op->permutation());
+
+ return rewriter.create<IREEInterp::HL::TransposeOp>(
+ op->getLoc(), convertTypeToMemRef(*op), operands[0], permutationOp);
+ }
+};
+
+struct ReverseOpLowering : public XlaOpLowering<xla_hlo::ReverseOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::ReverseOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto reverseOp =
+ rewriter.create<IREE::ConstantOp>(op->getLoc(), op->dimensions());
+
+ return rewriter.create<IREEInterp::HL::ReverseOp>(
+ op->getLoc(), convertTypeToMemRef(*op), operands[0], reverseOp);
+ }
+};
+
+} // namespace
+
+void populateLowerXlaToInterpreterPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ patterns
+ .insert<BroadcastInDimOpLowering, ConcatOpLowering, ConvertLowering,
+ CopyOpLowering, DotOpLowering, DynamicUpdateSliceOpLowering,
+ ExpOpLowering, FloorOpLowering, GatherOpLowering, LogOpLowering,
+ MaxOpLowering, MinOpLowering, PadOpLowering, ReshapeOpLowering,
+ ReverseOpLowering, RsqrtOpLowering, SelectOpLowering,
+ SliceOpLowering, TransposeOpLowering, TanhOpLowering>(ctx);
+}
+
+namespace {
+// Just for testing these passes.
+// TODO(b/141337493) can we get rid of this pass entirely?
+class LowerXLAToInterpreterDialectPass
+ : public FunctionPass<LowerXLAToInterpreterDialectPass> {
+ public:
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ populateLowerXlaToInterpreterPatterns(patterns, &getContext());
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<IREEHLInterpreterDialect, IREEDialect>();
+ if (failed(applyPartialConversion(getFunction(), target, patterns))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+static PassRegistration<LowerXLAToInterpreterDialectPass> pass(
+ "lower-xla-to-iree-interpreter",
+ "Convert all XLA functions to the IREE dialect");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/MakeExecutableABI.cpp b/compiler/Transforms/Interpreter/MakeExecutableABI.cpp
new file mode 100644
index 0000000..793cb06
--- /dev/null
+++ b/compiler/Transforms/Interpreter/MakeExecutableABI.cpp
@@ -0,0 +1,147 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Interpreter/HLOps.h"
+#include "compiler/IR/Ops.h"
+#include "compiler/Utils/OpCreationUtils.h"
+#include "compiler/Utils/OpUtils.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Replaces a load_input op with valid IR that loads the input value.
+LogicalResult replaceLoadInputOp(IREE::LoadInputOp bindOp) {
+ OpBuilder builder(bindOp);
+
+ Value *newValue = nullptr;
+ auto dstType = bindOp.getResult()->getType();
+ if (dstType.isa<TensorType>()) {
+ auto castOp =
+ builder.create<IREE::MemRefToTensorOp>(bindOp.getLoc(), bindOp.src());
+ newValue = castOp.getResult();
+ } else if (dstType.isIntOrIndexOrFloat()) {
+ auto loadOp = builder.create<LoadOp>(bindOp.getLoc(), dstType, bindOp.src(),
+ ArrayRef<Value *>{});
+ newValue = loadOp.getResult();
+ } else {
+ return bindOp.emitError()
+ << "Unsupported input destination type " << dstType;
+ }
+
+ bindOp.replaceAllUsesWith(newValue);
+ bindOp.erase();
+
+ return success();
+}
+
+// Replaces a store_output op with valid IR that stores the output value.
+LogicalResult replaceStoreOutputOp(IREE::StoreOutputOp bindOp) {
+ OpBuilder builder(bindOp);
+
+ auto srcType = bindOp.src()->getType();
+ if (srcType.isa<MemRefType>()) {
+ // Already stored into the output.
+ } else if (srcType.isa<TensorType>()) {
+ auto castOp =
+ builder.create<IREE::TensorToMemRefOp>(bindOp.getLoc(), bindOp.src());
+
+ // Insert a copy to our output parameter.
+ auto dst = bindOp.dst()->getType().cast<ShapedType>();
+ if (!dst.hasStaticShape()) {
+ return bindOp.emitError()
+ << "Dynamic output args are not yet implemented";
+ }
+
+ auto zeroValues = llvm::SmallVector<int64_t, 4>(dst.getRank());
+ auto zeros = createArrayConstant(builder, bindOp.getLoc(), zeroValues);
+ auto lengths =
+ createArrayConstant(builder, bindOp.getLoc(), dst.getShape());
+ builder.create<IREEInterp::HL::CopyOp>(bindOp.getLoc(), castOp.getResult(),
+ zeros, bindOp.dst(), zeros, lengths);
+ } else if (srcType.isIntOrIndexOrFloat()) {
+ builder.create<StoreOp>(bindOp.getLoc(), bindOp.src(), bindOp.dst(),
+ ArrayRef<Value *>{});
+ } else {
+ return bindOp.emitError() << "Unsupported output src type " << srcType;
+ }
+
+ bindOp.erase();
+
+ return success();
+}
+
+// Strips iree.bind_* ops from |func|.
+LogicalResult stripBindingOps(FuncOp func) {
+ // Find iree.load_input ops to replace with memref_to_tensor if needed.
+ SmallVector<IREE::LoadInputOp, 8> bindInputOps;
+ func.walk([&](IREE::LoadInputOp bindOp) { bindInputOps.push_back(bindOp); });
+ for (auto &bindOp : bindInputOps) {
+ if (failed(replaceLoadInputOp(bindOp))) {
+ return failure();
+ }
+ }
+
+ // Find iree.store_output ops and replace with tensor_to_memref if needed.
+ SmallVector<IREE::StoreOutputOp, 8> bindOutputOps;
+ func.walk(
+ [&](IREE::StoreOutputOp bindOp) { bindOutputOps.push_back(bindOp); });
+ for (auto &bindOp : bindOutputOps) {
+ if (failed(replaceStoreOutputOp(bindOp))) {
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+} // namespace
+
+// Finds iree.executable.export functions and fixes up bindings.
+// For the interpreter this really just means stripping the bind ops entirely.
+class MakeExecutableABIPass : public ModulePass<MakeExecutableABIPass> {
+ public:
+ void runOnModule() override {
+ auto module = getModule();
+ for (auto func : module.getOps<FuncOp>()) {
+ if (func.getAttr("iree.executable.export")) {
+ if (failed(stripBindingOps(func))) {
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createMakeExecutableABIPass() {
+ return std::make_unique<MakeExecutableABIPass>();
+}
+
+static PassRegistration<MakeExecutableABIPass> pass(
+ "iree-make-executable-abi",
+ "Makes functions match the IREE dispatch executable ABI.");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/Passes.h b/compiler/Transforms/Interpreter/Passes.h
similarity index 100%
rename from iree/compiler/Transforms/Interpreter/Passes.h
rename to compiler/Transforms/Interpreter/Passes.h
diff --git a/iree/compiler/Transforms/Interpreter/Rewrites.h b/compiler/Transforms/Interpreter/Rewrites.h
similarity index 100%
rename from iree/compiler/Transforms/Interpreter/Rewrites.h
rename to compiler/Transforms/Interpreter/Rewrites.h
diff --git a/compiler/Transforms/Interpreter/test/BUILD b/compiler/Transforms/Interpreter/test/BUILD
new file mode 100644
index 0000000..fb38390
--- /dev/null
+++ b/compiler/Transforms/Interpreter/test/BUILD
@@ -0,0 +1,16 @@
+# Tests for lowering MLIR in various dialects to IREE interpreter bytecode.
+
+load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_setup_lit_package(
+ data = [
+ "///tools:iree-opt",
+ ],
+)
+
+iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/Interpreter/test/clone.mlir b/compiler/Transforms/Interpreter/test/clone.mlir
similarity index 100%
rename from iree/compiler/Transforms/Interpreter/test/clone.mlir
rename to compiler/Transforms/Interpreter/test/clone.mlir
diff --git a/iree/compiler/Transforms/Interpreter/test/make_executable_abi.mlir b/compiler/Transforms/Interpreter/test/make_executable_abi.mlir
similarity index 100%
rename from iree/compiler/Transforms/Interpreter/test/make_executable_abi.mlir
rename to compiler/Transforms/Interpreter/test/make_executable_abi.mlir
diff --git a/compiler/Transforms/Interpreter/test/xla/BUILD b/compiler/Transforms/Interpreter/test/xla/BUILD
new file mode 100644
index 0000000..32c7031
--- /dev/null
+++ b/compiler/Transforms/Interpreter/test/xla/BUILD
@@ -0,0 +1,17 @@
+# Tests specific to lowering XLA to IREE.
+
+load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_setup_lit_package(
+ data = [
+ "///tools:iree-opt",
+ "///tools:iree-run-mlir",
+ ],
+)
+
+iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/Interpreter/test/xla/concat.mlir b/compiler/Transforms/Interpreter/test/xla/concat.mlir
similarity index 100%
rename from iree/compiler/Transforms/Interpreter/test/xla/concat.mlir
rename to compiler/Transforms/Interpreter/test/xla/concat.mlir
diff --git a/iree/compiler/Transforms/Interpreter/test/xla/dynamic_update_slice.mlir b/compiler/Transforms/Interpreter/test/xla/dynamic_update_slice.mlir
similarity index 100%
rename from iree/compiler/Transforms/Interpreter/test/xla/dynamic_update_slice.mlir
rename to compiler/Transforms/Interpreter/test/xla/dynamic_update_slice.mlir
diff --git a/iree/compiler/Transforms/Interpreter/test/xla/gather.mlir b/compiler/Transforms/Interpreter/test/xla/gather.mlir
similarity index 100%
rename from iree/compiler/Transforms/Interpreter/test/xla/gather.mlir
rename to compiler/Transforms/Interpreter/test/xla/gather.mlir
diff --git a/iree/compiler/Transforms/Interpreter/test/xla/slice.mlir b/compiler/Transforms/Interpreter/test/xla/slice.mlir
similarity index 100%
rename from iree/compiler/Transforms/Interpreter/test/xla/slice.mlir
rename to compiler/Transforms/Interpreter/test/xla/slice.mlir
diff --git a/compiler/Transforms/LegalizeTypeStorage.cpp b/compiler/Transforms/LegalizeTypeStorage.cpp
new file mode 100644
index 0000000..780fc97
--- /dev/null
+++ b/compiler/Transforms/LegalizeTypeStorage.cpp
@@ -0,0 +1,145 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Utils/TypeConversionUtils.h"
+#include "llvm/ADT/DenseSet.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+bool convertOperation(Operation *oldOp, OpBuilder &builder,
+ BlockAndValueMapping *mapping) {
+ OperationState state(oldOp->getLoc(), oldOp->getName());
+ if (oldOp->getNumSuccessors() == 0) {
+ // Non-branching operations can just add all the operands.
+ for (auto *oldOperand : oldOp->getOperands()) {
+ state.operands.push_back(mapping->lookupOrDefault(oldOperand));
+ }
+ } else {
+ // We add the operands separated by nullptr's for each successor.
+ unsigned firstSuccOperand = oldOp->getNumSuccessors()
+ ? oldOp->getSuccessorOperandIndex(0)
+ : oldOp->getNumOperands();
+ auto opOperands = oldOp->getOpOperands();
+ unsigned i = 0;
+ for (; i != firstSuccOperand; ++i) {
+ state.operands.push_back(mapping->lookupOrDefault(opOperands[i].get()));
+ }
+ for (unsigned succ = 0, e = oldOp->getNumSuccessors(); succ != e; ++succ) {
+ state.successors.push_back(
+ mapping->lookupOrDefault(oldOp->getSuccessor(succ)));
+ // Add sentinel to delineate successor operands.
+ state.operands.push_back(nullptr);
+ // Remap the successors operands.
+ for (auto *operand : oldOp->getSuccessorOperands(succ)) {
+ state.operands.push_back(mapping->lookupOrDefault(operand));
+ }
+ }
+ }
+ for (const auto &oldType : oldOp->getResultTypes()) {
+ state.types.push_back(legalizeType(oldType));
+ }
+ state.attributes = {oldOp->getAttrs().begin(), oldOp->getAttrs().end()};
+ auto newOp = builder.createOperation(state);
+ for (int i = 0; i < newOp->getNumResults(); ++i) {
+ mapping->map(oldOp->getResult(i), newOp->getResult(i));
+ }
+ return false;
+}
+
+bool convertFunction(FuncOp oldFunction, FuncOp newFunction) {
+ OpBuilder builder(newFunction.getBody());
+ BlockAndValueMapping mapping;
+
+ // Create new blocks matching the expected arguments of the old ones.
+ // This sets up the block mappings to enable us to reference blocks forward
+ // during conversion.
+ newFunction.getBlocks().clear();
+ for (auto &oldBlock : oldFunction.getBlocks()) {
+ auto *newBlock = builder.createBlock(&newFunction.getBody());
+ mapping.map(&oldBlock, newBlock);
+ for (auto *oldArg : oldBlock.getArguments()) {
+ auto *newArg = newBlock->addArgument(legalizeType(oldArg->getType()));
+ mapping.map(oldArg, newArg);
+ }
+ }
+
+ // Convert all ops in the blocks.
+ for (auto &oldBlock : oldFunction.getBlocks()) {
+ builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock));
+ for (auto &oldOp : oldBlock.getOperations()) {
+ if (convertOperation(&oldOp, builder, &mapping)) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+} // namespace
+
+class LegalizeTypeStoragePass : public ModulePass<LegalizeTypeStoragePass> {
+ public:
+ void runOnModule() override {
+ auto module = getModule();
+
+ // Build a list of (oldFunction, newFunction) for all functions we need to
+ // replace. This will ensure that when we go to convert function bodies we
+ // have only new functions defined.
+ std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions;
+
+ for (auto oldFunction : module.getOps<FuncOp>()) {
+ // Create the replacement function, ensuring that we copy attributes.
+ auto newFunction = FuncOp::create(
+ oldFunction.getLoc(), oldFunction.getName(),
+ legalizeType(oldFunction.getType()).cast<FunctionType>(),
+ oldFunction.getDialectAttrs());
+ convertedFunctions.push_back({oldFunction, newFunction});
+
+ // Perform the actual body conversion now that we have proper signatures.
+ if (convertFunction(oldFunction, newFunction)) {
+ return signalPassFailure();
+ }
+ }
+
+ // Replace functions in the module.
+ for (auto &pair : convertedFunctions) {
+ pair.first.erase();
+ module.push_back(pair.second);
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeTypeStoragePass() {
+ return std::make_unique<LegalizeTypeStoragePass>();
+}
+
+static PassRegistration<LegalizeTypeStoragePass> pass(
+ "iree-legalize-type-storage",
+ "Legalizes types to ones supported by the IREE VM.");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/LowerStdToIreeDialect.cpp b/compiler/Transforms/LowerStdToIreeDialect.cpp
new file mode 100644
index 0000000..5fb5c2c
--- /dev/null
+++ b/compiler/Transforms/LowerStdToIreeDialect.cpp
@@ -0,0 +1,77 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Ops.h"
+#include "compiler/Utils/MemRefUtils.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct ConstantOpLowering : public OpRewritePattern<ConstantOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(ConstantOp op,
+ PatternRewriter &rewriter) const override {
+ if (auto elementsValue = op.getValue().dyn_cast<ElementsAttr>()) {
+ auto ireeConst =
+ rewriter.create<IREE::ConstantOp>(op.getLoc(), elementsValue);
+
+ auto result = wrapAsTensor(ireeConst.getResult(), op, rewriter);
+ rewriter.replaceOp(op, result);
+ return matchSuccess();
+ }
+
+ auto type = op.getValue().getType();
+ if (!type.isIntOrFloat()) {
+ return matchFailure();
+ }
+ auto elementsValue =
+ DenseElementsAttr::get(RankedTensorType::get({}, type), op.getValue());
+ auto ireeConst =
+ rewriter.create<IREE::ConstantOp>(op.getLoc(), elementsValue);
+ rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(op, ireeConst);
+ return matchSuccess();
+ }
+};
+
+struct ExtractElementOpLowering : public OpRewritePattern<ExtractElementOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(ExtractElementOp op,
+ PatternRewriter &rewriter) const override {
+ Value *memRefInput =
+ wrapAsMemRef(loadAccessValue(op.getLoc(), op.getAggregate(), rewriter),
+ op, rewriter);
+
+ SmallVector<Value *, 4> indices = {op.indices().begin(),
+ op.indices().end()};
+ rewriter.replaceOpWithNewOp<LoadOp>(op, memRefInput, indices);
+ return matchSuccess();
+ }
+};
+
+} // namespace
+
+void populateLowerStdToIreePatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ patterns.insert<ConstantOpLowering, ExtractElementOpLowering>(ctx);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/LowerXLAToIreeDialect.cpp b/compiler/Transforms/LowerXLAToIreeDialect.cpp
new file mode 100644
index 0000000..e719c42
--- /dev/null
+++ b/compiler/Transforms/LowerXLAToIreeDialect.cpp
@@ -0,0 +1,44 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Ops.h"
+#include "compiler/Utils/MemRefUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct ConstOpLowering : public OpRewritePattern<xla_hlo::ConstOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(xla_hlo::ConstOp op,
+ PatternRewriter &rewriter) const override {
+ auto ireeConst = rewriter.create<IREE::ConstantOp>(op.getLoc(), op.value());
+ rewriter.replaceOp(op, wrapAsTensor(ireeConst, op, rewriter));
+ return matchSuccess();
+ }
+};
+
+} // namespace
+
+void populateLowerXlaToIreePatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ patterns.insert<ConstOpLowering>(ctx);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Transforms/Passes.h b/compiler/Transforms/Passes.h
similarity index 100%
rename from iree/compiler/Transforms/Passes.h
rename to compiler/Transforms/Passes.h
diff --git a/iree/compiler/Transforms/Rewrites.h b/compiler/Transforms/Rewrites.h
similarity index 100%
rename from iree/compiler/Transforms/Rewrites.h
rename to compiler/Transforms/Rewrites.h
diff --git a/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp b/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp
new file mode 100644
index 0000000..f64f313
--- /dev/null
+++ b/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp
@@ -0,0 +1,75 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Utils/OpUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+class AssignExecutableOrdinalsPass
+ : public ModulePass<AssignExecutableOrdinalsPass> {
+ public:
+ void runOnModule() override {
+ Builder builder(getModule());
+ int nextExecutableOrdinal = 0;
+ for (auto multiArchExecutableOp :
+ getModule().getOps<IREE::MultiArchExecutableOp>()) {
+ multiArchExecutableOp.setAttr(
+ "iree.ordinal", builder.getI32IntegerAttr(nextExecutableOrdinal++));
+
+ // We'll scan for all entry points in the first executable. Then on all
+ // other executables we can reuse the ordinals (ensuring that iteration
+ // order does not matter).
+ llvm::DenseMap<StringRef, FuncOp> entryPointMap;
+ for (auto executableOp :
+ multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) {
+ executableOp.setAttr("iree.ordinal",
+ multiArchExecutableOp.getAttr("iree.ordinal"));
+ int nextEntryPointOrdinal = 0;
+ for (auto funcOp : executableOp.getInnerModule().getOps<FuncOp>()) {
+ if (!funcOp.getAttr("iree.executable.export")) continue;
+ auto it = entryPointMap.find(funcOp.getName());
+ if (it == entryPointMap.end()) {
+ funcOp.setAttr("iree.ordinal",
+ builder.getI32IntegerAttr(nextEntryPointOrdinal++));
+ entryPointMap.insert({funcOp.getName(), funcOp});
+ } else {
+ funcOp.setAttr("iree.ordinal", it->second.getAttr("iree.ordinal"));
+ }
+ }
+ }
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createAssignExecutableOrdinalsPass() {
+ return std::make_unique<AssignExecutableOrdinalsPass>();
+}
+
+static PassRegistration<AssignExecutableOrdinalsPass> pass(
+ "iree-assign-executable-ordinals",
+ "Assigns executable and entry point ordinals");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp b/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp
new file mode 100644
index 0000000..20d51a2
--- /dev/null
+++ b/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp
@@ -0,0 +1,125 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Sequencer/LLOps.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Utils/OpUtils.h"
+#include "llvm/ADT/StringMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct WorkloadInfo {
+ SmallVector<ElementsAttr, 4> staticWorkloads;
+ SmallVector<Value *, 4> dynamicWorkloads;
+};
+
+// Finds all dispatches and records their workload attributes mapped by
+// (executable ordinal, entry point ordinal).
+llvm::StringMap<llvm::StringMap<WorkloadInfo>> gatherExecutableWorkloadInfos(
+ ModuleOp moduleOp) {
+ llvm::StringMap<llvm::StringMap<WorkloadInfo>> workloadInfos;
+ for (auto funcOp : moduleOp.getOps<FuncOp>()) {
+ funcOp.walk([&](IREESeq::LL::DynamicDispatchOp op) {
+ auto &workloadInfo =
+ workloadInfos[op.getExecutable()][op.getEntryPoint()];
+ workloadInfo.dynamicWorkloads.push_back(op.getWorkload());
+ });
+ funcOp.walk([&](IREESeq::LL::StaticDispatchOp op) {
+ auto &workloadInfo =
+ workloadInfos[op.getExecutable()][op.getEntryPoint()];
+ for (auto existingWorkloadAttr : workloadInfo.staticWorkloads) {
+ if (existingWorkloadAttr == op.getWorkload()) {
+ return; // Already present, ignore.
+ }
+ }
+ workloadInfo.staticWorkloads.push_back(op.getWorkload());
+ });
+ }
+ return workloadInfos;
+}
+
+// Adds attributes to the given executable entry point describing the workload
+// info to the backends that will be processing them.
+LogicalResult attributeExecutableEntryPointWorkload(
+ FuncOp entryPointOp, const WorkloadInfo &workloadInfo) {
+ if (!workloadInfo.dynamicWorkloads.empty()) {
+ return entryPointOp.emitError() << "Dynamic workloads not yet supported";
+ }
+ if (workloadInfo.staticWorkloads.size() != 1) {
+ return entryPointOp.emitError() << "Static workload sizes differ in shape";
+ }
+
+ // Easy because we just support static workloads now.
+ // When this code is adapted to support dynamic workloads we'll want to put
+ // a pair of attrs describing which dimensions may be static and which args
+ // have the dynamic values to reference.
+ entryPointOp.setAttr("iree.executable.workload",
+ workloadInfo.staticWorkloads.front());
+
+ return success();
+}
+
+} // namespace
+
+class AssignExecutableWorkloadAttrsPass
+ : public ModulePass<AssignExecutableWorkloadAttrsPass> {
+ public:
+ void runOnModule() override {
+ Builder builder(getModule());
+
+ // Find all dispatches and capture their workload information.
+ // We store this information by executable and then entry point ordinal.
+ auto executableWorkloadInfos = gatherExecutableWorkloadInfos(getModule());
+
+ // Process each executable with the workload information.
+ for (auto &executableIt : executableWorkloadInfos) {
+ auto multiArchExecutableOp = cast<IREE::MultiArchExecutableOp>(
+ getModule().lookupSymbol(executableIt.first()));
+ for (auto executableOp :
+ multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) {
+ for (auto &entryPointIt : executableIt.second) {
+ auto funcOp = cast<FuncOp>(
+ executableOp.getInnerModule().lookupSymbol(entryPointIt.first()));
+ if (failed(attributeExecutableEntryPointWorkload(
+ funcOp, entryPointIt.second))) {
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>>
+createAssignExecutableWorkloadAttrsPass() {
+ return std::make_unique<AssignExecutableWorkloadAttrsPass>();
+}
+
+static PassRegistration<AssignExecutableWorkloadAttrsPass> pass(
+ "iree-assign-executable-workload-attrs",
+ "Assigns executable entrypoint workload attributes");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/BUILD b/compiler/Transforms/Sequencer/BUILD
new file mode 100644
index 0000000..a86088b
--- /dev/null
+++ b/compiler/Transforms/Sequencer/BUILD
@@ -0,0 +1,46 @@
+# Transforms specific to the IREE sequencer.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "Sequencer",
+ srcs = [
+ "AssignExecutableOrdinals.cpp",
+ "AssignExecutableWorkloadAttrs.cpp",
+ "FoldCompatibleDispatchRegions.cpp",
+ "IdentifyDispatchRegions.cpp",
+ "IdentifyReductionRegions.cpp",
+ "LegalizeInputs.cpp",
+ "LowerSequencerDialect.cpp",
+ "LowerStdToSequencerDialect.cpp",
+ "LowerToSequencerDialect.cpp",
+ "LowerXLAToSequencerDialect.cpp",
+ "OutlineDispatchRegions.cpp",
+ "OutlineReductionRegions.cpp",
+ "RematerializeDispatchConstants.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ "Rewrites.h",
+ ],
+ deps = [
+ "///compiler/IR",
+ "///compiler/IR/Sequencer",
+ "///compiler/Transforms",
+ "///compiler/Utils",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Pass",
+ "@local_config_mlir//:StandardDialectRegistration",
+ "@local_config_mlir//:StandardOps",
+ "@local_config_mlir//:Support",
+ "@local_config_mlir//:TransformUtils",
+ "@local_config_mlir//:Transforms",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_lower_general_dot",
+ ],
+)
diff --git a/iree/compiler/Transforms/Sequencer/CMakeLists.txt b/compiler/Transforms/Sequencer/CMakeLists.txt
similarity index 100%
rename from iree/compiler/Transforms/Sequencer/CMakeLists.txt
rename to compiler/Transforms/Sequencer/CMakeLists.txt
diff --git a/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp b/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp
new file mode 100644
index 0000000..a48f735
--- /dev/null
+++ b/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp
@@ -0,0 +1,63 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Ops.h"
+#include "compiler/Utils/DispatchUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Identifies dispatch regions that have compatible workloads and folds them.
+// This relies on CSE having deduped workloads to simplify the logic to simply
+// looking for dispatch regions using the same values.
+class FoldCompatibleDispatchRegionsPass
+ : public FunctionPass<FoldCompatibleDispatchRegionsPass> {
+ public:
+ void runOnFunction() override {
+ auto func = getFunction();
+ for (auto &block : func) {
+ if (failed(mergeBlockDispatchRegions(func, &block))) {
+ return signalPassFailure();
+ }
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<FuncOp>> createFoldCompatibleDispatchRegionsPass() {
+ return std::make_unique<FoldCompatibleDispatchRegionsPass>();
+}
+
+static PassRegistration<FoldCompatibleDispatchRegionsPass> pass(
+ "iree-fold-compatible-dispatch-regions",
+ "Folds dispatch regions that have compatible workloads.");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp b/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp
new file mode 100644
index 0000000..54ce60e
--- /dev/null
+++ b/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp
@@ -0,0 +1,259 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <algorithm>
+
+#include "compiler/IR/Ops.h"
+#include "compiler/Utils/DispatchUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Returns true if the given |op| can be dispatched in all cases.
+// Other passes may handle special cases of these ops but this initial
+// identification is conservative.
+bool isDispatchableOp(Operation *op) {
+ if (op->getDialect() && op->getDialect()->getNamespace().startswith("iree")) {
+ // Ignore things we've already produced as they should only relate to
+ // sequencer operations.
+ return false;
+ } else if (op->isKnownTerminator()) {
+ // Currently we skip all terminators as we want to leave them in the block
+ // to keep it valid. Future folding passes may take care of them if they are
+ // worth bringing into the dispatch region.
+ return false;
+ } else if (isa<CallOp>(op)) {
+ // This may be handled by a control-flow folding pass later once we have
+ // done our initial analysis and know what functions are compatible.
+ return false;
+ } else if (isa<CallIndirectOp>(op)) {
+ // Indirect calls are not supported in dispatch code.
+ return false;
+ } else if (isa<AllocOp>(op)) {
+ // Allocations are sequencer ops.
+ // Note that we could support static allocations (convert to stack/etc).
+ return false;
+ } else if (isa<ConstantOp>(op)) {
+ // Constants are handled in the RematerializeDispatchConstants pass.
+ // We do that independently so that we can more easily see the use of
+ // constants across all dispatches instead of just on an individual basis
+ // as we do here.
+ return false;
+ } else if (isa<xla_hlo::DynamicUpdateSliceOp>(op)) {
+ // TODO(benvanik): lower these to the sequencer dialect prior to ID'ing.
+ return false;
+ }
+ return true;
+}
+
+// Returns true if the given |op| can have other ops fused into it.
+// This is sketchy and it'd be nice to define this as an op property instead.
+//
+// What we are looking for in foldable ops is whether the execution of the op
+// when fused has some possible benefit (or at least, a non-negative cost).
+// Eventually we want to allow backends to vote on this and allow multiple
+// folding strategies within the same executable. For now we just hardcode what
+// we know for the ops we have.
+//
+// Preconditions: isDispatchableOp(op) == true.
+bool isFusionRootOp(Operation *op) {
+ if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) {
+ // We have hand-written kernels for these right now we want to stand alone.
+ // When we do a bit more magic we should allow these ops to fold.
+ return false;
+ }
+ return true;
+}
+
+// Returns true if the given |op| can be fused into other ops.
+//
+// Ops that perform narrowing on shapes (such as reduction ops) should not
+// generally be fused with other downstream ops (probably...). This avoids
+// potential oversampling and indexing issues and allows backends to perform
+// more efficient rooted cascading reduction dispatches.
+//
+// Preconditions: isDispatchableOp(op) == true.
+bool isFusableOp(Operation *op) {
+ if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) {
+ return false;
+ } else if (isa<xla_hlo::ReduceOp>(op)) {
+ // Reduction is usually a dedicated root operation - we can shove things in
+ // the front of it but not behind.
+ return false;
+ }
+ return true;
+}
+
+// Puts all of the |unsortedOps| into |sortedOps| in an arbitrary topological
+// order.
+// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
+//
+// Preconditions: |unsortedOps| has no cycles within the set of ops.
+std::vector<Operation *> sortOpsTopologically(
+ const llvm::SetVector<Operation *> &unsortedOps) {
+ llvm::SetVector<Operation *> unmarkedOps;
+ unmarkedOps.insert(unsortedOps.begin(), unsortedOps.end());
+ llvm::SetVector<Operation *> markedOps;
+
+ using VisitFn = std::function<void(Operation * op)>;
+ VisitFn visit = [&](Operation *op) {
+ if (markedOps.count(op) > 0) return;
+ for (auto *result : op->getResults()) {
+ for (auto *user : result->getUsers()) {
+ // Don't visit ops not in our set.
+ if (unsortedOps.count(user) == 0) continue;
+ visit(user);
+ }
+ }
+ markedOps.insert(op);
+ };
+
+ while (!unmarkedOps.empty()) {
+ auto *op = unmarkedOps.pop_back_val();
+ visit(op);
+ }
+
+ auto sortedOps = markedOps.takeVector();
+ std::reverse(sortedOps.begin(), sortedOps.end());
+ return sortedOps;
+}
+
+// Recursively traverses the IR DAG along the operand edges to find ops we are
+// able to fuse and appends them to |subgraph|.
+void gatherFusionOps(Operation *op, llvm::SetVector<Operation *> *subgraph) {
+ // Skip ops that are used outside of the subgraph we are building.
+ for (auto *result : op->getResults()) {
+ if (result->use_empty() || result->hasOneUse()) continue;
+ for (auto *user : result->getUsers()) {
+ if (subgraph->count(user) == 0) {
+ // Op that consumes the result is not (yet) in the subgraph.
+ // For now we'll ignore these as it may represent a fork that we don't
+ // want to join too early.
+ return;
+ }
+ }
+ }
+
+ // Walk backward up to ops providing our input operands.
+ for (auto *operand : op->getOperands()) {
+ auto *sourceOp = operand->getDefiningOp();
+ if (!sourceOp) continue;
+ if (subgraph->count(sourceOp) == 0) {
+ if (isDispatchableOp(sourceOp) && isFusableOp(sourceOp)) {
+ gatherFusionOps(sourceOp, subgraph);
+ }
+ }
+ }
+
+ subgraph->insert(op);
+}
+
+// Finds all ops that can be fused together with the given |rootOp| by searching
+// backwards in the op order through input edges.
+// Returns a topologically sorted list of all fused ops with |rootOp| at the
+// end.
+std::vector<Operation *> findFusionSubgraphFromRoot(Operation *rootOp) {
+ if (!isFusionRootOp(rootOp)) {
+ return {rootOp};
+ }
+ llvm::SetVector<Operation *> subgraph;
+ subgraph.insert(rootOp);
+ gatherFusionOps(rootOp, &subgraph);
+ return sortOpsTopologically(subgraph);
+}
+
+// Identifies ranges of dispatchable ops and moves them into dispatch regions.
+LogicalResult identifyBlockDispatchRegions(FuncOp func, Block *block) {
+ // Fixed point iteration until we can no longer fuse anything.
+ bool didFindAnyNewRegions;
+ do {
+ // Iterate in reverse so we root further along in the op list.
+ didFindAnyNewRegions = false;
+ for (auto &rootOp : llvm::reverse(*block)) {
+ if (!isDispatchableOp(&rootOp)) {
+ // Op should remain at the sequencer level.
+ continue;
+ }
+
+ // Attempt to find all operations, including rootOp, that can be fused.
+ // The ops will be sorted in topological order with rootOp as the last op.
+ // Worst case we may end up with a subgraph of only the rootOp.
+ auto fusedSubgraph = findFusionSubgraphFromRoot(&rootOp);
+
+ // Compute the workload based on the output shape.
+ // When variadic all output shapes match so we can just take the first.
+ auto *workload = calculateWorkload(&rootOp, rootOp.getResult(0));
+
+ // Try to build a dispatch region from this root.
+ if (failed(buildDispatchRegion(func, block, workload, fusedSubgraph))) {
+ return failure();
+ }
+
+ // Successfully created a dispatch region from the ops and we must now
+ // start over again as we've likely trashed the whole block structure.
+ didFindAnyNewRegions = true;
+ break;
+ }
+ } while (didFindAnyNewRegions);
+ return success();
+}
+
+} // namespace
+
+// Identifies dispatchable ops and moves them into iree.dispatch_regions.
+// Some ops, such as call, will be deferred until following passes.
+class IdentifyDispatchRegionsPass
+ : public FunctionPass<IdentifyDispatchRegionsPass> {
+ public:
+ void runOnFunction() override {
+ auto func = getFunction();
+ for (auto &block : func) {
+ if (failed(identifyBlockDispatchRegions(func, &block))) {
+ return signalPassFailure();
+ }
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<FuncOp>> createIdentifyDispatchRegionsPass() {
+ return std::make_unique<IdentifyDispatchRegionsPass>();
+}
+
+static PassRegistration<IdentifyDispatchRegionsPass> pass(
+ "iree-identify-dispatch-regions",
+ "Conservatively identifies dispatch regions in functions.");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp b/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp
new file mode 100644
index 0000000..6c9a48a
--- /dev/null
+++ b/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp
@@ -0,0 +1,163 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <algorithm>
+
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/Types.h"
+#include "compiler/Utils/DispatchUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Builds a new iree.reduction_region with the given |invocationRegion|.
+// The new region will be inserted after |originalOp|.
+//
+// All |invocationRegion| ops must be compatible with the |workload| specified
+// as they will all be dispatched with the same workgroup structure. The
+// |invocationRegion| will not be modified.
+LogicalResult buildReductionRegion(Operation *originalOp,
+ ArrayRef<Value *> operands,
+ ArrayRef<Value *> initialValues,
+ ArrayRef<int64_t> dimensions,
+ Region &invocationRegion) {
+ OpBuilder parentBuilder(originalOp);
+
+ // Compute the workload based on the output shape.
+ // When variadic all output shapes match so we can just take the first.
+ auto *workload = calculateWorkload(originalOp, originalOp->getResult(0));
+
+ // Build the region op and add it to the parent block.
+ SmallVector<Type, 4> resultTypes{originalOp->getResultTypes()};
+ auto reductionRegionOp = parentBuilder.create<IREE::ReductionRegionOp>(
+ originalOp->getLoc(), resultTypes, workload, operands, initialValues,
+ dimensions);
+
+ // Create the block and setup the arg mapping for captured values.
+ BlockAndValueMapping mapping;
+ invocationRegion.cloneInto(&reductionRegionOp.getBody(), mapping);
+
+ // Replace xla_hlo.return -> iree.return.
+ OpBuilder regionBuilder(reductionRegionOp.getBody());
+ reductionRegionOp.walk([&](xla_hlo::ReturnOp returnOp) {
+ regionBuilder.setInsertionPoint(returnOp);
+ SmallVector<Value *, 4> returnValues(returnOp.getOperands());
+ regionBuilder.create<IREE::ReturnOp>(returnOp.getLoc(), returnValues);
+ returnOp.erase();
+ });
+
+ // Replace usage of values with the results of the region.
+ for (int i = 0; i < originalOp->getNumResults(); ++i) {
+ originalOp->getResult(i)->replaceAllUsesWith(
+ reductionRegionOp.getResult(i));
+ }
+
+ return success();
+}
+
+// Converts an xla_hlo::ReduceOp to a reduction region and inlines the target
+// computation into the region body.
+LogicalResult buildReductionRegionFromXLAReduceOp(xla_hlo::ReduceOp reduceOp) {
+ SmallVector<Value *, 4> operands(reduceOp.getOperands());
+ OperandAdaptor<xla_hlo::ReduceOp> adaptor(operands);
+
+ SmallVector<int64_t, 4> dimensions;
+ for (auto dim : reduceOp.dimensions().getIntValues()) {
+ dimensions.push_back(dim.getSExtValue());
+ }
+
+ // Create the iree.reduction_region.
+ if (failed(buildReductionRegion(reduceOp, adaptor.operands(),
+ adaptor.init_values(), dimensions,
+ reduceOp.body()))) {
+ return failure();
+ }
+
+ // Remove original XLA reduction op.
+ reduceOp.erase();
+
+ return success();
+}
+
+// Identifies reduction ops and moves them into reduction regions.
+LogicalResult identifyBlockReductionRegions(FuncOp funcOp, Block *block) {
+ // Fixed point iteration until we can no longer fuse anything.
+ bool didFindAnyNewRegions;
+ do {
+ // Iterate in reverse so we root further along in the op list.
+ didFindAnyNewRegions = false;
+ for (auto &rootOp : llvm::reverse(*block)) {
+ if (auto reduceOp = dyn_cast<xla_hlo::ReduceOp>(rootOp)) {
+ if (failed(buildReductionRegionFromXLAReduceOp(reduceOp))) {
+ return failure();
+ }
+
+ // Successfully created a dispatch region from the ops and we must now
+ // start over again as we've likely trashed the whole block structure.
+ didFindAnyNewRegions = true;
+ break;
+ }
+ }
+ } while (didFindAnyNewRegions);
+ return success();
+}
+
+} // namespace
+
+// Identifies reduction ops and moves their targets into iree.reduction_regions.
+class IdentifyReductionRegionsPass
+ : public ModulePass<IdentifyReductionRegionsPass> {
+ public:
+ void runOnModule() override {
+ for (auto funcOp : getModule().getOps<FuncOp>()) {
+ for (auto &block : funcOp) {
+ if (failed(identifyBlockReductionRegions(funcOp, &block))) {
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createIdentifyReductionRegionsPass() {
+ return std::make_unique<IdentifyReductionRegionsPass>(); // NOLINT
+}
+
+static PassRegistration<IdentifyReductionRegionsPass> pass(
+ "iree-identify-reduction-regions",
+ "Identifies reduction regions based on input reduction ops.");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/LegalizeInputs.cpp b/compiler/Transforms/Sequencer/LegalizeInputs.cpp
new file mode 100644
index 0000000..b8c860c
--- /dev/null
+++ b/compiler/Transforms/Sequencer/LegalizeInputs.cpp
@@ -0,0 +1,53 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/Transforms/Rewrites.h"
+#include "compiler/Transforms/Sequencer/Rewrites.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+class LegalizeInputOpsPass
+ : public FunctionPass<LegalizeInputOpsPass> {
+ public:
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<xla_hlo::XlaHloDialect, StandardOpsDialect>();
+ target.addLegalOp<FuncOp, ReturnOp>();
+ target.addIllegalOp<xla_hlo::DotGeneralOp, xla_hlo::WhileOp>();
+ if (failed(applyFullConversion(getFunction(), target, patterns))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> createLegalizeInputOpsPass() {
+ return std::make_unique<LegalizeInputOpsPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp b/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp
new file mode 100644
index 0000000..63408b7
--- /dev/null
+++ b/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp
@@ -0,0 +1,305 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/Sequencer/HLDialect.h"
+#include "compiler/IR/Sequencer/HLOps.h"
+#include "compiler/IR/Sequencer/LLDialect.h"
+#include "compiler/IR/Sequencer/LLOps.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Utils/TypeConversionUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+template <typename SrcOp>
+class SequencerLoweringPattern : public OpConversionPattern<SrcOp> {
+ public:
+ SequencerLoweringPattern(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern<SrcOp>(context), typeConverter_(typeConverter) {}
+
+ protected:
+ TypeConverter &typeConverter_;
+};
+
+// Returns an integer scalar memref containing the offset specified by |indices|
+// within |type|.
+Value *computeOffset(Location loc, Value *reference, Value *indices,
+ OpBuilder &builder) {
+ auto referenceType = reference->getType().cast<ShapedType>();
+ auto *shapeMemRef = builder
+ .create<IREESeq::LL::AllocHeapOp>(
+ loc,
+ MemRefType::get({referenceType.getRank()},
+ builder.getIntegerType(32)),
+ ArrayRef<Value *>{})
+ .getResult();
+ builder.create<IREESeq::LL::ShapeOp>(loc, reference, shapeMemRef);
+ auto *resultMemRef =
+ builder
+ .create<IREESeq::LL::AllocHeapOp>(
+ loc, MemRefType::get({}, builder.getIntegerType(32)),
+ ArrayRef<Value *>{})
+ .getResult();
+ auto elementSizeAttr = builder.getIntegerAttr(
+ builder.getIntegerType(8), referenceType.getElementTypeBitWidth() / 8);
+ builder.create<IREESeq::LL::ComputeOffsetOp>(
+ loc, shapeMemRef, elementSizeAttr, indices, resultMemRef);
+ return resultMemRef;
+}
+
+// Returns a tuple of (offset, length) integer scalar memrefs with the range
+// specified by |indices| and |lengths| within |type|.
+std::pair<Value *, Value *> computeRange(Location loc, Value *reference,
+ Value *indices, Value *lengths,
+ OpBuilder &builder) {
+ auto referenceType = reference->getType().cast<ShapedType>();
+ auto *shapeMemRef = builder
+ .create<IREESeq::LL::AllocHeapOp>(
+ loc,
+ MemRefType::get({referenceType.getRank()},
+ builder.getIntegerType(32)),
+ ArrayRef<Value *>{})
+ .getResult();
+ builder.create<IREESeq::LL::ShapeOp>(loc, reference, shapeMemRef);
+ auto *offsetMemRef =
+ builder
+ .create<IREESeq::LL::AllocHeapOp>(
+ loc, MemRefType::get({}, builder.getIntegerType(32)),
+ ArrayRef<Value *>{})
+ .getResult();
+ auto *lengthMemRef =
+ builder
+ .create<IREESeq::LL::AllocHeapOp>(
+ loc, MemRefType::get({}, builder.getIntegerType(32)),
+ ArrayRef<Value *>{})
+ .getResult();
+ auto elementSizeAttr = builder.getIntegerAttr(
+ builder.getIntegerType(8), referenceType.getElementTypeBitWidth() / 8);
+ builder.create<IREESeq::LL::ComputeRangeOp>(loc, shapeMemRef, elementSizeAttr,
+ indices, lengths, offsetMemRef,
+ lengthMemRef);
+ return {offsetMemRef, lengthMemRef};
+}
+
+struct LowerSliceOpPattern
+ : public SequencerLoweringPattern<IREESeq::HL::SliceOp> {
+ using SequencerLoweringPattern::SequencerLoweringPattern;
+
+ PatternMatchResult matchAndRewrite(
+ IREESeq::HL::SliceOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ OperandAdaptor<IREESeq::HL::SliceOp> operandAdaptor(operands);
+ auto range = computeRange(op.getLoc(), operandAdaptor.src(),
+ operandAdaptor.indices(),
+ operandAdaptor.lengths(), rewriter);
+ rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicSliceOp>(
+ op, typeConverter_.convertType(op.getType()),
+ ArrayRef<Value *>{operandAdaptor.src(), range.first, range.second},
+ op.getAttrs());
+ return matchSuccess();
+ }
+};
+
+struct LowerShapeOpPattern
+ : public SequencerLoweringPattern<IREESeq::HL::ShapeOp> {
+ using SequencerLoweringPattern::SequencerLoweringPattern;
+
+ PatternMatchResult matchAndRewrite(
+ IREESeq::HL::ShapeOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *shapeMemRef =
+ rewriter
+ .create<IREESeq::LL::AllocHeapOp>(
+ op.getLoc(),
+ MemRefType::get({op.getType().cast<ShapedType>().getRank()},
+ rewriter.getIntegerType(64)),
+ ArrayRef<Value *>{})
+ .getResult();
+ op.replaceAllUsesWith(shapeMemRef);
+ rewriter.replaceOpWithNewOp<IREESeq::LL::ShapeOp>(op, operands[0],
+ shapeMemRef);
+ return matchSuccess();
+ }
+};
+
+struct LowerCopyOpPattern
+ : public SequencerLoweringPattern<IREESeq::HL::CopyOp> {
+ using SequencerLoweringPattern::SequencerLoweringPattern;
+
+ PatternMatchResult matchAndRewrite(
+ IREESeq::HL::CopyOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ OperandAdaptor<IREESeq::HL::CopyOp> operandAdaptor(operands);
+ auto *srcOffsetMemRef =
+ computeOffset(op.getLoc(), operandAdaptor.src(),
+ operandAdaptor.srcIndices(), rewriter);
+ auto dstRange = computeRange(op.getLoc(), operandAdaptor.dst(),
+ operandAdaptor.dstIndices(),
+ operandAdaptor.lengths(), rewriter);
+ rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicCopyOp>(
+ op, operandAdaptor.src(), srcOffsetMemRef, operandAdaptor.dst(),
+ dstRange.first, dstRange.second);
+ return matchSuccess();
+ }
+};
+
+struct LowerFillOpPattern
+ : public SequencerLoweringPattern<IREESeq::HL::FillOp> {
+ using SequencerLoweringPattern::SequencerLoweringPattern;
+
+ PatternMatchResult matchAndRewrite(
+ IREESeq::HL::FillOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ OperandAdaptor<IREESeq::HL::FillOp> operandAdaptor(operands);
+ auto dstRange = computeRange(op.getLoc(), operandAdaptor.dst(),
+ operandAdaptor.dstIndices(),
+ operandAdaptor.lengths(), rewriter);
+ rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicFillOp>(
+ op, operandAdaptor.value(), operandAdaptor.dst(), dstRange.first,
+ dstRange.second);
+ return matchSuccess();
+ }
+};
+
+struct LowerBranchOpPattern
+ : public SequencerLoweringPattern<IREESeq::HL::BranchOp> {
+ using SequencerLoweringPattern<
+ IREESeq::HL::BranchOp>::SequencerLoweringPattern;
+
+ PatternMatchResult matchAndRewrite(
+ IREESeq::HL::BranchOp op, ArrayRef<Value *> properOperands,
+ ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<IREESeq::LL::BranchOp>(op, destinations[0],
+ operands[0]);
+ return matchSuccess();
+ }
+};
+
+struct LowerCondCondBranchOpPattern
+ : public SequencerLoweringPattern<IREESeq::HL::CondBranchOp> {
+ using SequencerLoweringPattern<
+ IREESeq::HL::CondBranchOp>::SequencerLoweringPattern;
+
+ PatternMatchResult matchAndRewrite(
+ IREESeq::HL::CondBranchOp op, ArrayRef<Value *> properOperands,
+ ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<IREESeq::LL::CondBranchOp>(
+ op, properOperands[0],
+ destinations[IREESeq::HL::CondBranchOp::trueIndex],
+ operands[IREESeq::HL::CondBranchOp::trueIndex],
+ destinations[IREESeq::HL::CondBranchOp::falseIndex],
+ operands[IREESeq::HL::CondBranchOp::falseIndex]);
+ return matchSuccess();
+ }
+};
+
+// Rewrites an op into one with all the same operands, results, and attributes.
+// Operands and results in the ops must have the same order and attributes must
+// have the same name. They must also be constructed properly by the default
+// builders.
+template <typename SRC, typename DST>
+struct LowerIdenticalOpPattern : public SequencerLoweringPattern<SRC> {
+ using SequencerLoweringPattern<SRC>::SequencerLoweringPattern;
+
+ PatternMatchResult matchAndRewrite(
+ SRC op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Type, 8> originalResultTypes{
+ op.getOperation()->getResultTypes()};
+ SmallVector<Type, 8> resultTypes;
+ if (failed(this->typeConverter_.convertTypes(originalResultTypes,
+ resultTypes))) {
+ op.emitOpError() << "Failed to convert result types";
+ return this->matchFailure();
+ }
+ rewriter.replaceOpWithNewOp<DST>(op, resultTypes, operands, op.getAttrs());
+ return this->matchSuccess();
+ }
+};
+
+} // namespace
+
+class LowerSequencerDialectPass : public ModulePass<LowerSequencerDialectPass> {
+ public:
+ void runOnModule() override {
+ auto *ctx = &getContext();
+ LLTypeConverter typeConverter(ctx);
+ OwningRewritePatternList patterns;
+ patterns.insert<
+ LowerIdenticalOpPattern<IREE::ConstantOp, IREESeq::LL::ConstantOp>,
+ LowerIdenticalOpPattern<IREESeq::HL::DispatchOp,
+ IREESeq::LL::DynamicDispatchOp>,
+ LowerShapeOpPattern, LowerCopyOpPattern, LowerSliceOpPattern,
+ LowerBranchOpPattern, LowerCondCondBranchOpPattern>(ctx, typeConverter);
+#define IDENTICAL_OP_LOWERING(op_name) \
+ LowerIdenticalOpPattern<IREESeq::HL::op_name, IREESeq::LL::op_name>
+ patterns.insert<
+ IDENTICAL_OP_LOWERING(AllocHeapOp), IDENTICAL_OP_LOWERING(CloneOp),
+ IDENTICAL_OP_LOWERING(ReshapeOp), IDENTICAL_OP_LOWERING(CallOp),
+ IDENTICAL_OP_LOWERING(ReturnOp)>(ctx, typeConverter);
+#undef IDENTICAL_OP_LOWERING
+
+ mlir::populateFuncOpTypeConversionPattern(patterns, ctx, typeConverter);
+ ConversionTarget target(*ctx);
+ target.addLegalDialect<IREELLSequencerDialect>();
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getType());
+ });
+
+ // TODO(b/142791494): The conversion framework will recurse into the
+ // executable if we just call it on the top-level module. This can't be a
+ // function pass because type conversion replaces the original functions.
+ auto funcsIt = getModule().getOps<FuncOp>();
+ SmallVector<Operation *, 4> funcs(funcsIt.begin(), funcsIt.end());
+
+ if (failed(applyFullConversion(funcs, target, patterns, &typeConverter))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createLowerSequencerDialectPass() {
+ return std::make_unique<LowerSequencerDialectPass>();
+}
+
+static PassRegistration<LowerSequencerDialectPass> pass(
+ "iree-lower-sequencer-dialect",
+ "Lowers the IREE HL sequencer dialect to the LL sequencer dialect.");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp b/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp
new file mode 100644
index 0000000..e3f760d
--- /dev/null
+++ b/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp
@@ -0,0 +1,204 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/Sequencer/HLDialect.h"
+#include "compiler/IR/Sequencer/HLOps.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Utils/MemRefUtils.h"
+#include "compiler/Utils/OpCreationUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+template <typename T>
+class SequencerConversionPattern : public OpConversionPattern<T> {
+ using OpConversionPattern<T>::OpConversionPattern;
+};
+
+struct CallOpLowering : public SequencerConversionPattern<CallOp> {
+ using SequencerConversionPattern::SequencerConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ CallOp callOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Type, 4> resultTypes(callOp.getResultTypes());
+ rewriter.replaceOpWithNewOp<IREESeq::HL::CallOp>(callOp, callOp.getCallee(),
+ resultTypes, operands);
+
+ return matchSuccess();
+ }
+};
+
+struct CallIndirectOpLowering
+ : public SequencerConversionPattern<CallIndirectOp> {
+ using SequencerConversionPattern::SequencerConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ CallIndirectOp callOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<IREESeq::HL::CallIndirectOp>(
+ callOp, callOp.getCallee(), operands);
+ return matchSuccess();
+ }
+};
+
+struct ReturnOpLowering : public SequencerConversionPattern<ReturnOp> {
+ using SequencerConversionPattern::SequencerConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ ReturnOp returnOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value *, 4> newOperands;
+ newOperands.reserve(operands.size());
+ for (auto *operand : operands) {
+ newOperands.push_back(wrapAsMemRef(operand, returnOp, rewriter));
+ }
+ rewriter.replaceOpWithNewOp<IREESeq::HL::ReturnOp>(returnOp, newOperands);
+ return matchSuccess();
+ }
+};
+
+struct BranchOpLowering : public SequencerConversionPattern<BranchOp> {
+ using SequencerConversionPattern::SequencerConversionPattern;
+ PatternMatchResult matchAndRewrite(
+ BranchOp branchOp, ArrayRef<Value *> properOperands,
+ ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<IREESeq::HL::BranchOp>(
+ branchOp, destinations[0], operands[0]);
+ return this->matchSuccess();
+ }
+};
+
+struct CondBranchOpLowering : public SequencerConversionPattern<CondBranchOp> {
+ using SequencerConversionPattern::SequencerConversionPattern;
+ PatternMatchResult matchAndRewrite(
+ CondBranchOp condBranchOp, ArrayRef<Value *> properOperands,
+ ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *condValue =
+ loadAccessValue(condBranchOp.getLoc(), properOperands[0], rewriter);
+ rewriter.replaceOpWithNewOp<IREESeq::HL::CondBranchOp>(
+ condBranchOp, condValue,
+ destinations[IREESeq::HL::CondBranchOp::trueIndex],
+ operands[IREESeq::HL::CondBranchOp::trueIndex],
+ destinations[IREESeq::HL::CondBranchOp::falseIndex],
+ operands[IREESeq::HL::CondBranchOp::falseIndex]);
+ return this->matchSuccess();
+ }
+};
+
+struct AllocOpLowering : public SequencerConversionPattern<AllocOp> {
+ using SequencerConversionPattern::SequencerConversionPattern;
+ PatternMatchResult matchAndRewrite(
+ AllocOp allocOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO(benvanik): replace with length computation.
+ rewriter.replaceOpWithNewOp<IREESeq::HL::AllocHeapOp>(
+ allocOp, allocOp.getType(), operands);
+ return matchSuccess();
+ }
+};
+
+struct DeallocOpLowering : public SequencerConversionPattern<DeallocOp> {
+ using SequencerConversionPattern::SequencerConversionPattern;
+ PatternMatchResult matchAndRewrite(
+ DeallocOp deallocOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<IREESeq::HL::DiscardOp>(deallocOp, operands[0]);
+ return matchSuccess();
+ }
+};
+
+struct LoadOpLowering : public SequencerConversionPattern<LoadOp> {
+ using SequencerConversionPattern::SequencerConversionPattern;
+ PatternMatchResult matchAndRewrite(
+ LoadOp loadOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (loadOp.getMemRefType().getRank() != 0) {
+ loadOp.emitError() << "Cannot lower load of non-scalar";
+ return matchFailure();
+ }
+ ArrayRef<Value *> dimPieces;
+ auto dst = rewriter.create<AllocOp>(loadOp.getLoc(), loadOp.getMemRefType(),
+ dimPieces);
+ auto emptyArrayMemref = createArrayConstant(rewriter, loadOp.getLoc(), {});
+ rewriter.create<IREESeq::HL::CopyOp>(loadOp.getLoc(), loadOp.getMemRef(),
+ /*srcIndices=*/emptyArrayMemref, dst,
+ /*dstIndices=*/emptyArrayMemref,
+ /*lengths=*/emptyArrayMemref);
+
+ rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(loadOp, dst);
+
+ return matchSuccess();
+ }
+};
+
+struct StoreOpLowering : public SequencerConversionPattern<StoreOp> {
+ using SequencerConversionPattern::SequencerConversionPattern;
+ PatternMatchResult matchAndRewrite(
+ StoreOp storeOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (storeOp.getMemRefType().getRank() != 0) {
+ storeOp.emitError() << "Cannot lower store of non-scalar";
+ return matchFailure();
+ }
+
+ auto src = rewriter.create<IREE::ScalarToMemRefOp>(
+ storeOp.getLoc(), storeOp.getValueToStore());
+
+ auto emptyArrayMemref = createArrayConstant(rewriter, storeOp.getLoc(), {});
+ rewriter.replaceOpWithNewOp<IREESeq::HL::CopyOp>(
+ storeOp, src, /*srcIndices=*/emptyArrayMemref, storeOp.getMemRef(),
+ /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref);
+
+ return matchSuccess();
+ }
+};
+
+} // namespace
+
+void populateLowerStdToSequencerPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context) {
+ patterns.insert<
+ // Control flow.
+ CallOpLowering, CallIndirectOpLowering, ReturnOpLowering,
+ BranchOpLowering, CondBranchOpLowering,
+ // Memory management.
+ AllocOpLowering, DeallocOpLowering, LoadOpLowering, StoreOpLowering>(
+ context);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp b/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp
new file mode 100644
index 0000000..720d259
--- /dev/null
+++ b/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp
@@ -0,0 +1,62 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Sequencer/HLDialect.h"
+#include "compiler/IR/Sequencer/LLDialect.h"
+#include "compiler/Transforms/Rewrites.h"
+#include "compiler/Transforms/Sequencer/Rewrites.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+class LowerToSequencerDialectPass
+ : public FunctionPass<LowerToSequencerDialectPass> {
+ public:
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ auto* ctx = &getContext();
+ xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, ctx);
+ xla_hlo::PopulateXlaToStdPatterns(&patterns, ctx);
+ populateLowerStdToIreePatterns(patterns, ctx);
+ populateLowerStdToSequencerPatterns(patterns, ctx);
+ populateLowerXlaToIreePatterns(patterns, ctx);
+ populateLowerXlaToSequencerPatterns(patterns, ctx);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<IREEHLSequencerDialect, IREEDialect>();
+ target.addLegalOp<FuncOp>();
+ if (failed(applyFullConversion(getFunction(), target, patterns))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> createLowerToSequencerDialectPass() {
+ return std::make_unique<LowerToSequencerDialectPass>();
+}
+
+static PassRegistration<LowerToSequencerDialectPass> pass(
+ "lower-to-iree-sequencer", "Convert all ops to the IREE sequencer dialect");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp b/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp
new file mode 100644
index 0000000..5626d6c
--- /dev/null
+++ b/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp
@@ -0,0 +1,224 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/IR/Dialect.h"
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/Sequencer/HLDialect.h"
+#include "compiler/IR/Sequencer/HLOps.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Transforms/ConversionUtils.h"
+#include "compiler/Utils/MemRefUtils.h"
+#include "compiler/Utils/OpCreationUtils.h"
+#include "compiler/Utils/TypeConversionUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// TODO(suderman): tablegen this? or something a bit more flexible.
+
+#define UNARY_OP_LOWERING(XlaOpType, IREEOpType) \
+ struct XlaOpType##Lowering \
+ : public UnaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \
+ using UnaryOpLowering::UnaryOpLowering; \
+ };
+
+#define TERNARY_OP_LOWERING(XlaOpType, IREEOpType) \
+ struct XlaOpType##Lowering \
+ : public TernaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \
+ using TernaryOpLowering::TernaryOpLowering; \
+ };
+
+UNARY_OP_LOWERING(CopyOp, IREESeq::HL::CloneOp);
+
+#undef UNARY_OP_LOWERING
+#undef TERNARY_OP_LOWERING
+
+template <typename T>
+static Operation *createShapeTargetingOp(ConversionPatternRewriter &rewriter,
+ Location loc, Value *input,
+ MemRefType targetType) {
+ auto shapeOp = createArrayConstant(rewriter, loc, targetType.getShape());
+ return rewriter.create<T>(loc, targetType, input, shapeOp);
+}
+
+static Value *inputAsMemref(ConversionPatternRewriter &rewriter, Operation *op,
+ Value *tensor) {
+ return wrapAsMemRef(loadAccessValue(op->getLoc(), tensor, rewriter), op,
+ rewriter);
+}
+
+template <typename SrcOp>
+class XlaOpLowering : public OpConversionPattern<SrcOp> {
+ public:
+ using OpConversionPattern<SrcOp>::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ SrcOp op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcOp = cast<SrcOp>(op);
+
+ SmallVector<Value *, 4> memrefOperands;
+ for (auto operand : operands) {
+ memrefOperands.push_back(inputAsMemref(rewriter, op, operand));
+ }
+
+ auto dstOp = rewriteInternal(&srcOp, memrefOperands, rewriter);
+ rewriter.replaceOp(op, wrapAsTensor(dstOp->getResult(0), srcOp, rewriter));
+ return this->matchSuccess();
+ }
+
+ protected:
+ virtual Operation *rewriteInternal(
+ SrcOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const {
+ llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?");
+ }
+};
+
+struct ConcatOpLowering : public XlaOpLowering<xla_hlo::ConcatenateOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::ConcatenateOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto finalType = convertTypeToMemRef(*op);
+
+ return rewriter.create<IREESeq::HL::ConcatOp>(
+ op->getLoc(), finalType, operands,
+ rewriter.getI32IntegerAttr(op->dimension().getZExtValue()));
+ }
+};
+
+struct DynamicUpdateSliceLowering
+ : public XlaOpLowering<xla_hlo::DynamicUpdateSliceOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::DynamicUpdateSliceOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto operand = operands[0];
+ auto update = operands[1];
+
+ auto updateType = update->getType().cast<ShapedType>();
+ Value *lengthConstant =
+ createArrayConstant(rewriter, op->getLoc(), updateType.getShape());
+
+ auto startIndices = makeArrayRef(operands).drop_front(2);
+ const int rank = startIndices.size();
+ llvm::SmallVector<Value *, 4> valuesToConcat;
+ valuesToConcat.reserve(startIndices.size());
+ auto type = getElementTypeOrSelf(startIndices.front());
+
+ // To generate the offset matrix we need to convert the variadic tensors
+ // into a reshaped and concated value.
+ for (auto index : startIndices) {
+ auto reshapedIndex = rewriter.create<IREESeq::HL::ReshapeOp>(
+ op->getLoc(), MemRefType::get({1}, type), index,
+ createArrayConstant(rewriter, op->getLoc(), {1}));
+ valuesToConcat.push_back(reshapedIndex);
+ }
+
+ auto dstOffset = rewriter
+ .create<IREESeq::HL::ConcatOp>(
+ op->getLoc(), MemRefType::get({rank}, type),
+ valuesToConcat, rewriter.getI32IntegerAttr(0))
+ .getResult();
+
+ llvm::SmallVector<int64_t, 4> zero_offset;
+ zero_offset.resize(updateType.getRank(), 0);
+ auto srcOffset = createArrayConstant(rewriter, op->getLoc(), zero_offset);
+
+ auto copiedOperand = rewriter.create<IREESeq::HL::CloneOp>(
+ op->getLoc(), operand->getType(), operand);
+
+ rewriter
+ .create<IREESeq::HL::CopyOp>(op->getLoc(), update, srcOffset,
+ copiedOperand, dstOffset, lengthConstant)
+ .getOperation();
+
+ return copiedOperand;
+ }
+};
+
+struct SliceLowering : public XlaOpLowering<xla_hlo::SliceOp> {
+ using XlaOpLowering::XlaOpLowering;
+ Operation *rewriteInternal(
+ xla_hlo::SliceOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // XLA slice has value semantics, whereas the IREE slice creates a view. We
+ // lower it to a copy if all strides are one which may be transformed to a
+ // slice by later optimizations.
+ auto isNotOne = [](APInt stride) { return stride != 1; };
+ if (llvm::any_of(op->strides(), isNotOne)) {
+ op->emitRemark() << "Could not lower slice op with non-singular strides";
+ return nullptr;
+ }
+
+ auto finalType = convertTypeToMemRef(*op);
+
+ auto src = operands[0];
+ std::vector<Value *> dim_pieces;
+ auto dst = rewriter.create<IREESeq::HL::AllocHeapOp>(op->getLoc(),
+ finalType, dim_pieces);
+ auto srcIndices =
+ rewriter.create<IREE::ConstantOp>(op->getLoc(), op->start_indices());
+ auto lengths =
+ createArrayConstant(rewriter, op->getLoc(), finalType.getShape());
+
+ llvm::SmallVector<int64_t, 4> zero_offset;
+ zero_offset.resize(finalType.getRank(), 0);
+ auto dstIndices = createArrayConstant(rewriter, op->getLoc(), zero_offset);
+
+ rewriter.create<IREESeq::HL::CopyOp>(op->getLoc(), src, srcIndices, dst,
+ dstIndices, lengths);
+ return dst;
+ }
+};
+
+struct ReshapeOpLowering : public XlaOpLowering<xla_hlo::ReshapeOp> {
+ using XlaOpLowering::XlaOpLowering;
+
+ Operation *rewriteInternal(
+ xla_hlo::ReshapeOp *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ return createShapeTargetingOp<IREESeq::HL::ReshapeOp>(
+ rewriter, op->getLoc(), operands[0], convertTypeToMemRef(*op));
+ }
+};
+
+} // namespace
+
+void populateLowerXlaToSequencerPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ patterns.insert<ConcatOpLowering, CopyOpLowering, DynamicUpdateSliceLowering,
+ ReshapeOpLowering, SliceLowering>(ctx);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp b/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp
new file mode 100644
index 0000000..b914d46
--- /dev/null
+++ b/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp
@@ -0,0 +1,242 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/Sequencer/HLOps.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/IR/Types.h"
+#include "compiler/Utils/DispatchUtils.h"
+#include "compiler/Utils/MemRefUtils.h"
+#include "compiler/Utils/TypeConversionUtils.h"
+#include "llvm/ADT/SetVector.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Inserts a load from a wrapped memref (as inserted via insertDispatcherStore).
+// Returns the value in the original type.
+Value *insertDispatcheeLoad(Operation *op, Type originalType, Value *value,
+ OpBuilder &builder) {
+ // If old value was a memref we don't need to change anything.
+ if (originalType.isa<MemRefType>()) {
+ return value;
+ }
+
+ auto loadInputOp =
+ builder.create<IREE::LoadInputOp>(op->getLoc(), originalType, value);
+ value->replaceAllUsesWith(loadInputOp.getResult());
+ loadInputOp.setOperand(value);
+ return loadInputOp.getResult();
+}
+
+// Marshals args and results as buffers for the given region.
+// Beyond inserting the appropriate tensor-to-memref ops we avoid mutating the
+// interior of the dispatch region as much as possible.
+LogicalResult marshalDispatchSite(IREE::DispatchRegionOp regionOp) {
+ auto &entryBlock = regionOp.getBody().getBlocks().front();
+ OpBuilder dispatcherBuilder(regionOp);
+ OpBuilder dispatcheeBuilder(&entryBlock, entryBlock.begin());
+
+ // Wrap input operands and unwrap in the entry block.
+ SmallVector<Value *, 8> newArgs;
+ for (int i = 0; i < regionOp.getNumArgOperands(); ++i) {
+ // Wrap the input outside of the region.
+ auto *blockArg = entryBlock.getArgument(i);
+ Type originalType = blockArg->getType();
+ auto *originalArg = regionOp.getArgOperand(i);
+ auto *wrappedArg =
+ insertDispatcherStore(regionOp, originalArg, dispatcherBuilder);
+ newArgs.push_back(wrappedArg);
+ blockArg->setType(wrappedArg->getType());
+
+ // Unwrap the block arg value and replace all of the uses with the newly
+ // unwrapped value.
+ insertDispatcheeLoad(regionOp, originalType, blockArg, dispatcheeBuilder);
+ }
+
+ // Allocate output arguments and replace the return values with those.
+ SmallVector<Type, 8> newResults;
+ SmallVector<std::pair<int, Value *>, 8> resultIndicesToOutputArgs;
+ SmallVector<int, 8> deadResultIndices;
+ SmallVector<std::pair<Value *, Value *>, 8> replacedResults;
+ for (int i = 0; i < regionOp.getNumResults(); ++i) {
+ auto *result = regionOp.getResult(i);
+ auto convertedType = convertTypeToMemRef(result->getType());
+
+ // Allocate output buffer in the dispatcher to pass in to the region.
+ Value *allocatedValue = allocateDispatchOutputBuffer(
+ regionOp.getLoc(), convertedType, dispatcherBuilder);
+ if (!allocatedValue) {
+ regionOp.emitError("unable to allocate result value");
+ return failure();
+ }
+ newArgs.push_back(allocatedValue);
+
+ auto *newBlockArg = entryBlock.addArgument(allocatedValue->getType());
+ resultIndicesToOutputArgs.push_back({i, newBlockArg});
+
+ // NOTE: right now we always replace results. If we want to allow return
+ // values we can avoid killing them here.
+ deadResultIndices.push_back(i);
+ replacedResults.push_back({result, allocatedValue});
+ }
+
+ // Remove dead results from return statements.
+ regionOp.walk([&](IREE::ReturnOp returnOp) {
+ // Replace the results we were returning with stores to output arguments.
+ OpBuilder builder(returnOp);
+ for (auto resultToArg : resultIndicesToOutputArgs) {
+ auto *value = returnOp.getOperand(resultToArg.first);
+ auto *outputArg = resultToArg.second;
+ builder.create<IREE::StoreOutputOp>(returnOp.getLoc(), value, outputArg);
+ }
+
+ // Filter out the results that are now dead.
+ SmallVector<Value *, 8> newOperands(returnOp.getOperands());
+ for (int i = deadResultIndices.size() - 1; i >= 0; --i) {
+ newOperands.erase(newOperands.begin() + deadResultIndices[i]);
+ }
+ returnOp.getOperation()->setOperands(newOperands);
+ });
+
+ // Clone the region op with the new args/results.
+ auto newRegionOp = dispatcherBuilder.create<IREE::DispatchRegionOp>(
+ regionOp.getLoc(), newResults, regionOp.getWorkload(), newArgs);
+ newRegionOp.getBody().takeBody(regionOp.getBody());
+
+ // Marshal back the results by replacing uses of the original with loads from
+ // the new output arg.
+ for (auto &it : replacedResults) {
+ insertDispatcherLoad(regionOp, it.first, it.second, dispatcherBuilder);
+ }
+
+ // Remove original region.
+ regionOp.erase();
+
+ return success();
+}
+
+// Converts a dispatch_region into a dispatch to the outlined region function.
+LogicalResult convertToDispatchOp(IREE::DispatchRegionOp regionOp,
+ IREE::MultiArchExecutableOp executable,
+ FuncOp entryPoint) {
+ // Insert at the same place as the original region.
+ OpBuilder dispatcherBuilder(regionOp);
+
+ // Ensure workload is a memref.
+ auto *workload =
+ wrapAsMemRef(regionOp.getWorkload(), regionOp, dispatcherBuilder);
+
+ // Create the dispatch op to the executable function.
+ SmallVector<Value *, 8> operandValues(regionOp.getArgOperands());
+ auto dispatchOp = dispatcherBuilder.create<IREESeq::HL::DispatchOp>(
+ regionOp.getLoc(), executable.getName(), entryPoint.getName(), workload,
+ entryPoint.getType().getResults(), operandValues);
+
+ // Replace uses of the existing results with the new results.
+ for (int i = 0; i < regionOp.getNumResults(); ++i) {
+ regionOp.getResult(i)->replaceAllUsesWith(dispatchOp.getResult(i));
+ }
+
+ // Erase original region.
+ regionOp.erase();
+
+ return success();
+}
+
+// Outlines a dispatch region into an iree.multi_arch_executable.
+LogicalResult outlineDispatchRegion(IREE::DispatchRegionOp regionOp,
+ int outlinedRegionOrdinal) {
+ // Build function type matching 1:1 with the region signature.
+ SmallVector<Type, 8> operandTypes;
+ for (auto *arg : regionOp.getArgOperands()) {
+ operandTypes.push_back(arg->getType());
+ }
+ SmallVector<Type, 8> resultTypes(regionOp.getResultTypes());
+ auto functionType =
+ FunctionType::get(operandTypes, resultTypes, regionOp.getContext());
+
+ // Create the executable with the region cloned into it.
+ IREE::MultiArchExecutableOp multiArchExecutable;
+ FuncOp outlinedFunc;
+ std::tie(multiArchExecutable, outlinedFunc) = createRegionExecutable(
+ regionOp, functionType,
+ "_dispatch_" + std::to_string(outlinedRegionOrdinal));
+ outlinedFunc.setAttr("iree.executable.export",
+ UnitAttr::get(regionOp.getContext()));
+
+ // Finally convert the dispatch region into a dispatch to the outlined func.
+ return convertToDispatchOp(regionOp, multiArchExecutable, outlinedFunc);
+}
+
+} // namespace
+
+class OutlineDispatchRegionsPass
+ : public ModulePass<OutlineDispatchRegionsPass> {
+ public:
+ void runOnModule() override {
+ auto module = getModule();
+
+ ModuleManager moduleManager(module);
+ auto funcs = module.getOps<FuncOp>();
+ SmallVector<FuncOp, 4> funcOps(funcs.begin(), funcs.end());
+ for (auto func : funcOps) {
+ // Perform marshaling of the dispatcher and dispatchee I/O.
+ // This inserts the required stores and loads to make everything memrefs
+ // and adds the iree.load_input/iree.store_output ops to the dispatchee.
+ if (func.walk([&](IREE::DispatchRegionOp op) {
+ if (failed(marshalDispatchSite(op))) {
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ })
+ .wasInterrupted()) {
+ return signalPassFailure();
+ }
+
+ // Outline all of the iree.dispatch_region ops in this function.
+ SmallVector<IREE::DispatchRegionOp, 8> dispatchRegionOps;
+ func.walk(
+ [&](IREE::DispatchRegionOp op) { dispatchRegionOps.push_back(op); });
+ for (int i = 0; i < dispatchRegionOps.size(); ++i) {
+ if (failed(outlineDispatchRegion(dispatchRegionOps[i], i))) {
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createOutlineDispatchRegionsPass() {
+ return std::make_unique<OutlineDispatchRegionsPass>();
+}
+
+static PassRegistration<OutlineDispatchRegionsPass> pass(
+ "iree-outline-dispatch-regions",
+ "Outlines dispatch regions into standalone functions");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp b/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp
new file mode 100644
index 0000000..407f014
--- /dev/null
+++ b/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp
@@ -0,0 +1,307 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/Sequencer/HLOps.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/IR/Types.h"
+#include "compiler/Utils/DispatchUtils.h"
+#include "compiler/Utils/MemRefUtils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Determines the shapes involved with reducing this dimension.
+SmallVector<int64_t, 4> calculateResultShape(Value *input,
+ int windowDimension) {
+ SmallVector<int64_t, 4> resultShape;
+ for (auto it :
+ llvm::enumerate(input->getType().cast<ShapedType>().getShape())) {
+ if (it.index() != windowDimension) {
+ resultShape.push_back(it.value());
+ }
+ }
+ return resultShape;
+}
+
+// Creates an executable that holds the given elemental reduction region.
+// The executable will have an entry point taking the specified reduction values
+// and writing the results to output arguments.
+std::pair<IREE::MultiArchExecutableOp, FuncOp> createReductionExecutable(
+ IREE::ReductionRegionOp regionOp, int outlinedRegionOrdinal,
+ int separatedReductionIndex, int reductionDimension,
+ SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs) {
+ Builder builder(regionOp.getContext());
+
+ // Build function type matching 1:1 with the region signature.
+ SmallVector<Type, 8> elementalOperandTypes;
+ SmallVector<Type, 8> elementalResultTypes;
+ for (auto *arg : regionOp.getInitialValueOperands()) {
+ // (in0, in1) -> out0
+ elementalOperandTypes.push_back(arg->getType());
+ elementalOperandTypes.push_back(arg->getType());
+ elementalResultTypes.push_back(arg->getType());
+ }
+ auto elementalFunctionType = FunctionType::get(
+ elementalOperandTypes, elementalResultTypes, regionOp.getContext());
+
+ // Create the executable with the region cloned into it.
+ IREE::MultiArchExecutableOp multiArchExecutable;
+ FuncOp elementalFunc;
+ std::tie(multiArchExecutable, elementalFunc) = createRegionExecutable(
+ regionOp, elementalFunctionType,
+ "_reduce_" + std::to_string(outlinedRegionOrdinal) + "_dim_" +
+ std::to_string(separatedReductionIndex));
+
+ // Create a new entry point that we can use with the signature for this
+ // dimension.
+ SmallVector<Type, 8> allOperandTypes;
+ auto inputTypes =
+ llvm::map_range(inputs, [](Value *value) { return value->getType(); });
+ allOperandTypes.append(inputTypes.begin(), inputTypes.end());
+ auto initialValueTypes = llvm::map_range(
+ initialValues, [](Value *value) { return value->getType(); });
+ allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end());
+ for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) {
+ auto shapedType = resultType.value().cast<ShapedType>();
+ allOperandTypes.push_back(MemRefType::get(
+ calculateResultShape(inputs[resultType.index()], reductionDimension),
+ shapedType.getElementType()));
+ }
+ auto entryFuncType = FunctionType::get(allOperandTypes, ArrayRef<Type>{},
+ regionOp.getContext());
+ auto entryFunc =
+ FuncOp::create(regionOp.getLoc(),
+ (elementalFunc.getName() + "_entry").str(), entryFuncType);
+ entryFunc.setAttr("iree.executable.export",
+ UnitAttr::get(regionOp.getContext()));
+ elementalFunc.getOperation()->getBlock()->push_back(entryFunc);
+ entryFunc.getOperation()->moveBefore(elementalFunc);
+ entryFunc.setAttr("iree.executable.reduction",
+ UnitAttr::get(regionOp.getContext()));
+ entryFunc.setAttr("iree.executable.reduction.apply",
+ builder.getSymbolRefAttr(elementalFunc));
+
+ return {multiArchExecutable, entryFunc};
+}
+
+// Converts a reduction_region into a dispatch to the outlined region function
+// for a single reduction dimension.
+// Returns the results of the reduction or empty if the construction fails.
+SmallVector<Value *, 4> convertToDispatchOp(
+ IREE::ReductionRegionOp regionOp, IREE::MultiArchExecutableOp executable,
+ FuncOp entryFunc, int reductionDimension,
+ SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs,
+ OpBuilder &dispatcherBuilder) {
+ // Allocate output args and replace the return values with those.
+ SmallVector<Value *, 4> resultValues;
+ for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) {
+ // Allocate output buffer in the dispatcher to pass in to the region.
+ auto shapedType = resultType.value().cast<ShapedType>();
+ Value *allocatedValue = allocateDispatchOutputBuffer(
+ regionOp.getLoc(),
+ MemRefType::get(calculateResultShape(inputs[resultType.index()],
+ reductionDimension),
+ shapedType.getElementType()),
+ dispatcherBuilder);
+ if (!allocatedValue) {
+ regionOp.emitError("unable to allocate result value");
+ return {};
+ }
+ resultValues.push_back(allocatedValue);
+ }
+
+ // Calculate workload from the result shape.
+ auto *workload =
+ wrapAsMemRef(calculateWorkload(regionOp, resultValues.front()), regionOp,
+ dispatcherBuilder);
+
+ // Create the reduce op to the executable function.
+ std::vector<Value *> allOperands;
+ allOperands.insert(allOperands.end(), inputs.begin(), inputs.end());
+ allOperands.insert(allOperands.end(), initialValues.begin(),
+ initialValues.end());
+ allOperands.insert(allOperands.end(), resultValues.begin(),
+ resultValues.end());
+ dispatcherBuilder.create<IREESeq::HL::DispatchOp>(
+ regionOp.getLoc(), executable.getName(), entryFunc.getName(), workload,
+ ArrayRef<Type>{}, allOperands);
+
+ return resultValues;
+}
+
+// Outlines a reduction region into one or more iree.multi_arch_executables.
+// This separates the reduction into multiple dispatches, one for each reduction
+// dimension (thankfully XLA's operation semantics state this is ok). We then
+// special case the first dispatch such that it takes the constant initial
+// values so that we don't have to materialize a buffer for them.
+LogicalResult outlineReductionRegion(IREE::ReductionRegionOp regionOp,
+ int outlinedRegionOrdinal) {
+ // Insert at the same place as the original region.
+ OpBuilder dispatcherBuilder(regionOp);
+
+ // Wrap input operands in memrefs.
+ SmallVector<Value *, 4> initialValues{llvm::map_range(
+ regionOp.getInitialValueOperands(), [&](Value *originalArg) {
+ return insertDispatcherStore(regionOp, originalArg, dispatcherBuilder);
+ })};
+ SmallVector<Value *, 4> temps{
+ llvm::map_range(regionOp.getReductionOperands(), [&](Value *originalArg) {
+ return insertDispatcherStore(regionOp, originalArg, dispatcherBuilder);
+ })};
+
+ // Create one dispatch per dimension being reduced.
+ // We'll do this by chaining the original input through with the temporary
+ // reduction results. The results we end up with will be the originally
+ // requested shape and we can just substitute them.
+ if (regionOp.isWindowed()) {
+ auto windowDimensions = regionOp.window_dimensions().getValue();
+ auto windowStrides = regionOp.window_strides().getValue();
+ auto baseDilations = regionOp.base_dilations().getValue();
+ auto windowDilations = regionOp.window_dilations().getValue();
+ SmallVector<std::tuple<int64_t, int64_t, int64_t, int64_t>, 4>
+ sortedWindowAttrs;
+ for (uint64_t i = 0; i < windowDimensions.getNumElements(); ++i) {
+ int64_t windowDimension =
+ windowDimensions.getValue<IntegerAttr>({i}).getInt();
+ int64_t windowStride = windowStrides.getValue<IntegerAttr>({i}).getInt();
+ int64_t baseDilation = baseDilations.getValue<IntegerAttr>({i}).getInt();
+ int64_t windowDilation =
+ windowDilations.getValue<IntegerAttr>({i}).getInt();
+ sortedWindowAttrs.push_back(
+ {windowDimension, windowStride, baseDilation, windowDilation});
+ }
+ llvm::sort(sortedWindowAttrs,
+ [](std::tuple<int64_t, int64_t, int64_t, int64_t> a,
+ std::tuple<int64_t, int64_t, int64_t, int64_t> b) {
+ return std::get<0>(a) - std::get<0>(b);
+ });
+ for (auto windowAttrs : llvm::enumerate(sortedWindowAttrs)) {
+ int64_t windowDimension = std::get<0>(windowAttrs.value());
+ int64_t windowStride = std::get<1>(windowAttrs.value());
+ int64_t baseDilation = std::get<2>(windowAttrs.value());
+ int64_t windowDilation = std::get<3>(windowAttrs.value());
+ IREE::MultiArchExecutableOp multiArchExecutable;
+ FuncOp entryFunc;
+ std::tie(multiArchExecutable, entryFunc) = createReductionExecutable(
+ regionOp, outlinedRegionOrdinal, windowAttrs.index(), windowDimension,
+ initialValues, temps);
+ entryFunc.setAttr("iree.executable.reduction.padding_mode",
+ dispatcherBuilder.getI32IntegerAttr(
+ regionOp.padding_mode().getValue()));
+ entryFunc.setAttr("iree.executable.reduction.window_dimension",
+ dispatcherBuilder.getI32IntegerAttr(windowDimension));
+ entryFunc.setAttr("iree.executable.reduction.window_stride",
+ dispatcherBuilder.getI32IntegerAttr(windowStride));
+ entryFunc.setAttr("iree.executable.reduction.base_dilation",
+ dispatcherBuilder.getI32IntegerAttr(baseDilation));
+ entryFunc.setAttr("iree.executable.reduction.window_dilation",
+ dispatcherBuilder.getI32IntegerAttr(windowDilation));
+ temps = convertToDispatchOp(regionOp, multiArchExecutable, entryFunc,
+ windowDimension, initialValues,
+ std::move(temps), dispatcherBuilder);
+ if (temps.empty()) {
+ return regionOp.emitOpError()
+ << "Failed to construct reduction for windowed dimension "
+ << windowDimension;
+ }
+ }
+ } else {
+ auto dimensions = regionOp.dimensions().getValue();
+ SmallVector<int64_t, 4> sortedDimensions;
+ for (uint64_t i = 0; i < dimensions.getNumElements(); ++i) {
+ sortedDimensions.push_back(
+ dimensions.getValue<IntegerAttr>({i}).getInt());
+ }
+ llvm::sort(sortedDimensions, [](int64_t a, int64_t b) { return a - b; });
+ for (auto dimension : llvm::enumerate(sortedDimensions)) {
+ IREE::MultiArchExecutableOp multiArchExecutable;
+ FuncOp entryFunc;
+ std::tie(multiArchExecutable, entryFunc) = createReductionExecutable(
+ regionOp, outlinedRegionOrdinal, dimension.index(), dimension.value(),
+ initialValues, temps);
+ entryFunc.setAttr("iree.executable.reduction.dimension",
+ dispatcherBuilder.getI32IntegerAttr(dimension.value()));
+ temps = convertToDispatchOp(regionOp, multiArchExecutable, entryFunc,
+ dimension.value(), initialValues,
+ std::move(temps), dispatcherBuilder);
+ if (temps.empty()) {
+ return regionOp.emitOpError()
+ << "Failed to construct reduction for dimension "
+ << dimension.value();
+ }
+ }
+ }
+ for (auto it : llvm::enumerate(regionOp.getResults())) {
+ insertDispatcherLoad(regionOp, it.value(), temps[it.index()],
+ dispatcherBuilder);
+ }
+
+ // Erase original region.
+ regionOp.erase();
+
+ return success();
+}
+
+} // namespace
+
+class OutlineReductionRegionsPass
+ : public ModulePass<OutlineReductionRegionsPass> {
+ public:
+ void runOnModule() override {
+ auto module = getModule();
+
+ ModuleManager moduleManager(module);
+ auto funcs = module.getOps<FuncOp>();
+ SmallVector<FuncOp, 4> funcOps(funcs.begin(), funcs.end());
+ for (auto func : funcOps) {
+ // Outline all of the iree.reduction_region ops in this function.
+ std::vector<IREE::ReductionRegionOp> reductionRegionOps;
+ func.walk([&](IREE::ReductionRegionOp op) {
+ reductionRegionOps.push_back(op);
+ });
+ for (int i = 0; i < reductionRegionOps.size(); ++i) {
+ if (failed(outlineReductionRegion(reductionRegionOps[i], i))) {
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createOutlineReductionRegionsPass() {
+ return std::make_unique<OutlineReductionRegionsPass>(); // NOLINT
+}
+
+static PassRegistration<OutlineReductionRegionsPass> pass(
+ "iree-outline-reduction-regions",
+ "Outlines reduction regions into standalone functions");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/Passes.h b/compiler/Transforms/Sequencer/Passes.h
similarity index 100%
rename from iree/compiler/Transforms/Sequencer/Passes.h
rename to compiler/Transforms/Sequencer/Passes.h
diff --git a/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp b/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp
new file mode 100644
index 0000000..a5479e5
--- /dev/null
+++ b/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp
@@ -0,0 +1,149 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <algorithm>
+
+#include "compiler/IR/Ops.h"
+#include "compiler/Utils/DispatchUtils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Chosen randomly for now. We can measure and see what makes sense.
+constexpr int64_t kMaxRematerializedConstantSizeInBytes = 1 * 1024;
+
+// Returns true if the constant value is under a certain threshold.
+// This threshold is fixed for all backends as a value that is assumed small
+// enough to be worth inlining possibly several times (at the cost of binary
+// bloat).
+bool isConstantSmall(ConstantOp constantOp) {
+ if (auto shapedType = constantOp.getType().dyn_cast<ShapedType>()) {
+ return shapedType.getSizeInBits() / 8 <=
+ kMaxRematerializedConstantSizeInBytes;
+ }
+
+ // Assume anything unshaped is small. This may not always be true in custom
+ // dialects but is in std for now.
+ return true;
+}
+
+// Returns true if the dispatch region is allowed to have constants inside.
+// Certain regions that may get replaced or turned into kernel imports shouldn't
+// have the constants moved into them as they'll just get lost.
+bool canDispatchRegionContainConstants(
+ IREE::DispatchRegionOp dispatchRegionOp) {
+ for (auto &block : dispatchRegionOp.getBody()) {
+ for (auto &op : block) {
+ if (isa<xla_hlo::DotOp>(&op)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+// Rematerializes a constant inside of all dispatch regions that use it.
+// Afterward the constant is only removed if there are no other uses within the
+// non-dispatch block (such as by sequencer ops).
+LogicalResult rematerializeConstantInDispatchRegions(ConstantOp constantOp) {
+ Value *constantValue = constantOp.getResult();
+ SmallVector<IREE::DispatchRegionOp, 4> usingRegionOps;
+ for (auto *user : constantValue->getUsers()) {
+ if (auto dispatchRegionOp = dyn_cast<IREE::DispatchRegionOp>(user)) {
+ // Ensure this isn't just the workload and is used as an arg.
+ if (std::find(dispatchRegionOp.arg_operand_begin(),
+ dispatchRegionOp.arg_operand_end(),
+ constantValue) != dispatchRegionOp.arg_operand_end()) {
+ if (canDispatchRegionContainConstants(dispatchRegionOp)) {
+ usingRegionOps.push_back(dispatchRegionOp);
+ }
+ }
+ }
+ }
+ for (auto &dispatchRegionOp : usingRegionOps) {
+ if (failed(inlineDispatchRegionOperandsUsingValue(dispatchRegionOp,
+ constantValue))) {
+ return failure();
+ }
+ }
+
+ // Remove if there are no other uses within the block.
+ if (constantOp.use_empty()) {
+ constantOp.erase();
+ }
+
+ return success();
+}
+
+} // namespace
+
+// Finds constant arguments to dispatch regions that are too small to be worth
+// putting into constant pools. This prevents things like a CSE'd scalar
+// constant of 0.0 being passed by reference to a bunch of regions. Later
+// backend-specific passes running on the dispatch regions may also be able to
+// improve their constant propagation chances by having the full constant value
+// available.
+//
+// Note that this currently only operates at the block level. Constants that are
+// pushed across branches are assumed to have been rematerialized within blocks
+// already, but if that isn't the case then this pass can be extended to do
+// that.
+class RematerializeDispatchConstantsPass
+ : public FunctionPass<RematerializeDispatchConstantsPass> {
+ public:
+ void runOnFunction() override {
+ for (auto &block : getFunction()) {
+ SmallVector<ConstantOp, 8> smallConstantOps;
+ for (auto constantOp : block.getOps<ConstantOp>()) {
+ if (isConstantSmall(constantOp)) {
+ smallConstantOps.push_back(constantOp);
+ }
+ }
+ // Note: we iterate in reverse so that the rematerialized constants appear
+ // in the same order they did originally (as insertion is at the top).
+ for (auto constantOp : llvm::reverse(smallConstantOps)) {
+ if (failed(rematerializeConstantInDispatchRegions(constantOp))) {
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+};
+
+std::unique_ptr<OpPassBase<FuncOp>> createRematerializeDispatchConstantsPass() {
+ return std::make_unique<RematerializeDispatchConstantsPass>();
+}
+
+static PassRegistration<RematerializeDispatchConstantsPass> pass(
+ "iree-rematerialize-dispatch-constants",
+ "Rematerializes small previously-CSE'd constants into dispatch regions.");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/Rewrites.h b/compiler/Transforms/Sequencer/Rewrites.h
new file mode 100644
index 0000000..c78a131
--- /dev/null
+++ b/compiler/Transforms/Sequencer/Rewrites.h
@@ -0,0 +1,40 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_
+#define IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_
+
+#include "compiler/Utils/TypeConversionUtils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Adds rewrite patterns for lowering IREE Sequencer HL ops (iree_hl_seq.*)
+// to LL ops (iree_ll_seq.*).
+void populateSequencerLoweringPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx);
+
+// Adds rewrite patterns for lowering xla_hlo ops to Sequencer HL ops.
+void populateLowerXlaToSequencerPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx);
+
+// Adds rewrite patterns for lowering standard ops to Sequencer HL ops.
+void populateLowerStdToSequencerPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx);
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_
diff --git a/compiler/Transforms/Sequencer/test/BUILD b/compiler/Transforms/Sequencer/test/BUILD
new file mode 100644
index 0000000..8a8ee90
--- /dev/null
+++ b/compiler/Transforms/Sequencer/test/BUILD
@@ -0,0 +1,16 @@
+# Tests specific to the sequencer.
+
+load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_setup_lit_package(
+ data = [
+ "///tools:iree-opt",
+ ],
+)
+
+iree_glob_lit_tests()
diff --git a/compiler/Transforms/test/BUILD b/compiler/Transforms/test/BUILD
new file mode 100644
index 0000000..c7b98c5
--- /dev/null
+++ b/compiler/Transforms/test/BUILD
@@ -0,0 +1,16 @@
+# Tests for common transforms.
+
+load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_setup_lit_package(
+ data = [
+ "///tools:iree-opt",
+ ],
+)
+
+iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/test/drop_unreachable_module_functions.mlir b/compiler/Transforms/test/drop_unreachable_module_functions.mlir
similarity index 100%
rename from iree/compiler/Transforms/test/drop_unreachable_module_functions.mlir
rename to compiler/Transforms/test/drop_unreachable_module_functions.mlir
diff --git a/iree/compiler/Translation/BUILD b/compiler/Translation/BUILD
similarity index 100%
rename from iree/compiler/Translation/BUILD
rename to compiler/Translation/BUILD
diff --git a/iree/compiler/Translation/CMakeLists.txt b/compiler/Translation/CMakeLists.txt
similarity index 100%
rename from iree/compiler/Translation/CMakeLists.txt
rename to compiler/Translation/CMakeLists.txt
diff --git a/compiler/Translation/Interpreter/BUILD b/compiler/Translation/Interpreter/BUILD
new file mode 100644
index 0000000..17858bd
--- /dev/null
+++ b/compiler/Translation/Interpreter/BUILD
@@ -0,0 +1,31 @@
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "Interpreter",
+ srcs = ["InterpreterExecutableTranslation.cpp"],
+ hdrs = ["InterpreterExecutableTranslation.h"],
+ deps = [
+ "///compiler/IR",
+ "///compiler/IR/Interpreter",
+ "///compiler/Serialization",
+ "///compiler/Transforms",
+ "///compiler/Transforms/Interpreter",
+ "///compiler/Utils",
+ "///schemas",
+ "@com_github_google_flatbuffers//:flatbuffers",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Pass",
+ "@local_config_mlir//:StandardDialectRegistration",
+ "@local_config_mlir//:Support",
+ "@local_config_mlir//:Transforms",
+ "@local_config_mlir//:Translation",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
+ ],
+ alwayslink = 1,
+)
diff --git a/iree/compiler/Translation/Interpreter/CMakeLists.txt b/compiler/Translation/Interpreter/CMakeLists.txt
similarity index 100%
rename from iree/compiler/Translation/Interpreter/CMakeLists.txt
rename to compiler/Translation/Interpreter/CMakeLists.txt
diff --git a/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp b/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp
new file mode 100644
index 0000000..ae0fb52
--- /dev/null
+++ b/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp
@@ -0,0 +1,289 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Translation/Interpreter/InterpreterExecutableTranslation.h"
+
+#include <cstdint>
+#include <iostream>
+#include <vector>
+
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/minireflect.h"
+#include "compiler/IR/ConfigOps.h"
+#include "compiler/IR/Interpreter/OpWriters.h"
+#include "compiler/IR/Types.h"
+#include "compiler/Serialization/VMFunctionBuilder.h"
+#include "compiler/Serialization/VMFunctionTableBuilder.h"
+#include "compiler/Serialization/VMModuleBuilder.h"
+#include "compiler/Transforms/Interpreter/Passes.h"
+#include "compiler/Transforms/Passes.h"
+#include "compiler/Utils/Macros.h"
+#include "compiler/Utils/OpUtils.h"
+#include "compiler/Utils/TranslationUtils.h"
+#include "schemas/executable_def_generated.h"
+#include "schemas/module_def_generated.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Translation.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Builds a pass pipeline that optimizes and legalizes the module to the form
+// expected by translation.
+void buildLegalizeInputPassPipeline(PassManager *passManager) {
+ // Standard passes that shake out a lot of garbage.
+ // Some may have been run prior to translation but this ensures we are always
+ // in a known state.
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createLoopFusionPass());
+ passManager->addPass(createLoopInvariantCodeMotionPass());
+ passManager->addPass(createMemRefDataFlowOptPass());
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createSimplifyAffineStructuresPass());
+ passManager->addPass(createCSEPass());
+ passManager->addPass(createCanonicalizerPass());
+
+ // Eliminate ops we don't care about based on a lack of side-effects.
+ // IREE does not guarantee exception/error behavior of dead ops.
+ passManager->addPass(createAggressiveOpEliminationPass());
+
+ // Expand uses of tuples into independent args/results.
+ passManager->addPass(createConvertFromTupleCallingConventionPass());
+ passManager->addPass(createCanonicalizerPass());
+}
+
+// Builds a pass pipeline that converts functions to the iree_hl_interp dialect.
+void buildInterpreterConversionPassPipeline(PassManager *passManager) {
+ // We don't need the IREE binding ops anymore, as we match the calling
+ // convention exactly (we're the same VM).
+ passManager->addPass(createMakeExecutableABIPass());
+
+ // Convert to the memref calling convention and optimize away as many
+ // loads and stores as we can prior to progressing.
+ passManager->addPass(createConvertToMemRefCallingConventionPass());
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createMemRefDataFlowOptPass());
+
+ // Convert various dialects to IREE opcodes and cleanup leftover conversions.
+ passManager->addPass(createLowerToInterpreterDialectPass());
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createAggressiveOpEliminationPass());
+
+ // Widen reduction functions (that have iree.executable.reduction attrs) to
+ // use their primitive IREE ops.
+ passManager->addPass(createExpandReductionsToOpsPass());
+
+ // Convert any uses of index to int32_t (as we explicitly don't want to
+ // support dynamic index width).
+ // This also looks for other weird types (i1, etc).
+ passManager->addPass(createLegalizeTypeStoragePass());
+
+ // Perform any last-minute optimizations to trim down the IR.
+ passManager->addPass(createAggressiveOpEliminationPass());
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createLoopFusionPass());
+ passManager->addPass(createLoopInvariantCodeMotionPass());
+ passManager->addPass(createMemRefDataFlowOptPass());
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createCSEPass());
+ passManager->addPass(createCanonicalizerPass());
+
+ // Drop all functions that are not reachable.
+ passManager->addPass(createDropUnreachableExecutableFunctionsPass());
+}
+
+// Builds a pass pipeline that lowers the iree_hl_interp dialect to the
+// iree_ll_interp dialect and prepares for serialization.
+void buildInterpreterLoweringPassPipeline(PassManager *passManager) {
+ // Lower iree_hl_interp -> iree_ll_interp.
+ passManager->addPass(createLowerInterpreterDialectPass());
+
+ // Assign ordinals used by the bytecode to reference executables and
+ // functions.
+ passManager->addPass(createAssignFunctionOrdinalsPass());
+}
+
+class InterpreterTranslator {
+ public:
+ explicit InterpreterTranslator(ExecutableTranslationOptions options)
+ : options_(options) {}
+
+ const ExecutableTranslationOptions &options() const { return options_; }
+
+ std::unique_ptr<iree::ExecutableDefT> translateExecutable(
+ IREE::ExecutableOp executableOp);
+
+ private:
+ LogicalResult translateExecutableModule(IREE::ExecutableOp executableOp,
+ ModuleOp moduleOp,
+ VMModuleBuilder *moduleBuilder);
+ LogicalResult declareFunction(FuncOp function,
+ VMModuleBuilder *moduleBuilder);
+ LogicalResult defineFunction(FuncOp function, VMModuleBuilder *moduleBuilder);
+
+ ExecutableTranslationOptions options_;
+};
+
+std::unique_ptr<iree::ExecutableDefT>
+InterpreterTranslator::translateExecutable(IREE::ExecutableOp executableOp) {
+ auto moduleOp = executableOp.getInnerModule();
+
+ // Run all passes to go from input to the iree_ll_interp dialect.
+ auto executableConversionPasses =
+ createPassManager(moduleOp.getContext(), options());
+ buildLegalizeInputPassPipeline(executableConversionPasses.get());
+ buildInterpreterConversionPassPipeline(executableConversionPasses.get());
+ buildInterpreterLoweringPassPipeline(executableConversionPasses.get());
+ if (failed(runPassPipeline(options(), executableConversionPasses.get(),
+ moduleOp))) {
+ executableOp.emitError() << "Failed to run conversion passes";
+ return {};
+ }
+
+ // Build the module bytecode.
+ ::flatbuffers::FlatBufferBuilder fbb;
+ VMModuleBuilder moduleBuilder(&fbb);
+ if (failed(
+ translateExecutableModule(executableOp, moduleOp, &moduleBuilder))) {
+ executableOp.emitError() << "Failed to translate executable module";
+ return {};
+ }
+ auto moduleDef = moduleBuilder.Finish();
+ if (moduleDef.IsNull()) {
+ moduleOp.emitError() << "Failed to verify completed module def";
+ return {};
+ }
+ auto bytes = moduleBuilder.Serialize(moduleDef);
+ if (bytes.empty()) {
+ moduleOp.emitError() << "Failed to serialize final module def";
+ return {};
+ }
+
+ OpBuilder builder(executableOp);
+ executableOp.setAttr("format", builder.getI32IntegerAttr(static_cast<int32_t>(
+ IREE::ExecutableFormat::IreeBytecode)));
+
+ auto executableDef = std::make_unique<iree::ExecutableDefT>();
+ executableDef->format =
+ static_cast<uint32_t>(IREE::ExecutableFormat::IreeBytecode);
+ executableDef->supported_features = iree::ExecutableFeature::kDebugging;
+ executableDef->contents = std::move(bytes);
+ return executableDef;
+}
+
+LogicalResult InterpreterTranslator::translateExecutableModule(
+ IREE::ExecutableOp executableOp, ModuleOp moduleOp,
+ VMModuleBuilder *moduleBuilder) {
+ // Declare functions first so that we get stable indices during declaration
+ // (as call ops need to use the function table).
+ for (auto function : moduleOp.getOps<FuncOp>()) {
+ RETURN_IF_FAILURE(declareFunction(function, moduleBuilder));
+ }
+
+ // Define functions now that all functions have been declared.
+ for (auto function : moduleOp.getOps<FuncOp>()) {
+ RETURN_IF_FAILURE(defineFunction(function, moduleBuilder));
+ }
+
+ return success();
+}
+
+LogicalResult InterpreterTranslator::declareFunction(
+ FuncOp function, VMModuleBuilder *moduleBuilder) {
+ auto *functionTable = moduleBuilder->function_table();
+ if (functionTable->IsFunctionDeclared(function)) {
+ // Already declared.
+ return success();
+ }
+
+ LinkageType linkageType;
+ if (function.isExternal()) {
+ linkageType = LinkageType::kImport;
+ } else if (function.getAttr("iree.executable.export")) {
+ linkageType = LinkageType::kExport;
+ } else {
+ linkageType = LinkageType::kInternal;
+ }
+ if (failed(functionTable->DeclareFunction(function, linkageType))) {
+ return function.emitError() << "Unable to declare function";
+ }
+
+ // Import functions must have their definition defined here so we get their
+ // type. Internal and export functions will be defined during conversion.
+ if (linkageType == LinkageType::kImport) {
+ VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(),
+ moduleBuilder->fbb());
+ auto functionOffset = functionBuilder.Finish();
+ if (functionOffset.IsNull()) {
+ return function.emitError()
+ << "Failed to create import function bytecode";
+ }
+ RETURN_IF_FAILURE(
+ functionTable->DefineFunction(function, functionOffset, {}));
+ }
+
+ return success();
+}
+
+LogicalResult InterpreterTranslator::defineFunction(
+ FuncOp function, VMModuleBuilder *moduleBuilder) {
+ VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(),
+ moduleBuilder->fbb());
+ registerInterpreterCustomWriters(&functionBuilder);
+ RETURN_IF_FAILURE(functionBuilder.ConvertBytecode());
+ auto functionOffset = functionBuilder.Finish();
+ if (functionOffset.IsNull()) {
+ return function.emitError() << "Failed to serialize function";
+ }
+ RETURN_IF_FAILURE(moduleBuilder->function_table()->DefineFunction(
+ function, functionOffset, functionBuilder.source_map()));
+ return success();
+}
+
+} // namespace
+
+llvm::Optional<ExecutableTranslationResult>
+translateExecutableToInterpreterExecutable(
+ ArrayRef<IREE::ExecutableOp> executableOps,
+ ExecutableTranslationOptions options) {
+ InterpreterTranslator translator(options);
+ ExecutableTranslationResult translationResult;
+ for (auto executableOp : llvm::make_early_inc_range(executableOps)) {
+ auto executableDef = translator.translateExecutable(executableOp);
+ if (!executableDef) {
+ executableOp.emitError() << "Failed to translate one or more executables";
+ return llvm::None;
+ }
+ translationResult.executable_defs.push_back(std::move(executableDef));
+ }
+ return translationResult;
+}
+
+static ExecutableTranslationRegistration
+ InterpreterExecutableTranslationRegistration(
+ "interpreter-bytecode", translateExecutableToInterpreterExecutable);
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h b/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h
new file mode 100644
index 0000000..0e90e61
--- /dev/null
+++ b/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h
@@ -0,0 +1,39 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_
+#define IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_
+
+#include <vector>
+
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Utils/TranslationUtils.h"
+#include "mlir/IR/Module.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Translates an MLIR module into a bytecode interpreter executable.
+// These executables are stored as IREE modules as defined in the
+// https://github.com/google/iree/tree/master/iree/schemas/module_def.fbs
+// schema.
+llvm::Optional<ExecutableTranslationResult>
+translateExecutableToInterpreterExecutable(
+ ArrayRef<IREE::ExecutableOp> executableOps,
+ ExecutableTranslationOptions options = {});
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_
diff --git a/compiler/Translation/SPIRV/AffineExprCodegen.h b/compiler/Translation/SPIRV/AffineExprCodegen.h
new file mode 100644
index 0000000..ff12b37
--- /dev/null
+++ b/compiler/Translation/SPIRV/AffineExprCodegen.h
@@ -0,0 +1,143 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- AffineExprCodegen.h -------------------------------------*- C++//-*-===//
+//
+// Code-generation for Affine Expression.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H
+#define IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H
+
+#include "compiler/Translation/SPIRV/XLAIndexPropagation.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/AffineExprVisitor.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Codegenerator for affine expressions.
+class AffineExprCodegen : public AffineExprVisitor<AffineExprCodegen, Value *> {
+ public:
+ explicit AffineExprCodegen(spirv::ModuleOp module,
+ IndexComputationCache &tensorIndices)
+ : builder(module.getContext()),
+ location(module.getLoc()),
+ tensorIndices(tensorIndices) {}
+
+ Value *visitAddExpr(AffineBinaryOpExpr expr) {
+ auto operand1 = getValueInternal(expr.getLHS());
+ auto operand2 = getValueInternal(expr.getRHS());
+ return builder.create<spirv::IAddOp>(location, operand1, operand2);
+ }
+ Value *visitMulExpr(AffineBinaryOpExpr expr) {
+ auto operand1 = getValueInternal(expr.getLHS());
+ auto operand2 = getValueInternal(expr.getRHS());
+ return builder.create<spirv::IMulOp>(location, operand1, operand2);
+ }
+ Value *visitModExpr(AffineBinaryOpExpr expr) {
+ auto operand1 = getValueInternal(expr.getLHS());
+ auto operand2 = getValueInternal(expr.getRHS());
+ return builder.create<spirv::SModOp>(location, operand1, operand2);
+ }
+ Value *visitFloorDivExpr(AffineBinaryOpExpr expr) {
+ auto operand1 = getValueInternal(expr.getLHS());
+ auto operand2 = getValueInternal(expr.getRHS());
+ return builder.create<spirv::SDivOp>(location, operand1, operand2);
+ }
+ Value *visitCeilDivExpr(AffineBinaryOpExpr expr) {
+ // TODO(ravishankarm): Implement ceil div expr codegen.
+ llvm_unreachable("Unimplemented affine AffineCeilDivExpr codegen");
+ return nullptr;
+ }
+ Value *visitConstantExpr(AffineConstantExpr expr) {
+ return builder.create<spirv::ConstantOp>(
+ location, builder.getIntegerType(32),
+ builder.getI32IntegerAttr(expr.getValue()));
+ }
+ Value *visitDimExpr(AffineDimExpr expr) {
+ return threadDimToDstValue.lookup(expr.getPosition());
+ }
+ Value *visitSymbolExpr(AffineSymbolExpr expr) {
+ // TODO(ravishankarm): Implement symbol expr codegen.
+ llvm_unreachable("Unimplemented affine AffineSymbolExpr codegen");
+ return nullptr;
+ }
+
+ /// Set the value that contains the workitem ID along a particular
+ /// dimension. 0 -> x-dimension, 1 -> y-dimension, etc.
+ void setDimDstValue(unsigned dimID, Value *value) {
+ threadDimToDstValue[dimID] = value;
+ }
+
+ /// Generates the scalar value for a affine expression.
+ Value *getValue(AffineExpr expr, OpBuilder::InsertPoint ip, Location loc) {
+ auto &val = exprToDstValue[expr];
+ if (!val) {
+ location = loc;
+ builder.restoreInsertionPoint(ip);
+ val = visit(expr);
+ }
+ return val;
+ }
+
+ /// Returns a list of indices of a particular tensor in the source dialect
+ /// needed within the dispatch function (obtained from the
+ /// IndexComputationCache)
+ SmallVector<AffineMap, 4> getIndices(Value *value) {
+ SmallVector<AffineMap, 4> indices;
+ for (auto &index : tensorIndices[value]) {
+ indices.push_back(index.first);
+ }
+ return indices;
+ }
+
+ /// For a given tensor in the source dialect and index, return the index of
+ /// all operands needed to compute the result.
+ ArrayRef<AffineMap> getOperandIndices(Value *value, AffineMap index) {
+ return tensorIndices[value][index];
+ }
+
+ private:
+ /// Returns the Value corresponding to the AffineExpr `expr` by either
+ /// previously generated value for the same index, or by generating the value.
+ /// This version assumes the insertion point/Location has already been set.
+ Value *getValueInternal(AffineExpr expr) {
+ auto &val = exprToDstValue[expr];
+ if (!val) {
+ val = visit(expr);
+ }
+ return val;
+ }
+
+ OpBuilder builder;
+
+ Location location;
+
+ /// Map from launch dimension to scalar value.
+ DenseMap<unsigned, Value *> threadDimToDstValue;
+
+ /// Cache of affine expression to scalar value. TODO(ravishankarm) : Might
+ /// need to be changed if we are handling control flow within the dispatch
+ /// function.
+ DenseMap<AffineExpr, Value *> exprToDstValue;
+
+ /// Map from tensor value in source dialect to list of indices of the tensor
+ /// needed within a workitem to compute the results of the dispatch function.
+ IndexComputationCache &tensorIndices;
+};
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H
diff --git a/compiler/Translation/SPIRV/BUILD b/compiler/Translation/SPIRV/BUILD
new file mode 100644
index 0000000..287af86
--- /dev/null
+++ b/compiler/Translation/SPIRV/BUILD
@@ -0,0 +1,50 @@
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "SPIRV",
+ srcs = [
+ "AffineExprCodegen.h",
+ "EmbeddedKernels.cpp",
+ "IREEIndexComputation.cpp",
+ "IREEToSPIRV.cpp",
+ "IREEToSPIRVPass.cpp",
+ "IndexComputation.cpp",
+ "SPIRVExecutableTranslation.cpp",
+ "SPIRVLowering.cpp",
+ "SPIRVLowering.h",
+ "XLAIndexPropagation.cpp",
+ ],
+ hdrs = [
+ "EmbeddedKernels.h",
+ "IREEIndexComputation.h",
+ "IREEToSPIRV.h",
+ "IREEToSPIRVPass.h",
+ "IndexComputation.h",
+ "SPIRVExecutableTranslation.h",
+ "XLAIndexPropagation.h",
+ ],
+ deps = [
+ "///compiler/IR",
+ "///compiler/Translation/SPIRV/Kernels",
+ "///compiler/Utils",
+ "///schemas",
+ "///schemas:spirv_executable_def_cc_fbs",
+ "@com_github_google_flatbuffers//:flatbuffers",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Pass",
+ "@local_config_mlir//:SPIRVDialect",
+ "@local_config_mlir//:SPIRVDialectRegistration",
+ "@local_config_mlir//:SPIRVSerialization",
+ "@local_config_mlir//:StandardDialectRegistration",
+ "@local_config_mlir//:StandardOps",
+ "@local_config_mlir//:Support",
+ "@local_config_mlir//:Transforms",
+ "@local_config_mlir//:Translation",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
+ ],
+ alwayslink = 1,
+)
diff --git a/iree/compiler/Translation/SPIRV/CMakeLists.txt b/compiler/Translation/SPIRV/CMakeLists.txt
similarity index 100%
rename from iree/compiler/Translation/SPIRV/CMakeLists.txt
rename to compiler/Translation/SPIRV/CMakeLists.txt
diff --git a/compiler/Translation/SPIRV/EmbeddedKernels.cpp b/compiler/Translation/SPIRV/EmbeddedKernels.cpp
new file mode 100644
index 0000000..1fd5b1c
--- /dev/null
+++ b/compiler/Translation/SPIRV/EmbeddedKernels.cpp
@@ -0,0 +1,219 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Translation/SPIRV/EmbeddedKernels.h"
+
+#include "compiler/Translation/SPIRV/Kernels/Kernels.h"
+#include "schemas/spirv_executable_def_generated.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Reads the SPIR-V code for the embedded kernel with the given file name.
+// If the kernel under Kernels/ is 'matmul.comp' then |kernelName| would be
+// 'matmul.spv' (because it's been compiled).
+std::vector<uint32_t> readEmbeddedKernelCode(std::string kernelName) {
+ auto *fileToc = spirv_kernels::Kernels_create();
+ for (int i = 0; i < spirv_kernels::Kernels_size(); ++i) {
+ if (std::strcmp(fileToc[i].name, kernelName.c_str()) == 0) {
+ std::vector<uint32_t> code;
+ code.resize(fileToc[i].size / 4);
+ std::memcpy(code.data(), fileToc[i].data, fileToc[i].size);
+ return code;
+ }
+ }
+ return {};
+}
+
+// Adds a storage buffer binding to the descriptor set layout.
+void addDescriptorSetLayoutBinding(uint32_t binding,
+ iree::VkDescriptorSetLayoutDefT *dsl) {
+ auto bindingDef = std::make_unique<iree::VkDescriptorSetLayoutBindingDefT>();
+ bindingDef->binding = binding;
+ bindingDef->descriptor_count = 1;
+ bindingDef->descriptor_type = 7; // VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
+ bindingDef->stage_flags = 0x00000020; // VK_SHADER_STAGE_COMPUTE_BIT
+ dsl->bindings.push_back(std::move(bindingDef));
+}
+
+// Adds a specialization map entry for |constant_id| set to a 4-byte int value.
+void addSpecializationMapEntry(
+ uint32_t constant_id, uint32_t value,
+ iree::VkSpecializationInfoDefT *specializationInfoDef) {
+ auto specValue = std::make_unique<iree::VkSpecializationMapEntryDefT>();
+ specValue->constant_id = constant_id;
+ specValue->uint32_value = value;
+ specializationInfoDef->map_entries.push_back(std::move(specValue));
+}
+
+LogicalResult buildReductionExecutable(IREE::ExecutableOp executableOp,
+ FuncOp entryFuncOp,
+ iree::SpirVExecutableDefT *out_def) {
+ auto funcType = entryFuncOp.getType();
+ auto arg0 = funcType.getInput(0).cast<ShapedType>();
+ if (!arg0.getElementType().isF32()) {
+ // When we do other types we'll need other shaders.
+ return entryFuncOp.emitOpError()
+ << "Only floating point reduction is implemented";
+ }
+
+ auto module = executableOp.getInnerModule();
+ auto applyFuncAttr = entryFuncOp.getAttrOfType<SymbolRefAttr>(
+ "iree.executable.reduction.apply");
+ auto applyFuncOp = module.lookupSymbol(applyFuncAttr.getValue());
+
+ // TODO(benvanik): specialize (template on shapes/types/etc).
+ std::string kernelName = "reduce_untiled.spv";
+ llvm::Optional<uint32_t> operationId;
+ applyFuncOp->walk([&](Operation *op) {
+ if (isa<xla_hlo::AddOp>(op)) {
+ operationId = 0;
+ } else if (isa<xla_hlo::MaxOp>(op)) {
+ operationId = 1;
+ } else if (isa<xla_hlo::MinOp>(op)) {
+ operationId = 2;
+ }
+ });
+ if (!operationId.hasValue()) {
+ applyFuncOp->dump();
+ return applyFuncOp->emitOpError() << "Unsupported reduction operator";
+ }
+
+ out_def->tag = "__reduce__";
+ out_def->entry_points = {"main"};
+
+ out_def->code = readEmbeddedKernelCode(kernelName);
+
+ // arg0, arg1, ret0
+ auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>();
+ pipelineLayoutDef->buffer_binding_set = 0;
+ auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>();
+ addDescriptorSetLayoutBinding(0, dsl.get());
+ addDescriptorSetLayoutBinding(1, dsl.get());
+ addDescriptorSetLayoutBinding(2, dsl.get());
+ pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl));
+ out_def->pipeline_layout = std::move(pipelineLayoutDef);
+
+ // See the shader source for documentation on the values of A/B/C/R.
+ int64_t reductionDimension =
+ entryFuncOp
+ .getAttrOfType<IntegerAttr>("iree.executable.reduction.dimension")
+ .getInt();
+ uint32_t r = arg0.getDimSize(reductionDimension);
+ uint32_t a = 1;
+ for (int i = 0; i < reductionDimension; ++i) {
+ a *= arg0.getDimSize(i);
+ }
+ uint32_t b = 1;
+ for (int i = reductionDimension + 1; i < arg0.getRank(); ++i) {
+ b *= arg0.getDimSize(i);
+ }
+ uint32_t c = b;
+
+ auto specializationInfoDef =
+ std::make_unique<iree::VkSpecializationInfoDefT>();
+ addSpecializationMapEntry(/*kOperationId*/ 100, operationId.getValue(),
+ specializationInfoDef.get());
+ addSpecializationMapEntry(/*kA*/ 101, a, specializationInfoDef.get());
+ addSpecializationMapEntry(/*kB*/ 102, b, specializationInfoDef.get());
+ addSpecializationMapEntry(/*kC*/ 103, c, specializationInfoDef.get());
+ addSpecializationMapEntry(/*kR*/ 104, r, specializationInfoDef.get());
+ out_def->specialization_info = std::move(specializationInfoDef);
+
+ return success();
+}
+
+// Builds a SPIR-V executable from a well-known matmul executable.
+// |out_def| will be populated with all required information for serialization.
+LogicalResult buildMatMulExecutable(IREE::ExecutableOp executableOp,
+ FuncOp entryFuncOp, xla_hlo::DotOp dotOp,
+ iree::SpirVExecutableDefT *out_def) {
+ auto arg0 = dotOp.getOperand(0)->getType().cast<ShapedType>();
+ auto arg1 = dotOp.getOperand(1)->getType().cast<ShapedType>();
+
+ out_def->tag = "__matmul__";
+ out_def->entry_points = {"main"};
+
+ // TODO(benvanik): specialize (template on shapes/types/etc).
+ out_def->code = readEmbeddedKernelCode("matmul.spv");
+
+ // arg0, arg1, ret0
+ auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>();
+ pipelineLayoutDef->buffer_binding_set = 0;
+ auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>();
+ addDescriptorSetLayoutBinding(0, dsl.get());
+ addDescriptorSetLayoutBinding(1, dsl.get());
+ addDescriptorSetLayoutBinding(2, dsl.get());
+ pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl));
+ out_def->pipeline_layout = std::move(pipelineLayoutDef);
+
+ // Shapes of [arg0, arg1, ret0].
+ // arg0 = [b0, m, k]
+ // arg1 = [b0, k, n]
+ // ret0 = [b0, m, n]
+ // Note that we handle both batched (rank 3) and unbatched (rank 2).
+ uint32_t m = arg0.getRank() == 3 ? arg0.getDimSize(1) : arg0.getDimSize(0);
+ uint32_t k = arg0.getRank() == 3 ? arg0.getDimSize(2) : arg0.getDimSize(1);
+ uint32_t n = arg1.getRank() == 3 ? arg1.getDimSize(2) : arg1.getDimSize(1);
+ auto specializationInfoDef =
+ std::make_unique<iree::VkSpecializationInfoDefT>();
+ addSpecializationMapEntry(/*kMatrixM*/ 100, m, specializationInfoDef.get());
+ addSpecializationMapEntry(/*kMatrixK*/ 101, k, specializationInfoDef.get());
+ addSpecializationMapEntry(/*kMatrixN*/ 102, n, specializationInfoDef.get());
+ out_def->specialization_info = std::move(specializationInfoDef);
+
+ return success();
+}
+
+} // namespace
+
+bool tryEmbeddedKernelRewrite(IREE::ExecutableOp executableOp,
+ iree::SpirVExecutableDefT *out_def) {
+ auto module = executableOp.getInnerModule();
+ for (auto funcOp : module.getOps<FuncOp>()) {
+ if (funcOp.getAttr("iree.executable.reduction")) {
+ if (failed(buildReductionExecutable(executableOp, funcOp, out_def))) {
+ executableOp.emitOpError() << "Failed to splat in the reduction kernel";
+ return false;
+ }
+ return true;
+ }
+
+ for (auto &block : funcOp) {
+ for (auto &op : block) {
+ if (isa<xla_hlo::ConvOp>(&op)) {
+ executableOp.emitOpError() << "Conv not yet implemented";
+ return false;
+ } else if (auto dotOp = dyn_cast_or_null<xla_hlo::DotOp>(&op)) {
+ if (failed(buildMatMulExecutable(executableOp, funcOp, dotOp,
+ out_def))) {
+ executableOp.emitOpError()
+ << "Failed to splat in the matmul kernel";
+ return false;
+ }
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Translation/SPIRV/EmbeddedKernels.h b/compiler/Translation/SPIRV/EmbeddedKernels.h
new file mode 100644
index 0000000..3535944
--- /dev/null
+++ b/compiler/Translation/SPIRV/EmbeddedKernels.h
@@ -0,0 +1,35 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_
+#define IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_
+
+#include "compiler/IR/StructureOps.h"
+#include "flatbuffers/flatbuffers.h"
+#include "mlir/Support/LogicalResult.h"
+#include "schemas/spirv_executable_def_generated.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Tries to match the |executableOp| against an embedded kernel and if matched
+// will populate |out_def| with the kernel.
+// Returns true if the kernel matched and was populated.
+bool tryEmbeddedKernelRewrite(IREE::ExecutableOp executableOp,
+ iree::SpirVExecutableDefT* out_def);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_
diff --git a/compiler/Translation/SPIRV/IREEIndexComputation.cpp b/compiler/Translation/SPIRV/IREEIndexComputation.cpp
new file mode 100644
index 0000000..b05de91
--- /dev/null
+++ b/compiler/Translation/SPIRV/IREEIndexComputation.cpp
@@ -0,0 +1,107 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- IREEIndexComputation.cpp --------------------------------*- C++//-*-===//
+//
+// Implementaiton of Index Propagation for IREE statements that are used in
+// dispatch functions.
+//
+//===----------------------------------------------------------------------===//
+#include "compiler/Translation/SPIRV/IREEIndexComputation.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+//===----------------------------------------------------------------------===//
+// IREELoadInputOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult IREELoadIndexPropagation::propagateIndexMap(
+ Operation *operation, IndexComputationCache &indexMap) const {
+ auto loadOp = cast<IREE::LoadInputOp>(operation);
+ auto result = operation->getResult(0);
+ auto src = loadOp.src();
+ auto resultType = result->getType().dyn_cast<RankedTensorType>();
+ auto srcType = src->getType().dyn_cast<MemRefType>();
+ if (!resultType || !srcType || resultType.getShape() != srcType.getShape()) {
+ return loadOp.emitError(
+ "mismatch in shape of the result tensor and source memref");
+ }
+ // Initialize the storage for the src.
+ indexMap[src];
+ for (auto &resultIndexMap : indexMap[operation->getResult(0)]) {
+ indexMap[src][resultIndexMap.first];
+ resultIndexMap.second.push_back(resultIndexMap.first);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// IREEStoreOutputOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult IREEStoreIndexPropagation::propagateIndexMap(
+ Operation *operation, IndexComputationCache &indexMap) const {
+ auto storeOp = cast<IREE::StoreOutputOp>(operation);
+ auto src = storeOp.src();
+ auto srcType = src->getType().dyn_cast<ShapedType>();
+ if (!srcType || !srcType.hasStaticShape()) {
+ return storeOp.emitError(
+ "can only handle store with src being tensor of static shape");
+ }
+
+ SmallVector<int64_t, 3> launchSize;
+ if (failed(getLaunchSize(operation, launchSize))) {
+ return failure();
+ }
+
+ // The launch dimensions are [x, y, z] co-ordinates. The reverse of this is
+ // used to determine the location of the tensor element computed by a
+ // workitem. The choice is failry arbitrary but is done to enable the common
+ // case where consecutive workitems compute "logically" adjacent tensor
+ // elements.
+ Builder builder(storeOp.getContext());
+ SmallVector<AffineExpr, 4> affineExprs;
+ int64_t numElements = 1;
+ for (size_t i = launchSize.size(); i > 0; --i) {
+ // If launchSize along any dimension is 1, just use 0 for the index. This is
+ // not just an optimization. If you have an output of type memref<f32> which
+ // is lowered to !spv.ptr<!spv.struct<f32>, StorageBuffer> with launchSize
+ // <1>, then spv.AccessChain requires the indices to be a constant.
+ if (launchSize[i - 1] == 1) {
+ affineExprs.push_back(builder.getAffineConstantExpr(0));
+ } else {
+ affineExprs.push_back(builder.getAffineDimExpr(i - 1));
+ }
+ numElements *= launchSize[i - 1];
+ }
+ auto launchMap = AffineMap::get(launchSize.size(), 0, affineExprs);
+
+ // The stored tensor can be a reshape of the launch dimension. It still
+ // retains the requirement that each workitem is computing a single element
+ // of the stored tensor.
+ AffineMap srcMap;
+ SmallVector<int64_t, 3> revLaunchSize(reverse(launchSize));
+ if (numElements != srcType.getNumElements() ||
+ failed(getReshapeOperandMap(builder, launchMap, revLaunchSize,
+ srcType.getShape(), srcMap))) {
+ return storeOp.emitError(
+ "unable to map from launch id to element to compute within a "
+ "workitem");
+ }
+ indexMap[src][srcMap];
+ return success();
+}
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Translation/SPIRV/IREEIndexComputation.h b/compiler/Translation/SPIRV/IREEIndexComputation.h
new file mode 100644
index 0000000..55def25
--- /dev/null
+++ b/compiler/Translation/SPIRV/IREEIndexComputation.h
@@ -0,0 +1,92 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- IREEIndexComputation.h ----------------------------------*- C++//-*-===//
+//
+// Index Propagation for IREE statements that are used in dispatch functions.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_TRANSLATION_SPIRV_H
+#define IREE_COMPILER_TRANSLATION_SPIRV_H
+
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Translation/SPIRV/XLAIndexPropagation.h"
+#include "mlir/IR/Function.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Gets the launch size associated with the dispatch function that this op is
+/// part of.
+inline LogicalResult getLaunchSize(Operation *op,
+ SmallVectorImpl<int64_t> &launchSize) {
+ auto funcOp = op->getParentOfType<FuncOp>();
+ if (!funcOp || !funcOp.getAttr("iree.executable.export")) {
+ return op->emitError(
+ "expected operation to be in dispatch function to get launch size");
+ }
+ auto workloadAttr =
+ funcOp.getAttrOfType<DenseElementsAttr>("iree.executable.workload");
+ if (!workloadAttr) {
+ op->emitError(
+ "unable to find workload size, missing attribute "
+ "iree.executable.workload in dispatch function");
+ }
+ launchSize.clear();
+ for (auto value : workloadAttr.getValues<APInt>()) {
+ launchSize.push_back(value.getSExtValue());
+ }
+ // Drop trailing ones.
+ auto dropFrom = launchSize.size() - 1;
+ while (dropFrom > 0 && launchSize[dropFrom] == 1) {
+ --dropFrom;
+ }
+ if (dropFrom > 0) {
+ launchSize.erase(std::next(launchSize.begin(), dropFrom + 1),
+ launchSize.end());
+ }
+ return success();
+}
+
+/// Index propagation for iree.load_input operation. This operation is
+/// essentially a copy from a memref to a tensor. So just copy the index map to
+/// the memref operand from the result tensor.
+class IREELoadIndexPropagation final
+ : public IndexPropagationOp<IREE::LoadInputOp> {
+ public:
+ using IndexPropagationOp<IREE::LoadInputOp>::IndexPropagationOp;
+
+ LogicalResult propagateIndexMap(
+ Operation *operation, IndexComputationCache &indexMap) const override;
+};
+
+/// Index propagation for iree.store_output operation. The launch size is
+/// assumed to match the shape of the tensor that is being stored. This
+/// operation acts as a seed for the index propogation. Each workitem is assumed
+/// to compute a single element of this tensor. The range of the index map is
+/// the reverse of the launch dimension.
+class IREEStoreIndexPropagation final
+ : public IndexPropagationOp<IREE::StoreOutputOp> {
+ public:
+ using IndexPropagationOp<IREE::StoreOutputOp>::IndexPropagationOp;
+
+ LogicalResult propagateIndexMap(
+ Operation *operation, IndexComputationCache &indexMap) const override;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSLATION_SPIRV_H
diff --git a/compiler/Translation/SPIRV/IREEToSPIRV.cpp b/compiler/Translation/SPIRV/IREEToSPIRV.cpp
new file mode 100644
index 0000000..97721ae
--- /dev/null
+++ b/compiler/Translation/SPIRV/IREEToSPIRV.cpp
@@ -0,0 +1,77 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- IREEToSPIRV.cpp -----------------------------------------*- C++//-*-===//
+//
+// Translation of IREE statements in dispatch functions to SPIR-V.
+//
+//===----------------------------------------------------------------------===//
+#include "compiler/Translation/SPIRV/IREEToSPIRV.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// IREE::LoadInputOp is essentially a memcpy. Just update the `valueCache` with
+/// the value of the operand.
+LogicalResult IREELoadOpSPIRVLowering::lowerOperation(
+ Operation *op, OpBuilder &builder, AffineMap index,
+ ArrayRef<Value *> operands, ValueCache &valueCache) const {
+ auto loadOp = cast<IREE::LoadInputOp>(op);
+ auto result = loadOp.getResult();
+ valueCache.setOperandDstValue(result, index, operands[0]);
+ return success();
+}
+
+/// IREE::StoreOp needs to write to the spv.globalVariable created for the
+/// memref that holds the result of the dispatch function.
+LogicalResult IREEStoreOpSPIRVLowering::lowerOperation(
+ Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
+ ValueCache &valueCache,
+ DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
+ ArrayRef<spirv::GlobalVariableOp> outputBuffers) const {
+ auto storeOp = cast<IREE::StoreOutputOp>(op);
+ auto src = storeOp.src();
+ auto indices = affineExprCodegen.getIndices(src);
+ if (indices.size() != 1) {
+ return storeOp.emitError(
+ "expected to compute a single element of the tensor that is stored "
+ "into the output memref");
+ }
+ auto var = inputBuffers.lookup(storeOp.dst());
+ if (!var) {
+ return storeOp.emitError(
+ "unable to find spv.globalVariable that corresponds to the dst memref");
+ }
+ auto ptr = genPointerOffset(builder, storeOp.getLoc(), affineExprCodegen,
+ indices[0], var);
+ auto scalarValue = valueCache.getOperandDstValue(src, indices[0]);
+ builder.create<spirv::StoreOp>(storeOp.getLoc(), ptr, scalarValue,
+ /*memory_access = */ nullptr,
+ /*alignment = */ nullptr);
+ return success();
+}
+
+/// IREE::ReturnOp in dispatch functions lowered to SPIR-V should have no
+/// operands.
+LogicalResult IREEReturnOpSPIRVLowering::lowerOperation(
+ Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
+ ValueCache &valueCache,
+ DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
+ ArrayRef<spirv::GlobalVariableOp> outputBuffers) const {
+ builder.create<spirv::ReturnOp>(op->getLoc());
+ return success();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Translation/SPIRV/IREEToSPIRV.h b/compiler/Translation/SPIRV/IREEToSPIRV.h
new file mode 100644
index 0000000..ada74b7
--- /dev/null
+++ b/compiler/Translation/SPIRV/IREEToSPIRV.h
@@ -0,0 +1,69 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- IREEToSPIRV.h -------------------------------------------*- C++//-*-===//
+//
+// Translation of IREE statements in dispatch functions to SPIR-V.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H
+#define IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H
+
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Translation/SPIRV/SPIRVLowering.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Translation of iree.load_input operation.
+class IREELoadOpSPIRVLowering final
+ : public SPIRVOpLowering<IREE::LoadInputOp> {
+ public:
+ using SPIRVOpLowering<IREE::LoadInputOp>::SPIRVOpLowering;
+
+ LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
+ AffineMap index, ArrayRef<Value *> operands,
+ ValueCache &valueCache) const override;
+};
+
+/// Translation of iree.return operation.
+class IREEReturnOpSPIRVLowering final : public SPIRVOpLowering<IREE::ReturnOp> {
+ public:
+ using SPIRVOpLowering<IREE::ReturnOp>::SPIRVOpLowering;
+
+ LogicalResult lowerOperation(
+ Operation *op, OpBuilder &builder, AffineExprCodegen &codegen,
+ ValueCache &valueCache,
+ DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
+ ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override;
+};
+
+/// Translation of iree.store_output operation.
+class IREEStoreOpSPIRVLowering final
+ : public SPIRVOpLowering<IREE::StoreOutputOp> {
+ public:
+ using SPIRVOpLowering<IREE::StoreOutputOp>::SPIRVOpLowering;
+
+ LogicalResult lowerOperation(
+ Operation *op, OpBuilder &builder, AffineExprCodegen &codegen,
+ ValueCache &valueCache,
+ DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
+ ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H
diff --git a/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp b/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp
new file mode 100644
index 0000000..5672738
--- /dev/null
+++ b/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp
@@ -0,0 +1,196 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- IREEToSPIRVPass.cpp -------------------------------------*- C++//-*-===//
+//
+// Pass to translate iree executables for vulkan-spirv.
+//
+//===----------------------------------------------------------------------===//
+#include "compiler/Translation/SPIRV/IREEToSPIRVPass.h"
+
+#include "compiler/Translation/SPIRV/IREEIndexComputation.h"
+#include "compiler/Translation/SPIRV/IREEToSPIRV.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+class IREEToSPIRVPass : public ModulePass<IREEToSPIRVPass> {
+ void runOnModule() override;
+};
+
+} // namespace
+
+void IREEToSPIRVPass::runOnModule() {
+ auto module = getModule();
+ OpBuilder builder(module.getBodyRegion());
+
+ // Initialize the index computation.
+ IndexPropagationList<IndexPropagationOp<ConstantOp>,
+ // IREE-specific ops:
+ IndexPropagationOp<IREE::ReturnOp>,
+ IREELoadIndexPropagation, IREEStoreIndexPropagation,
+ // Standard dialect unary elementwise ops:
+ NoBroadcastPwOpIndexPropagation<SIToFPOp>,
+ NoBroadcastPwOpIndexPropagation<SignExtendIOp>,
+ // Standard dialect binary elementwise ops:
+ NoBroadcastPwOpIndexPropagation<AddFOp>,
+ NoBroadcastPwOpIndexPropagation<AddIOp>,
+ NoBroadcastPwOpIndexPropagation<AndOp>,
+ NoBroadcastPwOpIndexPropagation<CmpFOp>,
+ NoBroadcastPwOpIndexPropagation<CmpIOp>,
+ NoBroadcastPwOpIndexPropagation<DivFOp>,
+ NoBroadcastPwOpIndexPropagation<DivISOp>,
+ NoBroadcastPwOpIndexPropagation<DivIUOp>,
+ NoBroadcastPwOpIndexPropagation<MulFOp>,
+ NoBroadcastPwOpIndexPropagation<MulIOp>,
+ NoBroadcastPwOpIndexPropagation<OrOp>,
+ NoBroadcastPwOpIndexPropagation<RemFOp>,
+ NoBroadcastPwOpIndexPropagation<RemISOp>,
+ NoBroadcastPwOpIndexPropagation<RemIUOp>,
+ NoBroadcastPwOpIndexPropagation<SubFOp>,
+ NoBroadcastPwOpIndexPropagation<SubFOp>,
+ NoBroadcastPwOpIndexPropagation<SubIOp>,
+ NoBroadcastPwOpIndexPropagation<TruncateIOp>,
+ NoBroadcastPwOpIndexPropagation<XOrOp>,
+ NoBroadcastPwOpIndexPropagation<ZeroExtendIOp>,
+ // XLA unary elementwise ops:
+ NoBroadcastPwOpIndexPropagation<xla_hlo::AbsOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::CeilOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::ConvertOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::CosOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::ExpOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::FloorOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::LogOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::NegOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::RsqrtOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::SignOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::TanhOp>,
+ // XLA binary elementwise ops:
+ NoBroadcastPwOpIndexPropagation<xla_hlo::AddOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::AndOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::DivOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::MaxOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::MinOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::MulOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::SubOp>,
+ // XLA other ops:
+ // TODO(ravishankarm): conv, dot.
+ // TODO(ravishankarm): gather.
+ // TODO(ravishankarm): pad.
+ // TODO(ravishankarm): slice.
+ NoBroadcastPwOpIndexPropagation<xla_hlo::CopyOp>,
+ ReshapeOpIndexPropagation<xla_hlo::ReshapeOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::SelectOp>,
+ XLABroadcastOpIndexPropagation,
+ XLABroadcastInDimOpIndexPropagation,
+ XLAReverseOpIndexPropagation,
+ XLATransposeOpIndexPropagation>
+ indexPropagation;
+
+ // Initialize the spir-v codegenerator.
+ SPIRVCodegen<
+ ConstantOpSPIRVLowering,
+ // IREE-specific ops:
+ IREELoadOpSPIRVLowering, IREEReturnOpSPIRVLowering,
+ IREEStoreOpSPIRVLowering,
+ // Standard dialect unary elementwise ops:
+ // Standard dialect binary elementwise ops:
+ SPIRVPwOpLowering<AddFOp, spirv::FAddOp>,
+ SPIRVPwOpLowering<DivFOp, spirv::FDivOp>,
+ SPIRVPwOpLowering<MulFOp, spirv::FMulOp>,
+ SPIRVPwOpLowering<SubFOp, spirv::FSubOp>,
+ SPIRVPwOpLowering<AddIOp, spirv::IAddOp>,
+ SPIRVPwOpLowering<DivISOp, spirv::SDivOp>,
+ SPIRVPwOpLowering<MulIOp, spirv::IMulOp>,
+ SPIRVPwOpLowering<SubIOp, spirv::ISubOp>,
+ // XLA unary elementwise ops:
+ SPIRVPwOpLowering<xla_hlo::AbsOp, spirv::GLSLSAbsOp, spirv::GLSLFAbsOp>,
+ SPIRVPwOpLowering<xla_hlo::CeilOp, spirv::GLSLCeilOp>,
+ // TODO(ravishankarm): xla_hlo::ConvertOp
+ SPIRVPwOpLowering<xla_hlo::CosOp, spirv::GLSLCosOp>,
+ SPIRVPwOpLowering<xla_hlo::ExpOp, spirv::GLSLExpOp>,
+ SPIRVPwOpLowering<xla_hlo::FloorOp, spirv::GLSLFloorOp>,
+ SPIRVPwOpLowering<xla_hlo::LogOp, spirv::GLSLLogOp>,
+ SPIRVPwOpLowering<xla_hlo::NegOp, spirv::FNegateOp>,
+ SPIRVPwOpLowering<xla_hlo::RsqrtOp, spirv::GLSLInverseSqrtOp>,
+ SPIRVPwOpLowering<xla_hlo::SignOp, spirv::GLSLSSignOp,
+ spirv::GLSLFSignOp>,
+ SPIRVPwOpLowering<xla_hlo::TanhOp, spirv::GLSLTanhOp>,
+ // XLA binary elementwise ops:
+ SPIRVPwOpLowering<xla_hlo::AddOp, spirv::IAddOp, spirv::FAddOp>,
+ SPIRVPwOpLowering<xla_hlo::AndOp, spirv::LogicalAndOp>,
+ SPIRVPwOpLowering<xla_hlo::DivOp, spirv::FDivOp>,
+ SPIRVPwOpLowering<xla_hlo::MaxOp, spirv::GLSLSMaxOp, spirv::GLSLFMaxOp>,
+ SPIRVPwOpLowering<xla_hlo::MinOp, spirv::GLSLSMinOp, spirv::GLSLFMinOp>,
+ SPIRVPwOpLowering<xla_hlo::MulOp, spirv::IMulOp, spirv::FMulOp>,
+ SPIRVPwOpLowering<xla_hlo::SubOp, spirv::ISubOp, spirv::FSubOp>,
+ // XLA other ops:
+ CmpFOpSPIRVLowering,
+ SPIRVPwOpLowering<xla_hlo::SelectOp, spirv::SelectOp>,
+ SPIRVIndexOpLowering<xla_hlo::BroadcastOp>,
+ SPIRVIndexOpLowering<xla_hlo::BroadcastInDimOp>,
+ SPIRVIndexOpLowering<xla_hlo::CopyOp>,
+ SPIRVIndexOpLowering<xla_hlo::ReshapeOp>,
+ SPIRVIndexOpLowering<xla_hlo::ReverseOp>,
+ SPIRVIndexOpLowering<xla_hlo::TransposeOp>>
+ spirvCodegen;
+
+ // Create a spirv.module Op.
+ auto spvModule = builder.create<spirv::ModuleOp>(
+ module.getLoc(),
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(spirv::AddressingModel::Logical)),
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(spirv::MemoryModel::GLSL450)));
+ SmallVector<StringRef, 2> caps;
+ caps.push_back(spirv::stringifyCapability(spirv::Capability::Shader));
+ spvModule.setAttr("capabilities", builder.getStrArrayAttr(caps));
+ SmallVector<StringRef, 2> exts;
+ exts.push_back("SPV_KHR_storage_buffer_storage_class");
+ spvModule.setAttr("extensions", builder.getStrArrayAttr(exts));
+
+ for (auto funcOp : module.getOps<FuncOp>()) {
+ // TODO(ravishankarm): FuncOps in executable that are not dispatch functions
+ // are not lowered to SPIR-V. Fix this limitation.
+ if (!funcOp.getAttr("iree.executable.export")) continue;
+
+ IndexComputationCache indexMap;
+ if (failed(indexPropagation.propagate(funcOp.getBody(), indexMap))) {
+ return signalPassFailure();
+ }
+ // dumpIndexCache(indexMap);
+
+ ValueCache valueCache;
+ AffineExprCodegen affineExprCodegen(spvModule, indexMap);
+ if (failed(spirvCodegen.codegen(spvModule, funcOp, affineExprCodegen,
+ valueCache))) {
+ return signalPassFailure();
+ }
+ }
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>> createIREEToSPIRVPass() {
+ return std::make_unique<IREEToSPIRVPass>();
+}
+static PassRegistration<IREEToSPIRVPass> pass(
+ "convert-iree-to-spirv",
+ "Convert IREE dispatch functions to SPIR-V dialect");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h b/compiler/Translation/SPIRV/IREEToSPIRVPass.h
similarity index 100%
rename from iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h
rename to compiler/Translation/SPIRV/IREEToSPIRVPass.h
diff --git a/compiler/Translation/SPIRV/IndexComputation.cpp b/compiler/Translation/SPIRV/IndexComputation.cpp
new file mode 100644
index 0000000..877fab0
--- /dev/null
+++ b/compiler/Translation/SPIRV/IndexComputation.cpp
@@ -0,0 +1,269 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- IndexComputation.cpp ------------------------------------*- C++//-*-===//
+//
+// For an IREE dispatch function, compute the map from workitem ID to index of
+// tensor computed within that workitem.
+//
+//===----------------------------------------------------------------------===//
+#include "compiler/Translation/SPIRV/IndexComputation.h"
+
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/raw_ostream.h"
+
+static llvm::cl::opt<bool> doAffineExprSimplify(
+ "simplify-spirv-affine-exprs",
+ llvm::cl::desc("Simplify affine expressions during code-generation."),
+ llvm::cl::init(true));
+
+namespace mlir {
+namespace iree_compiler {
+
+//===----------------------------------------------------------------------===//
+// Reshape Utility Functions
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Handles shapes for scalars. Shape of scalars are represented as empty vetor,
+/// i.e. {}. Its easier to do index propogation to handle the scalar as vector
+/// of size 1.
+inline SmallVector<int64_t, 4> handleIfScalar(ArrayRef<int64_t> shape) {
+ SmallVector<int64_t, 4> resultShape;
+ if (shape.empty()) {
+ return {1};
+ }
+ return SmallVector<int64_t, 4>(shape.begin(), shape.end());
+}
+
+/// Reshapes are often used to either add a dimension of size 1 or remove a
+/// dimension of size 1. Recognizing such cases can make the code-generation
+/// easier. The AffineMap needs to either add a constant 0 in the range for such
+/// added dimensions or drop those dimensions.
+inline LogicalResult getAffineExprForAddOrRemoveDimension(
+ Builder &builder, ArrayRef<AffineExpr> resultExprs,
+ ArrayRef<int64_t> resultShape, ArrayRef<int64_t> operandShape,
+ SmallVectorImpl<AffineExpr> &operandExprs) {
+ auto resultIndex = resultShape.size();
+ auto operandIndex = operandShape.size();
+ operandExprs.resize(operandShape.size());
+ // Try to match up the dimensions of the operand and result by ignoring any
+ // dimensions of size of 1 that are introduced.
+ while (resultIndex > 0 && operandIndex > 0) {
+ if (resultShape[resultIndex - 1] == -1 ||
+ operandShape[operandIndex - 1] == -1) {
+ return failure();
+ }
+ if (resultShape[resultIndex - 1] == operandShape[operandIndex - 1]) {
+ operandExprs[operandIndex - 1] = resultExprs[resultIndex - 1];
+ resultIndex--;
+ operandIndex--;
+ continue;
+ }
+ if (resultShape[resultIndex - 1] == 1) {
+ // This is a dimension that is added on the operand. This affine
+ // expression corresponding to this dimension is dropped.
+ resultIndex--;
+ continue;
+ }
+ if (operandShape[operandIndex - 1] == 1) {
+ // This is a dimension of size 1 of the operand that is dropped. Add a
+ // constant expr 0.
+ operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0);
+ operandIndex--;
+ continue;
+ }
+ return failure();
+ }
+ // Any remaining dimensions should be 1.
+ while (resultIndex > 0) {
+ if (resultShape[resultIndex - 1] != 1) {
+ return failure();
+ }
+ resultIndex--;
+ }
+ while (operandIndex > 0) {
+ if (operandShape[operandIndex - 1] != 1) {
+ return failure();
+ }
+ // This is a dimension of size 1 that is dropped. Add a constant expression
+ // 0.
+ operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0);
+ operandIndex--;
+ }
+ return success();
+}
+
+/// Constructs the strides of an array assuming a row-major packed layout.
+// TODO(ravishankarm): This assumes the shape are static. When using dynamic
+// shapes, parameters of each dimension can be used to construct AffineExpr for
+// strides along each dimension. Note that multiplying two symbolic constants is
+// technically not affine, but you could use another symbol to represent the
+// product, so it should be still representable as affine exprs.
+inline LogicalResult getRowMajorPackedStrides(
+ Builder &builder, ArrayRef<int64_t> shape,
+ SmallVectorImpl<AffineExpr> &strides) {
+ strides.resize(shape.size());
+ int64_t stride = 1;
+ for (auto dim : enumerate(reverse(shape))) {
+ if (dim.value() < 0) {
+ // TODO(ravishankarm) : Better error message.
+ return failure();
+ }
+ strides[shape.size() - 1 - dim.index()] =
+ builder.getAffineConstantExpr(stride);
+ stride *= dim.value();
+ }
+ return success();
+}
+
+/// Linearizes the index of the result position accessed using the shape of the
+/// result tensor and delinearizes it to get the position of the operand.
+inline LogicalResult getAffineExprForReshape(
+ Builder &builder, unsigned numDims, unsigned numSymbols,
+ ArrayRef<AffineExpr> resultExprs, ArrayRef<int64_t> resultShape,
+ ArrayRef<int64_t> operandShape, SmallVectorImpl<AffineExpr> &operandExprs) {
+ // To linearize the index, assume that the memory is laid out in
+ // packed-row-major layout based on the shape.
+ // TODO(ravishankarm) : When there is stride information, use that to map from
+ // index to memory location.
+ SmallVector<AffineExpr, 4> resultStrides;
+ if (failed(getRowMajorPackedStrides(builder, resultShape, resultStrides))) {
+ return failure();
+ }
+ AffineExpr linearizedExpr;
+ for (auto index : enumerate(resultExprs)) {
+ auto val = getAffineBinaryOpExpr(AffineExprKind::Mul, index.value(),
+ resultStrides[index.index()]);
+ if (doAffineExprSimplify) {
+ val = simplifyAffineExpr(val, numDims, numSymbols);
+ }
+ linearizedExpr = (index.index() ? getAffineBinaryOpExpr(AffineExprKind::Add,
+ linearizedExpr, val)
+ : val);
+ if (doAffineExprSimplify) {
+ linearizedExpr = simplifyAffineExpr(val, numDims, numSymbols);
+ }
+ }
+
+ // Unlinearize the index, assuming row-major-packed layout.
+ // TODO(ravishankarm) : When there is stride information, use that to map from
+ // memory location to index.
+ SmallVector<AffineExpr, 4> operandStrides;
+ if (failed(getRowMajorPackedStrides(builder, operandShape, operandStrides))) {
+ return failure();
+ }
+ operandExprs.resize(operandStrides.size());
+ for (auto stride : enumerate(operandStrides)) {
+ if (stride.index() == operandStrides.size() - 1) {
+ operandExprs[stride.index()] = linearizedExpr;
+ break;
+ }
+ auto expr = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, linearizedExpr,
+ stride.value());
+ operandExprs[stride.index()] =
+ (doAffineExprSimplify ? simplifyAffineExpr(expr, numDims, numSymbols)
+ : expr);
+
+ linearizedExpr = getAffineBinaryOpExpr(AffineExprKind::Mod, linearizedExpr,
+ stride.value());
+ if (doAffineExprSimplify) {
+ linearizedExpr = simplifyAffineExpr(linearizedExpr, numDims, numSymbols);
+ }
+ }
+ return success();
+}
+} // namespace
+
+LogicalResult getReshapeOperandMap(Builder &builder, AffineMap resultIndexMap,
+ ArrayRef<int64_t> resultShapeRef,
+ ArrayRef<int64_t> operandShapeRef,
+ AffineMap &operandIndexMap) {
+ auto resultShape = handleIfScalar(resultShapeRef);
+ auto operandShape = handleIfScalar(operandShapeRef);
+ auto resultExprs = resultIndexMap.getResults();
+ assert(resultShape.size() == resultExprs.size() &&
+ "Ranks of the Domain of index map and result must be the same");
+ SmallVector<AffineExpr, 4> operandExprs;
+ if (failed(getAffineExprForAddOrRemoveDimension(
+ builder, resultExprs, resultShape, operandShape, operandExprs)) &&
+ failed(getAffineExprForReshape(
+ builder, resultIndexMap.getNumDims(), resultIndexMap.getNumSymbols(),
+ resultExprs, resultShape, operandShape, operandExprs))) {
+ return failure();
+ }
+ assert(operandExprs.size() == operandShape.size() &&
+ "expected as many exprs for the operand as the rank of the operand");
+ operandIndexMap =
+ AffineMap::get(resultIndexMap.getNumDims(),
+ resultIndexMap.getNumSymbols(), operandExprs);
+
+ return success();
+}
+
+LogicalResult IndexPropagation::propagateIndexMap(
+ Operation *op, IndexComputationCache &indexMap) const {
+ if (op->getNumResults() == 0) {
+ // Nothing to do for this op.
+ return success();
+ }
+ if (op->getNumResults() != 1) {
+ return op->emitError(
+ "default index propagation handles case with a single-return value");
+ }
+ // Initialize the storage for all the operands.
+ for (auto arg : op->getOperands()) {
+ indexMap[arg];
+ }
+ for (auto &resultIndexMap : indexMap[op->getResult(0)]) {
+ SmallVector<AffineMap, 4> operandIndices;
+ if (failed(this->propagateIndexMap(op, resultIndexMap.first,
+ operandIndices))) {
+ return failure();
+ }
+ assert(operandIndices.size() == op->getNumOperands() &&
+ "Expected as many indices as operands");
+ for (auto arg : enumerate(op->getOperands())) {
+ indexMap[arg.value()][operandIndices[arg.index()]];
+ resultIndexMap.second.push_back(operandIndices[arg.index()]);
+ }
+ }
+ return success();
+}
+
+void dumpIndexCache(IndexComputationCache &indexMap) {
+ for (auto &el : indexMap) {
+ // llvm::errs() << "Value : " << *(el.first);
+ // llvm::errs().flush();
+ if (isa<OpResult>(el.first)) {
+ llvm::errs() << "Operation : " << el.first->getDefiningOp()->getName();
+ } else if (isa<BlockArgument>(el.first)) {
+ llvm::errs() << "BlockArgument";
+ }
+ for (auto &used : el.second) {
+ llvm::errs() << "\n\t" << used.first << " : [";
+ std::string sep = "";
+ for (auto &operand : used.second) {
+ llvm::errs() << sep << operand;
+ sep = ", ";
+ }
+ llvm::errs() << "]";
+ }
+ llvm::errs() << "\n";
+ }
+ llvm::errs() << "\n";
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation.h b/compiler/Translation/SPIRV/IndexComputation.h
similarity index 100%
rename from iree/compiler/Translation/SPIRV/IndexComputation.h
rename to compiler/Translation/SPIRV/IndexComputation.h
diff --git a/iree/compiler/Translation/SPIRV/Kernels/BUILD b/compiler/Translation/SPIRV/Kernels/BUILD
similarity index 100%
rename from iree/compiler/Translation/SPIRV/Kernels/BUILD
rename to compiler/Translation/SPIRV/Kernels/BUILD
diff --git a/iree/compiler/Translation/SPIRV/Kernels/CMakeLists.txt b/compiler/Translation/SPIRV/Kernels/CMakeLists.txt
similarity index 100%
rename from iree/compiler/Translation/SPIRV/Kernels/CMakeLists.txt
rename to compiler/Translation/SPIRV/Kernels/CMakeLists.txt
diff --git a/iree/compiler/Translation/SPIRV/Kernels/matmul.comp b/compiler/Translation/SPIRV/Kernels/matmul.comp
similarity index 100%
rename from iree/compiler/Translation/SPIRV/Kernels/matmul.comp
rename to compiler/Translation/SPIRV/Kernels/matmul.comp
diff --git a/iree/compiler/Translation/SPIRV/Kernels/reduce_untiled.comp b/compiler/Translation/SPIRV/Kernels/reduce_untiled.comp
similarity index 100%
rename from iree/compiler/Translation/SPIRV/Kernels/reduce_untiled.comp
rename to compiler/Translation/SPIRV/Kernels/reduce_untiled.comp
diff --git a/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl b/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl
new file mode 100644
index 0000000..f6e224d
--- /dev/null
+++ b/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl
@@ -0,0 +1,32 @@
+"""Utilities for handling hand-written SPIR-V files."""
+
+load("//:build_defs.google.bzl", "iree_glsl_vulkan")
+load("///build_tools/embed_data:build_defs.bzl", "cc_embed_data")
+
+def spirv_kernel_cc_library(name, srcs):
+ """Compiles GLSL files into SPIR-V binaries and embeds them in a cc_library.
+
+ Args:
+ name: cc_library name to depend on.
+ srcs: a list of GLSL source files.
+ """
+ spv_files = []
+ for src in srcs:
+ spv_name = src.split(".")[-2]
+ iree_glsl_vulkan(
+ name = spv_name,
+ srcs = [src],
+ )
+ spv_files.append(spv_name + ".spv")
+ native.filegroup(
+ name = name + "_files",
+ srcs = spv_files,
+ )
+ cc_embed_data(
+ name = name,
+ srcs = spv_files,
+ cc_file_output = name + ".cc",
+ h_file_output = name + ".h",
+ cpp_namespace = "mlir::iree_compiler::spirv_kernels",
+ flatten = True,
+ )
diff --git a/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp b/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
new file mode 100644
index 0000000..5ab78c7
--- /dev/null
+++ b/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
@@ -0,0 +1,314 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Translation/SPIRV/SPIRVExecutableTranslation.h"
+
+#include <cstdint>
+#include <iostream>
+#include <map>
+#include <vector>
+
+#include "flatbuffers/flatbuffers.h"
+#include "compiler/Translation/SPIRV/EmbeddedKernels.h"
+#include "compiler/Translation/SPIRV/IREEToSPIRVPass.h"
+#include "compiler/Utils/OpUtils.h"
+#include "compiler/Utils/TranslationUtils.h"
+#include "schemas/executable_def_generated.h"
+#include "schemas/spirv_executable_def_generated.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Serialization.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Translation.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+class SPIRVTranslator {
+ public:
+ explicit SPIRVTranslator(ExecutableTranslationOptions options)
+ : options_(options) {}
+
+ const ExecutableTranslationOptions &options() const { return options_; }
+
+ // Returns a populated ExecutableDef or nullptr if translation is
+ // unsuccessful.
+ std::unique_ptr<iree::ExecutableDefT> translateExecutable(
+ IREE::ExecutableOp executableOp);
+
+ private:
+ // Returns a list of entry point names matching the expected export ordinals.
+ std::vector<std::string> populateEntryPointNames(
+ IREE::ExecutableOp executableOp);
+
+ // Translates the input module into the SPIR-V dialect and returns the
+ // serialized code words or empty if translation failed.
+ std::vector<uint32_t> translateAndSerializeShaderModule(
+ IREE::ExecutableOp executableOp);
+
+ // Returns a pipeline layout definition based on the bindings required.
+ std::unique_ptr<iree::VkPipelineLayoutDefT> populatePipelineLayout(
+ spirv::ModuleOp spirvModuleOp);
+
+ ExecutableTranslationOptions options_;
+};
+
+std::unique_ptr<iree::ExecutableDefT> SPIRVTranslator::translateExecutable(
+ IREE::ExecutableOp executableOp) {
+ // Try first to match against an embedded kernel (such as matmul) and
+ // otherwise fall back to generating the kernel.
+ iree::SpirVExecutableDefT spirvExecutableDef;
+ if (!tryEmbeddedKernelRewrite(executableOp, &spirvExecutableDef)) {
+ // The sequencer and runtime use ordinals instead of names. We provide the
+ // list of entry point names here that are then passed in
+ // VkShaderModuleCreateInfo.
+ spirvExecutableDef.entry_points = populateEntryPointNames(executableOp);
+
+ // Translate the module and generate the SPIR-V code.
+ // The module is expected to be modified and must contain the metadata
+ // required to enable the following information needed for the
+ // SpirVExecutableDef to be extracted.
+ spirvExecutableDef.code = translateAndSerializeShaderModule(executableOp);
+ if (spirvExecutableDef.code.empty()) {
+ executableOp.emitError()
+ << "Failed to translate and serialize SPIR-V executable";
+ return {};
+ }
+
+ // Reflect against the entry thunk to identify the required pipeline
+ // layout based on binding information. This is used by the runtime to
+ // create the VkPipelineLayout.
+ for (auto spirvModuleOp :
+ executableOp.getBlock().getOps<spirv::ModuleOp>()) {
+ spirvExecutableDef.pipeline_layout =
+ populatePipelineLayout(spirvModuleOp);
+ if (!spirvExecutableDef.pipeline_layout) {
+ spirvModuleOp.emitError()
+ << "Failed to generate pipeline for SPIR-V module";
+ return {};
+ }
+ break;
+ }
+ }
+
+ // Pack the executable definition and get the bytes with the proper header.
+ // The header is used to verify the contents at runtime.
+ ::flatbuffers::FlatBufferBuilder fbb;
+ auto executableOffset =
+ iree::SpirVExecutableDef::Pack(fbb, &spirvExecutableDef);
+ iree::FinishSpirVExecutableDefBuffer(fbb, executableOffset);
+ std::vector<uint8_t> bytes;
+ bytes.resize(fbb.GetSize());
+ std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+
+ OpBuilder builder(executableOp);
+ executableOp.setAttr("format", builder.getI32IntegerAttr(static_cast<int32_t>(
+ IREE::ExecutableFormat::SpirV)));
+
+ auto executableDef = std::make_unique<iree::ExecutableDefT>();
+ executableDef->format = static_cast<uint32_t>(IREE::ExecutableFormat::SpirV);
+ executableDef->contents = std::move(bytes);
+ return executableDef;
+}
+
+std::vector<std::string> SPIRVTranslator::populateEntryPointNames(
+ IREE::ExecutableOp executableOp) {
+ auto module = executableOp.getInnerModule();
+ DenseMap<unsigned, StringRef> entryPoints;
+ for (auto funcOp : module.getOps<FuncOp>()) {
+ if (!funcOp.getAttr("iree.executable.export")) continue;
+ auto ordinalAttr = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal");
+ entryPoints[ordinalAttr.getInt()] = funcOp.getName();
+ }
+ std::vector<std::string> entryPointNames(entryPoints.size());
+ for (auto &entry : entryPoints) {
+ entryPointNames[entry.first] = entry.second.str();
+ }
+ return entryPointNames;
+}
+
+std::vector<uint32_t> SPIRVTranslator::translateAndSerializeShaderModule(
+ IREE::ExecutableOp executableOp) {
+ auto module = executableOp.getInnerModule();
+
+ // We can use the workload hint to know what the expected dispatch workload
+ // is. If we want to remap this to make more sense for the operations we are
+ // performing we can do that here.
+ //
+ // Note that workloads are computed per entry point. There may be some
+ // dimensions of the workload that are static (in which case workloadAttr will
+ // have non-dynamic dims) and others that need to be taken from an argument
+ // shape (in which case workloadRef is the argument ordinal to take dynamic
+ // dimensions from).
+ // TODO(benvanik): make it just an arg instead? iree.workload special op?
+ // TODO(benvanik): instead of FuncOp have an iree.entry_point op with these.
+ for (auto funcOp : module.getOps<FuncOp>()) {
+ // TODO(ravishankarm): FuncOps in executable that are not dispatch functions
+ // are not lowered to SPIR-V. Fix this limitation.
+ if (!funcOp.getAttr("iree.executable.export")) continue;
+ auto workloadAttr =
+ funcOp.getAttrOfType<ElementsAttr>("iree.executable.workload");
+ auto workloadRefAttr =
+ funcOp.getAttrOfType<IntegerAttr>("iree.executable.workload_ref");
+ std::array<int32_t, 3> staticWorkloadDims = {-1, -1, -1};
+ if (workloadAttr) {
+ for (unsigned i = 0; i < 3; ++i) {
+ if (auto dimAttr =
+ workloadAttr.getValue({i}).dyn_cast_or_null<IntegerAttr>()) {
+ staticWorkloadDims[i] = dimAttr.getInt();
+ }
+ }
+ }
+ std::array<BlockArgument *, 3> dynamicWorkloadDimRefs;
+ if (workloadRefAttr) {
+ for (unsigned i = 0; i < 3; ++i) {
+ if (staticWorkloadDims[i] == -1) {
+ dynamicWorkloadDimRefs[i] =
+ funcOp.getArgument(workloadRefAttr.getInt());
+ }
+ }
+ }
+
+ // Now staticWorkloadDims will have non-negative values for known dimensions
+ // and any dim with -1 will need to be pulled from the corresponding shape
+ // dimension of dynamicWorkloadDimRefs.
+
+ // TODO(b/137868263): use this information to map from workgroup to
+ // invocation and perform indexing.
+ }
+
+ // Lower module to spirv::ModuleOp.
+ auto spirvGenPasses = createPassManager(module.getContext(), options());
+ spirvGenPasses->addPass(xla_hlo::createLegalizeToStdPass());
+ spirvGenPasses->addPass(createIREEToSPIRVPass());
+ if (failed(runPassPipeline(options(), spirvGenPasses.get(), module))) {
+ executableOp.emitError() << "Failed to generate spv.module";
+ return {};
+ }
+
+ auto spvModules = module.getOps<spirv::ModuleOp>();
+ if (std::distance(spvModules.begin(), spvModules.end()) != 1) {
+ executableOp.emitError()
+ << "Expected a single spv.module for an IREE executable op";
+ return {};
+ }
+
+ // Serialize the spirv::ModuleOp into the binary that we will embed in the
+ // final flatbuffer.
+ std::vector<uint32_t> spvBinaries;
+ for (auto spvModule : spvModules) {
+ SmallVector<uint32_t, 256> spvBinary;
+ if (failed(spirv::serialize(spvModule, spvBinary))) {
+ executableOp.emitError() << "Failed to serialize spv.module";
+ return {};
+ }
+ spvBinaries.insert(spvBinaries.end(), spvBinary.begin(), spvBinary.end());
+
+ // Clone the module into executableOp directly.
+ auto clonedModule = spvModule.clone();
+ executableOp.getBlock().getOperations().insert(
+ std::prev(executableOp.getBlock().getOperations().end()), clonedModule);
+ }
+ // Remove the original code.
+ module.erase();
+
+ return spvBinaries;
+}
+
+std::unique_ptr<iree::VkPipelineLayoutDefT>
+SPIRVTranslator::populatePipelineLayout(spirv::ModuleOp spirvModuleOp) {
+ // NOTE: we currently make some assumptions about this based on the expected
+ // ABI of the runtime. If we wanted to support more general shaders with more
+ // complex I/O we'd need to find a better way to communicate this through the
+ // VkPipelineLayoutDef.
+ auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>();
+ pipelineLayoutDef->buffer_binding_set = 0;
+
+ // Build a set of descriptor_set -> binding -> variable.
+ // This makes it easier to write out the descriptor in a logical order, even
+ // though this is not strictly required.
+ int64_t maxDescriptorSetOrdinal = -1;
+ std::map<int32_t, std::map<int32_t, spirv::GlobalVariableOp>> descriptorSets;
+ for (auto globalVar :
+ spirvModuleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
+ auto descriptorSetAttr =
+ globalVar.getAttrOfType<IntegerAttr>("descriptor_set");
+ auto bindingAttr = globalVar.getAttrOfType<IntegerAttr>("binding");
+ if (!descriptorSetAttr || !bindingAttr) {
+ // Not something the runtime cares about.
+ continue;
+ }
+ maxDescriptorSetOrdinal =
+ std::max(descriptorSetAttr.getInt(), maxDescriptorSetOrdinal);
+ auto &descriptorSet = descriptorSets[descriptorSetAttr.getInt()];
+ descriptorSet[bindingAttr.getInt()] = globalVar;
+ }
+
+ // Create the individual layout and binding defs.
+ pipelineLayoutDef->descriptor_set_layouts.resize(maxDescriptorSetOrdinal + 1);
+ for (auto &descriptorSetBindings : descriptorSets) {
+ int32_t descriptorSet = descriptorSetBindings.first;
+ auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>();
+
+ for (auto &globalVarBinding : descriptorSetBindings.second) {
+ auto binding = std::make_unique<iree::VkDescriptorSetLayoutBindingDefT>();
+ binding->binding = globalVarBinding.first;
+ binding->descriptor_count = 1;
+ // TODO(benvanik): pull from type info.
+ binding->descriptor_type = 7; // VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
+ binding->stage_flags = 0x00000020; // VK_SHADER_STAGE_COMPUTE_BIT
+ dsl->bindings.push_back(std::move(binding));
+ }
+
+ pipelineLayoutDef->descriptor_set_layouts[descriptorSet] = std::move(dsl);
+ }
+
+ return pipelineLayoutDef;
+}
+
+} // namespace
+
+llvm::Optional<ExecutableTranslationResult>
+translateExecutableToSPIRVExecutable(ArrayRef<IREE::ExecutableOp> executableOps,
+ ExecutableTranslationOptions options) {
+ SPIRVTranslator translator(options);
+ ExecutableTranslationResult translationResult;
+ for (auto executableOp : llvm::make_early_inc_range(executableOps)) {
+ auto executableDef = translator.translateExecutable(executableOp);
+ if (!executableDef) {
+ executableOp.emitError() << "Failed to translate one or more executables";
+ return llvm::None;
+ }
+ translationResult.executable_defs.push_back(std::move(executableDef));
+ }
+ return translationResult;
+}
+
+static ExecutableTranslationRegistration SPIRVExecutableTranslationRegistration(
+ "vulkan-spirv", translateExecutableToSPIRVExecutable);
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h b/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h
new file mode 100644
index 0000000..89fe483
--- /dev/null
+++ b/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h
@@ -0,0 +1,38 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_
+#define IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_
+
+#include <vector>
+
+#include "compiler/IR/StructureOps.h"
+#include "compiler/Utils/TranslationUtils.h"
+#include "mlir/IR/Module.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Translates an MLIR module into a SPIR-V executable.
+// These executables are stored as FlatBuffers in the
+// https://github.com/google/iree/tree/master/iree/schemas/spirv_executable_def.fbs
+// schema.
+llvm::Optional<ExecutableTranslationResult>
+translateExecutableToSPIRVExecutable(ArrayRef<IREE::ExecutableOp> executableOps,
+ ExecutableTranslationOptions options = {});
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_
diff --git a/compiler/Translation/SPIRV/SPIRVLowering.cpp b/compiler/Translation/SPIRV/SPIRVLowering.cpp
new file mode 100644
index 0000000..23f87a7
--- /dev/null
+++ b/compiler/Translation/SPIRV/SPIRVLowering.cpp
@@ -0,0 +1,131 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- SPIRVLowering.cpp ---------------------------------------*- C++//-*-===//
+//
+// SPIR-V Code-generation for XLA-HLO Ops within IREE Dispatch functions
+//
+//===----------------------------------------------------------------------===//
+#include "compiler/Translation/SPIRV/SPIRVLowering.h"
+
+namespace mlir {
+namespace iree_compiler {
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+LogicalResult ConstantOpSPIRVLowering::lowerOperation(
+ Operation *op, OpBuilder &builder, AffineMap index, ArrayRef<Value *>,
+ ValueCache &valueCache) const {
+ auto constOp = cast<ConstantOp>(op);
+ auto attr = constOp.value().dyn_cast<DenseElementsAttr>();
+ if (!attr || !attr.isSplat()) {
+ return op->emitError(
+ "unhandled constant lowering unless value is a splat dense element "
+ "attribute");
+ }
+ auto resultType = constOp.getResult()->getType();
+ Type resultElemType;
+ if (resultType.isIntOrFloat()) {
+ resultElemType = resultType;
+ } else if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+ resultElemType = shapedType.getElementType();
+ } else {
+ return op->emitError("unhandled result type of constant : ") << resultType;
+ }
+ Attribute constVal = attr.getSplatValue();
+ auto spirvConstOp =
+ builder.create<spirv::ConstantOp>(op->getLoc(), resultElemType, constVal);
+ valueCache.setOperandDstValue(constOp.getResult(), index,
+ spirvConstOp.getResult());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CmpFOp
+//===----------------------------------------------------------------------===//
+LogicalResult CmpFOpSPIRVLowering::lowerOperation(
+ Operation *op, OpBuilder &builder, AffineMap index,
+ ArrayRef<Value *> operands, ValueCache &valueCache) const {
+ if (operands.size() != 2) {
+ return op->emitError("expected two operands in spir-v lowering of CmpFOp");
+ }
+ Operation *spirvOp = nullptr;
+ auto opInfo = op->getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName());
+ if (!opInfo) {
+ return op->emitError("expected CmpFOp to contain ")
+ << CmpFOp::getPredicateAttrName() << " attribute";
+ }
+ auto boolType = builder.getI1Type();
+ auto predicateVal = static_cast<CmpFPredicate>(opInfo.getInt());
+ switch (predicateVal) {
+#define DISPATCH(caseLabel, opName) \
+ case caseLabel: \
+ spirvOp = builder.create<opName>(op->getLoc(), boolType, operands[0], \
+ operands[1]); \
+ break;
+
+ DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
+ DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
+ DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
+ DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
+ DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
+ DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
+ DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
+ DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
+ DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
+ DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
+ DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
+ DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
+
+#undef DISPATCH
+
+ default:
+ return op->emitError("unhandled predicate attribute for SPIR-V lowering");
+ }
+ valueCache.setOperandDstValue(op->getResult(0), index, spirvOp->getResult(0));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+LogicalResult ReturnOpSPIRVLowering::lowerOperation(
+ Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
+ ValueCache &valueCache,
+ DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
+ ArrayRef<spirv::GlobalVariableOp> outputBuffers) const {
+ auto returnOp = cast<ReturnOp>(op);
+ if (returnOp.getNumOperands() != 1) {
+ return returnOp.emitError(
+ "unhandled lowering of return statement with multiple returns");
+ }
+ auto returnTensor = returnOp.getOperand(0);
+ auto indices = affineExprCodegen.getIndices(returnTensor);
+ if (indices.size() != 1) {
+ return returnOp.emitError(
+ "expected to compute a single element of the return tensor");
+ }
+ assert(outputBuffers.size() == 1 && "Expected a single output buffer");
+ auto var = outputBuffers[0];
+ auto ptr = genPointerOffset(builder, returnOp.getLoc(), affineExprCodegen,
+ indices[0], var);
+ auto scalarVal = valueCache.getOperandDstValue(returnTensor, indices[0]);
+ builder.create<spirv::StoreOp>(returnOp.getLoc(), ptr, scalarVal,
+ /*memory_access = */ nullptr,
+ /*alignment = */ nullptr);
+ builder.create<spirv::ReturnOp>(returnOp.getLoc());
+ return success();
+}
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Translation/SPIRV/SPIRVLowering.h b/compiler/Translation/SPIRV/SPIRVLowering.h
new file mode 100644
index 0000000..eacb551
--- /dev/null
+++ b/compiler/Translation/SPIRV/SPIRVLowering.h
@@ -0,0 +1,591 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- SPIRVLowering.h -----------------------------------------*- C++//-*-===//
+//
+// SPIR-V Code-generation for tensor operations within IREE Dispatch functions
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H
+#define IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H
+
+#include "compiler/Translation/SPIRV/AffineExprCodegen.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Support/StringExtras.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+class ValueCache {
+ public:
+ Value *getOperandDstValue(Value *value, AffineMap index) {
+ return convertedValueMap.lookup(value).lookup(index);
+ }
+
+ void setOperandDstValue(Value *value, AffineMap index, Value *scalar) {
+ convertedValueMap[value][index] = scalar;
+ }
+
+ private:
+ DenseMap<Value *, DenseMap<AffineMap, Value *>> convertedValueMap;
+};
+
+/// Base class for lowering tensor operations in the dispatch function to SPIR-V
+/// op.
+class SPIRVLowering {
+ public:
+ virtual ~SPIRVLowering() = default;
+ virtual StringRef getOpName() = 0;
+ /// This method (in the derived class) should generate the scalar operation
+ /// corresponding the the tensor operation `op` to generate the value of the
+ /// result tensor at a particular `index`. The scalar value of the operands
+ /// needed to compute this value is passed in within `operands`. The methods
+ /// have to insert the scalar result value of the generated operation into the
+ /// `valueCache`.
+ virtual LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
+ AffineMap index,
+ ArrayRef<Value *> operands,
+ ValueCache &valueCache) const {
+ return failure();
+ }
+
+ /// This method (in the derived class) should generate the scalar operations
+ /// corresponding to the tensor operation `op`. This should be implemented
+ /// when the `op` has no result value, typically store operations and return
+ /// operations.
+ virtual LogicalResult lowerOperation(
+ Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
+ ValueCache &valueCache,
+ DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
+ ArrayRef<spirv::GlobalVariableOp> outputBuffers) const {
+ return failure();
+ }
+};
+
+/// Base class that gets the opName for the operation.
+template <typename OpTy>
+class SPIRVOpLowering : public SPIRVLowering {
+ public:
+ using SPIRVLowering::SPIRVLowering;
+ virtual ~SPIRVOpLowering<OpTy>() {}
+ StringRef getOpName() override { return OpTy::getOperationName(); }
+};
+
+/// SPIR-V lowering for ConstantOp.
+class ConstantOpSPIRVLowering final : public SPIRVOpLowering<ConstantOp> {
+ public:
+ using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
+ LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
+ AffineMap index, ArrayRef<Value *> operands,
+ ValueCache &valueCache) const override;
+};
+
+/// SPIR-V lowering for CmpFOp.
+class CmpFOpSPIRVLowering final : public SPIRVOpLowering<CmpFOp> {
+ public:
+ using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
+
+ LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
+ AffineMap index, ArrayRef<Value *> operands,
+ ValueCache &valueCache) const override;
+};
+
+/// SPIR-V lowering for Min/Max operations.
+template <typename OpTy, typename CmpOpTy, typename CmpFOpTy>
+class CmpSelectOpSPIRVLowering final : public SPIRVOpLowering<OpTy> {
+ public:
+ using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
+ LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
+ AffineMap index, ArrayRef<Value *> operands,
+ ValueCache &valueCache) const override {
+ if (op->getNumOperands() != 2) {
+ return op->emitError(
+ "unhandled SPIR-V lowering for more than 2 operands");
+ }
+ assert(operands.size() == op->getNumOperands() &&
+ "expected as many operands for the replacement as the original "
+ "instruction");
+ auto cmpSelectOp = cast<OpTy>(op);
+ auto result = cmpSelectOp.getResult();
+ auto resultTy = result->getType().template dyn_cast<ShapedType>();
+ if (!resultTy) {
+ return op->emitError(
+ "unhandled lowering of operations that don't return a "
+ "ShapedType");
+ }
+ auto elementTy = resultTy.getElementType();
+ auto boolTy = builder.getI1Type();
+ Operation *cmpOp = nullptr;
+ if (elementTy.template isa<FloatType>()) {
+ cmpOp = builder.create<CmpFOpTy>(op->getLoc(), boolTy, operands,
+ ArrayRef<NamedAttribute>());
+ } else {
+ cmpOp = builder.create<CmpOpTy>(op->getLoc(), boolTy, operands,
+ ArrayRef<NamedAttribute>());
+ }
+ auto selectOp = builder.create<spirv::SelectOp>(
+ op->getLoc(), operands[0]->getType(), cmpOp->getResult(0), operands[0],
+ operands[1]);
+ valueCache.setOperandDstValue(op->getResult(0), index,
+ selectOp.getResult());
+ return success();
+ }
+};
+
+/// This class is the general template used to emit scalar instruction
+/// corresponding for point-wise operations. Assumes that the original
+/// instruction has a single result value of type ShapedType.
+/// TODO(ravishankarm) : In XLA-HLO, the same operations is used for
+/// integer/float tensor operations. So allow this op to take an additional op
+/// type as a template parameter to handle such cases. Find a better way to do
+/// this.
+template <typename OpTy, typename ReplacementOpTy,
+ typename FloatOpTy = ReplacementOpTy>
+class SPIRVPwOpLowering final : public SPIRVOpLowering<OpTy> {
+ public:
+ using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
+
+ LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
+ AffineMap index,
+ ArrayRef<Value *> scalarOperands,
+ ValueCache &valueCache) const override {
+ // TODO(ravishankarm) : This check should really be a static_assert. See if
+ // that can be changed.
+ if (op->getNumOperands() == 0) {
+ return op->emitError("expected op to have at least one operand");
+ }
+ auto pwOp = cast<OpTy>(op);
+ auto result = pwOp.getResult();
+ auto resultType = result->getType().template dyn_cast<ShapedType>();
+ if (!resultType) {
+ return op->emitError(
+ "unhandled lowering of operations that don't return a "
+ "ShapedType");
+ }
+ auto elementType = resultType.getElementType();
+ Operation *scalarOp = nullptr;
+ if (elementType.template isa<IntegerType>()) {
+ scalarOp = builder
+ .create<ReplacementOpTy>(op->getLoc(), elementType,
+ scalarOperands,
+ ArrayRef<NamedAttribute>())
+ .getOperation();
+ } else {
+ scalarOp =
+ builder
+ .create<FloatOpTy>(op->getLoc(), elementType, scalarOperands,
+ ArrayRef<NamedAttribute>())
+ .getOperation();
+ }
+ if (!scalarOp) {
+ return op->emitError("unable to lower operation");
+ }
+ valueCache.setOperandDstValue(pwOp.getResult(), index,
+ scalarOp->getResult(0));
+ return success();
+ }
+};
+
+/// This class is the general template used to emit scalar instruction for index
+/// transformation instructions like transpose. Assumes a single result value
+/// and a single operand
+template <typename OpTy>
+class SPIRVIndexOpLowering final : public SPIRVOpLowering<OpTy> {
+ public:
+ using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
+
+ LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
+ AffineMap index,
+ ArrayRef<Value *> scalarOperands,
+ ValueCache &valueCache) const override {
+ if (op->getNumOperands() != 1) {
+ return op->emitError(
+ "unhandled lowering of index transformation operation with multiple "
+ "operands");
+ }
+ auto indexOp = cast<OpTy>(op);
+ valueCache.setOperandDstValue(indexOp.getResult(), index,
+ scalarOperands[0]);
+ return success();
+ }
+};
+
+/// Ggenerates spv.AccessChain instruction to get the pointer value at a given
+/// location of a spv.globalVariable.
+inline Value *genPointerOffset(OpBuilder &builder, Location loc,
+ AffineExprCodegen &affineExprCodegen,
+ AffineMap indexMap,
+ spirv::GlobalVariableOp &var) {
+ auto basePtr = builder.create<spirv::AddressOfOp>(
+ loc, var.type(), builder.getSymbolRefAttr(var.sym_name()));
+ auto varPtrType = var.type().cast<spirv::PointerType>().getPointeeType();
+ // The variable has to be a struct type with a single element.
+ assert(varPtrType.isa<spirv::StructType>() &&
+ "expected variable type to be a spv.ptr<spv.struct<...>>");
+ auto varStructType = varPtrType.cast<spirv::StructType>();
+ assert(varStructType.getNumElements() == 1 &&
+ "expected variable type to be a spv.ptr of spv.struct with a single "
+ "element");
+ auto varType = varStructType.getElementType(0);
+
+ SmallVector<Value *, 2> accessIndex;
+ /// For scalar values, the index-map computed with already map to the 0-th
+ /// element. For arrays, they map to the position accessed. So just for arrays
+ /// we need to add an extra 0 to index into the struct.
+ if (varType.isa<spirv::ArrayType>() ||
+ varType.isa<spirv::RuntimeArrayType>()) {
+ auto i32Type = builder.getIntegerType(32);
+ auto zero = builder.create<spirv::ConstantOp>(loc, i32Type,
+ builder.getI32IntegerAttr(0));
+ accessIndex.push_back(zero);
+ }
+ for (auto indexExpr : indexMap.getResults()) {
+ accessIndex.push_back(affineExprCodegen.getValue(
+ indexExpr, builder.saveInsertionPoint(), loc));
+ }
+ return builder.create<spirv::AccessChainOp>(loc, basePtr, accessIndex);
+}
+
+/// Lower return statements during SPIR-V codegeneration.
+class ReturnOpSPIRVLowering : public SPIRVOpLowering<ReturnOp> {
+ public:
+ using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
+
+ LogicalResult lowerOperation(
+ Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
+ ValueCache &valueCache,
+ DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
+ ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override;
+};
+
+/// Class to drive the SPIRV code-generation.
+template <typename... Ts>
+class SPIRVCodegen {
+ using OpCodegenListT = llvm::StringMap<std::unique_ptr<SPIRVLowering>>;
+
+ public:
+ explicit SPIRVCodegen() { insert(); }
+
+ LogicalResult codegen(spirv::ModuleOp &spirvModule, FuncOp &fn,
+ AffineExprCodegen &affineExprCodegen,
+ ValueCache &valueCache) {
+ if (fn.getBlocks().size() != 1) {
+ return emitError(
+ fn.getLoc(),
+ "unimplemeneted handling multiple blocks within a function");
+ }
+
+ OpBuilder builder(spirvModule.body());
+ // Create the entry function and generate global invocation ID. Creates a
+ // global variable for all inputs and output tensors.
+ return createEntryFn(builder, fn, affineExprCodegen, valueCache);
+ }
+
+ private:
+ /// Helper method to create the entry function. Creates global variables for
+ /// all inputs and outputs. Inserts the spv.EntryPoint operations as well.
+ LogicalResult createEntryFn(OpBuilder &builder, FuncOp &fn,
+ AffineExprCodegen &affineExprCodegen,
+ ValueCache &valueCache) {
+ auto loc = fn.getLoc();
+ // TODO(ravishankarm) : This should actually be part of the SPIR-V
+ // conversion framework in MLIR core. Move it there.
+ auto convertType = [&loc](Type t,
+ spirv::PointerType &varType) -> LogicalResult {
+ auto shapedType = t.dyn_cast<ShapedType>();
+ if (!shapedType) {
+ return emitError(loc, "expected ShapedType argument");
+ }
+ auto elementType = shapedType.getElementType();
+ if (!elementType.isIntOrFloat()) {
+ return emitError(loc, "unhandled element type ")
+ << elementType << " while lowering to SPIR-V";
+ }
+ int64_t stride = elementType.getIntOrFloatBitWidth() / 8;
+ for (auto dim : reverse(shapedType.getShape())) {
+ if (dim <= 0) {
+ return emitError(loc, "expected tensor dimensions to be non-zero");
+ }
+ elementType = spirv::ArrayType::get(
+ elementType, dim,
+ static_cast<spirv::ArrayType::LayoutInfo>(stride));
+ stride *= dim;
+ }
+ // TODO(ravishankarm): Verify that the type of the variable passes
+ // spirv-val.
+ varType = spirv::PointerType::get(
+ spirv::StructType::get(elementType,
+ static_cast<spirv::StructType::LayoutInfo>(0)),
+ spirv::StorageClass::StorageBuffer);
+ return success();
+ };
+
+ // Convert functions arguments and return values to
+ // spirv::GlobalVariables. All global variables are given a descriptor set
+ // of 0 and binding is the argument number.
+ auto fnType = fn.getType();
+ auto descriptorSetAttrName = convertToSnakeCase(
+ stringifyDecoration(spirv::Decoration::DescriptorSet));
+ auto bindingAttrName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+ for (auto argType : enumerate(fnType.getInputs())) {
+ spirv::PointerType varType;
+ if (failed(convertType(argType.value(), varType))) {
+ return failure();
+ }
+ auto varName =
+ fn.getName().str() + "_arg_" + std::to_string(argType.index());
+ auto var = builder.create<spirv::GlobalVariableOp>(
+ loc, TypeAttr::get(varType), builder.getStringAttr(varName), nullptr);
+ // Set descriptor_set to 0.
+ var.setAttr(descriptorSetAttrName, builder.getI32IntegerAttr(0));
+ // Set binding to argument number.
+ var.setAttr(bindingAttrName, builder.getI32IntegerAttr(argType.index()));
+
+ inputArgToVariable[fn.getArgument(argType.index())] = var;
+ }
+ for (auto resType : enumerate(fnType.getResults())) {
+ spirv::PointerType varType;
+ if (failed(convertType(resType.value(), varType))) {
+ return failure();
+ }
+ auto varName =
+ fn.getName().str() + "_res_" + std::to_string(resType.index());
+ auto var = builder.create<spirv::GlobalVariableOp>(
+ loc, TypeAttr::get(varType), builder.getStringAttr(varName), nullptr);
+ // Set descriptor_set to 0.
+ var.setAttr(descriptorSetAttrName, builder.getI32IntegerAttr(0));
+ // Set binding to (result number + num arguments)
+ var.setAttr(
+ bindingAttrName,
+ builder.getI32IntegerAttr(fnType.getNumInputs() + resType.index()));
+
+ resultIndexToVariable.push_back(var);
+ }
+
+ auto entryFnType =
+ builder.getFunctionType(ArrayRef<Type>(), ArrayRef<Type>());
+ auto entryFn = builder.create<FuncOp>(loc, fn.getName(), entryFnType,
+ ArrayRef<NamedAttribute>());
+
+ // Start a scope to create an insertion guard to reset the builder once the
+ // function is lowered.
+ {
+ OpBuilder::InsertionGuard funcInsertGuard(builder);
+ builder.setInsertionPointToStart(entryFn.addEntryBlock());
+
+ // Create the Global invocation ID.
+ if (failed(createGlobalInvocationID(builder, fn.getLoc(),
+ affineExprCodegen))) {
+ return failure();
+ }
+
+ if (failed(lowerFunction(builder, fn, entryFn, affineExprCodegen,
+ valueCache))) {
+ return failure();
+ }
+ }
+
+ // Create the entry point instructions for the entry function.
+ if (failed(createEntryPoint(builder, loc, entryFn))) {
+ return failure();
+ }
+ return success();
+ }
+
+ /// Creates the global variable for GlobalInvocationID, and gets the ID at x,
+ /// y and z dimensions.
+ LogicalResult createGlobalInvocationID(OpBuilder &builder, Location loc,
+ AffineExprCodegen &affineExprCodegen) {
+ auto moduleOp = builder.getInsertionBlock()
+ ->getParentOp()
+ ->getParentOfType<spirv::ModuleOp>();
+ OpBuilder moduleBuilder(moduleOp.body());
+ auto i32Type = builder.getIntegerType(32);
+ auto idType = VectorType::get(3, i32Type);
+ auto ptrIdType =
+ spirv::PointerType::get(idType, spirv::StorageClass::Input);
+ auto globalInvocationID = moduleBuilder.create<spirv::GlobalVariableOp>(
+ loc, TypeAttr::get(ptrIdType),
+ builder.getStringAttr("globalInvocationID"), nullptr);
+ globalInvocationID.setAttr(
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)),
+ builder.getStringAttr(
+ spirv::stringifyBuiltIn(spirv::BuiltIn::GlobalInvocationId)));
+ interface.push_back(
+ builder.getSymbolRefAttr(globalInvocationID.sym_name()));
+
+ auto globalInvocationIDPtr = builder.create<spirv::AddressOfOp>(
+ loc, ptrIdType,
+ builder.getSymbolRefAttr(globalInvocationID.getOperation()));
+ auto id = builder.create<spirv::LoadOp>(loc, idType, globalInvocationIDPtr,
+ nullptr, nullptr);
+ auto id_x = builder.create<spirv::CompositeExtractOp>(
+ loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(0)));
+ auto id_y = builder.create<spirv::CompositeExtractOp>(
+ loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(1)));
+ auto id_z = builder.create<spirv::CompositeExtractOp>(
+ loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(2)));
+ affineExprCodegen.setDimDstValue(0, id_x);
+ affineExprCodegen.setDimDstValue(1, id_y);
+ affineExprCodegen.setDimDstValue(2, id_z);
+ return success();
+ }
+
+ /// Method to load the values of globalVariables corresponding to the
+ /// arguments of the dispatch function at all indices needed within the
+ /// dispatch function.
+ LogicalResult initArgValues(OpBuilder &builder, Location loc,
+ AffineExprCodegen &affineExprCodegen,
+ ValueCache &valueCache, Value *origArg) {
+ for (auto indexMap : affineExprCodegen.getIndices(origArg)) {
+ auto var = inputArgToVariable.lookup(origArg);
+ if (!var) {
+ return emitError(
+ loc, "undefined SPIR-V global variable for tensor argument");
+ }
+ auto ptr =
+ genPointerOffset(builder, loc, affineExprCodegen, indexMap, var);
+ auto elementType =
+ ptr->getType().template cast<spirv::PointerType>().getPointeeType();
+ auto val = builder.create<spirv::LoadOp>(loc, elementType, ptr,
+ /*memory_access =*/nullptr,
+ /*alignment = */ nullptr);
+ valueCache.setOperandDstValue(origArg, indexMap, val);
+ }
+ return success();
+ }
+
+ /// Adds the spv.EntryPointOp and records all the interface variables used in
+ /// the entryFn.
+ LogicalResult createEntryPoint(OpBuilder &builder, Location loc,
+ FuncOp entryFn) {
+ builder.create<spirv::EntryPointOp>(
+ loc,
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
+ builder.getSymbolRefAttr(entryFn), builder.getArrayAttr(interface));
+ builder.create<spirv::ExecutionModeOp>(
+ loc, builder.getSymbolRefAttr(entryFn),
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(spirv::ExecutionMode::LocalSize)),
+ builder.getI32ArrayAttr({1, 1, 1}));
+ interface.clear();
+ return success();
+ }
+
+ /// Lowers the body of the function in the original dialect to SPIR-V dialect.
+ LogicalResult lowerFunction(OpBuilder &builder, FuncOp fn, FuncOp entryFn,
+ AffineExprCodegen &affineExprCodegen,
+ ValueCache &valueCache) {
+ for (auto arg : fn.getArguments()) {
+ // Load values of the argument at all indices needed for computation
+ // within the dispatch function.
+ if (failed(initArgValues(builder, fn.getLoc(), affineExprCodegen,
+ valueCache, arg))) {
+ return failure();
+ }
+ }
+
+ for (auto &block : fn) {
+ for (auto &op : block) {
+ // Lower individual operations.
+ if (failed(
+ lowerOperation(builder, affineExprCodegen, valueCache, &op))) {
+ return failure();
+ }
+ }
+ }
+ return success();
+ }
+
+ /// Dispatches the lowering of tensor operation to SPIR-V scalar
+ /// operation.
+ LogicalResult lowerOperation(OpBuilder &builder,
+ AffineExprCodegen &affineExprCodegen,
+ ValueCache &valueCache, Operation *op) {
+ auto opName = op->getName().getStringRef();
+ if (!opCodegenList.count(opName)) {
+ return op->emitError("unhandled codegen");
+ }
+ if (op->getNumResults() > 1) {
+ return op->emitError("unhandled codegen for multiple result values");
+ }
+
+ // Zero return case.
+ if (!op->getNumResults()) {
+ return opCodegenList[opName]->lowerOperation(
+ op, builder, affineExprCodegen, valueCache, inputArgToVariable,
+ resultIndexToVariable);
+ }
+
+ // Single return case.
+ auto resultTensor = op->getResult(0);
+ auto indices = affineExprCodegen.getIndices(resultTensor);
+ for (auto &index : indices) {
+ auto operandIndices =
+ affineExprCodegen.getOperandIndices(resultTensor, index);
+ SmallVector<Value *, 2> scalarOperands;
+ for (auto arg : llvm::enumerate(op->getOperands())) {
+ auto scalarArg = valueCache.getOperandDstValue(
+ arg.value(), operandIndices[arg.index()]);
+ if (!scalarArg) {
+ return op->emitError("argument ")
+ << arg.index() << " has no scalar value";
+ }
+ scalarOperands.push_back(scalarArg);
+ }
+ if (failed(opCodegenList[opName]->lowerOperation(
+ op, builder, index, scalarOperands, valueCache))) {
+ return failure();
+ }
+ }
+ return success();
+ }
+
+ void insert() {
+ std::vector<std::unique_ptr<SPIRVLowering>> objs;
+ using dummy = int[];
+ (void)dummy{0, (objs.emplace_back(std::make_unique<Ts>()), 0)...};
+ for (auto &elem : objs) {
+ StringRef opName = elem->getOpName();
+ opCodegenList.try_emplace(opName, std::move(elem));
+ }
+ }
+
+ /// List of classes that implement the operation lowering from tensor
+ /// operations to SPIR-V.
+ OpCodegenListT opCodegenList;
+
+ /// I/O interface for the entry function containing global variables that are
+ /// used by the entire function call tree.
+ SmallVector<Attribute, 4> interface;
+
+ /// Mapping from argument of the dispatch function in tensor dialect to the
+ /// corresponding spv.globalVariable.
+ DenseMap<Value *, spirv::GlobalVariableOp> inputArgToVariable;
+
+ /// List of spv.globalVariables created for tensors returned by the dispatch
+ /// function in tensor dialects.
+ SmallVector<spirv::GlobalVariableOp, 1> resultIndexToVariable;
+
+ /// GlobalInvocationID variable.
+ spirv::GlobalVariableOp globalInvocationID;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H
diff --git a/compiler/Translation/SPIRV/XLAIndexPropagation.cpp b/compiler/Translation/SPIRV/XLAIndexPropagation.cpp
new file mode 100644
index 0000000..7f30579
--- /dev/null
+++ b/compiler/Translation/SPIRV/XLAIndexPropagation.cpp
@@ -0,0 +1,126 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- XLAIndexPropagation.cpp ---------------------------------*- C++//-*-===//
+//
+// For an IREE dispatch function in XLA-HLO dialect, compute the indices of all
+// tensors needed to produce the value of the result tensors at a particlar
+// index.
+//
+//===----------------------------------------------------------------------===//
+
+#include "compiler/Translation/SPIRV/XLAIndexPropagation.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+//===----------------------------------------------------------------------===//
+// BroadcastInDimOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult XLABroadcastInDimOpIndexPropagation::propagateIndexMap(
+ Operation *operation, AffineMap resultIndex,
+ SmallVectorImpl<AffineMap> &indexMap) const {
+ auto broadcastOp = cast<xla_hlo::BroadcastInDimOp>(operation);
+ auto broadcastDim = broadcastOp.broadcast_dimensions();
+
+ Builder builder(operation->getContext());
+ if (!broadcastDim) {
+ // This is a scalar. So all indices map to the same element.
+ AffineMap scalarMap =
+ AffineMap::get(resultIndex.getNumDims(), resultIndex.getNumSymbols(),
+ builder.getAffineConstantExpr(0));
+ indexMap.push_back(scalarMap);
+ return success();
+ }
+
+ // Handle non-scalar cases.
+ auto dimensions = broadcastDim->getValues<int64_t>();
+ SmallVector<AffineExpr, 4> exprs;
+ for (auto resultExpr : enumerate(resultIndex.getResults())) {
+ if (llvm::any_of(dimensions, [&resultExpr](int64_t dim) {
+ return dim == resultExpr.index();
+ })) {
+ exprs.push_back(resultExpr.value());
+ }
+ }
+ auto operandMap = AffineMap::get(resultIndex.getNumDims(),
+ resultIndex.getNumSymbols(), exprs);
+ indexMap.push_back(operandMap);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// BroadcastOp
+//===----------------------------------------------------------------------===//
+
+// For broadcast op, just drop the first N expressions of the resultIndex, where
+// N is the number of elements in broadcast_sizes attribute.
+LogicalResult XLABroadcastOpIndexPropagation::propagateIndexMap(
+ Operation *operation, AffineMap resultIndex,
+ SmallVectorImpl<AffineMap> &indexMap) const {
+ auto broadcastOp = cast<xla_hlo::BroadcastOp>(operation);
+ auto broadcastDim = broadcastOp.broadcast_sizes();
+
+ SmallVector<AffineExpr, 4> exprs;
+ for (auto i : llvm::seq<size_t>(
+ broadcastDim.getType().getShape()[0],
+ operation->getResult(0)->getType().cast<ShapedType>().getRank())) {
+ exprs.push_back(resultIndex.getResult(i));
+ }
+
+ Builder builder(operation->getContext());
+ if (exprs.empty()) {
+ // The result is a scalar. Just add a constant expr 0.
+ exprs.push_back(builder.getAffineConstantExpr(0));
+ }
+ auto operandMap = AffineMap::get(resultIndex.getNumDims(),
+ resultIndex.getNumSymbols(), exprs);
+ indexMap.push_back(operandMap);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ReverseOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult XLAReverseOpIndexPropagation::propagateIndexMap(
+ Operation *op, AffineMap resultIndex,
+ SmallVectorImpl<AffineMap> &indexMap) const {
+ auto reverseOp = cast<xla_hlo::ReverseOp>(op);
+ DenseSet<unsigned> dimensions;
+ for (auto index : reverseOp.dimensions()) {
+ dimensions.insert(index.getZExtValue());
+ }
+ return propagateIndexMapImpl(op, dimensions, resultIndex, indexMap);
+}
+
+//===----------------------------------------------------------------------===//
+// TransposeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult XLATransposeOpIndexPropagation::propagateIndexMap(
+ Operation *op, AffineMap resultIndex,
+ SmallVectorImpl<AffineMap> &indexMap) const {
+ auto transposeOp = cast<xla_hlo::TransposeOp>(op);
+ // Compute the affine map that represents the permutation.
+ SmallVector<unsigned, 4> permutation;
+ for (auto index : transposeOp.permutation()) {
+ permutation.push_back(index.getZExtValue());
+ }
+ return propagateIndexMapImpl(op, permutation, resultIndex, indexMap);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Translation/SPIRV/XLAIndexPropagation.h b/compiler/Translation/SPIRV/XLAIndexPropagation.h
new file mode 100644
index 0000000..e3baa21
--- /dev/null
+++ b/compiler/Translation/SPIRV/XLAIndexPropagation.h
@@ -0,0 +1,112 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- XLAIndexPropagation.h -----------------------------------*- C++//-*-===//
+//
+// For an IREE dispatch function in XLA-HLO dialect, compute the indices of all
+// tensors needed to produce the value of the result tensors at a particlar
+// index.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H
+#define IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H
+
+#include "compiler/Translation/SPIRV/IndexComputation.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Function.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+class XLABroadcastInDimOpIndexPropagation final
+ : public IndexPropagationOp<xla_hlo::BroadcastInDimOp> {
+ public:
+ using IndexPropagationOp<xla_hlo::BroadcastInDimOp>::IndexPropagationOp;
+
+ LogicalResult propagateIndexMap(
+ Operation *operation, AffineMap resultIndex,
+ SmallVectorImpl<AffineMap> &operandIndices) const override;
+};
+
+// For broadcast op, just drop the first N expressions of the resultIndex, where
+// N is the number of elements in broadcast_sizes attribute.
+class XLABroadcastOpIndexPropagation final
+ : public IndexPropagationOp<xla_hlo::BroadcastOp> {
+ public:
+ using IndexPropagationOp<xla_hlo::BroadcastOp>::IndexPropagationOp;
+
+ LogicalResult propagateIndexMap(
+ Operation *operation, AffineMap resultIndex,
+ SmallVectorImpl<AffineMap> &operandIndices) const override;
+};
+
+/// For return ops, it is assumed that each thread is computing the value of one
+/// element of the returned tensor.
+template <typename OpTy>
+class ReturnOpIndexPropagation : public IndexPropagationOp<OpTy> {
+ public:
+ using IndexPropagationOp<OpTy>::IndexPropagationOp;
+
+ LogicalResult propagateIndexMap(
+ Operation *operation, IndexComputationCache &indexMap) const override {
+ if (operation->getNumOperands() != 1) {
+ return operation->emitError("unhandled multiple return values");
+ }
+ auto returnValue = operation->getOperand(0);
+ auto returnType = returnValue->getType().cast<RankedTensorType>();
+ auto returnRank = returnType.getRank();
+ if (returnRank > 3) {
+ return operation->emitError("unhandled return tensor of dimension ")
+ << returnType.getShape().size();
+ }
+ // Have as many symbols as the rank of the input tensor. These symbols map
+ // to GlobalInvocationID along the three dimensions.
+ Builder builder(operation->getContext());
+ SmallVector<AffineExpr, 4> affineExprs;
+ for (size_t i = returnRank; i > 0; --i) {
+ affineExprs.push_back(builder.getAffineDimExpr(i - 1));
+ }
+ indexMap[operation->getOperand(0)]
+ [AffineMap::get(returnRank, 0, affineExprs)];
+ return success();
+ }
+};
+
+/// Index propogation for XLA Reverse.
+class XLAReverseOpIndexPropagation final
+ : public ReverseOpIndexPropagation<xla_hlo::ReverseOp> {
+ public:
+ using ReverseOpIndexPropagation<
+ xla_hlo::ReverseOp>::ReverseOpIndexPropagation;
+ LogicalResult propagateIndexMap(
+ Operation *op, AffineMap resultIndex,
+ SmallVectorImpl<AffineMap> &indexMap) const override;
+};
+
+/// Index propogation for XLA Transpose.
+class XLATransposeOpIndexPropagation final
+ : public TransposeOpIndexPropagation<xla_hlo::TransposeOp> {
+ public:
+ using TransposeOpIndexPropagation<
+ xla_hlo::TransposeOp>::TransposeOpIndexPropagation;
+ LogicalResult propagateIndexMap(
+ Operation *op, AffineMap resultIndex,
+ SmallVectorImpl<AffineMap> &indexMap) const override;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H
diff --git a/compiler/Translation/SPIRV/test/BUILD b/compiler/Translation/SPIRV/test/BUILD
new file mode 100644
index 0000000..c7b98c5
--- /dev/null
+++ b/compiler/Translation/SPIRV/test/BUILD
@@ -0,0 +1,16 @@
+# Tests for common transforms.
+
+load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_setup_lit_package(
+ data = [
+ "///tools:iree-opt",
+ ],
+)
+
+iree_glob_lit_tests()
diff --git a/iree/compiler/Translation/SPIRV/test/exp_test.mlir b/compiler/Translation/SPIRV/test/exp_test.mlir
similarity index 100%
rename from iree/compiler/Translation/SPIRV/test/exp_test.mlir
rename to compiler/Translation/SPIRV/test/exp_test.mlir
diff --git a/iree/compiler/Translation/SPIRV/test/simple_test.mlir b/compiler/Translation/SPIRV/test/simple_test.mlir
similarity index 100%
rename from iree/compiler/Translation/SPIRV/test/simple_test.mlir
rename to compiler/Translation/SPIRV/test/simple_test.mlir
diff --git a/compiler/Translation/Sequencer/BUILD b/compiler/Translation/Sequencer/BUILD
new file mode 100644
index 0000000..cea0b27
--- /dev/null
+++ b/compiler/Translation/Sequencer/BUILD
@@ -0,0 +1,33 @@
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "Sequencer",
+ srcs = ["SequencerModuleTranslation.cpp"],
+ hdrs = ["SequencerModuleTranslation.h"],
+ deps = [
+ "///base:status",
+ "///compiler/IR",
+ "///compiler/IR/Sequencer",
+ "///compiler/Serialization",
+ "///compiler/Transforms",
+ "///compiler/Transforms/Sequencer",
+ "///compiler/Utils",
+ "///hal:executable_format",
+ "///schemas",
+ "@com_github_google_flatbuffers//:flatbuffers",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Pass",
+ "@local_config_mlir//:StandardDialectRegistration",
+ "@local_config_mlir//:Support",
+ "@local_config_mlir//:Transforms",
+ "@local_config_mlir//:Translation",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
+ ],
+ alwayslink = 1,
+)
diff --git a/iree/compiler/Translation/Sequencer/CMakeLists.txt b/compiler/Translation/Sequencer/CMakeLists.txt
similarity index 100%
rename from iree/compiler/Translation/Sequencer/CMakeLists.txt
rename to compiler/Translation/Sequencer/CMakeLists.txt
diff --git a/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp b/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp
new file mode 100644
index 0000000..a3d6834
--- /dev/null
+++ b/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp
@@ -0,0 +1,505 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Translation/Sequencer/SequencerModuleTranslation.h"
+
+#include <cstdint>
+#include <iostream>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/minireflect.h"
+#include "base/status.h"
+#include "compiler/IR/ConfigOps.h"
+#include "compiler/IR/Sequencer/OpWriters.h"
+#include "compiler/IR/StructureOps.h"
+#include "compiler/IR/Types.h"
+#include "compiler/Serialization/VMFunctionBuilder.h"
+#include "compiler/Serialization/VMFunctionTableBuilder.h"
+#include "compiler/Serialization/VMModuleBuilder.h"
+#include "compiler/Transforms/Passes.h"
+#include "compiler/Transforms/Sequencer/Passes.h"
+#include "compiler/Utils/Macros.h"
+#include "compiler/Utils/OpUtils.h"
+#include "compiler/Utils/TranslationUtils.h"
+#include "hal/executable_format.h"
+#include "schemas/executable_def_generated.h"
+#include "schemas/executable_table_def_generated.h"
+#include "schemas/module_def_generated.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Translation.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Builds a pass pipeline that optimizes and legalizes the module to the form
+// expected by partitioning.
+void buildLegalizeInputPassPipeline(PassManager *passManager) {
+ // Convert to the subset of XLA HLO and Standard dialects supported as IREE
+ // input. In particular, move from XLA HLO to standard control flow.
+ passManager->addPass(xla_hlo::createLegalizeControlFlowPass());
+ passManager->addPass(createLegalizeInputOpsPass());
+
+ // Standard passes that shake out a lot of garbage.
+ // Some may have been run prior to translation but this ensures we are always
+ // in a known state.
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createLoopFusionPass());
+ passManager->addPass(createLoopInvariantCodeMotionPass());
+ passManager->addPass(createMemRefDataFlowOptPass());
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createSimplifyAffineStructuresPass());
+ passManager->addPass(createCSEPass());
+ passManager->addPass(createCanonicalizerPass());
+
+ // Expand uses of tuples into independent args/results.
+ passManager->addPass(createConvertFromTupleCallingConventionPass());
+ passManager->addPass(createCanonicalizerPass());
+}
+
+// Builds a pass pipeline that partitions the module into sequencer functions
+// and executables ready to be translated.
+void buildPartitioningPassPipeline(PassManager *passManager) {
+ // Find reduction ops and create iree.reduction_regions. We do this prior to
+ // performing dispatch region identification so that we can build as big of
+ // fused reduction regions as possible. The remaining ops will be put into
+ // dispatch regions.
+ passManager->addPass(createIdentifyReductionRegionsPass());
+ passManager->addPass(createCSEPass());
+
+ // Create all of the dispatch regions, CSE their workloads, and fold.
+ passManager->addPass(createIdentifyDispatchRegionsPass());
+ passManager->addPass(createCSEPass());
+ passManager->addPass(createFoldCompatibleDispatchRegionsPass());
+
+ // Note that as we are rematerializing things here it's critical we do not run
+ // the canonicalizer/CSE between now and when we outline - otherwise it'll
+ // undo all of our work!
+ passManager->addPass(createRematerializeDispatchConstantsPass());
+
+ // Outline the dispatch regions into their own functions. This separates the
+ // sequencer functions performing dispatches from the dispatchees.
+ passManager->addPass(createOutlineDispatchRegionsPass());
+ passManager->addPass(createOutlineReductionRegionsPass());
+
+ // Cleanup identity sequencer tensor-to-memref ops that clutter up the IR.
+ passManager->addPass(createCanonicalizerPass());
+
+ // Drop all functions that are no longer reachable.
+ // This is important as many of the functions remaining are probably
+ // dispatchable and unused now that we've outlined them executables.
+ passManager->addPass(createDropUnreachableModuleFunctionsPass());
+
+ // Drop all unused executables.
+ // Note that we need to have dropped unreachable functions first otherwise
+ // references could keep executables that are unreachable from exported
+ // functions alive.
+ passManager->addPass(createDropUnusedExecutablesPass());
+}
+
+// Builds a pass pipeline that converts sequencer functions to the iree_seq.hl
+// dialect.
+void buildSequencerConversionPassPipeline(PassManager *passManager) {
+ passManager->addPass(createConvertToMemRefCallingConventionPass());
+
+ // Convert ops that are supported by the sequencer directly to the sequencer
+ // dialect. The ops that remain should be only those that can be moved into
+ // dispatch regions.
+ passManager->addPass(createLowerToSequencerDialectPass());
+
+ // Cleanup identity sequencer tensor-to-memref ops and other memory accesses
+ // that clutter up the IR.
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createMemRefDataFlowOptPass());
+
+ // Eliminate ops we don't care about based on a lack of side-effects.
+ // IREE does not guarantee exception/error behavior of dead ops.
+ passManager->addPass(createAggressiveOpEliminationPass());
+
+ // Perform any last-minute optimizations to trim down the IR.
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createMemRefDataFlowOptPass());
+ passManager->addPass(createCSEPass());
+}
+
+// Builds a pass pipeline that lowers the iree_seq.hl dialect to the iree_seq.ll
+// dialect and prepares for serialization.
+void buildSequencerLoweringPassPipeline(PassManager *passManager) {
+ // Lower iree_hl_seq -> iree_ll_seq.
+ passManager->addPass(createLowerSequencerDialectPass());
+ passManager->addPass(createCanonicalizerPass());
+ passManager->addPass(createMemRefDataFlowOptPass());
+ passManager->addPass(createAggressiveOpEliminationPass());
+
+ // Assign ordinals used by the bytecode to reference executables and
+ // functions.
+ passManager->addPass(createAssignFunctionOrdinalsPass());
+ passManager->addPass(createAssignExecutableOrdinalsPass());
+
+ // Plumb workload information down into executable entry points. This allows
+ // the backends to calculate their workgroup sizes, indexing, etc.
+ passManager->addPass(createAssignExecutableWorkloadAttrsPass());
+}
+
+// Inserts one or more iree.executable_target_config ops based on the
+// translation options.
+void insertTargetConfigOps(const ModuleTranslationOptions &options,
+ OpBuilder &builder) {
+ llvm::StringSet<> targetBackends;
+ if (options.target_backends.empty()) {
+ // Add all backends when none are explicitly provided.
+ targetBackends.insert(getExecutableTranslationRegistry().keys().begin(),
+ getExecutableTranslationRegistry().keys().end());
+ } else {
+ for (auto &targetBackend : options.target_backends) {
+ for (auto &matchedBackend :
+ matchExecutableTranslationBackendNames(targetBackend)) {
+ targetBackends.insert(matchedBackend);
+ }
+ }
+ }
+ for (auto &targetBackend : targetBackends) {
+ builder.create<IREE::ExecutableTargetConfigOp>(builder.getUnknownLoc(),
+ targetBackend.getKey());
+ }
+}
+
+class SequencerTranslator {
+ public:
+ explicit SequencerTranslator(ModuleTranslationOptions options)
+ : options_(options) {}
+
+ const ModuleTranslationOptions &options() const { return options_; }
+
+ std::vector<uint8_t> translateModule(ModuleOp module);
+
+ private:
+ LogicalResult translateMultiArchExecutable(
+ IREE::MultiArchExecutableOp executableOp, VMModuleBuilder *moduleBuilder);
+
+ LogicalResult translateSequencerModule(ModuleOp module,
+ VMModuleBuilder *moduleBuilder);
+ LogicalResult declareFunction(FuncOp function,
+ VMModuleBuilder *moduleBuilder);
+ LogicalResult defineFunction(FuncOp function, VMModuleBuilder *moduleBuilder);
+
+ ModuleTranslationOptions options_;
+};
+
+std::vector<uint8_t> SequencerTranslator::translateModule(ModuleOp module) {
+ // Run one large set of passes to get to a partitioned module.
+ auto partitioningPasses = createPassManager(module.getContext(), options());
+ buildLegalizeInputPassPipeline(partitioningPasses.get());
+ buildPartitioningPassPipeline(partitioningPasses.get());
+ if (failed(runPassPipeline(options(), partitioningPasses.get(), module))) {
+ module.emitError() << "Failed to run partitioning passes";
+ return {};
+ }
+
+ // Run the sequencer-specific conversion passes on the module.
+ auto sequencerConversionPasses =
+ createPassManager(module.getContext(), options());
+ buildSequencerConversionPassPipeline(sequencerConversionPasses.get());
+ if (failed(runPassPipeline(options(), sequencerConversionPasses.get(),
+ module))) {
+ module.emitError() << "Failed to run sequencer conversion passes";
+ return {};
+ }
+
+ // Lower sequencer functions to their final form.
+ auto sequencerLoweringPasses =
+ createPassManager(module.getContext(), options());
+ buildSequencerLoweringPassPipeline(sequencerLoweringPasses.get());
+ if (failed(
+ runPassPipeline(options(), sequencerLoweringPasses.get(), module))) {
+ module.emitError() << "Failed to run sequencer lowering passes";
+ return {};
+ }
+
+ // Perform translation on all executables.
+ // We then know exactly what executable formats we have and can query them to
+ // see if we need to do any additional processing (such as to support better
+ // types/etc).
+ ::flatbuffers::FlatBufferBuilder fbb;
+ VMModuleBuilder moduleBuilder(&fbb);
+ for (auto multiArchExecutableOp :
+ module.getOps<IREE::MultiArchExecutableOp>()) {
+ if (failed(translateMultiArchExecutable(multiArchExecutableOp,
+ &moduleBuilder))) {
+ module.emitError() << "Failed to translate multi-arch-executable";
+ return {};
+ }
+ }
+
+ // Build the module bytecode.
+ if (failed(translateSequencerModule(module, &moduleBuilder))) {
+ module.emitError() << "Unable to translate sequencer module";
+ return {};
+ }
+ auto moduleDef = moduleBuilder.Finish();
+ if (moduleDef.IsNull()) {
+ module.emitError() << "Failed to verify completed module def";
+ return {};
+ }
+ auto bytes = moduleBuilder.Serialize(moduleDef);
+ if (bytes.empty()) {
+ module.emitError() << "Failed to serialize final module def";
+ return {};
+ }
+ return bytes;
+}
+
+LogicalResult SequencerTranslator::translateMultiArchExecutable(
+ IREE::MultiArchExecutableOp multiArchExecutableOp,
+ VMModuleBuilder *moduleBuilder) {
+ auto &fbb = *moduleBuilder->fbb();
+
+ // Find the unspecified executable. This is the template from which we will
+ // translate to other targets.
+ IREE::ExecutableOp templateExecutableOp;
+ for (auto executableOp :
+ multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) {
+ if (executableOp.format() ==
+ static_cast<uint32_t>(IREE::ExecutableFormat::Unspecified)) {
+ templateExecutableOp = executableOp;
+ break;
+ }
+ }
+ if (!templateExecutableOp) {
+ // Fine for there to be no unspecified executable - just ignore.
+ return success();
+ }
+ int entryPointCount = 0;
+ for (auto func : templateExecutableOp.getInnerModule().getOps<FuncOp>()) {
+ if (func.getAttr("iree.executable.export")) {
+ ++entryPointCount;
+ }
+ }
+
+ // For now we just add target config ops based on options. In the future we
+ // could do this earlier via an analysis pass determining which targets should
+ // be used for each executable.
+ OpBuilder configBuilder(templateExecutableOp);
+ configBuilder.setInsertionPointToStart(&templateExecutableOp.getBlock());
+ insertTargetConfigOps(options(), configBuilder);
+
+ // Find all target configs and bucket them into the backends that will
+ // translate them. This way we can batch the translations and possibly enable
+ // backends to dedupe some things.
+ DenseMap<StringRef, std::vector<IREE::ExecutableTargetConfigOp>>
+ backendTargetConfigOps;
+ for (auto targetConfigOp : templateExecutableOp.getBlock()
+ .getOps<IREE::ExecutableTargetConfigOp>()) {
+ auto &targetConfigOps = backendTargetConfigOps[targetConfigOp.backend()];
+ targetConfigOps.push_back(targetConfigOp);
+ }
+ if (backendTargetConfigOps.empty()) {
+ // There are no target configs - which likely means we've already translated
+ // this in a previous pass.
+ return success();
+ }
+
+ ExecutableTranslationOptions translationOptions;
+ translationOptions.CopyFrom(options());
+
+ // Invoke each backend translator on the template executables to produce new
+ // executables. The backends may produce any number of executables that we
+ // then merge back in to the iree.multi_arch_executable and the module
+ // flatbuffer.
+ std::vector<std::unique_ptr<iree::ExecutableDefT>> translatedExecutableDefs;
+ for (auto it : backendTargetConfigOps) {
+ const auto &backendKey = it.first;
+ const auto &targetConfigOps = it.second;
+
+ // Find the translator to use in the registry. It must have been linked in
+ // and the name must match what is used in the registration macro.
+ auto translateExecutableFn =
+ getExecutableTranslationRegistry().lookup(backendKey);
+ if (!translateExecutableFn) {
+ return multiArchExecutableOp.emitError()
+ << "No registered backend found for target '" << backendKey.str()
+ << "'; ensure it is linked in to your binary (have: "
+ << llvm::join(getExecutableTranslationRegistry().keys(), ", ")
+ << ")";
+ }
+
+ // Clone the executable for each config so that the translator is allowed to
+ // modify it in-place.
+ // We also need to strip all of the other configs so that the translator
+ // backend only sees the one for each of its configs.
+ OpBuilder builder(&multiArchExecutableOp.getBlock());
+ builder.setInsertionPoint(multiArchExecutableOp.getBlock().getTerminator());
+ SmallVector<IREE::ExecutableOp, 4> clonedExecutableOps;
+ for (auto targetConfigOp : targetConfigOps) {
+ auto executableCloneOp = cast<IREE::ExecutableOp>(
+ builder.clone(*templateExecutableOp.getOperation()));
+ for (auto existingTargetConfigOp : llvm::make_early_inc_range(
+ executableCloneOp.getBlock()
+ .getOps<IREE::ExecutableTargetConfigOp>())) {
+ existingTargetConfigOp.erase();
+ }
+ OpBuilder configBuilder(executableCloneOp);
+ configBuilder.setInsertionPointToStart(&executableCloneOp.getBlock());
+ configBuilder.clone(*targetConfigOp.getOperation());
+ clonedExecutableOps.push_back(executableCloneOp);
+ }
+
+ // Perform translation on all of the backend-specific targets.
+ // Note that the results here may not have the same number of executables we
+ // started with if the backend either couldn't satisfy some of the requests
+ // or decided to dedupe or expand certain ones.
+ auto translationResults =
+ translateExecutableFn(clonedExecutableOps, translationOptions);
+ if (!translationResults.hasValue()) {
+ return multiArchExecutableOp.emitError()
+ << "Failed to translate executable with backend " << backendKey;
+ }
+ for (auto &executableDef : translationResults.getValue().executable_defs) {
+ translatedExecutableDefs.push_back(std::move(executableDef));
+ }
+ }
+
+ // Remove configs from the template executable so that if we are called again
+ // we don't re-translate.
+ for (auto targetConfigOp : llvm::make_early_inc_range(
+ templateExecutableOp.getBlock()
+ .getOps<IREE::ExecutableTargetConfigOp>())) {
+ targetConfigOp.erase();
+ }
+
+ // Create multi-arch executable with all of the target-specific executables.
+ iree::MultiArchExecutableDefT maedf;
+ maedf.name = multiArchExecutableOp.getName();
+ maedf.entry_point_count = entryPointCount;
+ maedf.executables = std::move(translatedExecutableDefs);
+ auto maedfOffset = iree::MultiArchExecutableDef::Pack(fbb, &maedf);
+ RETURN_IF_FAILURE(
+ moduleBuilder->executable_table()->AddMultiArchExecutable(maedfOffset));
+
+ return success();
+}
+
+LogicalResult SequencerTranslator::translateSequencerModule(
+ ModuleOp module, VMModuleBuilder *moduleBuilder) {
+ // Declare functions. This must happen first so that we get stable indices
+ // during declaration (as call ops need to use the function table).
+ for (auto function : module.getOps<FuncOp>()) {
+ RETURN_IF_FAILURE(declareFunction(function, moduleBuilder));
+ }
+
+ // Define functions and convert their bodies to bytecode.
+ for (auto function : module.getOps<FuncOp>()) {
+ RETURN_IF_FAILURE(defineFunction(function, moduleBuilder));
+ }
+
+ return success();
+}
+
+LogicalResult SequencerTranslator::declareFunction(
+ FuncOp function, VMModuleBuilder *moduleBuilder) {
+ auto *functionTable = moduleBuilder->function_table();
+ if (functionTable->IsFunctionDeclared(function)) {
+ // Already declared.
+ return success();
+ }
+
+ LinkageType linkageType;
+ if (function.isExternal()) {
+ linkageType = LinkageType::kImport;
+ } else if (function.getAttr("iree.module.export")) {
+ linkageType = LinkageType::kExport;
+ } else {
+ linkageType = LinkageType::kInternal;
+ }
+ if (failed(functionTable->DeclareFunction(function, linkageType))) {
+ return function.emitError()
+ << "Unable to declare function " << function.getName();
+ }
+
+ // Import functions must have their definition defined here so we get their
+ // type. Internal and export functions will be defined during conversion.
+ if (linkageType == LinkageType::kImport) {
+ VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(),
+ moduleBuilder->fbb());
+ auto functionOffset = functionBuilder.Finish();
+ if (functionOffset.IsNull()) {
+ return function.emitError()
+ << "Failed to create import function bytecode";
+ }
+ RETURN_IF_FAILURE(
+ functionTable->DefineFunction(function, functionOffset, {}));
+ }
+
+ return success();
+}
+
+LogicalResult SequencerTranslator::defineFunction(
+ FuncOp function, VMModuleBuilder *moduleBuilder) {
+ VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(),
+ moduleBuilder->fbb());
+ registerSequencerCustomWriters(&functionBuilder);
+ RETURN_IF_FAILURE(functionBuilder.ConvertBytecode());
+ auto functionOffset = functionBuilder.Finish();
+ if (functionOffset.IsNull()) {
+ return function.emitError() << "Failed to convert function to bytecode";
+ }
+ RETURN_IF_FAILURE(moduleBuilder->function_table()->DefineFunction(
+ function, functionOffset, functionBuilder.source_map()));
+ return success();
+}
+
+} // namespace
+
+std::vector<uint8_t> translateMlirToIreeSequencerModule(
+ ModuleOp module, ModuleTranslationOptions options) {
+ SequencerTranslator translator(options);
+ return translator.translateModule(module);
+}
+
+LogicalResult translateMlirToIreeSequencerModuleFile(
+ ModuleOp module, llvm::raw_ostream &output) {
+ ModuleTranslationOptions options;
+ SequencerTranslator translator(options);
+ auto bytecodeModule = translator.translateModule(module);
+ if (bytecodeModule.empty()) {
+ return emitError(UnknownLoc::get(module.getContext()),
+ "failed to translate module");
+ }
+
+ output.write(reinterpret_cast<const char *>(bytecodeModule.data()),
+ bytecodeModule.size());
+ return success();
+}
+
+static TranslateFromMLIRRegistration MlirToIreeSequencerModuleTranslate(
+ "mlir-to-iree-module", translateMlirToIreeSequencerModuleFile);
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Translation/Sequencer/SequencerModuleTranslation.h b/compiler/Translation/Sequencer/SequencerModuleTranslation.h
new file mode 100644
index 0000000..7bd6c37
--- /dev/null
+++ b/compiler/Translation/Sequencer/SequencerModuleTranslation.h
@@ -0,0 +1,36 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_
+#define IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_
+
+#include <vector>
+
+#include "compiler/Utils/TranslationUtils.h"
+#include "mlir/IR/Module.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Translates an MLIR module in a compatible IREE input dialect (such as XLA HLO
+// and/or Std) into an IREE Module. Executables will be lowered based on the
+// provided configuration.
+// Returns an empty vector on translation failure.
+std::vector<uint8_t> translateMlirToIreeSequencerModule(
+ ModuleOp module, ModuleTranslationOptions options = {});
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_
diff --git a/compiler/Utils/BUILD b/compiler/Utils/BUILD
new file mode 100644
index 0000000..bb09cb6
--- /dev/null
+++ b/compiler/Utils/BUILD
@@ -0,0 +1,41 @@
+# Utilities for working with IREE MLIR types.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "Utils",
+ srcs = [
+ "DispatchUtils.cpp",
+ "MemRefUtils.cpp",
+ "ModuleUtils.cpp",
+ "OpCreationUtils.cpp",
+ "OpUtils.cpp",
+ "TranslationUtils.cpp",
+ "TypeConversionUtils.cpp",
+ ],
+ hdrs = [
+ "DispatchUtils.h",
+ "Macros.h",
+ "MemRefUtils.h",
+ "ModuleUtils.h",
+ "OpCreationUtils.h",
+ "OpUtils.h",
+ "TranslationUtils.h",
+ "TypeConversionUtils.h",
+ ],
+ deps = [
+ "///compiler/IR",
+ "///schemas",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Pass",
+ "@local_config_mlir//:StandardOps",
+ "@local_config_mlir//:Support",
+ "@local_config_mlir//:TransformUtils",
+ "@local_config_mlir//:Transforms",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
+ ],
+)
diff --git a/iree/compiler/Utils/CMakeLists.txt b/compiler/Utils/CMakeLists.txt
similarity index 100%
rename from iree/compiler/Utils/CMakeLists.txt
rename to compiler/Utils/CMakeLists.txt
diff --git a/compiler/Utils/DispatchUtils.cpp b/compiler/Utils/DispatchUtils.cpp
new file mode 100644
index 0000000..461d12a
--- /dev/null
+++ b/compiler/Utils/DispatchUtils.cpp
@@ -0,0 +1,726 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Utils/DispatchUtils.h"
+
+#include "compiler/IR/Ops.h"
+#include "compiler/Utils/TypeConversionUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+Value *calculateWorkload(Operation *op, Value *baseOperand) {
+ OpBuilder builder(op);
+
+ std::array<int32_t, 3> workload = {1, 1, 1};
+
+ // TODO(b/139353314): lookup/calculate based on type/etc.
+ auto resultType = baseOperand->getType();
+ if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+ if (!shapedType.hasStaticShape()) {
+ op->emitOpError() << "Dynamic shapes not yet supported";
+ return nullptr;
+ }
+ auto shape = shapedType.getShape();
+ // Drop the trailing ones from the shape.
+ while (shape.size() > 1 && shape.back() == 1) {
+ shape = shape.drop_back();
+ }
+ if (shape.size() <= 3) {
+ // Maps to XYZ (possibly with 1's for unused dimensions).
+ for (auto dim : enumerate(shape)) {
+ workload[shape.size() - 1 - dim.index()] = dim.value();
+ }
+ } else {
+ // Need to flatten the shape to fit XYZ. For now we just squash from LHS.
+ workload[2] = 1;
+ for (int i = 0; i < shape.size(); ++i) {
+ workload[2] *= shape[i];
+ }
+ workload[1] = shape[shape.size() - 2];
+ workload[0] = shape.back();
+ }
+ }
+
+ // TODO(b/139353314): optimize workload layout.
+
+ auto constantType = RankedTensorType::get({3}, builder.getIntegerType(32));
+ return builder.create<ConstantOp>(
+ op->getLoc(), constantType,
+ DenseIntElementsAttr::get<int32_t>(constantType, workload));
+}
+
+bool isTriviallyDispatchable(FuncOp func) {
+ if (func.empty()) return false;
+ auto &block = func.front();
+ if (block.getOperations().size() != 2) return false;
+ auto &op0 = block.front();
+ auto &op1 = block.back();
+ auto regionOp = dyn_cast<IREE::DispatchRegionOp>(op0);
+ auto returnOp = dyn_cast<ReturnOp>(op1);
+ if (!regionOp || !returnOp ||
+ regionOp.getNumResults() != returnOp.getNumOperands()) {
+ return false;
+ }
+ for (int i = 0; i < regionOp.getNumResults(); ++i) {
+ if (regionOp.getResult(i) != returnOp.getOperand(i)) return false;
+ }
+ return true;
+}
+
+namespace {
+
+// Returns the set of values that must be captured for use by |ops| and the
+// set of values defined by |ops| that are used outside of the set.
+LogicalResult analyzeOpRangeValues(
+ const llvm::SmallDenseSet<Operation *> &opSet,
+ llvm::SetVector<Value *> *capturedValues,
+ llvm::SetVector<Value *> *escapingValues) {
+ for (auto *op : opSet) {
+ for (auto *value : op->getOperands()) {
+ if (!llvm::is_contained(opSet, value->getDefiningOp())) {
+ // Op is using a value not in the ops set, ensure we capture it.
+ capturedValues->insert(value);
+ }
+ }
+ for (auto *value : op->getResults()) {
+ for (auto &use : value->getUses()) {
+ if (!llvm::is_contained(opSet, use.getOwner())) {
+ // An op outside of the ops set is using the value, needs to escape.
+ escapingValues->insert(value);
+ }
+ }
+ }
+ }
+ return success();
+}
+
+} // namespace
+
+LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock,
+ Value *workload, ArrayRef<Operation *> ops) {
+ // Fused location with all ops.
+ SmallVector<Location, 16> opLocs;
+ for (auto *op : ops) {
+ opLocs.push_back(op->getLoc());
+ }
+ auto regionLoc = FusedLoc::get(opLocs, func.getContext());
+
+ // Get a list of values that we need to capture and values that escape the
+ // region and need to be returned.
+ llvm::SmallDenseSet<Operation *> opSet;
+ opSet.reserve(ops.size());
+ opSet.insert(ops.begin(), ops.end());
+ llvm::SetVector<Value *> capturedValues;
+ llvm::SetVector<Value *> escapingValues;
+ if (failed(analyzeOpRangeValues(opSet, &capturedValues, &escapingValues))) {
+ return failure();
+ }
+ SmallVector<Type, 8> escapingTypes;
+ for (auto *value : escapingValues) escapingTypes.push_back(value->getType());
+
+ // Build the region op and add it to the parent block.
+ OpBuilder parentBuilder(parentBlock);
+ parentBuilder.setInsertionPoint(ops.back());
+ auto dispatchRegionOp = parentBuilder.create<IREE::DispatchRegionOp>(
+ regionLoc, escapingTypes, workload, capturedValues.getArrayRef());
+
+ // Create the block and setup the arg mapping for captured values.
+ auto *regionBlock = new Block();
+ dispatchRegionOp.getBody().push_back(regionBlock);
+ OpBuilder regionBuilder(regionBlock);
+ BlockAndValueMapping mapping;
+ for (auto *capturedValue : capturedValues) {
+ auto *blockArg = regionBlock->addArgument(capturedValue->getType());
+ mapping.map(capturedValue, blockArg);
+ }
+
+ // Clone ops into the new region block.
+ for (auto *op : ops) {
+ // Note that this updates the mapping with the new values (so at the end
+ // we have those new values).
+ regionBuilder.clone(*op, mapping);
+ }
+
+ // Return results (as we need a terminator in our block).
+ // These are all of the values that escape our region.
+ SmallVector<Value *, 8> resultValues;
+ for (auto *oldValue : escapingValues) {
+ resultValues.push_back(mapping.lookupOrDefault(oldValue));
+ }
+ regionBuilder.create<IREE::ReturnOp>(opLocs.back(), resultValues);
+
+ // Replace usage of values with the results of the region.
+ for (int i = 0; i < escapingValues.size(); ++i) {
+ escapingValues[i]->replaceAllUsesWith(dispatchRegionOp.getResult(i));
+ }
+
+ // Remove original ops from the parent region.
+ for (auto it = ops.rbegin(); it != ops.rend(); ++it) {
+ (*it)->erase();
+ }
+
+ return success();
+}
+
+namespace {
+
+// Replaces |returnOp| with a clone including |newOperands| appended.
+LogicalResult appendReturnOperands(IREE::ReturnOp returnOp,
+ ArrayRef<Value *> newOperands) {
+ // Insert prior to the original return.
+ OpBuilder builder(returnOp);
+
+ // Clone with new args.
+ SmallVector<Value *, 8> operands;
+ operands.reserve(returnOp.getNumOperands() + newOperands.size());
+ operands.append(returnOp.operand_begin(), returnOp.operand_end());
+ operands.append(newOperands.begin(), newOperands.end());
+ builder.create<IREE::ReturnOp>(returnOp.getLoc(), operands);
+
+ // Remove original.
+ returnOp.erase();
+
+ return success();
+}
+
+// Replaces |regionOp| with a clone including |newArgs| and |newResults|.
+IREE::DispatchRegionOp appendRegionArgsAndResults(
+ IREE::DispatchRegionOp ®ionOp, ArrayRef<Value *> newArgs,
+ ArrayRef<Value *> newResults, Location otherLoc) {
+ // Insert prior to the original region.
+ OpBuilder builder(regionOp);
+
+ // Location is original region + new region location (both probably fused).
+ SmallVector<Location, 2> fusedLocs = {regionOp.getLoc(), otherLoc};
+ auto fusedLoc = FusedLoc::get(fusedLocs, regionOp.getContext());
+
+ // Clone with new results.
+ SmallVector<Value *, 8> operands;
+ operands.append(regionOp.getArgOperands().begin(),
+ regionOp.getArgOperands().end());
+ operands.append(newArgs.begin(), newArgs.end());
+ SmallVector<Type, 8> resultTypes;
+ resultTypes.append(regionOp.result_type_begin(), regionOp.result_type_end());
+ for (auto *newResult : newResults) {
+ resultTypes.push_back(newResult->getType());
+ }
+ auto newRegionOp = builder.create<IREE::DispatchRegionOp>(
+ fusedLoc, resultTypes, regionOp.getWorkload(), operands,
+ regionOp.getAttrs());
+ newRegionOp.getBody().takeBody(regionOp.getBody());
+
+ // Replace uses of original values with the new values.
+ for (int i = 0; i < regionOp.getNumResults(); ++i) {
+ regionOp.getResult(i)->replaceAllUsesWith(newRegionOp.getResult(i));
+ }
+
+ // Erase the original region.
+ regionOp.erase();
+
+ return newRegionOp;
+}
+
+// Removes results that are not used from the dispatch region.
+// Returns the new operation. There may be unused ops in the region but DCE
+// should take care of that later.
+IREE::DispatchRegionOp removeUnusedResults(IREE::DispatchRegionOp regionOp) {
+ // Find return value within the region.
+ auto ®ionBlock = regionOp.getBody().getBlocks().front();
+ auto returnOp = dyn_cast<IREE::ReturnOp>(regionBlock.getTerminator());
+ if (!returnOp) {
+ regionBlock.getParent()->getParentOfType<FuncOp>().emitError()
+ << "Block does not contain an iree.return op";
+ }
+
+ // Calculate new return values.
+ SmallVector<Type, 8> newReturnTypes;
+ SmallVector<Value *, 8> newReturnValues;
+ SmallVector<Value *, 8> newRegionResults;
+ for (int i = 0; i < returnOp.getNumOperands(); ++i) {
+ auto *resultValue = regionOp.getResult(i);
+ if (!resultValue->use_empty()) {
+ // Still has uses so we will preserve it.
+ newReturnTypes.push_back(resultValue->getType());
+ newReturnValues.push_back(returnOp.getOperand(i));
+ newRegionResults.push_back(resultValue);
+ }
+ }
+
+ // Update return op operands. We can do this in-place as we are only shrinking
+ // the list.
+ returnOp.getOperation()->setOperands(newReturnValues);
+
+ // Insert prior to the original region.
+ OpBuilder builder(regionOp);
+
+ // Clone with new results.
+ SmallVector<Value *, 8> operands(regionOp.getArgOperands());
+ auto newRegionOp = builder.create<IREE::DispatchRegionOp>(
+ regionOp.getLoc(), newReturnTypes, regionOp.getWorkload(), operands,
+ regionOp.getAttrs());
+ newRegionOp.getBody().takeBody(regionOp.getBody());
+
+ // Replace uses of original values with the new values.
+ for (int i = 0; i < newRegionResults.size(); ++i) {
+ newRegionResults[i]->replaceAllUsesWith(newRegionOp.getResult(i));
+ }
+
+ // Erase the original region.
+ regionOp.erase();
+
+ return newRegionOp;
+}
+
+// Returns true if |lhs| and |rhs| have either an identical workload or one that
+// is compatible.
+bool areDispatchRegionWorkloadsCompatible(IREE::DispatchRegionOp &lhs,
+ IREE::DispatchRegionOp &rhs) {
+ // TODO(benvanik): more sophisticated checking; right now it's just identical.
+ return lhs.getWorkload() == rhs.getWorkload();
+}
+
+// Returns true if |value| depends in any way on |op| through any path.
+// Only works if the operations are within the same block.
+bool doesValueDependOnOperation(Value *value, Operation *op) {
+ if (!value->getDefiningOp()) {
+ return false;
+ } else if (value->getDefiningOp() == op) {
+ return true;
+ } else if (value->getDefiningOp()->isBeforeInBlock(op)) {
+ // Can't depend on |op| as it is defined prior to it.
+ return false;
+ }
+ for (auto *operand : value->getDefiningOp()->getOperands()) {
+ if (doesValueDependOnOperation(operand, op)) {
+ return true;
+ }
+ }
+ return true;
+}
+
+// Returns true if |rhs| transitively depends on any out of |lhs|.
+// |rhs| may depend directly on the results of |lhs| but no other ops in the
+// parent block will use the results prior to |rhs|.
+bool areDispatchRegionsTransitivelyDependent(IREE::DispatchRegionOp &lhs,
+ IREE::DispatchRegionOp &rhs) {
+ for (auto *arg : rhs.getArgOperands()) {
+ if (arg->getDefiningOp() != lhs && doesValueDependOnOperation(arg, lhs)) {
+ // Transitively dependent - boo - can't merge yet.
+ return true;
+ }
+ }
+ return false;
+}
+
+// Returns true if the dispatch region contains only a single block.
+// This is because our merge isn't very smart and will not preserve the CFG
+// right now. We can fix this when needed.
+bool isDispatchRegionMergable(IREE::DispatchRegionOp ®ionOp) {
+ // Disallow merging of dispatch regions containing matmuls and other big ops.
+ // We do this to allow backends to lower the big op as entirely isolated such
+ // that substituting library calls is easier.
+ for (auto &block : regionOp.getBody().getBlocks()) {
+ for (auto &op : block) {
+ if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) {
+ return false;
+ }
+ }
+ }
+ return regionOp.getBody().getBlocks().size() == 1;
+}
+
+// Merges |rhs| into |lhs| and returns the new |lhs| op.
+// Precondition: !areDispatchRegionsTransitivelyDependent
+IREE::DispatchRegionOp mergeDispatchRegions(IREE::DispatchRegionOp &lhs,
+ IREE::DispatchRegionOp &rhs) {
+ auto &lhsBlock = lhs.getBody().front();
+ auto &rhsBlock = rhs.getBody().front();
+
+ // Find the values used as return values in the lhs.
+ // We'll need to replace the uses in rhs with these.
+ auto lhsReturnOp = cast<IREE::ReturnOp>(lhsBlock.getTerminator());
+ SmallVector<Value *, 8> lhsReturnValues;
+ lhsReturnValues.reserve(lhsReturnOp.getNumOperands());
+ lhsReturnValues.append(lhsReturnOp.operand_begin(),
+ lhsReturnOp.operand_end());
+
+ // Find the values used as return values in the rhs.
+ // We'll add these to the results of the lhs region.
+ auto rhsReturnOp = cast<IREE::ReturnOp>(rhsBlock.getTerminator());
+ SmallVector<Value *, 8> rhsReturnValues;
+ rhsReturnValues.reserve(rhsReturnOp.getNumOperands());
+ rhsReturnValues.append(rhsReturnOp.operand_begin(),
+ rhsReturnOp.operand_end());
+
+ // Compute new args.
+ BlockAndValueMapping mapping;
+ SmallVector<Value *, 8> newArgs;
+ for (int rhsOpIdx = 0; rhsOpIdx < rhs.getNumArgOperands(); ++rhsOpIdx) {
+ bool didElide = false;
+ // Find if the rhs arg already exists on the lhs and dedupe.
+ for (int lhsOpIdx = 0; lhsOpIdx < lhs.getNumArgOperands(); ++lhsOpIdx) {
+ if (rhs.getArgOperand(rhsOpIdx) == lhs.getArgOperand(lhsOpIdx)) {
+ mapping.map(rhsBlock.getArgument(rhsOpIdx),
+ lhsBlock.getArgument(lhsOpIdx));
+ didElide = true;
+ break;
+ }
+ }
+ // Find if the arg has a direct dependency on the results of the lhs.
+ for (int lhsResultIdx = 0; lhsResultIdx < lhs.getNumResults();
+ ++lhsResultIdx) {
+ if (rhs.getArgOperand(rhsOpIdx) == lhs.getResult(lhsResultIdx)) {
+ // Direct dependency; can elide. We'll skip adding it to the new region
+ // args and instead just remap it later.
+ mapping.map(rhsBlock.getArgument(rhsOpIdx),
+ lhsReturnValues[lhsResultIdx]);
+ didElide = true;
+ break;
+ }
+ }
+ if (!didElide) {
+ // Add to the lhs block.
+ auto *oldArg = rhs.getOperand(rhsOpIdx + 1);
+ auto *newArg = lhsBlock.addArgument(oldArg->getType());
+ mapping.map(rhsBlock.getArgument(rhsOpIdx), newArg);
+ newArgs.push_back(oldArg);
+ }
+ }
+
+ OpBuilder regionBuilder(&lhsBlock);
+
+ // Copy ops (replacing any args as needed).
+ // Note that we need to insert prior to the terminator.
+ regionBuilder.setInsertionPoint(lhsReturnOp);
+ for (auto &op : rhsBlock) {
+ // Note that this updates the mapping with the new values (so at the end
+ // we have those new values).
+ //
+ // We avoid the return op here as we have already merged it above.
+ if (!op.isKnownTerminator()) {
+ regionBuilder.clone(op, mapping);
+ }
+ }
+
+ // Compute new results and add to both region and return op.
+ SmallVector<Value *, 8> newResults;
+ for (auto *rhsResult : rhsReturnValues) {
+ newResults.push_back(mapping.lookupOrDefault(rhsResult));
+ }
+ if (failed(appendReturnOperands(lhsReturnOp, newResults))) {
+ return nullptr;
+ }
+ auto newRegionOp =
+ appendRegionArgsAndResults(lhs, newArgs, newResults, rhs.getLoc());
+
+ // Replace uses of original values with the new values.
+ for (int i = 0; i < rhs.getNumResults(); ++i) {
+ rhs.getResult(i)->replaceAllUsesWith(
+ newRegionOp.getResult(lhsReturnValues.size() + i));
+ }
+
+ // Remove rhs region.
+ rhs.erase();
+
+ // Remove results from the lhs that aren't used anymore as they may have been
+ // elided when we merged as only the rhs was using them.
+ newRegionOp = removeUnusedResults(newRegionOp);
+
+ return newRegionOp;
+}
+
+} // namespace
+
+LogicalResult mergeBlockDispatchRegions(FuncOp func, Block *parentBlock) {
+ SmallVector<IREE::DispatchRegionOp, 8> mergableRegions;
+ for (auto &op : *parentBlock) {
+ if (auto regionOp = dyn_cast<IREE::DispatchRegionOp>(op)) {
+ if (isDispatchRegionMergable(regionOp)) {
+ mergableRegions.push_back(regionOp);
+ } else {
+ regionOp.emitRemark(
+ "Unable to merge into following iree.dispatch_regions; "
+ "contains non-trivial control flow");
+ }
+ }
+ }
+ for (int i = 0; i < mergableRegions.size(); ++i) {
+ if (!mergableRegions[i]) continue;
+ auto &lhs = mergableRegions[i];
+ for (int j = i + 1; j < mergableRegions.size(); ++j) {
+ if (!mergableRegions[j]) continue;
+ auto &rhs = mergableRegions[j];
+ if (!areDispatchRegionWorkloadsCompatible(lhs, rhs) ||
+ areDispatchRegionsTransitivelyDependent(lhs, rhs)) {
+ continue;
+ }
+ if (!isDispatchRegionMergable(rhs)) {
+ // TODO(b/134675461): support non-trivial control flow.
+ rhs.emitRemark(
+ "Unable to merge into previous iree.dispatch_region; "
+ "contains non-trivial control flow");
+ }
+ mergableRegions[i] = mergeDispatchRegions(lhs, rhs);
+ if (!mergableRegions[i]) {
+ return failure();
+ }
+ mergableRegions[j] = nullptr;
+ --i; // Try again to see if there are subsequent regions to merge.
+ break;
+ }
+ }
+
+ return success();
+}
+
+namespace {
+
+// Recursively clones the given |sourceOp| and returns the newly cloned op.
+Operation *recursivelyCloneOp(Operation *sourceOp, OpBuilder &builder,
+ BlockAndValueMapping *mapping) {
+ // Note that we dedupe required operands in the case of multiple arguments
+ // coming from the same source operation.
+ SmallPtrSet<Operation *, 4> operandOps;
+ for (auto *operand : sourceOp->getOperands()) {
+ operandOps.insert(operand->getDefiningOp());
+ }
+ for (auto *operandOp : operandOps) {
+ recursivelyCloneOp(operandOp, builder, mapping);
+ }
+ return builder.clone(*sourceOp, *mapping);
+}
+
+// Clones the |sourceValue| op tree into |targetBlock|.
+// |mapping| is used to lookup existing values that may be present in the block
+// such as block arguments or already cloned ancestor ops. |mapping| will be
+// updated as the tree is cloned.
+Value *cloneOpTreeIntoBlock(Value *sourceValue, Block *targetBlock,
+ BlockAndValueMapping *mapping) {
+ // If the op has already been cloned we can just reuse that.
+ // This happens if multiple arguments reference the same trees.
+ if (auto *existingValue = mapping->lookupOrNull(sourceValue)) {
+ return existingValue;
+ }
+
+ OpBuilder builder(targetBlock);
+ builder.setInsertionPointToStart(targetBlock);
+ auto *sourceOp = sourceValue->getDefiningOp();
+ auto *clonedOp = recursivelyCloneOp(sourceOp, builder, mapping);
+
+ // Return only the result matching our source value (in the case of multiple
+ // results).
+ int resultIndex = std::distance(
+ sourceOp->result_begin(),
+ std::find(sourceOp->result_begin(), sourceOp->result_end(), sourceValue));
+ return clonedOp->getResult(resultIndex);
+}
+
+} // namespace
+
+LogicalResult inlineDispatchRegionOperandsUsingValue(
+ IREE::DispatchRegionOp dispatchRegionOp, Value *value) {
+ // Find all args that are using this value.
+ SmallVector<unsigned, 4> argIndices;
+ for (auto arg : llvm::enumerate(dispatchRegionOp.getArgOperands())) {
+ if (arg.value() == value) {
+ argIndices.push_back(arg.index());
+ }
+ }
+ if (argIndices.empty()) {
+ // Not used? Wasteful call!
+ return success();
+ }
+
+ // Clone the value (and the ops required to create it) into the entry block.
+ auto &entryBlock = dispatchRegionOp.getBody().getBlocks().front();
+ BlockAndValueMapping mapping;
+ auto *clonedValue = cloneOpTreeIntoBlock(value, &entryBlock, &mapping);
+
+ // Replace all uses of the inner operand with the new value.
+ for (unsigned argIndex : argIndices) {
+ entryBlock.getArgument(argIndex)->replaceAllUsesWith(clonedValue);
+ }
+
+ // Remove the dispatch region args and the block args that have been
+ // replaced.
+ for (unsigned argIndex : llvm::reverse(argIndices)) {
+ dispatchRegionOp.getOperation()->eraseOperand(
+ dispatchRegionOp.mapArgOperandToOpOperand(argIndex));
+ entryBlock.eraseArgument(argIndex);
+ }
+
+ return success();
+}
+
+namespace {
+
+// Recursively finds all reachable functions from the given |rootFunc| and adds
+// them to the |reachableFuncs| set.
+//
+// Note that indirect calls are not supported, however we don't allow those in
+// dispatch regions anyway so they should not be present here.
+LogicalResult findReachableFunctions(Operation *rootFunc,
+ llvm::SetVector<FuncOp> &reachableFuncs) {
+ bool allCallsValid = true;
+ rootFunc->walk([&](CallOp op) {
+ auto callee = rootFunc->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
+ op.getCallee());
+ if (!callee.getAttr("iree.dispatchable")) {
+ allCallsValid = false;
+ rootFunc->emitError() << callee.getName().str() << " is not dispatchable";
+ return;
+ }
+ if (reachableFuncs.insert(callee)) {
+ findReachableFunctions(callee, reachableFuncs);
+ }
+ });
+ return success(allCallsValid);
+}
+
+} // namespace
+
+std::pair<IREE::MultiArchExecutableOp, FuncOp> createRegionExecutable(
+ Operation *op, FunctionType functionType, StringRef symbolSuffix) {
+ // Create the function and take the region body directly.
+ // NOTE: this will get uniquified if we have multiple in the same block.
+ auto parentFunc = op->getParentOfType<FuncOp>();
+ std::string functionName =
+ (parentFunc.getName().str() + "_rgn" + symbolSuffix).str();
+ auto outlinedFunc = FuncOp::create(op->getLoc(), functionName, functionType);
+ BlockAndValueMapping mapping;
+ op->getRegion(0).cloneInto(&outlinedFunc.getBody(), mapping);
+
+ // Gather all reachable functions.
+ llvm::SetVector<FuncOp> reachableFuncs;
+ findReachableFunctions(outlinedFunc, reachableFuncs);
+
+ // Create the multi-arch executable that will contain the outlined region.
+ // NOTE: this will get uniquified if we have multiple in the same block.
+ auto parentModule = parentFunc.getParentOfType<ModuleOp>();
+ OpBuilder parentModuleBuilder(parentModule);
+ parentModuleBuilder.setInsertionPoint(parentFunc);
+ std::string executableName =
+ (parentFunc.getName().str() + "_ex" + symbolSuffix).str();
+ auto multiArchExecutable =
+ parentModuleBuilder.create<IREE::MultiArchExecutableOp>(
+ outlinedFunc.getLoc(), executableName);
+
+ // Create the executable op initially unspecified so that later
+ // transformations can compile it to various formats.
+ OpBuilder multiArchExecutableBuilder(multiArchExecutable);
+ multiArchExecutableBuilder.setInsertionPointToStart(
+ &multiArchExecutable.getBlock());
+ auto executable = multiArchExecutableBuilder.create<IREE::ExecutableOp>(
+ outlinedFunc.getLoc(), IREE::ExecutableFormat::Unspecified);
+
+ // Create the inner ModuleOp that contains the original functions. We need
+ // to provide this shim as some ops (like std.call) look for the
+ // containing module to provide symbol resolution.
+ OpBuilder executableBuilder(executable);
+ executableBuilder.setInsertionPointToStart(&executable.getBlock());
+ auto innerModule = executableBuilder.create<ModuleOp>(outlinedFunc.getLoc());
+
+ // TODO(b/137674142): make an ExecutableEntryPointOp and convert the
+ // entry thunk into that format.
+ innerModule.push_back(outlinedFunc);
+
+ // Copy all reachable functions into the executable.
+ // Linker passes may dedupe these later on.
+ for (auto reachableFunc : reachableFuncs) {
+ auto clonedFunc = reachableFunc.clone();
+ clonedFunc.removeAttr("iree.dispatchable");
+ innerModule.push_back(clonedFunc);
+ }
+
+ return std::make_pair(multiArchExecutable, outlinedFunc);
+}
+
+Value *insertDispatcherStore(Operation *op, Value *value, OpBuilder &builder) {
+ if (!value) {
+ return nullptr;
+ }
+
+ // If the previous value was already a memref we don't need to change
+ // anything.
+ // TODO(benvanik): ensure indices make sense.
+ if (value->getType().isa<MemRefType>()) {
+ return value;
+ } else if (value->getType().isa<TensorType>()) {
+ auto castOp = builder.create<IREE::TensorToMemRefOp>(op->getLoc(), value);
+ return castOp.getResult();
+ }
+
+ // Allocate the memref to store the value.
+ auto newStorage = builder.create<AllocOp>(
+ op->getLoc(), convertTypeToMemRef(value->getType()));
+
+ // Insert the store we'll use to box the value.
+ builder.create<StoreOp>(op->getLoc(), value, newStorage, ArrayRef<Value *>{});
+
+ return newStorage;
+}
+
+Value *insertDispatcherLoad(Operation *op, Value *originalValue,
+ Value *allocatedValue, OpBuilder &builder) {
+ // If old value was a memref we don't need to change anything.
+ if (originalValue->getType().isa<MemRefType>()) {
+ return allocatedValue;
+ } else if (originalValue->getType().isa<TensorType>()) {
+ auto castOp =
+ builder.create<IREE::MemRefToTensorOp>(op->getLoc(), allocatedValue);
+ originalValue->replaceAllUsesWith(castOp.getResult());
+ return castOp.getResult();
+ }
+
+ // Insert the load we'll use to unbox the value.
+ auto loadOp =
+ builder.create<LoadOp>(op->getLoc(), allocatedValue, ArrayRef<Value *>{});
+ originalValue->replaceAllUsesWith(loadOp);
+ return loadOp;
+}
+
+// TODO(benvanik): enough information to walk into dispatch region and compute
+// shape when not static.
+Value *allocateDispatchOutputBuffer(Location loc, MemRefType type,
+ OpBuilder &builder) {
+ // TODO(benvanik): allocation algorithm:
+ // - synthesize shape logic (magic) [[ for now assume fixed shapes ]]
+ // - insert shape logic above region
+ // - rely on folding to merge multiple calculations together
+ // - unranked = death, need to be able to alloc shape outputs
+ // - insert alloc
+ SmallVector<Value *, 4> dimPieces;
+ return builder.create<AllocOp>(loc, type, dimPieces);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Utils/DispatchUtils.h b/compiler/Utils/DispatchUtils.h
new file mode 100644
index 0000000..02a8d50
--- /dev/null
+++ b/compiler/Utils/DispatchUtils.h
@@ -0,0 +1,92 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Utilities for dispatch region and function manipulation.
+// These are shared between all dispatchable types such as the standard
+// iree.dispatch_region as well as dispatch-related types like
+// iree.reduction_region.
+
+#ifndef IREE_COMPILER_UTILS_DISPATCHUTILS_H_
+#define IREE_COMPILER_UTILS_DISPATCHUTILS_H_
+
+#include <utility>
+
+#include "compiler/IR/Ops.h"
+#include "compiler/IR/StructureOps.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Calculates the workload for |op| based on the op type.
+Value *calculateWorkload(Operation *op, Value *baseOperand);
+
+// Returns true if the func is trivially dispatchable, meaning that:
+// - it contains a single block
+// - it contains a single dispatch region
+// - it contains a return op directly returning the dispatch region results
+bool isTriviallyDispatchable(FuncOp func);
+
+// Builds a new iree.dispatch_region with the given |ops|.
+// The region will capture all required values and return all values used
+// outside of the |ops| provided. The region will be inserted at the location of
+// the last operation in the set.
+//
+// All |ops| must be compatible with the |workload| specified as they will all
+// be dispatched with the same workgroup structure.
+// TODO(benvanik): ensure we want to insert at end. Maybe front?
+LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock,
+ Value *workload, ArrayRef<Operation *> ops);
+
+// Merges multiple dispatch regions within a block into the same region,
+// if possible. Operations may be reordered if it's possible to merge more while
+// still obeying data dependencies.
+LogicalResult mergeBlockDispatchRegions(FuncOp func, Block *parentBlock);
+
+// Inlines use of the given |value| from outside of a dispatch region to inside
+// of it and removes the argument. Supports multiple arguments that reference
+// |value| and will clone the entire value tree.
+LogicalResult inlineDispatchRegionOperandsUsingValue(
+ IREE::DispatchRegionOp dispatchRegionOp, Value *value);
+
+// Creates an iree.multi_arch_executable containing an iree.executable with an
+// exported function containing the body region of |op|. Created executables
+// will be named for their original function concatenated with |symbolSuffix|.
+std::pair<IREE::MultiArchExecutableOp, FuncOp> createRegionExecutable(
+ Operation *op, FunctionType functionType, StringRef symbolSuffix);
+
+// Inserts a conversion of an arbitrary |value| to a memref, possibly by way of
+// wrapping in an allocation.
+// Returns a new memref containing the value or an alias to |value|.
+Value *insertDispatcherStore(Operation *op, Value *value, OpBuilder &builder);
+
+// Inserts a load from a wrapped memref.
+// Returns the value in the original type or an alias to the |value| memref.
+Value *insertDispatcherLoad(Operation *op, Value *originalValue,
+ Value *allocatedValue, OpBuilder &builder);
+
+// TODO(benvanik): enough information to walk into dispatch region and compute
+// shape when not static.
+Value *allocateDispatchOutputBuffer(Location loc, MemRefType type,
+ OpBuilder &builder);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_UTILS_DISPATCHUTILS_H_
diff --git a/iree/compiler/Utils/Macros.h b/compiler/Utils/Macros.h
similarity index 100%
rename from iree/compiler/Utils/Macros.h
rename to compiler/Utils/Macros.h
diff --git a/compiler/Utils/MemRefUtils.cpp b/compiler/Utils/MemRefUtils.cpp
new file mode 100644
index 0000000..c97b610
--- /dev/null
+++ b/compiler/Utils/MemRefUtils.cpp
@@ -0,0 +1,94 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Utils/MemRefUtils.h"
+
+#include <cassert>
+
+#include "compiler/IR/Ops.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+Value *resolveValueToSourceMemRef(Value *value, Operation *useOp) {
+ // TODO(benvanik): implement this for real; this is naive but enough for our
+ // simple load patterns.
+ auto *defInstr = value->getDefiningOp();
+ if (auto loadOp = dyn_cast_or_null<LoadOp>(defInstr)) {
+ // TODO(benvanik): support views.
+ return loadOp.getMemRef();
+ }
+ return nullptr;
+}
+
+Value *wrapAsTensor(Value *value, Operation *srcOp, OpBuilder &builder) {
+ if (srcOp->getResult(0)->getType().isa<TensorType>()) {
+ if (isa_and_nonnull<IREE::TensorToMemRefOp>(value->getDefiningOp())) {
+ return value->getDefiningOp()->getOperand(0);
+ }
+ auto newOp = builder.create<IREE::MemRefToTensorOp>(srcOp->getLoc(), value);
+ value = newOp.getResult();
+ }
+ return value;
+}
+
+Value *wrapAsMemRef(Value *value, Operation *srcOp, OpBuilder &builder) {
+ if (value->getType().isa<TensorType>()) {
+ if (isa_and_nonnull<IREE::MemRefToTensorOp>(value->getDefiningOp())) {
+ return value->getDefiningOp()->getOperand(0);
+ }
+ auto newOp = builder.create<IREE::TensorToMemRefOp>(srcOp->getLoc(), value);
+ value = newOp.getResult();
+ }
+ return value;
+}
+
+Value *loadAccessValue(Location location, Value *operand, OpBuilder &builder) {
+ if (operand->getType().isa<MemRefType>() ||
+ operand->getType().isa<TensorType>()) {
+ return operand;
+ }
+
+ auto memRefType = MemRefType::get({}, operand->getType());
+ if (auto loadOp = dyn_cast_or_null<LoadOp>(operand->getDefiningOp())) {
+ // TODO(benvanik): handle creating views.
+ if (loadOp.getMemRefType() == memRefType) {
+ return loadOp.getMemRef();
+ }
+ }
+
+ auto allocOp = builder.create<AllocOp>(location, memRefType);
+ builder.create<StoreOp>(location, operand, allocOp.getResult(),
+ ArrayRef<Value *>{});
+ return allocOp.getResult();
+}
+
+Value *loadResultValue(Location location, const Type &originalType,
+ Value *result, OpBuilder &builder) {
+ if (originalType.isa<MemRefType>()) {
+ return result;
+ } else if (auto tensorType = originalType.dyn_cast<TensorType>()) {
+ return result;
+ }
+
+ auto loadOp = builder.create<LoadOp>(location, result, ArrayRef<Value *>{});
+ return loadOp.getResult();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Utils/MemRefUtils.h b/compiler/Utils/MemRefUtils.h
similarity index 100%
rename from iree/compiler/Utils/MemRefUtils.h
rename to compiler/Utils/MemRefUtils.h
diff --git a/compiler/Utils/ModuleUtils.cpp b/compiler/Utils/ModuleUtils.cpp
new file mode 100644
index 0000000..439b3fc
--- /dev/null
+++ b/compiler/Utils/ModuleUtils.cpp
@@ -0,0 +1,98 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Utils/ModuleUtils.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "mlir/IR/Function.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Finds a list of functions with the given |attrName| and adds them to |funcs|.
+void findFunctionsWithAttr(ModuleOp module, const char *attrName,
+ llvm::SetVector<FuncOp> &funcs) {
+ for (auto func : module.getOps<FuncOp>()) {
+ if (func.getAttr(attrName)) {
+ funcs.insert(func);
+ }
+ }
+}
+
+// Inserts functions reachable directly from |func| to |usedFuncs|.
+void insertUsedFunctions(ModuleOp module, FuncOp func,
+ DenseSet<FuncOp> *usedFuncs,
+ std::vector<FuncOp> *toSearch) {
+ auto onCalledFunction = [&](StringRef calleeName) {
+ auto calleeFunc = module.lookupSymbol<FuncOp>(calleeName);
+ if (usedFuncs->insert(calleeFunc).second) {
+ // New function found! Add to queue for searching.
+ toSearch->push_back(calleeFunc);
+ }
+ };
+ for (auto &block : func) {
+ for (auto &op : block) {
+ // TODO(benvanik): replace with iree_hl.call check.
+ if (auto calleeAttr = op.getAttr("callee")) {
+ onCalledFunction(calleeAttr.cast<SymbolRefAttr>().getValue());
+ }
+ }
+ }
+}
+
+// Returns a set containing the names of all functions used by the given
+// |rootFuncs| list.
+DenseSet<FuncOp> findUsedFunctions(ModuleOp module,
+ ArrayRef<FuncOp> rootFuncs) {
+ // Breadth-first search.
+ DenseSet<FuncOp> usedFuncs;
+ usedFuncs.insert(rootFuncs.begin(), rootFuncs.end());
+ std::vector<FuncOp> toSearch = {rootFuncs.begin(), rootFuncs.end()};
+ while (!toSearch.empty()) {
+ auto func = toSearch.back();
+ toSearch.pop_back();
+ insertUsedFunctions(module, func, &usedFuncs, &toSearch);
+ }
+ return usedFuncs;
+}
+
+} // namespace
+
+void dropUnusedFunctions(ModuleOp module, ArrayRef<const char *> keepAttrs) {
+ // Find all of the exported functions we'll treat as roots.
+ llvm::SetVector<FuncOp> rootFuncs;
+ for (auto keepAttr : keepAttrs) {
+ findFunctionsWithAttr(module, keepAttr, rootFuncs);
+ }
+
+ // Find the full set of all used functions reachable from the given rootFuncs.
+ // This set will contain the rootFuncs.
+ auto usedFuncs = findUsedFunctions(module, rootFuncs.getArrayRef());
+
+ // Drop all unused functions.
+ std::vector<FuncOp> deadFuncs;
+ for (auto func : module.getOps<FuncOp>()) {
+ if (!llvm::is_contained(usedFuncs, func)) {
+ deadFuncs.push_back(func);
+ }
+ }
+ for (auto func : deadFuncs) {
+ func.erase();
+ }
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Utils/ModuleUtils.h b/compiler/Utils/ModuleUtils.h
similarity index 100%
rename from iree/compiler/Utils/ModuleUtils.h
rename to compiler/Utils/ModuleUtils.h
diff --git a/compiler/Utils/OpCreationUtils.cpp b/compiler/Utils/OpCreationUtils.cpp
new file mode 100644
index 0000000..7d5bcca
--- /dev/null
+++ b/compiler/Utils/OpCreationUtils.cpp
@@ -0,0 +1,45 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Utils/OpCreationUtils.h"
+
+#include <cstdint>
+
+#include "compiler/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+ElementsAttr elementsAttrFromArray(OpBuilder &builder,
+ ArrayRef<int64_t> elements) {
+ return DenseIntElementsAttr::get(
+ RankedTensorType::get(elements.size(), builder.getIntegerType(64)),
+ elements);
+}
+
+} // namespace
+
+IREE::ConstantOp createArrayConstant(OpBuilder &builder, Location loc,
+ llvm::ArrayRef<int64_t> elements) {
+ auto elementsAttr = elementsAttrFromArray(builder, elements);
+ return builder.create<IREE::ConstantOp>(loc, elementsAttr);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Utils/OpCreationUtils.h b/compiler/Utils/OpCreationUtils.h
new file mode 100644
index 0000000..1ad01bc
--- /dev/null
+++ b/compiler/Utils/OpCreationUtils.h
@@ -0,0 +1,39 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Utility functions related to the creation of new operations. Where possible,
+// use custom builders. These helpers are for situations where a custom builder
+// is not appropriate.
+
+#ifndef IREE_COMPILER_UTILS_OPCREATIONUTILS_H_
+#define IREE_COMPILER_UTILS_OPCREATIONUTILS_H_
+
+#include <cstdint>
+
+#include "compiler/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+IREE::ConstantOp createArrayConstant(OpBuilder &builder, Location loc,
+ llvm::ArrayRef<int64_t> elements);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_UTILS_OPCREATIONUTILS_H_
diff --git a/compiler/Utils/OpUtils.cpp b/compiler/Utils/OpUtils.cpp
new file mode 100644
index 0000000..df2c8f2
--- /dev/null
+++ b/compiler/Utils/OpUtils.cpp
@@ -0,0 +1,44 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Utils/OpUtils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+void removeDeadOperations(llvm::SetVector<Operation *> &deadOperations) {
+ while (!deadOperations.empty()) {
+ auto *op = deadOperations.front();
+ deadOperations.erase(deadOperations.begin());
+ for (auto *operand : op->getOperands()) {
+ // TODO(benvanik): add check for op side effects.
+ if (operand->hasOneUse()) {
+ deadOperations.insert(operand->getDefiningOp());
+ }
+ }
+ op->erase();
+ }
+}
+
+void replaceSubsequentUses(Operation *userOp, Value *oldValue,
+ Value *newValue) {
+ for (auto &use : oldValue->getUses()) {
+ if (userOp->isBeforeInBlock(use.getOwner())) {
+ use.set(newValue);
+ }
+ }
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Utils/OpUtils.h b/compiler/Utils/OpUtils.h
similarity index 100%
rename from iree/compiler/Utils/OpUtils.h
rename to compiler/Utils/OpUtils.h
diff --git a/compiler/Utils/TranslationUtils.cpp b/compiler/Utils/TranslationUtils.cpp
new file mode 100644
index 0000000..a9ec1ca
--- /dev/null
+++ b/compiler/Utils/TranslationUtils.cpp
@@ -0,0 +1,139 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Utils/TranslationUtils.h"
+
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Returns the static registry of translator names to translation functions.
+llvm::StringMap<TranslateExecutableFn>
+ &getMutableExecutableTranslationRegistry() {
+ static llvm::StringMap<TranslateExecutableFn> registry;
+ return registry;
+}
+
+// Returns true if the given |value| matches |pattern| (normal * and ? rules).
+bool matchPattern(StringRef value, StringRef pattern) {
+ size_t nextCharIndex = pattern.find_first_of("*?");
+ if (nextCharIndex == std::string::npos) {
+ return value == pattern;
+ } else if (nextCharIndex > 0) {
+ if (value.substr(0, nextCharIndex) != pattern.substr(0, nextCharIndex)) {
+ return false;
+ }
+ value = value.substr(nextCharIndex);
+ pattern = pattern.substr(nextCharIndex);
+ }
+ char patternChar = pattern[0];
+ if (value.empty() && pattern.empty()) {
+ return true;
+ } else if (patternChar == '*' && pattern.size() > 1 && value.empty()) {
+ return false;
+ } else if (patternChar == '*' && pattern.size() == 1) {
+ return true;
+ } else if (patternChar == '?' || value[0] == patternChar) {
+ return matchPattern(value.substr(1), pattern.substr(1));
+ } else if (patternChar == '*') {
+ return matchPattern(value, pattern.substr(1)) ||
+ matchPattern(value.substr(1), pattern);
+ }
+ return false;
+}
+
+// Force enables IR printing on the |passManager|.
+void enableIRPrinting(PassManager *passManager) {
+ auto notVerifier = [](Pass *pass) {
+ return pass->getName() != "FunctionVerifier" &&
+ pass->getName() != "ModuleVerifier";
+ };
+ bool printModuleScope = false;
+ passManager->enableIRPrinting(/*shouldPrintBeforePass=*/{},
+ /*shouldPrintAfterPass=*/notVerifier,
+ printModuleScope, llvm::dbgs());
+ passManager->disableMultithreading();
+}
+
+} // namespace
+
+ExecutableTranslationRegistration::ExecutableTranslationRegistration(
+ llvm::StringRef name, const TranslateExecutableFn &fn) {
+ auto ®istry = getMutableExecutableTranslationRegistry();
+ if (registry.find(name) != registry.end()) {
+ llvm::report_fatal_error(
+ "Attempting to overwrite an existing translation function");
+ }
+ assert(fn && "Attempting to register an empty translation function");
+ registry[name] = fn;
+}
+
+const llvm::StringMap<TranslateExecutableFn>
+ &getExecutableTranslationRegistry() {
+ return getMutableExecutableTranslationRegistry();
+}
+
+std::vector<std::string> matchExecutableTranslationBackendNames(
+ llvm::StringRef pattern) {
+ std::vector<std::string> matches;
+ for (auto &entry : getExecutableTranslationRegistry()) {
+ if (matchPattern(entry.getKey(), pattern)) {
+ matches.push_back(entry.getKey().str());
+ }
+ }
+ return matches;
+}
+
+std::unique_ptr<PassManager> createPassManager(
+ MLIRContext *ctx, const TranslationOptions &translationOptions) {
+ std::unique_ptr<PassManager> passManager(new PassManager(ctx));
+
+ // Enable IR printing/timing/etc from command line options.
+ registerPassManagerCLOptions();
+ applyPassManagerCLOptions(*passManager);
+
+ // Override with programmatic options.
+ if (translationOptions.print_mlir) {
+ enableIRPrinting(passManager.get());
+ }
+
+ return passManager;
+}
+
+LogicalResult runPassPipeline(const TranslationOptions &translationOptions,
+ PassManager *passManager, ModuleOp module) {
+ if (translationOptions.print_mlir) {
+ module.dump();
+ }
+
+ // Run on the module.
+ if (failed(passManager->run(module))) {
+ return failure();
+ }
+
+ if (translationOptions.print_mlir) {
+ module.dump();
+ }
+
+ return success();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/Utils/TranslationUtils.h b/compiler/Utils/TranslationUtils.h
new file mode 100644
index 0000000..a96a428
--- /dev/null
+++ b/compiler/Utils/TranslationUtils.h
@@ -0,0 +1,109 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_
+#define IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_
+
+#include <functional>
+#include <memory>
+
+#include "compiler/IR/StructureOps.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/PassManager.h"
+#include "schemas/executable_def_generated.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Common translation options for diagnostics and debugging.
+struct TranslationOptions {
+ // Enables MLIR IR printing during translation.
+ // This can be specified via the -print-ir-before-all and -print-ir-after-all
+ // command line flags or overridden programmatically via this flag.
+ bool print_mlir = false;
+
+ void CopyFrom(const TranslationOptions &other) {
+ print_mlir = other.print_mlir;
+ }
+};
+
+// Options for iree.module translation for diagnostics and debugging.
+struct ModuleTranslationOptions : public TranslationOptions {
+ // Defines which backend translators will be used to translate executables.
+ // If empty then all linked in translators will be used.
+ // TODO(benvanik): extend to allow specifying entire config blobs via mlir.
+ std::vector<std::string> target_backends;
+};
+
+// Options for iree.executable translation for diagnostics and debugging.
+// Target configuration is sourced from the iree.target_config op within the
+// iree.executable.
+struct ExecutableTranslationOptions : public TranslationOptions {};
+
+// Results of a translation operation.
+// May contain zero or more executable defs depending on translation options,
+// defined target configs, and support.
+struct ExecutableTranslationResult {
+ std::vector<std::unique_ptr<iree::ExecutableDefT>> executable_defs;
+};
+
+// Registered function that given a set of |executableOps| containing one
+// or more iree.executables will produce zero or more serialized executables.
+//
+// Each iree.executable provided contains one iree.executable_target_config with
+// backend-specific translation information. The translator can decide whether
+// to translate each independently, group them together, etc.
+//
+// The provided |executableOps| can be mutated by the callee and will be
+// preserved for debugging after translation. If any executable in
+// |executableOps| is not used by the translator then it should be erased.
+using TranslateExecutableFn =
+ std::function<llvm::Optional<ExecutableTranslationResult>(
+ ArrayRef<IREE::ExecutableOp> executableOps,
+ ExecutableTranslationOptions options)>;
+
+// Registers an executable translation function.
+struct ExecutableTranslationRegistration {
+ ExecutableTranslationRegistration(llvm::StringRef name,
+ const TranslateExecutableFn &fn);
+};
+
+// Returns a read-only reference to the translator registry.
+const llvm::StringMap<TranslateExecutableFn>
+ &getExecutableTranslationRegistry();
+
+// Returns executable translation backend names matching the given pattern.
+// This accepts wildcards for any delimited value. For example, 'foo-*-bar' will
+// match 'foo-123-bar' and 'foo-456-bar' and 'foo-10?' will match 'foo-101' and
+// 'foo-102'.
+std::vector<std::string> matchExecutableTranslationBackendNames(
+ llvm::StringRef pattern);
+
+// Creates a new pass manager initialized with the given options.
+std::unique_ptr<PassManager> createPassManager(
+ MLIRContext *ctx, const TranslationOptions &translationOptions);
+
+// Runs an initialized set of passes on the given module.
+LogicalResult runPassPipeline(const TranslationOptions &translationOptions,
+ PassManager *passManager, ModuleOp module);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_
diff --git a/compiler/Utils/TypeConversionUtils.cpp b/compiler/Utils/TypeConversionUtils.cpp
new file mode 100644
index 0000000..cf416d0
--- /dev/null
+++ b/compiler/Utils/TypeConversionUtils.cpp
@@ -0,0 +1,74 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "compiler/Utils/TypeConversionUtils.h"
+
+#include <cassert>
+
+#include "compiler/IR/Ops.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+Type legalizeType(Type type) {
+ if (type.isIndex()) {
+ return IntegerType::get(kIndexBitWidth, type.getContext());
+ } else if (type.isInteger(1)) {
+ return IntegerType::get(kBoolBitWidth, type.getContext());
+ } else if (auto memRefType = type.dyn_cast<MemRefType>()) {
+ return MemRefType::get(memRefType.getShape(),
+ legalizeType(memRefType.getElementType()));
+ } else if (auto functionType = type.dyn_cast<FunctionType>()) {
+ llvm::SmallVector<Type, 4> inputs;
+ for (const auto &oldType : functionType.getInputs()) {
+ inputs.push_back(legalizeType(oldType));
+ }
+ llvm::SmallVector<Type, 4> results;
+ for (const auto &oldType : functionType.getResults()) {
+ results.push_back(legalizeType(oldType));
+ }
+ return FunctionType::get(inputs, results, type.getContext());
+ }
+ return type;
+}
+
+Type LLTypeConverter::convertType(Type type) { return legalizeType(type); }
+
+MemRefType convertTypeToMemRef(Type type) {
+ if (type.isIntOrIndexOrFloat()) {
+ return MemRefType::get({}, type, {}, 0);
+ } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ } else if (auto memRefType = type.dyn_cast<MemRefType>()) {
+ return memRefType;
+ } else {
+ llvm_unreachable("Unconvertable type");
+ }
+}
+
+MemRefType convertTypeToMemRef(Value *value) {
+ return convertTypeToMemRef(value->getType());
+}
+
+Type MemRefTypeConverter::convertType(Type type) {
+ return convertTypeToMemRef(type);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Utils/TypeConversionUtils.h b/compiler/Utils/TypeConversionUtils.h
similarity index 100%
rename from iree/compiler/Utils/TypeConversionUtils.h
rename to compiler/Utils/TypeConversionUtils.h
diff --git a/hal/BUILD b/hal/BUILD
new file mode 100644
index 0000000..aeaf38d
--- /dev/null
+++ b/hal/BUILD
@@ -0,0 +1,377 @@
+# HAL (Hardware Abstraction Layer).
+# Subdirectories contain implementations for different hardware and
+# software backends.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "allocator",
+ srcs = ["allocator.cc"],
+ hdrs = ["allocator.h"],
+ deps = [
+ ":buffer",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "api",
+ srcs = ["api.cc"],
+ hdrs = [
+ "api.h",
+ "api_detail.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":api_hdrs",
+ ":buffer",
+ ":buffer_view",
+ ":fence",
+ ":heap_buffer",
+ ":semaphore",
+ "///base:api",
+ "///base:api_util",
+ "///base:shape",
+ "///base:tracing",
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
+cc_library(
+ name = "api_hdrs",
+ hdrs = ["api.h"],
+ deps = [
+ "///base:api_hdrs",
+ ],
+)
+
+cc_library(
+ name = "buffer",
+ srcs = ["buffer.cc"],
+ hdrs = ["buffer.h"],
+ deps = [
+ ":resource",
+ "///base:bitfield",
+ "///base:logging",
+ "///base:source_location",
+ "///base:status",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@com_google_absl//absl/types:variant",
+ ],
+)
+
+cc_test(
+ name = "buffer_test",
+ srcs = [
+ "buffer_mapping_test.cc",
+ "buffer_test.cc",
+ ],
+ deps = [
+ ":buffer",
+ ":heap_buffer",
+ "///base:status",
+ "///base:status_matchers",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/types:span",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "buffer_view",
+ srcs = ["buffer_view.cc"],
+ hdrs = ["buffer_view.h"],
+ deps = [
+ ":buffer",
+ "///base:shape",
+ "///base:source_location",
+ "///base:status",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "buffer_view_test",
+ srcs = [
+ "buffer_view_test.cc",
+ ],
+ deps = [
+ ":buffer",
+ ":buffer_view",
+ ":heap_buffer",
+ "///base:status",
+ "///base:status_matchers",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "buffer_view_string_util",
+ srcs = ["buffer_view_string_util.cc"],
+ hdrs = ["buffer_view_string_util.h"],
+ deps = [
+ ":allocator",
+ ":buffer_view",
+ ":heap_buffer",
+ "///base:source_location",
+ "///base:status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_test(
+ name = "buffer_view_string_util_test",
+ srcs = ["buffer_view_string_util_test.cc"],
+ deps = [
+ ":buffer_view_string_util",
+ "///base:status",
+ "///base:status_matchers",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "command_buffer",
+ srcs = ["command_buffer.cc"],
+ hdrs = ["command_buffer.h"],
+ deps = [
+ ":allocator",
+ ":buffer",
+ ":buffer_view",
+ ":event",
+ ":executable",
+ ":resource",
+ "///base:bitfield",
+ "///base:shape",
+ "///base:status",
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
+cc_library(
+ name = "command_buffer_validation",
+ srcs = ["command_buffer_validation.cc"],
+ hdrs = ["command_buffer_validation.h"],
+ deps = [
+ ":command_buffer",
+ "///base:logging",
+ "///base:status",
+ ],
+)
+
+cc_library(
+ name = "command_queue",
+ hdrs = ["command_queue.h"],
+ deps = [
+ ":command_buffer",
+ ":fence",
+ ":semaphore",
+ "///base:bitfield",
+ "///base:status",
+ "///base:time",
+ "@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "deferred_buffer",
+ srcs = ["deferred_buffer.cc"],
+ hdrs = ["deferred_buffer.h"],
+ deps = [
+ ":allocator",
+ ":buffer",
+ "///base:status",
+ ],
+)
+
+cc_test(
+ name = "deferred_buffer_test",
+ srcs = ["deferred_buffer_test.cc"],
+ deps = [
+ ":deferred_buffer",
+ ":heap_buffer",
+ "///base:status_matchers",
+ "///hal/testing:mock_allocator",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "device",
+ hdrs = ["device.h"],
+ deps = [
+ ":allocator",
+ ":buffer",
+ ":command_queue",
+ ":device_info",
+ ":event",
+ ":executable_cache",
+ ":semaphore",
+ "///base:status",
+ "///base:time",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "device_info",
+ hdrs = ["device_info.h"],
+ deps = [
+ "///base:bitfield",
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
+cc_library(
+ name = "device_manager",
+ srcs = ["device_manager.cc"],
+ hdrs = ["device_manager.h"],
+ deps = [
+ ":allocator",
+ ":buffer",
+ ":command_queue",
+ ":device",
+ ":device_placement",
+ ":executable_format",
+ ":fence",
+ ":heap_buffer",
+ "///base:source_location",
+ "///base:status",
+ "///base:time",
+ "///base:tracing",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "device_placement",
+ hdrs = ["device_placement.h"],
+)
+
+cc_library(
+ name = "driver",
+ hdrs = ["driver.h"],
+ deps = [
+ ":device",
+ ":device_info",
+ "///base:status",
+ ],
+)
+
+cc_library(
+ name = "driver_registry",
+ srcs = ["driver_registry.cc"],
+ hdrs = ["driver_registry.h"],
+ deps = [
+ ":driver",
+ "///base:init",
+ "///base:status",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_library(
+ name = "event",
+ hdrs = ["event.h"],
+ deps = [
+ ":resource",
+ ],
+)
+
+cc_library(
+ name = "executable",
+ hdrs = ["executable.h"],
+ deps = [":resource"],
+)
+
+cc_library(
+ name = "executable_cache",
+ srcs = ["executable_cache.cc"],
+ hdrs = ["executable_cache.h"],
+ deps = [
+ ":executable",
+ ":executable_format",
+ ":executable_spec",
+ "///base:bitfield",
+ "///base:ref_ptr",
+ "///base:status",
+ ],
+)
+
+cc_library(
+ name = "executable_format",
+ hdrs = ["executable_format.h"],
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
+cc_library(
+ name = "executable_spec",
+ hdrs = ["executable_spec.h"],
+ deps = [
+ ":executable_format",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "fence",
+ hdrs = ["fence.h"],
+ deps = [
+ ":resource",
+ "///base:status",
+ ],
+)
+
+cc_library(
+ name = "heap_buffer",
+ srcs = ["heap_buffer.cc"],
+ hdrs = ["heap_buffer.h"],
+ deps = [
+ ":allocator",
+ ":buffer",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal/host:host_buffer",
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
+cc_library(
+ name = "resource",
+ hdrs = ["resource.h"],
+ deps = [
+ "///base:ref_ptr",
+ ],
+)
+
+cc_library(
+ name = "semaphore",
+ hdrs = ["semaphore.h"],
+ deps = [
+ ":resource",
+ "@com_google_absl//absl/types:variant",
+ ],
+)
+
+cc_library(
+ name = "stack_trace",
+ hdrs = ["stack_trace.h"],
+)
diff --git a/iree/hal/CMakeLists.txt b/hal/CMakeLists.txt
similarity index 100%
rename from iree/hal/CMakeLists.txt
rename to hal/CMakeLists.txt
diff --git a/hal/allocator.cc b/hal/allocator.cc
new file mode 100644
index 0000000..57c18c2
--- /dev/null
+++ b/hal/allocator.cc
@@ -0,0 +1,77 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/allocator.h"
+
+#include <cstdint>
+#include <cstdlib>
+#include <string>
+#include <utility>
+
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+
+namespace iree {
+namespace hal {
+
+bool Allocator::CanUseBuffer(Buffer* buffer,
+ BufferUsageBitfield intended_usage) const {
+ return CanUseBufferLike(buffer->allocator(), buffer->memory_type(),
+ buffer->usage(), intended_usage);
+}
+
+StatusOr<ref_ptr<Buffer>> Allocator::AllocateConstant(
+ BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) {
+ if (AnyBitSet(source_buffer->usage() & BufferUsage::kConstant) &&
+ CanUseBuffer(source_buffer.get(), buffer_usage)) {
+ // Buffer can be used directly by the device.
+ return source_buffer;
+ }
+
+ IREE_TRACE_SCOPE0("Allocator::AllocateConstant");
+
+ // We need to map so we can copy into it.
+ buffer_usage |= BufferUsage::kMapping;
+ // It will be constant after we write it.
+ buffer_usage |= BufferUsage::kConstant;
+
+ MemoryTypeBitfield memory_type =
+ MemoryType::kDeviceLocal | MemoryType::kHostVisible;
+ ASSIGN_OR_RETURN(auto device_buffer, Allocate(memory_type, buffer_usage,
+ source_buffer->byte_length()));
+ ASSIGN_OR_RETURN(auto source_mapping,
+ source_buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+ RETURN_IF_ERROR(device_buffer->WriteData(0, source_mapping.data(),
+ source_mapping.byte_length()));
+ return device_buffer;
+}
+
+StatusOr<ref_ptr<Buffer>> Allocator::Wrap(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ const void* data,
+ size_t data_length) {
+ return WrapMutable(memory_type, MemoryAccess::kRead, buffer_usage,
+ const_cast<void*>(data), data_length);
+}
+
+StatusOr<ref_ptr<Buffer>> Allocator::WrapMutable(
+ MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield buffer_usage, void* data, size_t data_length) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Allocator does not support wrapping host memory";
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/allocator.h b/hal/allocator.h
new file mode 100644
index 0000000..b3dba23
--- /dev/null
+++ b/hal/allocator.h
@@ -0,0 +1,138 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_ALLOCATOR_H_
+#define IREE_HAL_ALLOCATOR_H_
+
+#include <cstddef>
+#include <memory>
+
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "hal/buffer.h"
+
+namespace iree {
+namespace hal {
+
+// Allocates buffers for a particular device memory space.
+//
+// Buffers allocated are only guaranteed to work with the driver that the
+// allocator services. Any attempt to use buffers on drivers they were not
+// allocated from must first be checked with CanUseBuffer.
+//
+// Thread-safe.
+class Allocator {
+ public:
+ virtual ~Allocator() = default;
+
+ // Returns true if the device can use the given buffer for the provided usage.
+ // For buffers allocated from this allocator it's expected that the result
+ // will always be true. For buffers that originate from another allocator
+ // there may be limited support for cross-device usage.
+ //
+ // Returning false indicates that the buffer must be transferred externally
+ // into a buffer compatible with the device this allocator services.
+ bool CanUseBuffer(Buffer* buffer, BufferUsageBitfield intended_usage) const;
+ virtual bool CanUseBufferLike(Allocator* source_allocator,
+ MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ BufferUsageBitfield intended_usage) const = 0;
+
+ // Returns true if the allocator can allocate a buffer with the given
+ // attributes.
+ virtual bool CanAllocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) const = 0;
+
+ // Adjusts allocation parameters to be compatible with the allocator.
+ // Certain allocators may require particular memory types to function. By
+ // adjusting the parameters prior to allocation callers can be sure they are
+ // able to successfully Allocate a buffer later on with the same parameters.
+ virtual Status MakeCompatible(MemoryTypeBitfield* memory_type,
+ BufferUsageBitfield* buffer_usage) const {
+ return OkStatus();
+ }
+
+ // Allocates a buffer from the allocator.
+ // Fails if the memory type requested for the given usage cannot be serviced.
+ // Callers can use CanAllocate to decide their memory use strategy.
+ //
+ // The memory type of the buffer returned may differ from the requested value
+ // if the device can provide more functionality; for example, if requesting
+ // MemoryType::kHostVisible but the memory is really host cached you may get
+ // a buffer back with MemoryType::kHostVisible | MemoryType::kHostCached. The
+ // only requirement is that the buffer satisfy the required bits.
+ virtual StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) = 0;
+
+ // Allocates a buffer from the allocator for use as a constant value.
+ // The provided |source_buffer| may be returned if the device can use it
+ // directly and otherwise will be copied.
+ virtual StatusOr<ref_ptr<Buffer>> AllocateConstant(
+ BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer);
+
+ // Wraps an existing host heap allocation in a buffer.
+ // Ownership of the host allocation remains with the caller and the memory
+ // must remain valid for so long as the Buffer may be in use.
+ // Will have MemoryType::kHostLocal in most cases and may not be usable
+ // by the device.
+ //
+ // The inference optimizer makes assumptions about buffer aliasing based on
+ // Buffer instances and because of this wrapping the same host buffer in
+ // multiple Buffers will create potential memory aliasing issues that can be
+ // difficult to track down. There's no checking as to whether a host buffer
+ // has already been wrapped so it's best for callers to ensure this is never
+ // possible (the simplest way being to never use Wrap and always just allocate
+ // new Buffers).
+ //
+ // Fails if the allocator cannot access host memory in this way.
+ StatusOr<ref_ptr<Buffer>> Wrap(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ const void* data, size_t data_length);
+ virtual StatusOr<ref_ptr<Buffer>> WrapMutable(
+ MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield buffer_usage, void* data, size_t data_length);
+ template <typename T>
+ StatusOr<ref_ptr<Buffer>> Wrap(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ absl::Span<const T> data);
+ template <typename T>
+ StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield buffer_usage,
+ absl::Span<T> data);
+};
+
+// Inline functions and template definitions follow:
+
+template <typename T>
+StatusOr<ref_ptr<Buffer>> Allocator::Wrap(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ absl::Span<const T> data) {
+ return Wrap(memory_type, buffer_usage, data.data(), data.size() * sizeof(T));
+}
+
+template <typename T>
+StatusOr<ref_ptr<Buffer>> Allocator::WrapMutable(
+ MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield buffer_usage, absl::Span<T> data) {
+ return WrapMutable(memory_type, allowed_access, buffer_usage, data.data(),
+ data.size() * sizeof(T));
+}
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_ALLOCATOR_H_
diff --git a/hal/api.cc b/hal/api.cc
new file mode 100644
index 0000000..b103406
--- /dev/null
+++ b/hal/api.cc
@@ -0,0 +1,439 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/api.h"
+
+#include "base/api.h"
+#include "base/api_util.h"
+#include "base/shape.h"
+#include "base/tracing.h"
+#include "hal/api_detail.h"
+#include "hal/buffer.h"
+#include "hal/buffer_view.h"
+#include "hal/fence.h"
+#include "hal/heap_buffer.h"
+#include "hal/semaphore.h"
+
+namespace iree {
+namespace hal {
+
+//===----------------------------------------------------------------------===//
+// iree::hal::Buffer
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t iree_hal_buffer_subspan(
+ iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
+ iree_device_size_t byte_length, iree_allocator_t allocator,
+ iree_hal_buffer_t** out_buffer) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_subspan");
+
+ if (!out_buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_buffer = nullptr;
+
+ if (!buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ auto handle = add_ref(reinterpret_cast<Buffer*>(buffer));
+
+ IREE_API_ASSIGN_OR_RETURN(auto new_handle,
+ Buffer::Subspan(handle, byte_offset, byte_length));
+
+ *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(new_handle.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_buffer_retain(iree_hal_buffer_t* buffer) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_retain");
+ auto* handle = reinterpret_cast<Buffer*>(buffer);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_buffer_release(iree_hal_buffer_t* buffer) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_release");
+ auto* handle = reinterpret_cast<Buffer*>(buffer);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_device_size_t
+iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_byte_length");
+ const auto* handle = reinterpret_cast<const Buffer*>(buffer);
+ CHECK(handle) << "NULL buffer handle";
+ return handle->byte_length();
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
+ iree_device_size_t byte_length) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_zero");
+ auto* handle = reinterpret_cast<Buffer*>(buffer);
+ if (!buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(handle->Fill8(byte_offset, byte_length, 0));
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
+ iree_device_size_t byte_length, const void* pattern,
+ iree_host_size_t pattern_length) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_fill");
+ auto* handle = reinterpret_cast<Buffer*>(buffer);
+ if (!buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(
+ handle->Fill(byte_offset, byte_length, pattern, pattern_length));
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_buffer_read_data(
+ iree_hal_buffer_t* buffer, iree_device_size_t source_offset,
+ void* target_buffer, iree_device_size_t data_length) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_read_data");
+ auto* handle = reinterpret_cast<Buffer*>(buffer);
+ if (!buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(
+ handle->ReadData(source_offset, target_buffer, data_length));
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_buffer_write_data(
+ iree_hal_buffer_t* buffer, iree_device_size_t target_offset,
+ const void* source_buffer, iree_device_size_t data_length) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_write_data");
+ auto* handle = reinterpret_cast<Buffer*>(buffer);
+ if (!buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(
+ handle->WriteData(target_offset, source_buffer, data_length));
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_buffer_map(
+ iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access,
+ iree_device_size_t element_offset, iree_device_size_t element_length,
+ iree_hal_mapped_memory_t* out_mapped_memory) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_map");
+
+ if (!out_mapped_memory) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_mapped_memory, 0, sizeof(*out_mapped_memory));
+
+ auto* buffer_handle = reinterpret_cast<Buffer*>(buffer);
+ if (!buffer_handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ IREE_API_ASSIGN_OR_RETURN(
+ auto mapping, buffer_handle->MapMemory<uint8_t>(
+ static_cast<MemoryAccessBitfield>(memory_access),
+ element_offset, element_length));
+
+ static_assert(sizeof(iree_hal_mapped_memory_t::reserved) >=
+ sizeof(MappedMemory<uint8_t>),
+ "C mapped memory struct must have large enough storage for the "
+ "matching C++ struct");
+ auto* mapping_storage =
+ reinterpret_cast<MappedMemory<uint8_t>*>(out_mapped_memory->reserved);
+ *mapping_storage = std::move(mapping);
+
+ out_mapped_memory->contents = {const_cast<uint8_t*>(mapping_storage->data()),
+ mapping_storage->size()};
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_buffer_unmap(
+ iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_map");
+ auto* buffer_handle = reinterpret_cast<Buffer*>(buffer);
+ if (!buffer_handle || !mapped_memory) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto* mapping =
+ reinterpret_cast<MappedMemory<uint8_t>*>(mapped_memory->reserved);
+ mapping->reset();
+
+ std::memset(mapped_memory, 0, sizeof(*mapped_memory));
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::hal::HeapBuffer
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate(
+ iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage,
+ iree_host_size_t allocation_size, iree_allocator_t contents_allocator,
+ iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) {
+ IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate");
+
+ if (!out_buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_buffer = nullptr;
+
+ if (!allocation_size) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto handle = HeapBuffer::Allocate(
+ static_cast<MemoryTypeBitfield>(memory_type),
+ static_cast<BufferUsageBitfield>(usage), allocation_size);
+
+ *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(
+ static_cast<Buffer*>(handle.release()));
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy(
+ iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage,
+ iree_hal_memory_access_t allowed_access, iree_byte_span_t contents,
+ iree_allocator_t contents_allocator, iree_allocator_t allocator,
+ iree_hal_buffer_t** out_buffer) {
+ IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate_copy");
+
+ if (!out_buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_buffer = nullptr;
+
+ if (!contents.data || !contents.data_length) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto handle = HeapBuffer::AllocateCopy(
+ static_cast<BufferUsageBitfield>(usage),
+ static_cast<MemoryAccessBitfield>(allowed_access), contents.data,
+ contents.data_length);
+
+ *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(handle.release());
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap(
+ iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t usage, iree_byte_span_t contents,
+ iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) {
+ IREE_TRACE_SCOPE0("iree_hal_heap_buffer_wrap");
+
+ if (!out_buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_buffer = nullptr;
+
+ if (!contents.data || !contents.data_length) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto handle =
+ HeapBuffer::WrapMutable(static_cast<MemoryTypeBitfield>(memory_type),
+ static_cast<MemoryAccessBitfield>(allowed_access),
+ static_cast<BufferUsageBitfield>(usage),
+ contents.data, contents.data_length);
+
+ *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(handle.release());
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::hal::BufferView
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t iree_hal_buffer_view_create(
+ iree_hal_buffer_t* buffer, iree_shape_t shape, int8_t element_size,
+ iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_view_create");
+
+ if (!out_buffer_view) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_buffer_view = nullptr;
+
+ if (!buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ } else if (shape.rank > kMaxRank || element_size <= 0) {
+ return IREE_STATUS_OUT_OF_RANGE;
+ }
+
+ // Allocate and initialize the iree_hal_buffer_view struct.
+ iree_hal_buffer_view* handle = nullptr;
+ IREE_API_RETURN_IF_API_ERROR(allocator.alloc(
+ allocator.self, sizeof(*handle), reinterpret_cast<void**>(&handle)));
+ new (handle) iree_hal_buffer_view();
+ handle->allocator = allocator;
+
+ handle->impl.buffer = add_ref(reinterpret_cast<Buffer*>(buffer));
+ handle->impl.shape = {shape.dims, shape.rank};
+ handle->impl.element_size = element_size;
+
+ *out_buffer_view = reinterpret_cast<iree_hal_buffer_view_t*>(handle);
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_view_retain");
+ auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_view_release");
+ auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_buffer_view_assign(
+ iree_hal_buffer_view_t* buffer_view, iree_hal_buffer_t* buffer,
+ iree_shape_t shape, int8_t element_size) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_view_assign");
+ auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->impl.buffer.reset();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_buffer_view_reset(iree_hal_buffer_view_t* buffer_view) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_view_reset");
+ auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->impl.buffer.reset();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_hal_buffer_t* iree_hal_buffer_view_buffer(
+ const iree_hal_buffer_view_t* buffer_view) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_view_buffer");
+ const auto* handle =
+ reinterpret_cast<const iree_hal_buffer_view*>(buffer_view);
+ CHECK(handle) << "NULL buffer_view handle";
+ return reinterpret_cast<iree_hal_buffer_t*>(handle->impl.buffer.get());
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_buffer_view_shape(
+ const iree_hal_buffer_view_t* buffer_view, iree_shape_t* out_shape) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_view_shape");
+
+ if (!out_shape) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ out_shape->rank = 0;
+
+ const auto* handle =
+ reinterpret_cast<const iree_hal_buffer_view*>(buffer_view);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ const auto& shape = handle->impl.shape;
+ return ToApiShape(shape, out_shape);
+}
+
+IREE_API_EXPORT int8_t
+iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_view_element_size");
+ const auto* handle =
+ reinterpret_cast<const iree_hal_buffer_view*>(buffer_view);
+ CHECK(handle) << "NULL buffer_view handle";
+ return handle->impl.element_size;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::hal::Semaphore
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t
+iree_hal_semaphore_retain(iree_hal_semaphore_t* semaphore) {
+ IREE_TRACE_SCOPE0("iree_hal_semaphore_retain");
+ auto* handle = reinterpret_cast<Semaphore*>(semaphore);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t
+iree_hal_semaphore_release(iree_hal_semaphore_t* semaphore) {
+ IREE_TRACE_SCOPE0("iree_hal_semaphore_release");
+ auto* handle = reinterpret_cast<Semaphore*>(semaphore);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::hal::Fence
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t iree_hal_fence_retain(iree_hal_fence_t* fence) {
+ IREE_TRACE_SCOPE0("iree_hal_fence_retain");
+ auto* handle = reinterpret_cast<Fence*>(fence);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_fence_release(iree_hal_fence_t* fence) {
+ IREE_TRACE_SCOPE0("iree_hal_fence_release");
+ auto* handle = reinterpret_cast<Fence*>(fence);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/api.h b/hal/api.h
new file mode 100644
index 0000000..4d305f9
--- /dev/null
+++ b/hal/api.h
@@ -0,0 +1,366 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// See iree/base/api.h for documentation on the API conventions used.
+
+#ifndef IREE_HAL_API_H_
+#define IREE_HAL_API_H_
+
+#include <stdint.h>
+
+#include "base/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// Types and Enums
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_buffer iree_hal_buffer_t;
+typedef struct iree_hal_buffer_view iree_hal_buffer_view_t;
+typedef struct iree_hal_semaphore iree_hal_semaphore_t;
+typedef struct iree_hal_fence iree_hal_fence_t;
+
+// Reference to a buffer's mapped memory.
+typedef struct {
+ // Contents of the buffer. Behavior is undefined if an access is performed
+ // whose type was not specified during mapping.
+ iree_byte_span_t contents;
+
+ // Used internally - do not modify.
+ uint64_t reserved[8];
+} iree_hal_mapped_memory_t;
+
+// A bitfield specifying properties for a memory type.
+typedef enum {
+ IREE_HAL_MEMORY_TYPE_NONE = 0,
+
+ // Memory is lazily allocated by the device and only exists transiently.
+ // This is the optimal mode for memory used only within a single command
+ // buffer. Transient buffers, even if they have
+ // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE set, should be treated as device-local
+ // and opaque as they may have no memory attached to them outside of the time
+ // they are being evaluated on devices.
+ //
+ // This flag can be treated as a hint in most cases; allocating a buffer with
+ // it set _may_ return the same as if it had not be set. Certain allocation
+ // routines may use the hint to more tightly control reuse or defer wiring the
+ // memory.
+ IREE_HAL_MEMORY_TYPE_TRANSIENT = 1 << 0,
+
+ // Memory allocated with this type can be mapped for host access using
+ // iree_hal_buffer_map.
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE = 1 << 1,
+
+ // The host cache management commands MappedMemory::Flush and
+ // MappedMemory::Invalidate are not needed to flush host writes
+ // to the device or make device writes visible to the host, respectively.
+ IREE_HAL_MEMORY_TYPE_HOST_COHERENT = 1 << 2,
+
+ // Memory allocated with this type is cached on the host. Host memory
+ // accesses to uncached memory are slower than to cached memory, however
+ // uncached memory is always host coherent. MappedMemory::Flush must be used
+ // to ensure the device has visibility into any changes made on the host and
+ // Invalidate must be used to ensure the host has visibility into any changes
+ // made on the device.
+ IREE_HAL_MEMORY_TYPE_HOST_CACHED = 1 << 3,
+
+ // Memory is accessible as normal host allocated memory.
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL =
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_COHERENT,
+
+ // Memory allocated with this type is visible to the device for execution.
+ // Being device visible does not mean the same thing as
+ // IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL. Though an allocation may be visible to
+ // the device and therefore useable for execution it may require expensive
+ // mapping or implicit transfers.
+ IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE = 1 << 4,
+
+ // Memory allocated with this type is the most efficient for device access.
+ // Devices may support using memory that is not device local via
+ // IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE but doing so can incur non-trivial
+ // performance penalties. Device local memory, on the other hand, is
+ // guaranteed to be fast for all operations.
+ IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL =
+ IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | (1 << 5),
+} iree_hal_memory_type_t;
+
+// A bitfield specifying how memory will be accessed in a mapped memory region.
+typedef enum {
+ // Memory is not mapped.
+ IREE_HAL_MEMORY_ACCESS_NONE = 0,
+ // Memory will be read.
+ // If a buffer is only mapped for reading it may still be possible to write to
+ // it but the results will be undefined (as it may present coherency issues).
+ IREE_HAL_MEMORY_ACCESS_READ = 1 << 0,
+ // Memory will be written.
+ // If a buffer is only mapped for writing it may still be possible to read
+ // from it but the results will be undefined or incredibly slow (as it may
+ // be mapped by the driver as uncached).
+ IREE_HAL_MEMORY_ACCESS_WRITE = 1 << 1,
+ // Memory will be discarded prior to mapping.
+ // The existing contents will be undefined after mapping and must be written
+ // to ensure validity.
+ IREE_HAL_MEMORY_ACCESS_DISCARD = 1 << 2,
+ // Memory will be discarded and completely overwritten in a single operation.
+ IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE =
+ IREE_HAL_MEMORY_ACCESS_WRITE | IREE_HAL_MEMORY_ACCESS_DISCARD,
+ // Memory may have any operation performed on it.
+ IREE_HAL_MEMORY_ACCESS_ALL = IREE_HAL_MEMORY_ACCESS_READ |
+ IREE_HAL_MEMORY_ACCESS_WRITE |
+ IREE_HAL_MEMORY_ACCESS_DISCARD,
+} iree_hal_memory_access_t;
+
+// Bitfield that defines how a buffer is intended to be used.
+// Usage allows the driver to appropriately place the buffer for more
+// efficient operations of the specified types.
+typedef enum {
+ IREE_HAL_BUFFER_USAGE_NONE = 0,
+
+ // The buffer, once defined, will not be mapped or updated again.
+ // This should be used for uniform parameter values such as runtime
+ // constants for executables. Doing so may allow drivers to inline values or
+ // represent them in command buffers more efficiently (avoiding memory reads
+ // or swapping, etc).
+ IREE_HAL_BUFFER_USAGE_CONSTANT = 1 << 0,
+
+ // The buffer can be used as the source or target of a transfer command
+ // (CopyBuffer, UpdateBuffer, etc).
+ //
+ // If |IREE_HAL_BUFFER_USAGE_MAPPING| is not specified drivers may safely
+ // assume that the host may never need visibility of this buffer as all
+ // accesses will happen via command buffers.
+ IREE_HAL_BUFFER_USAGE_TRANSFER = 1 << 1,
+
+ // The buffer can be mapped by the host application for reading and writing.
+ //
+ // As mapping may require placement in special address ranges or system
+ // calls to enable visibility the driver can use the presence (or lack of)
+ // this flag to perform allocation-type setup and avoid initial mapping
+ // overhead.
+ IREE_HAL_BUFFER_USAGE_MAPPING = 1 << 2,
+
+ // The buffer can be provided as an input or output to an executable.
+ // Buffers of this type may be directly used by drivers during dispatch.
+ IREE_HAL_BUFFER_USAGE_DISPATCH = 1 << 3,
+
+ // Buffer may be used for any operation.
+ IREE_HAL_BUFFER_USAGE_ALL = IREE_HAL_BUFFER_USAGE_TRANSFER |
+ IREE_HAL_BUFFER_USAGE_MAPPING |
+ IREE_HAL_BUFFER_USAGE_DISPATCH,
+} iree_hal_buffer_usage_t;
+
+//===----------------------------------------------------------------------===//
+// iree::hal::Buffer
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Returns a reference to a subspan of the |buffer|.
+// If |byte_length| is IREE_WHOLE_BUFFER the remaining bytes in the buffer after
+// |byte_offset| (possibly 0) will be selected.
+//
+// The parent buffer will remain alive for the lifetime of the subspan
+// returned. If the subspan is a small portion this may cause additional
+// memory to remain allocated longer than required.
+//
+// Returns the given |buffer| if the requested span covers the entire range.
+// |out_buffer| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_subspan(
+ iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
+ iree_device_size_t byte_length, iree_allocator_t allocator,
+ iree_hal_buffer_t** out_buffer);
+
+// Retains the given |buffer| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_buffer_retain(iree_hal_buffer_t* buffer);
+
+// Releases the given |buffer| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_buffer_release(iree_hal_buffer_t* buffer);
+
+// Returns the size in bytes of the buffer.
+IREE_API_EXPORT iree_device_size_t IREE_API_CALL
+iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer);
+
+// Sets a range of the buffer to binary zero.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
+ iree_device_size_t byte_length);
+
+// Sets a range of the buffer to the given value.
+// Only |pattern_length| values with 1, 2, or 4 bytes are supported.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
+ iree_device_size_t byte_length, const void* pattern,
+ iree_host_size_t pattern_length);
+
+// Reads a block of data from the buffer at the given offset.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_read_data(
+ iree_hal_buffer_t* buffer, iree_device_size_t source_offset,
+ void* target_buffer, iree_device_size_t data_length);
+
+// Writes a block of byte data into the buffer at the given offset.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_write_data(
+ iree_hal_buffer_t* buffer, iree_device_size_t target_offset,
+ const void* source_buffer, iree_device_size_t data_length);
+
+// Maps the buffer to be accessed as a host pointer into |out_mapped_memory|.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_map(
+ iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access,
+ iree_device_size_t element_offset, iree_device_size_t element_length,
+ iree_hal_mapped_memory_t* out_mapped_memory);
+
+// Unmaps the buffer as was previously mapped to |mapped_memory|.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_unmap(
+ iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::hal::HeapBuffer
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Allocates a zeroed host heap buffer of the given size.
+// The buffer contents will be allocated with |contents_allocator| while
+// |allocator| is used for the iree_hal_buffer_t.
+//
+// Returns a buffer allocated with malloc that may not be usable by devices
+// without copies. |memory_type| should be set to
+// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases.
+// |out_buffer| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate(
+ iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage,
+ iree_host_size_t allocation_size, iree_allocator_t contents_allocator,
+ iree_allocator_t allocator, iree_hal_buffer_t** out_buffer);
+
+// Allocates a host heap buffer with a copy of the given data.
+// The buffer contents will be allocated with |contents_allocator| while
+// |allocator| is used for the iree_hal_buffer_t.
+//
+// Returns a buffer allocated with malloc that may not be usable by devices
+// without copies. |memory_type| should be set to
+// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases.
+// |out_buffer| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy(
+ iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage,
+ iree_hal_memory_access_t allowed_access, iree_byte_span_t contents,
+ iree_allocator_t contents_allocator, iree_allocator_t allocator,
+ iree_hal_buffer_t** out_buffer);
+
+// Wraps an existing host heap allocation in a buffer.
+// Ownership of the host allocation remains with the caller and the memory
+// must remain valid for so long as the iree_hal_buffer_t may be in use.
+//
+// Returns a buffer allocated with malloc that may not be usable by devices
+// without copies. |memory_type| should be set to
+// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases.
+// |out_buffer| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap(
+ iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t usage, iree_byte_span_t contents,
+ iree_allocator_t allocator, iree_hal_buffer_t** out_buffer);
+
+// TODO(benvanik): add a wrap that takes an allocator just for the buffer.
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::hal::BufferView
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Creates a buffer view with the given |buffer|, which may be nullptr.
+// |out_buffer_view| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_create(
+ iree_hal_buffer_t* buffer, iree_shape_t shape, int8_t element_size,
+ iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view);
+
+// Retains the given |buffer_view| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view);
+
+// Releases the given |buffer_view| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view);
+
+// Sets the buffer view to point at the new |buffer| with the given metadata.
+// To clear a buffer_view to empty use iree_hal_buffer_view_reset.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_assign(
+ iree_hal_buffer_view_t* buffer_view, iree_hal_buffer_t* buffer,
+ iree_shape_t shape, int8_t element_size);
+
+// Resets the buffer view to have an empty buffer and shape.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_buffer_view_reset(iree_hal_buffer_view_t* buffer_view);
+
+// Returns the buffer underlying the buffer view.
+// The caller must retain the returned buffer if they want to continue using it.
+IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL
+iree_hal_buffer_view_buffer(const iree_hal_buffer_view_t* buffer_view);
+
+// Returns the shape of the buffer view in |out_shape|.
+// If there is not enough space in |out_shape| to store all dimensions then
+// IREE_STATUS_OUT_OF_RANGE is returned and |out_shape|.rank is set to the rank.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_shape(
+ const iree_hal_buffer_view_t* buffer_view, iree_shape_t* out_shape);
+
+// Returns the size of each element in the buffer view in bytes.
+IREE_API_EXPORT int8_t IREE_API_CALL
+iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::hal::Semaphore
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Retains the given |semaphore| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_semaphore_retain(iree_hal_semaphore_t* semaphore);
+
+// Releases the given |semaphore| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_semaphore_release(iree_hal_semaphore_t* semaphore);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::hal::Fence
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Retains the given |fence| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_fence_retain(iree_hal_fence_t* fence);
+
+// Releases the given |fence| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_fence_release(iree_hal_fence_t* fence);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_API_H_
diff --git a/hal/api_detail.h b/hal/api_detail.h
new file mode 100644
index 0000000..09b6230
--- /dev/null
+++ b/hal/api_detail.h
@@ -0,0 +1,42 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Additional definitions for internal users of the api. This should only
+// be included from internal implementation files.
+
+#ifndef IREE_HAL_API_DETAIL_H_
+#define IREE_HAL_API_DETAIL_H_
+
+#include "hal/api.h"
+#include "hal/buffer_view.h"
+
+namespace iree {
+namespace hal {
+
+// In the API, buffer views are ref objects, and this allows parts of the
+// API outside of the HAL to work with them.
+struct iree_hal_buffer_view : public RefObject<iree_hal_buffer_view> {
+ BufferView impl;
+ iree_allocator_t allocator;
+
+ static void Delete(iree_hal_buffer_view* ptr) {
+ ptr->impl.buffer.reset();
+ ptr->allocator.free(ptr->allocator.self, ptr);
+ }
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif
diff --git a/hal/buffer.cc b/hal/buffer.cc
new file mode 100644
index 0000000..bfac15d
--- /dev/null
+++ b/hal/buffer.cc
@@ -0,0 +1,549 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/buffer.h"
+
+#include <algorithm>
+#include <atomic>
+#include <cstdint>
+#include <cstring>
+#include <sstream>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/variant.h"
+#include "base/status.h"
+
+namespace iree {
+namespace hal {
+
+#if HAS_IREE_BUFFER_DEBUG_NAME
+namespace {
+// Used for diagnostic purposes only as a default buffer name.
+std::atomic<int> next_buffer_id_{0};
+} // namespace
+#endif // HAS_IREE_BUFFER_DEBUG_NAME
+
+std::string MemoryTypeString(MemoryTypeBitfield memory_type) {
+ return FormatBitfieldValue(memory_type,
+ {
+ // Combined:
+ {MemoryType::kHostLocal, "kHostLocal"},
+ {MemoryType::kDeviceLocal, "kDeviceLocal"},
+ // Separate:
+ {MemoryType::kTransient, "kTransient"},
+ {MemoryType::kHostVisible, "kHostVisible"},
+ {MemoryType::kHostCoherent, "kHostCoherent"},
+ {MemoryType::kHostCached, "kHostCached"},
+ {MemoryType::kDeviceVisible, "kDeviceVisible"},
+ });
+}
+
+std::string MemoryAccessString(MemoryAccessBitfield memory_access) {
+ return FormatBitfieldValue(memory_access,
+ {
+ // Combined:
+ {MemoryAccess::kAll, "kAll"},
+ {MemoryAccess::kDiscardWrite, "kDiscardWrite"},
+ // Separate:
+ {MemoryAccess::kRead, "kRead"},
+ {MemoryAccess::kWrite, "kWrite"},
+ {MemoryAccess::kDiscard, "kDiscard"},
+ });
+}
+
+std::string BufferUsageString(BufferUsageBitfield buffer_usage) {
+ return FormatBitfieldValue(buffer_usage,
+ {
+ // Combined:
+ {BufferUsage::kAll, "kAll"},
+ // Separate:
+ {BufferUsage::kConstant, "kConstant"},
+ {BufferUsage::kTransfer, "kTransfer"},
+ {BufferUsage::kMapping, "kMapping"},
+ {BufferUsage::kDispatch, "kDispatch"},
+ });
+}
+
+// Special router for buffers that just reference other buffers.
+// We keep this out of the base Buffer so that it's a bit easier to track
+// delegation.
+class SubspanBuffer : public Buffer {
+ public:
+ SubspanBuffer(ref_ptr<Buffer> parent_buffer, device_size_t byte_offset,
+ device_size_t byte_length)
+ : Buffer(parent_buffer->allocator(), parent_buffer->memory_type(),
+ parent_buffer->allowed_access(), parent_buffer->usage(),
+ parent_buffer->allocation_size(), byte_offset, byte_length) {
+ allocated_buffer_ = parent_buffer.get();
+ parent_buffer_ = std::move(parent_buffer);
+ }
+
+ protected:
+ Status FillImpl(device_size_t byte_offset, device_size_t byte_length,
+ const void* pattern, device_size_t pattern_length) override {
+ return parent_buffer_->FillImpl(byte_offset, byte_length, pattern,
+ pattern_length);
+ }
+
+ Status ReadDataImpl(device_size_t source_offset, void* data,
+ device_size_t data_length) override {
+ return parent_buffer_->ReadDataImpl(source_offset, data, data_length);
+ }
+
+ Status WriteDataImpl(device_size_t target_offset, const void* data,
+ device_size_t data_length) override {
+ return parent_buffer_->WriteDataImpl(target_offset, data, data_length);
+ }
+
+ Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer,
+ device_size_t source_offset,
+ device_size_t data_length) override {
+ return parent_buffer_->CopyDataImpl(target_offset, source_buffer,
+ source_offset, data_length);
+ }
+
+ Status MapMemoryImpl(MappingMode mapping_mode,
+ MemoryAccessBitfield memory_access,
+ device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void** out_data) override {
+ return parent_buffer_->MapMemoryImpl(mapping_mode, memory_access,
+ local_byte_offset, local_byte_length,
+ out_data);
+ }
+
+ Status UnmapMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length, void* data) override {
+ return parent_buffer_->UnmapMemoryImpl(local_byte_offset, local_byte_length,
+ data);
+ }
+
+ Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) override {
+ return parent_buffer_->InvalidateMappedMemoryImpl(local_byte_offset,
+ local_byte_length);
+ }
+
+ Status FlushMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) override {
+ return parent_buffer_->FlushMappedMemoryImpl(local_byte_offset,
+ local_byte_length);
+ }
+};
+
+// static
+StatusOr<ref_ptr<Buffer>> Buffer::Subspan(const ref_ptr<Buffer>& buffer,
+ device_size_t byte_offset,
+ device_size_t byte_length) {
+ RETURN_IF_ERROR(buffer->CalculateRange(byte_offset, byte_length, &byte_offset,
+ &byte_length));
+ if (byte_offset == 0 && byte_length == buffer->byte_length()) {
+ // Asking for the same buffer.
+ return add_ref(buffer);
+ }
+
+ // To avoid heavy nesting of subspans that just add indirection we go to the
+ // parent buffer directly. If we wanted better accounting (to track where
+ // buffers came from) we'd want to avoid this but I'm not sure that's worth
+ // the super deep indirection that could arise.
+ if (buffer->allocated_buffer() != buffer.get()) {
+ CHECK(buffer->parent_buffer_);
+ return Buffer::Subspan(buffer->parent_buffer_, byte_offset, byte_length);
+ } else {
+ return {make_ref<SubspanBuffer>(add_ref(buffer), byte_offset, byte_length)};
+ }
+}
+
+// static
+Buffer::Overlap Buffer::TestOverlap(
+ Buffer* lhs_buffer, device_size_t lhs_offset, device_size_t lhs_length,
+ Buffer* rhs_buffer, device_size_t rhs_offset, device_size_t rhs_length) {
+ if (lhs_buffer->allocated_buffer() != rhs_buffer->allocated_buffer()) {
+ // Not even the same buffers.
+ return Overlap::kDisjoint;
+ }
+ // Resolve offsets into the underlying allocation.
+ device_size_t lhs_alloc_offset = lhs_buffer->byte_offset() + lhs_offset;
+ device_size_t rhs_alloc_offset = rhs_buffer->byte_offset() + rhs_offset;
+ device_size_t lhs_alloc_length = lhs_length == kWholeBuffer
+ ? lhs_buffer->byte_length() - lhs_offset
+ : lhs_length;
+ device_size_t rhs_alloc_length = rhs_length == kWholeBuffer
+ ? rhs_buffer->byte_length() - rhs_offset
+ : rhs_length;
+ if (!lhs_alloc_length || !rhs_alloc_length) {
+ return Overlap::kDisjoint;
+ }
+ if (lhs_alloc_offset == rhs_alloc_offset &&
+ lhs_alloc_length == rhs_alloc_length) {
+ return Overlap::kComplete;
+ }
+ return lhs_alloc_offset + lhs_alloc_length > rhs_alloc_offset &&
+ rhs_alloc_offset + rhs_alloc_length > lhs_alloc_offset
+ ? Overlap::kPartial
+ : Overlap::kDisjoint;
+}
+
+// static
+bool Buffer::DoesOverlap(Buffer* lhs_buffer, device_size_t lhs_offset,
+ device_size_t lhs_length, Buffer* rhs_buffer,
+ device_size_t rhs_offset, device_size_t rhs_length) {
+ return TestOverlap(lhs_buffer, lhs_offset, lhs_length, rhs_buffer, rhs_offset,
+ rhs_length) != Overlap::kDisjoint;
+}
+
+Buffer::Buffer(Allocator* allocator, MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
+ device_size_t allocation_size, device_size_t byte_offset,
+ device_size_t byte_length)
+ : allocated_buffer_(const_cast<Buffer*>(this)),
+ allocator_(allocator),
+ memory_type_(memory_type),
+ allowed_access_(allowed_access),
+ usage_(usage),
+ allocation_size_(allocation_size),
+ byte_offset_(byte_offset),
+ byte_length_(byte_length) {
+#if HAS_IREE_BUFFER_DEBUG_NAME
+ // Default name for logging.
+ // It'd be nice to defer this until it's required but that would require
+ // synchronization or something.
+ const char* debug_name_prefix = "";
+ if ((memory_type_ & MemoryType::kHostLocal) == MemoryType::kHostLocal) {
+ debug_name_prefix = "host_buffer_";
+ } else if ((memory_type_ & MemoryType::kDeviceLocal) ==
+ MemoryType::kDeviceLocal) {
+ // TODO(benvanik): include allocator ID to differentiate devices.
+ debug_name_prefix = "device_buffer_";
+ }
+ debug_name_ = absl::StrCat(debug_name_prefix, next_buffer_id_++);
+#endif // HAS_IREE_BUFFER_DEBUG_NAME
+}
+
+Buffer* Buffer::allocated_buffer() const noexcept {
+ Buffer* allocated_buffer = allocated_buffer_;
+ while (allocated_buffer != this &&
+ allocated_buffer != allocated_buffer->allocated_buffer()) {
+ allocated_buffer = allocated_buffer->allocated_buffer();
+ }
+ return allocated_buffer;
+}
+
+std::string Buffer::DebugString() const {
+ std::ostringstream stream;
+ stream << allocated_buffer()->debug_name() << "["
+ << (allocation_size() == kWholeBuffer
+ ? "?"
+ : std::to_string(allocation_size()))
+ << "].";
+ if (AnyBitSet(memory_type() & MemoryType::kTransient)) stream << "Z";
+ if ((memory_type() & MemoryType::kHostLocal) == MemoryType::kHostLocal) {
+ stream << "h";
+ } else {
+ if (AnyBitSet(memory_type() & MemoryType::kHostVisible)) stream << "v";
+ if (AnyBitSet(memory_type() & MemoryType::kHostCoherent)) stream << "x";
+ if (AnyBitSet(memory_type() & MemoryType::kHostCached)) stream << "c";
+ }
+ if ((memory_type() & MemoryType::kDeviceLocal) == MemoryType::kDeviceLocal) {
+ stream << "D";
+ } else {
+ if (AnyBitSet(memory_type() & MemoryType::kDeviceVisible)) stream << "V";
+ }
+ stream << ".";
+ if (AnyBitSet(usage() & BufferUsage::kConstant)) stream << "c";
+ if (AnyBitSet(usage() & BufferUsage::kTransfer)) stream << "t";
+ if (AnyBitSet(usage() & BufferUsage::kMapping)) stream << "m";
+ if (AnyBitSet(usage() & BufferUsage::kDispatch)) stream << "d";
+ if (byte_offset_ || byte_length_ != allocation_size_) {
+ stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1)
+ << ")";
+ }
+ return stream.str();
+}
+
+std::string Buffer::DebugStringShort() const {
+ // TODO(benvanik): figure out what's most useful here. Maybe a long variant?
+ std::ostringstream stream;
+ stream << allocated_buffer()->debug_name() << "["
+ << (allocation_size() == kWholeBuffer
+ ? "?"
+ : std::to_string(allocation_size()))
+ << "]";
+ if (byte_offset_ || byte_length_ != allocation_size_) {
+ stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1)
+ << ")";
+ }
+ return stream.str();
+}
+
+Status Buffer::ValidateCompatibleMemoryType(
+ MemoryTypeBitfield memory_type) const {
+ if ((memory_type_ & memory_type) != memory_type) {
+ // Missing one or more bits.
+ return PermissionDeniedErrorBuilder(IREE_LOC)
+ << "Buffer memory type is not compatible with the requested "
+ "operation; buffer has "
+ << MemoryTypeString(memory_type_) << ", operation requires "
+ << MemoryTypeString(memory_type);
+ }
+ return OkStatus();
+}
+
+Status Buffer::ValidateAccess(MemoryAccessBitfield memory_access) const {
+ if (!AnyBitSet(memory_access &
+ (MemoryAccess::kRead | MemoryAccess::kWrite))) {
+ // No actual access bits defined.
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Memory access must specify one or more of kRead or kWrite";
+ } else if ((allowed_access_ & memory_access) != memory_access) {
+ // Bits must match exactly.
+ return PermissionDeniedErrorBuilder(IREE_LOC)
+ << "The buffer does not support the requested access type; buffer "
+ "allows "
+ << MemoryAccessString(allowed_access_) << ", operation requires "
+ << MemoryAccessString(memory_access);
+ }
+ return OkStatus();
+}
+
+Status Buffer::ValidateUsage(BufferUsageBitfield usage) const {
+ if ((usage_ & usage) != usage) {
+ // Missing one or more bits.
+ return PermissionDeniedErrorBuilder(IREE_LOC)
+ << "Requested usage was not specified when the buffer was "
+ "allocated; buffer allows "
+ << BufferUsageString(usage_) << ", operation requires "
+ << BufferUsageString(usage);
+ }
+ return OkStatus();
+}
+
+Status Buffer::CalculateRange(device_size_t base_offset,
+ device_size_t max_length, device_size_t offset,
+ device_size_t length,
+ device_size_t* out_adjusted_offset,
+ device_size_t* out_adjusted_length) {
+ // Check if the start of the range runs off the end of the buffer.
+ if (offset > max_length) {
+ *out_adjusted_offset = 0;
+ if (out_adjusted_length) *out_adjusted_length = 0;
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Attempted to access an address off the end of the valid buffer "
+ "range (offset="
+ << offset << ", length=" << length
+ << ", buffer byte_length=" << max_length << ")";
+ }
+
+ // Handle length as kWholeBuffer by adjusting it (if allowed).
+ if (length == kWholeBuffer && !out_adjusted_length) {
+ *out_adjusted_offset = 0;
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "kWholeBuffer may only be used with buffer ranges, not external "
+ "pointer ranges";
+ }
+
+ // Calculate the real ranges adjusted for our region within the allocation.
+ device_size_t adjusted_offset = base_offset + offset;
+ device_size_t adjusted_length =
+ length == kWholeBuffer ? max_length - offset : length;
+ if (adjusted_length == 0) {
+ // Fine to have a zero length.
+ *out_adjusted_offset = adjusted_offset;
+ if (out_adjusted_length) *out_adjusted_length = adjusted_length;
+ return OkStatus();
+ }
+
+ // Check if the end runs over the allocation.
+ device_size_t end = offset + adjusted_length - 1;
+ if (end >= max_length) {
+ *out_adjusted_offset = 0;
+ if (out_adjusted_length) *out_adjusted_length = 0;
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Attempted to access an address outside of the valid buffer "
+ "range (offset="
+ << offset << ", adjusted_length=" << adjusted_length
+ << ", end=" << end << ", buffer byte_length=" << max_length << ")";
+ }
+
+ *out_adjusted_offset = adjusted_offset;
+ if (out_adjusted_length) *out_adjusted_length = adjusted_length;
+ return OkStatus();
+}
+
+Status Buffer::CalculateRange(device_size_t offset, device_size_t length,
+ device_size_t* out_adjusted_offset,
+ device_size_t* out_adjusted_length) const {
+ return CalculateRange(byte_offset_, byte_length_, offset, length,
+ out_adjusted_offset, out_adjusted_length);
+}
+
+Status Buffer::CalculateLocalRange(device_size_t max_length,
+ device_size_t offset, device_size_t length,
+ device_size_t* out_adjusted_offset,
+ device_size_t* out_adjusted_length) {
+ return CalculateRange(0, max_length, offset, length, out_adjusted_offset,
+ out_adjusted_length);
+}
+
+Status Buffer::Fill(device_size_t byte_offset, device_size_t byte_length,
+ const void* pattern, device_size_t pattern_length) {
+ // If not host visible we'll need to issue command buffers.
+ RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
+ RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite));
+ RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
+ RETURN_IF_ERROR(
+ CalculateRange(byte_offset, byte_length, &byte_offset, &byte_length));
+ if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Fill patterns must be 1, 2, or 4 bytes";
+ }
+ if ((byte_offset % pattern_length) != 0 ||
+ (byte_length % pattern_length) != 0) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Attempting to fill a range with " << pattern_length
+ << " byte values that is not "
+ "aligned (offset="
+ << byte_offset << ", length=" << byte_length << ")";
+ }
+ if (byte_length == 0) {
+ return OkStatus(); // No-op.
+ }
+ const uint32_t kZero = 0;
+ if (std::memcmp(pattern, &kZero, pattern_length) == 0) {
+ // We can turn all-zero values into single-byte fills as that can be much
+ // faster on devices (doing a fill8 vs fill32).
+ pattern_length = 1;
+ }
+ return FillImpl(byte_offset, byte_length, pattern, pattern_length);
+}
+
+Status Buffer::ReadData(device_size_t source_offset, void* data,
+ device_size_t data_length) {
+ // If not host visible we'll need to issue command buffers.
+ RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
+ RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead));
+ RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
+ RETURN_IF_ERROR(CalculateRange(source_offset, data_length, &source_offset));
+ if (data_length == 0) {
+ return OkStatus(); // No-op.
+ }
+ return ReadDataImpl(source_offset, data, data_length);
+}
+
+Status Buffer::WriteData(device_size_t target_offset, const void* data,
+ device_size_t data_length) {
+ // If not host visible we'll need to issue command buffers.
+ RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
+ RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite));
+ RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
+ RETURN_IF_ERROR(CalculateRange(target_offset, data_length, &target_offset));
+ if (data_length == 0) {
+ return OkStatus(); // No-op.
+ }
+ return WriteDataImpl(target_offset, data, data_length);
+}
+
+Status Buffer::CopyData(device_size_t target_offset, Buffer* source_buffer,
+ device_size_t source_offset,
+ device_size_t data_length) {
+ RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
+ RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite));
+ RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
+ RETURN_IF_ERROR(
+ source_buffer->ValidateCompatibleMemoryType(MemoryType::kHostVisible));
+ RETURN_IF_ERROR(source_buffer->ValidateAccess(MemoryAccess::kRead));
+ RETURN_IF_ERROR(source_buffer->ValidateUsage(BufferUsage::kMapping));
+
+ // We need to validate both buffers.
+ device_size_t source_data_length = data_length;
+ device_size_t target_data_length = data_length;
+ device_size_t adjusted_source_offset;
+ RETURN_IF_ERROR(source_buffer->CalculateRange(
+ source_offset, source_data_length, &adjusted_source_offset,
+ &source_data_length));
+ RETURN_IF_ERROR(CalculateRange(target_offset, target_data_length,
+ &target_offset, &target_data_length));
+ device_size_t adjusted_data_length;
+ if (data_length == kWholeBuffer) {
+ // Whole buffer copy requested - that could mean either, so take the min.
+ adjusted_data_length = std::min(source_data_length, target_data_length);
+ } else {
+ // Specific length requested - validate that we have matching lengths.
+ CHECK_EQ(source_data_length, target_data_length);
+ adjusted_data_length = source_data_length;
+ }
+
+ // Elide zero length copies.
+ if (adjusted_data_length == 0) {
+ return OkStatus();
+ }
+
+ // Check for overlap.
+ if (this == source_buffer &&
+ adjusted_source_offset <= target_offset + adjusted_data_length &&
+ target_offset <= adjusted_source_offset + adjusted_data_length) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Source and target ranges overlap within the same buffer";
+ }
+
+ return CopyDataImpl(target_offset, source_buffer, source_offset,
+ adjusted_data_length);
+}
+
+Status Buffer::MapMemory(MappingMode mapping_mode,
+ MemoryAccessBitfield memory_access,
+ device_size_t* byte_offset, device_size_t* byte_length,
+ void** out_data) {
+ RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
+ RETURN_IF_ERROR(ValidateAccess(memory_access));
+ RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
+ RETURN_IF_ERROR(
+ CalculateRange(*byte_offset, *byte_length, byte_offset, byte_length));
+ *out_data = nullptr;
+ return MapMemoryImpl(mapping_mode, memory_access, *byte_offset, *byte_length,
+ out_data);
+}
+
+Status Buffer::UnmapMemory(device_size_t local_byte_offset,
+ device_size_t local_byte_length, void* data) {
+ RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
+ RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
+ // NOTE: local_byte_offset/local_byte_length are already adjusted.
+ return UnmapMemoryImpl(local_byte_offset, local_byte_length, data);
+}
+
+Status Buffer::InvalidateMappedMemory(device_size_t local_byte_offset,
+ device_size_t local_byte_length) {
+ RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
+ if (AnyBitSet(memory_type_ & MemoryType::kHostCoherent)) {
+ return PermissionDeniedErrorBuilder(IREE_LOC)
+ << "Buffer memory type is coherent and invalidation is not required";
+ }
+ RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
+ // NOTE: local_byte_offset/local_byte_length are already adjusted.
+ return InvalidateMappedMemoryImpl(local_byte_offset, local_byte_length);
+}
+
+Status Buffer::FlushMappedMemory(device_size_t local_byte_offset,
+ device_size_t local_byte_length) {
+ RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible |
+ MemoryType::kHostCached));
+ RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
+ // NOTE: local_byte_offset/local_byte_length are already adjusted.
+ return FlushMappedMemoryImpl(local_byte_offset, local_byte_length);
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/buffer.h b/hal/buffer.h
new file mode 100644
index 0000000..5de808d
--- /dev/null
+++ b/hal/buffer.h
@@ -0,0 +1,903 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Allocated memory buffer wrapper type and utilities.
+//
+// Buffers are the basic unit of memory used by the inference system. They may
+// be allocated such that they are accessible from the host (normal C++ code
+// running on the main CPU), a particular device (such as an accelerator) or
+// family of devices, or from some mix of all of those.
+//
+// The type of memory a buffer is allocated within has implications on it's
+// performance and lifetime. For example if an application attempts to use a
+// host-allocated buffer (MemoryType::kHostLocal) on an accelerator with
+// discrete memory the accelerator may either be unable to access the memory or
+// take a non-trivial performance hit when attempting to do so (involving
+// setting up kernel mappings, doing DMA transfers, etc). Likewise, trying to
+// access a device-allocated buffer (MemoryType::kDeviceLocal) may incur similar
+// overhead or not be possible at all. This may be due to restrictions in the
+// memory visibility, address spaces, mixed endianness or pointer widths,
+// and other weirdness.
+//
+// The memory types (defined by a bitfield of MemoryType values) that a
+// particular context (host or device) may use vary from device to device and
+// must be queried by the application when allocating buffers. It's strongly
+// recommended that the most specific memory type be set as possible. For
+// example allocating a buffer with MemoryType::kHostCoherent even when it will
+// never be used in a way that requires coherency may occupy address space
+// reservations or memory mapping that would otherwise not be needed.
+//
+// As buffers may sometimes not be accessible from the host the base Buffer type
+// does not allow for direct void* access and instead buffers must be either
+// manipulated using utility functions (such as ReadData or WriteData) or by
+// mapping them into a host-accessible address space via MapMemory. Buffer must
+// be unmapped before any command may use it.
+//
+// Buffers may map (roughly) 1:1 with an allocation either from the host heap or
+// a device. Buffer::Subspan can be used to reference subspans of buffers like
+// absl::Span - though unlike absl::Span the returned Buffer holds a reference
+// to the parent buffer.
+
+#ifndef IREE_HAL_BUFFER_H_
+#define IREE_HAL_BUFFER_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/types/span.h"
+#include "absl/types/variant.h"
+#include "base/bitfield.h"
+#include "base/logging.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "hal/resource.h"
+
+// Only enable debug names in non-opt modes (unless the user forces it on).
+#if !defined(NDEBUG) && !defined(HAS_IREE_BUFFER_DEBUG_NAME)
+#define HAS_IREE_BUFFER_DEBUG_NAME 1
+#endif // !NDEBUG
+
+namespace iree {
+
+// std::size_t equivalent that is the size as used on device.
+// As the device may have a larger memory address space than the host we treat
+// all byte offsets as this type instead of the host-specified size_t.
+using device_size_t = uint64_t;
+
+// When used as a length value in functions causes the length to be the entire
+// remaining buffer from the specified offset.
+constexpr device_size_t kWholeBuffer = ~0ull;
+
+} // namespace iree
+
+namespace iree {
+namespace hal {
+
+class Allocator;
+template <typename T>
+class MappedMemory;
+
+// A bitfield specifying properties for a memory type.
+enum class MemoryType : uint32_t {
+ kNone = 0,
+
+ // Memory is lazily allocated by the device and only exists transiently.
+ // This is the optimal mode for memory used only within a single command
+ // buffer. Transient buffers, even if they have kHostVisible set, should be
+ // treated as device-local and opaque as they may have no memory attached to
+ // them outside of the time they are being evaluated on devices.
+ //
+ // This flag can be treated as a hint in most cases; allocating a buffer with
+ // it set _may_ return the same as if it had not be set. Certain allocation
+ // routines may use the hint to more tightly control reuse or defer wiring the
+ // memory.
+ kTransient = 1 << 0,
+
+ // Memory allocated with this type can be mapped for host access using
+ // Buffer::MapMemory.
+ kHostVisible = 1 << 1,
+
+ // The host cache management commands MappedMemory::Flush and
+ // MappedMemory::Invalidate are not needed to flush host writes
+ // to the device or make device writes visible to the host, respectively.
+ kHostCoherent = 1 << 2,
+
+ // Memory allocated with this type is cached on the host. Host memory
+ // accesses to uncached memory are slower than to cached memory, however
+ // uncached memory is always host coherent. MappedMemory::Flush must be used
+ // to ensure the device has visibility into any changes made on the host and
+ // Invalidate must be used to ensure the host has visibility into any changes
+ // made on the device.
+ kHostCached = 1 << 3,
+
+ // Memory is accessible as normal host allocated memory.
+ kHostLocal = kHostVisible | kHostCoherent,
+
+ // Memory allocated with this type is visible to the device for execution.
+ // Being device visible does not mean the same thing as kDeviceLocal. Though
+ // an allocation may be visible to the device and therefore useable for
+ // execution it may require expensive mapping or implicit transfers.
+ kDeviceVisible = 1 << 4,
+
+ // Memory allocated with this type is the most efficient for device access.
+ // Devices may support using memory that is not device local via
+ // kDeviceVisible but doing so can incur non-trivial performance penalties.
+ // Device local memory, on the other hand, is guaranteed to be fast for all
+ // operations.
+ kDeviceLocal = kDeviceVisible | (1 << 5),
+};
+IREE_BITFIELD(MemoryType);
+using MemoryTypeBitfield = MemoryType;
+std::string MemoryTypeString(MemoryTypeBitfield memory_type);
+
+// A bitfield specifying how memory will be accessed in a mapped memory region.
+enum class MemoryAccess : uint32_t {
+ // Memory is not mapped.
+ kNone = 0,
+
+ // Memory will be read.
+ // If a buffer is only mapped for reading it may still be possible to write to
+ // it but the results will be undefined (as it may present coherency issues).
+ kRead = 1 << 0,
+
+ // Memory will be written.
+ // If a buffer is only mapped for writing it may still be possible to read
+ // from it but the results will be undefined or incredibly slow (as it may
+ // be mapped by the driver as uncached).
+ kWrite = 1 << 1,
+
+ // Memory will be discarded prior to mapping.
+ // The existing contents will be undefined after mapping and must be written
+ // to ensure validity.
+ kDiscard = 1 << 2,
+
+ // Memory will be discarded and completely overwritten in a single operation.
+ kDiscardWrite = kWrite | kDiscard,
+
+ // Memory may have any operation performed on it.
+ kAll = kRead | kWrite | kDiscard,
+};
+IREE_BITFIELD(MemoryAccess);
+using MemoryAccessBitfield = MemoryAccess;
+std::string MemoryAccessString(MemoryAccessBitfield memory_access);
+
+// Bitfield that defines how a buffer is intended to be used.
+// Usage allows the driver to appropriately place the buffer for more
+// efficient operations of the specified types.
+enum class BufferUsage {
+ kNone = 0,
+
+ // The buffer, once defined, will not be mapped or updated again.
+ // This should be used for uniform parameter values such as runtime
+ // constants for executables. Doing so may allow drivers to inline values or
+ // represent them in command buffers more efficiently (avoiding memory reads
+ // or swapping, etc).
+ kConstant = 1 << 0,
+
+ // The buffer can be used as the source or target of a transfer command
+ // (CopyBuffer, UpdateBuffer, etc).
+ //
+ // If |kMapping| is not specified drivers may safely assume that the host
+ // may never need visibility of this buffer as all accesses will happen via
+ // command buffers.
+ kTransfer = 1 << 1,
+
+ // The buffer can be mapped by the host application for reading and writing.
+ //
+ // As mapping may require placement in special address ranges or system
+ // calls to enable visibility the driver can use the presence (or lack of)
+ // this flag to perform allocation-type setup and avoid initial mapping
+ // overhead.
+ kMapping = 1 << 2,
+
+ // The buffer can be provided as an input or output to an executable.
+ // Buffers of this type may be directly used by drivers during dispatch.
+ kDispatch = 1 << 3,
+
+ // Buffer may be used for any operation.
+ kAll = kTransfer | kMapping | kDispatch,
+};
+IREE_BITFIELD(BufferUsage);
+using BufferUsageBitfield = BufferUsage;
+std::string BufferUsageString(BufferUsageBitfield buffer_usage);
+
+// A memory buffer.
+// Buffers have a specific memory_type that is used to describe the capabilities
+// and behavior of the backing memory of the buffer. Buffers may be any mix of
+// host-accessible, host-coherent, or device-accessible for various usages.
+// Depending on these memory types the buffers may be mapped for access on the
+// host as memory though certain restrictions may be imposed.
+//
+// See MemoryType for more information about the types and what operations they
+// support.
+class Buffer : public Resource {
+ public:
+ // Returns a reference to a subspan of the buffer.
+ // If |byte_length| is kWholeBuffer the remaining bytes in the buffer after
+ // |byte_offset| (possibly 0) will be selected.
+ //
+ // The parent buffer will remain alive for the lifetime of the subspan
+ // returned. If the subspan is a small portion this may cause additional
+ // memory to remain allocated longer than required.
+ //
+ // Returns the given |buffer| if the requested span covers the entire range.
+ static StatusOr<ref_ptr<Buffer>> Subspan(const ref_ptr<Buffer>& buffer,
+ device_size_t byte_offset,
+ device_size_t byte_length);
+
+ // Overlap test results.
+ enum class Overlap {
+ // No overlap between the two buffers.
+ kDisjoint,
+ // Partial overlap between the two buffers.
+ kPartial,
+ // Complete overlap between the two buffers (they are the same).
+ kComplete,
+ };
+
+ // Tests whether the given buffers overlap, including support for subspans.
+ // kWholeBuffer may be used for |lhs_length| and/or |rhs_length| to use the
+ // lengths of those buffers, respectively.
+ static Overlap TestOverlap(Buffer* lhs_buffer, device_size_t lhs_offset,
+ device_size_t lhs_length, Buffer* rhs_buffer,
+ device_size_t rhs_offset,
+ device_size_t rhs_length);
+
+ // Returns true if the two buffer ranges overlap at all.
+ static bool DoesOverlap(Buffer* lhs_buffer, device_size_t lhs_offset,
+ device_size_t lhs_length, Buffer* rhs_buffer,
+ device_size_t rhs_offset, device_size_t rhs_length);
+
+ // Disallow copies (as copying requires real work).
+ Buffer(const Buffer&) = delete;
+ Buffer& operator=(const Buffer&) = delete;
+
+ ~Buffer() override = default;
+
+#if HAS_IREE_BUFFER_DEBUG_NAME
+ // Optionally populated name useful for logging a persistent name for the
+ // buffer.
+ absl::string_view debug_name() const { return debug_name_; }
+ void set_debug_name(std::string debug_name) {
+ debug_name_ = std::move(debug_name);
+ }
+#else
+ absl::string_view debug_name() const { return ""; }
+ void set_debug_name(std::string debug_name) {}
+#endif // HAS_IREE_BUFFER_DEBUG_NAME
+
+ // Memory allocator this buffer was allocated from.
+ // May be nullptr if the buffer has no particular allocator and should be
+ // assumed to be allocated from the host heap.
+ constexpr Allocator* allocator() const {
+ return allocated_buffer_ == this ? allocator_
+ : allocated_buffer_->allocator();
+ }
+
+ // Memory type this buffer is allocated from.
+ MemoryTypeBitfield memory_type() const { return memory_type_; }
+
+ // Memory access operations allowed on the buffer.
+ MemoryAccessBitfield allowed_access() const { return allowed_access_; }
+
+ // Bitfield describing how the buffer is to be used.
+ BufferUsageBitfield usage() const { return usage_; }
+
+ // Returns the underlying buffer that represents the allocated memory for the
+ // Buffer. In most cases this is the buffer itself but for buffer subspan
+ // references it will point to the parent buffer.
+ Buffer* allocated_buffer() const noexcept;
+
+ // Size of the resource memory allocation in bytes.
+ // This may be rounded up from the originally requested size or the ideal
+ // size for the resource based on device restrictions.
+ constexpr device_size_t allocation_size() const {
+ return allocated_buffer_ == this ? allocation_size_
+ : allocated_buffer_->allocation_size();
+ }
+
+ // Range within the underlying allocation this buffer occupies.
+ // For buffers that map 1:1 with an allocation this should be
+ // [0, allocation_size()), however may still differ if the allocation needed
+ // to be aligned.
+ //
+ // The offset is most often manipulated by Subspan, however it's important to
+ // note that the offset may not be what was passed to Subspan as it refers to
+ // the offset in the original ancestor buffer, not the buffer from which the
+ // subspan was taken.
+ constexpr device_size_t byte_offset() const noexcept { return byte_offset_; }
+ constexpr device_size_t byte_length() const noexcept { return byte_length_; }
+
+ // TODO(benvanik): add debug_name.
+
+ // Returns a longer debug string describing the buffer and its attributes.
+ std::string DebugString() const;
+ // Returns a short debug string describing the buffer.
+ std::string DebugStringShort() const;
+
+ // Sets a range of the buffer to the given value.
+ // This requires that the resource was allocated with
+ // MemoryType::kHostVisible and BufferUsage::kMapping.
+ // If |byte_length| is kWholeBuffer the remaining bytes in the buffer after
+ // |byte_offset| (possibly 0) will be filled.
+ //
+ // The |byte_offset| and |byte_length| must be aligned to the size of the fill
+ // value. Multi-byte values will be written in host order for host buffers and
+ // device order for device buffers.
+ //
+ // Only |pattern_length| values with 1, 2, or 4 bytes are supported.
+ //
+ // Fails if the write could not be performed; either the bounds are out of
+ // range or the memory type does not support writing in this way.
+ Status Fill(device_size_t byte_offset, device_size_t byte_length,
+ const void* pattern, device_size_t pattern_length);
+ template <typename T>
+ Status Fill8(device_size_t byte_offset, device_size_t byte_length, T value);
+ template <typename T>
+ Status Fill16(device_size_t byte_offset, device_size_t byte_length, T value);
+ template <typename T>
+ Status Fill32(device_size_t byte_offset, device_size_t byte_length, T value);
+ template <typename T>
+ Status Fill8(T value);
+ template <typename T>
+ Status Fill16(T value);
+ template <typename T>
+ Status Fill32(T value);
+
+ // Reads a block of byte data from the resource at the given offset.
+ // This requires that the resource was allocated with
+ // MemoryType::kHostVisible and BufferUsage::kMapping.
+ //
+ // Fails if the read could not be performed; either the bounds are out of
+ // range or the memory type does not support reading in this way.
+ Status ReadData(device_size_t source_offset, void* data,
+ device_size_t data_length);
+
+ // Writes a block of byte data into the resource at the given offset.
+ // This requires that the resource was allocated with
+ // MemoryType::kHostVisible and BufferUsage::kMapping.
+ //
+ // Fails if the write could not be performed; either the bounds are out of
+ // range or the memory type does not support writing in this way.
+ Status WriteData(device_size_t target_offset, const void* data,
+ device_size_t data_length);
+
+ // Copies data from the provided source_buffer into the buffer.
+ // This requires that the resource was allocated with
+ // MemoryType::kHostVisible and BufferUsage::kMapping.
+ // The source and destination may be the same buffer but the ranges must not
+ // overlap (a la memcpy).
+ //
+ // Fails if the write could not be performed; either the bounds are out of
+ // range or the memory type does not support writing in this way.
+ Status CopyData(device_size_t target_offset, Buffer* source_buffer,
+ device_size_t source_offset, device_size_t data_length);
+ Status CopyData(device_size_t target_offset, Buffer* source_buffer) {
+ return CopyData(target_offset, source_buffer, 0, kWholeBuffer);
+ }
+
+ // Maps the resource memory for direct access from the host.
+ // This requires that the resource was allocated with
+ // MemoryType::kHostVisible and BufferUsage::kMapping.
+ //
+ // If MemoryType::kHostCoherent was not specified then explicit
+ // Invalidate and Flush calls must be used to control visibility of the data
+ // on the device. If MemoryType::kHostCached is not set callers must not
+ // attempt to read from the mapped memory as doing so may produce undefined
+ // results and/or ultra slow reads.
+ //
+ // If the MemoryAccess::kDiscard bit is set when mapping for writes the caller
+ // guarantees that they will be overwriting all data in the mapped range. This
+ // is used as a hint to the device that the prior contents are no longer
+ // required and can enable optimizations that save on synchronization and
+ // readback. Note however that it is strictly a hint and the contents are not
+ // guaranteed to be zeroed during mapping.
+ //
+ // This allows mapping the memory as a C++ type. Care must be taken to ensure
+ // the data layout in C++ matches the expected data layout in the executables
+ // that consume this data. For simple primitives like uint8_t or float this is
+ // usually not a problem however struct packing may have many restrictions.
+ //
+ // The returned mapping should be unmapped when it is no longer required.
+ // Unmapping does not implicitly flush.
+ //
+ // Fails if the memory could not be mapped due to mapping exhaustion, invalid
+ // arguments, or unsupported memory types.
+ //
+ // Example:
+ // ASSIGN_OR_RETURN(auto mapping, buffer->MapForRead<MyStruct>());
+ // mapping[5].foo = 3;
+ // std::memcpy(mapping.data(), source_data, mapping.size());
+ // mapping.reset();
+ template <typename T>
+ StatusOr<MappedMemory<T>> MapMemory(
+ MemoryAccessBitfield memory_access, device_size_t element_offset = 0,
+ device_size_t element_length = kWholeBuffer);
+
+ protected:
+ template <typename T>
+ friend class MappedMemory;
+
+ // Defines the mode of a MapMemory operation.
+ enum class MappingMode {
+ // The call to MapMemory will always be matched with UnmapMemory.
+ kScoped,
+ };
+
+ Buffer(Allocator* allocator, MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
+ device_size_t allocation_size, device_size_t byte_offset,
+ device_size_t byte_length);
+
+ // Allows subclasses to override the allowed access bits.
+ // This should only be done when known safe by the allocation scheme.
+ void set_allowed_access(MemoryAccessBitfield allowed_access) {
+ allowed_access_ = allowed_access;
+ }
+
+ // Sets a range of the buffer to the given value.
+ // State and parameters have already been validated. For the >8bit variants
+ // the offset and length have already been validated to be aligned to the
+ // natural alignment of the type.
+ virtual Status FillImpl(device_size_t byte_offset, device_size_t byte_length,
+ const void* pattern,
+ device_size_t pattern_length) = 0;
+
+ // Reads a block of byte data from the resource at the given offset.
+ // State and parameters have already been validated.
+ virtual Status ReadDataImpl(device_size_t source_offset, void* data,
+ device_size_t data_length) = 0;
+
+ // Writes a block of byte data into the resource at the given offset.
+ // State and parameters have already been validated.
+ virtual Status WriteDataImpl(device_size_t target_offset, const void* data,
+ device_size_t data_length) = 0;
+
+ // Copies a block of byte data into the resource at the given offset.
+ // State and parameters have already been validated.
+ virtual Status CopyDataImpl(device_size_t target_offset,
+ Buffer* source_buffer,
+ device_size_t source_offset,
+ device_size_t data_length) = 0;
+
+ // Maps memory directly.
+ // The output data pointer will be properly aligned to the start of the data.
+ // |local_byte_offset| and |local_byte_length| are the adjusted values that
+ // should map into the local space of the buffer.
+ //
+ // Fails if the memory could not be mapped (invalid access type, invalid
+ // range, or unsupported memory type).
+ // State and parameters have already been validated.
+ virtual Status MapMemoryImpl(MappingMode mapping_mode,
+ MemoryAccessBitfield memory_access,
+ device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void** out_data) = 0;
+
+ // Unmaps previously mapped memory.
+ // No-op if the memory is not mapped. As this is often used in destructors
+ // we can't rely on failures here propagating with anything but CHECK/DCHECK.
+ // State and parameters have already been validated.
+ virtual Status UnmapMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void* data) = 0;
+
+ // Invalidates ranges of non-coherent memory from the host caches.
+ // Use this before reading from non-coherent memory.
+ // This guarantees that device writes to the memory ranges provided are
+ // visible on the host.
+ // This is only required for memory types without kHostCoherent set.
+ // State and parameters have already been validated.
+ virtual Status InvalidateMappedMemoryImpl(
+ device_size_t local_byte_offset, device_size_t local_byte_length) = 0;
+
+ // Flushes ranges of non-coherent memory from the host caches.
+ // Use this after writing to non-coherent memory.
+ // This guarantees that host writes to the memory ranges provided are made
+ // available for device access.
+ // This is only required for memory types without kHostCoherent set.
+ // State and parameters have already been validated.
+ virtual Status FlushMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) = 0;
+
+ // Validates the given buffer range and adjusts the offset and length if the
+ // provided length is kWholeBuffer or the buffer is offset within its
+ // allocation. This calculates the range in the given domain without adjusting
+ // to any particular buffer base offsets.
+ static Status CalculateLocalRange(device_size_t max_length,
+ device_size_t offset, device_size_t length,
+ device_size_t* out_adjusted_offset,
+ device_size_t* out_adjusted_length);
+
+ private:
+ friend class Allocator;
+
+ // This is not great and deserves cleanup.
+ friend class DeferredBuffer;
+ friend class SubspanBuffer;
+ friend class HeapBuffer;
+
+ // Maps memory directly.
+ // The byte offset and byte length may be adjusted for device alignment.
+ // The output data pointer will be properly aligned to the start of the data.
+ // Fails if the memory could not be mapped (invalid access type, invalid
+ // range, or unsupported memory type).
+ Status MapMemory(MappingMode mapping_mode, MemoryAccessBitfield memory_access,
+ device_size_t* byte_offset, device_size_t* byte_length,
+ void** out_data);
+
+ // Unmaps previously mapped memory.
+ // No-op if the memory is not mapped. As this is often used in destructors
+ // we can't rely on failures here propagating with anything but CHECK/DCHECK.
+ Status UnmapMemory(device_size_t local_byte_offset,
+ device_size_t local_byte_length, void* data);
+
+ // Invalidates ranges of non-coherent memory from the host caches.
+ // Use this before reading from non-coherent memory.
+ // This guarantees that device writes to the memory ranges provided are
+ // visible on the host.
+ // This is only required for memory types without kHostCoherent set.
+ Status InvalidateMappedMemory(device_size_t local_byte_offset,
+ device_size_t local_byte_length);
+
+ // Flushes ranges of non-coherent memory from the host caches.
+ // Use this after writing to non-coherent memory.
+ // This guarantees that host writes to the memory ranges provided are made
+ // available for device access.
+ // This is only required for memory types without kHostCoherent set.
+ Status FlushMappedMemory(device_size_t local_byte_offset,
+ device_size_t local_byte_length);
+
+ // Returns a failure if the memory type the buffer was allocated from is not
+ // compatible with the given type.
+ Status ValidateCompatibleMemoryType(MemoryTypeBitfield memory_type) const;
+ // Returns a failure if the buffer memory type or usage disallows the given
+ // access type.
+ Status ValidateAccess(MemoryAccessBitfield memory_access) const;
+ // Returns a failure if the buffer was not allocated for the given usage.
+ Status ValidateUsage(BufferUsageBitfield usage) const;
+ // Validates the given buffer range and optionally adjusts the offset and
+ // length if the provided length is kWholeBuffer or the buffer is offset
+ // within its allocation.
+ static Status CalculateRange(device_size_t base_offset,
+ device_size_t max_length, device_size_t offset,
+ device_size_t length,
+ device_size_t* out_adjusted_offset,
+ device_size_t* out_adjusted_length = nullptr);
+ Status CalculateRange(device_size_t offset, device_size_t length,
+ device_size_t* out_adjusted_offset,
+ device_size_t* out_adjusted_length = nullptr) const;
+
+ // Points to either this or parent_buffer_.get().
+ Buffer* allocated_buffer_ = nullptr;
+
+ Allocator* allocator_ = nullptr;
+ MemoryTypeBitfield memory_type_ = MemoryType::kNone;
+ MemoryAccessBitfield allowed_access_ = MemoryAccess::kNone;
+ BufferUsageBitfield usage_ = BufferUsage::kNone;
+
+ device_size_t allocation_size_ = 0;
+ device_size_t byte_offset_ = 0;
+ device_size_t byte_length_ = 0;
+
+#if HAS_IREE_BUFFER_DEBUG_NAME
+ // Friendly name for the buffer used in DebugString. May be set by the app or
+ // auto generated.
+ std::string debug_name_;
+#endif // HAS_IREE_BUFFER_DEBUG_NAME
+
+ // Defined when this buffer is a subspan of another buffer.
+ ref_ptr<Buffer> parent_buffer_;
+};
+
+// A memory mapping RAII object.
+// The mapping will stay active until it is reset and will retain the buffer.
+template <typename T>
+class MappedMemory {
+ public:
+ using unspecified_bool_type = const T* MappedMemory<T>::*;
+
+ MappedMemory() = default;
+ MappedMemory(MemoryAccessBitfield access, ref_ptr<Buffer> buffer,
+ device_size_t byte_offset, device_size_t byte_length,
+ device_size_t element_size, T* data);
+
+ // Allow moving but disallow copying as the mapping is stateful.
+ MappedMemory(MappedMemory&& rhs) noexcept;
+ MappedMemory& operator=(MappedMemory&& rhs) noexcept;
+ MappedMemory(const MappedMemory&) = delete;
+ MappedMemory& operator=(const MappedMemory&) = delete;
+
+ ~MappedMemory();
+
+ // The buffer resource that this mapping references.
+ const ref_ptr<Buffer>& buffer() const noexcept { return buffer_; }
+ // Offset, in bytes, into the resource allocation.
+ // This value is *informative only*, as it may vary from device to device.
+ device_size_t byte_offset() const noexcept { return byte_offset_; }
+ // Length, in bytes, of the resource mapping.
+ // This may be larger than the originally requested length due to alignment.
+ // This value is *informative only*, as it may vary from device to device.
+ device_size_t byte_length() const noexcept { return byte_length_; }
+
+ // True if the mapping is empty.
+ bool empty() const noexcept { return element_size_ == 0; }
+ // The size of the mapping as requested in elements.
+ size_t size() const noexcept { return static_cast<size_t>(element_size_); }
+
+ // Returns a read-only pointer to the mapped memory.
+ // This will be nullptr if the mapping failed or the mapping is not readable.
+ const T* data() const noexcept;
+ absl::Span<const T> contents() const noexcept { return {data(), size()}; }
+
+ // Returns a mutable pointer to the mapped memory.
+ // This will be nullptr if the mapping failed or the mapping is not writable.
+ // If the mapping was not made with read access it may still be possible to
+ // read from this memory but behavior is undefined.
+ T* mutable_data() noexcept;
+ absl::Span<T> mutable_contents() noexcept { return {mutable_data(), size()}; }
+
+ // Equivalent to absl::Span::subspan().
+ // May return a 0-length span.
+ // Fails if the buffer is not mapped or not mapped for the requested access.
+ StatusOr<absl::Span<const T>> Subspan(
+ device_size_t element_offset = 0,
+ device_size_t element_length = kWholeBuffer) const noexcept;
+ StatusOr<absl::Span<T>> MutableSubspan(
+ device_size_t element_offset = 0,
+ device_size_t element_length = kWholeBuffer) noexcept;
+
+ // Accesses an element in the mapped memory.
+ // Must be called with a valid index in [0, size()).
+ const T& operator[](device_size_t i) const noexcept { return data_[i]; }
+
+ // Invalidates a range of non-coherent elements from the host caches.
+ Status Invalidate(device_size_t element_offset = 0,
+ device_size_t element_length = kWholeBuffer) const;
+
+ // Flushes a range of non-coherent elements from the host caches.
+ Status Flush(device_size_t element_offset = 0,
+ device_size_t element_length = kWholeBuffer);
+
+ // Unmaps the mapped memory.
+ // The memory will not be implicitly flushed when unmapping.
+ void reset();
+
+ private:
+ Status ValidateAccess(MemoryAccessBitfield memory_access) const;
+ Status CalculateDataRange(device_size_t element_offset,
+ device_size_t element_length,
+ device_size_t* out_adjusted_element_offset,
+ device_size_t* out_adjusted_element_length) const;
+
+ MemoryAccessBitfield access_ = MemoryAccess::kNone;
+ ref_ptr<Buffer> buffer_;
+ device_size_t byte_offset_ = 0;
+ device_size_t byte_length_ = 0;
+ device_size_t element_size_ = 0;
+ T* data_ = nullptr;
+};
+
+// Inline functions and template definitions follow:
+
+template <typename T>
+Status Buffer::Fill8(device_size_t byte_offset, device_size_t byte_length,
+ T value) {
+ auto sized_value = reinterpret_cast<uint8_t*>(&value);
+ return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value));
+}
+
+template <typename T>
+Status Buffer::Fill16(device_size_t byte_offset, device_size_t byte_length,
+ T value) {
+ auto sized_value = reinterpret_cast<uint16_t*>(&value);
+ return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value));
+}
+
+template <typename T>
+Status Buffer::Fill32(device_size_t byte_offset, device_size_t byte_length,
+ T value) {
+ auto sized_value = reinterpret_cast<uint32_t*>(&value);
+ return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value));
+}
+
+template <typename T>
+Status Buffer::Fill8(T value) {
+ return Fill8(0, kWholeBuffer, value);
+}
+
+template <typename T>
+Status Buffer::Fill16(T value) {
+ return Fill16(0, kWholeBuffer, value);
+}
+
+template <typename T>
+Status Buffer::Fill32(T value) {
+ return Fill32(0, kWholeBuffer, value);
+}
+
+template <typename T>
+StatusOr<MappedMemory<T>> Buffer::MapMemory(MemoryAccessBitfield memory_access,
+ device_size_t element_offset,
+ device_size_t element_length) {
+ device_size_t byte_offset = element_offset * sizeof(T);
+ device_size_t byte_length = element_length == kWholeBuffer
+ ? kWholeBuffer
+ : element_length * sizeof(T);
+ void* data = nullptr;
+ RETURN_IF_ERROR(MapMemory(MappingMode::kScoped, memory_access, &byte_offset,
+ &byte_length, &data));
+ return MappedMemory<T>{
+ memory_access, add_ref(this), byte_offset,
+ byte_length, byte_length / sizeof(T), static_cast<T*>(data)};
+}
+
+template <typename T>
+MappedMemory<T>::MappedMemory(MemoryAccessBitfield access,
+ ref_ptr<Buffer> buffer, device_size_t byte_offset,
+ device_size_t byte_length,
+ device_size_t element_size, T* data)
+ : access_(access),
+ buffer_(std::move(buffer)),
+ byte_offset_(byte_offset),
+ byte_length_(byte_length),
+ element_size_(element_size),
+ data_(data) {}
+
+template <typename T>
+MappedMemory<T>::MappedMemory(MappedMemory<T>&& rhs) noexcept
+ : access_(rhs.access_),
+ buffer_(std::move(rhs.buffer_)),
+ byte_offset_(rhs.byte_offset_),
+ byte_length_(rhs.byte_length_),
+ element_size_(rhs.element_size_),
+ data_(rhs.data_) {
+ rhs.access_ = MemoryAccess::kNone;
+ rhs.buffer_.reset();
+ rhs.byte_offset_ = 0;
+ rhs.byte_length_ = 0;
+ rhs.element_size_ = 0;
+ rhs.data_ = nullptr;
+}
+
+template <typename T>
+MappedMemory<T>& MappedMemory<T>::operator=(MappedMemory<T>&& rhs) noexcept {
+ if (this != &rhs) {
+ reset();
+ access_ = rhs.access_;
+ buffer_ = std::move(rhs.buffer_);
+ byte_offset_ = rhs.byte_offset_;
+ byte_length_ = rhs.byte_length_;
+ element_size_ = rhs.element_size_;
+ data_ = rhs.data_;
+
+ rhs.access_ = MemoryAccess::kNone;
+ rhs.buffer_.reset();
+ rhs.byte_offset_ = 0;
+ rhs.byte_length_ = 0;
+ rhs.element_size_ = 0;
+ rhs.data_ = nullptr;
+ }
+ return *this;
+}
+
+template <typename T>
+MappedMemory<T>::~MappedMemory() {
+ // Unmap (if needed) - note that we can't fail gracefully here :(
+ reset();
+}
+
+template <typename T>
+const T* MappedMemory<T>::data() const noexcept {
+ if (!data_ || !AnyBitSet(access_ & MemoryAccess::kRead)) {
+ return nullptr;
+ }
+ return data_;
+}
+
+template <typename T>
+T* MappedMemory<T>::mutable_data() noexcept {
+ if (!data_ || !AnyBitSet(access_ & MemoryAccess::kWrite)) {
+ return nullptr;
+ }
+ return data_;
+}
+
+template <typename T>
+Status MappedMemory<T>::ValidateAccess(
+ MemoryAccessBitfield memory_access) const {
+ if (!data_) {
+ return FailedPreconditionErrorBuilder(IREE_LOC) << "Buffer is not mapped";
+ } else if (!AnyBitSet(access_ & memory_access)) {
+ return PermissionDeniedErrorBuilder(IREE_LOC)
+ << "Buffer is not mapped for the desired access";
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status MappedMemory<T>::CalculateDataRange(
+ device_size_t element_offset, device_size_t element_length,
+ device_size_t* out_adjusted_element_offset,
+ device_size_t* out_adjusted_element_length) const {
+ RETURN_IF_ERROR(Buffer::CalculateLocalRange(
+ element_size_ * sizeof(T), element_offset * sizeof(T),
+ element_length == kWholeBuffer ? kWholeBuffer
+ : element_length * sizeof(T),
+ out_adjusted_element_offset, out_adjusted_element_length));
+ *out_adjusted_element_offset /= sizeof(T);
+ *out_adjusted_element_length /= sizeof(T);
+ return OkStatus();
+}
+
+template <typename T>
+inline StatusOr<absl::Span<const T>> MappedMemory<T>::Subspan(
+ device_size_t element_offset, device_size_t element_length) const noexcept {
+ RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead));
+ RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length,
+ &element_offset, &element_length));
+ return absl::Span<const T>(data_ + element_offset, element_length);
+}
+
+template <typename T>
+inline StatusOr<absl::Span<T>> MappedMemory<T>::MutableSubspan(
+ device_size_t element_offset, device_size_t element_length) noexcept {
+ RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite));
+ RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length,
+ &element_offset, &element_length));
+ return absl::Span<T>(data_ + element_offset, element_length);
+}
+
+template <typename T>
+Status MappedMemory<T>::Invalidate(device_size_t element_offset,
+ device_size_t element_length) const {
+ RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead));
+ RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length,
+ &element_offset, &element_length));
+ if (!element_length) return OkStatus();
+ return buffer_->InvalidateMappedMemory(
+ byte_offset_ + element_offset * sizeof(T), element_length * sizeof(T));
+}
+
+template <typename T>
+Status MappedMemory<T>::Flush(device_size_t element_offset,
+ device_size_t element_length) {
+ RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite));
+ RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length,
+ &element_offset, &element_length));
+ if (!element_length) return OkStatus();
+ return buffer_->FlushMappedMemory(byte_offset_ + element_offset * sizeof(T),
+ element_length * sizeof(T));
+}
+
+template <typename T>
+void MappedMemory<T>::reset() {
+ if (!buffer_) return;
+ // TODO(benvanik): better handling of errors? may be fine to always warn.
+ buffer_->UnmapMemory(byte_offset_, byte_length_, data_).IgnoreError();
+ buffer_.reset();
+ access_ = MemoryAccess::kNone;
+ byte_offset_ = 0;
+ byte_length_ = 0;
+ element_size_ = 0;
+ data_ = nullptr;
+}
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_BUFFER_H_
diff --git a/hal/buffer_mapping_test.cc b/hal/buffer_mapping_test.cc
new file mode 100644
index 0000000..1bffbaa
--- /dev/null
+++ b/hal/buffer_mapping_test.cc
@@ -0,0 +1,539 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for the MemoryMapping RAII wrapper.
+// This uses a mock buffer implementation such that it is only testing
+// MemoryMapping and not any real underlying memory mapping behavior.
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "hal/buffer.h"
+
+namespace iree {
+namespace hal {
+class Allocator;
+
+namespace {
+
+using ::testing::_;
+using ::testing::DoAll;
+using ::testing::Return;
+using ::testing::SetArgPointee;
+
+static void* const kValidPtr = reinterpret_cast<void*>(0xBEEFCAFEF00D1234ull);
+
+class MockBuffer : public Buffer {
+ public:
+ using Buffer::MappingMode;
+
+ MockBuffer(Allocator* allocator, MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
+ device_size_t allocation_size)
+ : Buffer(allocator, memory_type, allowed_access, usage, allocation_size,
+ 0, allocation_size) {}
+
+ MOCK_METHOD4(FillImpl,
+ Status(device_size_t byte_offset, device_size_t byte_length,
+ const void* pattern, device_size_t pattern_length));
+
+ MOCK_METHOD3(ReadDataImpl, Status(device_size_t source_offset, void* data,
+ device_size_t data_length));
+ MOCK_METHOD3(WriteDataImpl,
+ Status(device_size_t target_offset, const void* data,
+ device_size_t data_length));
+ MOCK_METHOD4(CopyDataImpl,
+ Status(device_size_t target_offset, Buffer* source_buffer,
+ device_size_t source_offset, device_size_t data_length));
+
+ MOCK_METHOD5(MapMemoryImpl,
+ Status(MappingMode mapping_mode,
+ MemoryAccessBitfield memory_access,
+ device_size_t local_byte_offset,
+ device_size_t local_byte_length, void** out_data));
+ MOCK_METHOD3(UnmapMemoryImpl,
+ Status(device_size_t local_byte_offset,
+ device_size_t local_byte_length, void* data));
+ MOCK_METHOD2(InvalidateMappedMemoryImpl,
+ Status(device_size_t local_byte_offset,
+ device_size_t local_byte_length));
+ MOCK_METHOD2(FlushMappedMemoryImpl, Status(device_size_t local_byte_offset,
+ device_size_t local_byte_length));
+};
+
+TEST(MemoryMappingTest, MapWholeBuffer) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mapping,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mapping.reset();
+}
+
+TEST(MemoryMappingTest, MapPartialBuffer) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead, 4, 12, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mapping,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 4, 12));
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(4, 12, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mapping.reset();
+}
+
+TEST(MemoryMappingTest, EmptyHandle) {
+ MappedMemory<uint8_t> mm_a;
+ MappedMemory<uint8_t> mm_b;
+ mm_a = std::move(mm_b);
+ EXPECT_EQ(nullptr, mm_a.buffer());
+ EXPECT_EQ(0, mm_a.byte_offset());
+ EXPECT_EQ(0, mm_a.byte_length());
+ EXPECT_TRUE(mm_a.empty());
+ EXPECT_EQ(0, mm_a.size());
+ EXPECT_EQ(nullptr, mm_a.data());
+ EXPECT_EQ(nullptr, mm_a.mutable_data());
+ EXPECT_TRUE(IsFailedPrecondition(mm_a.Subspan().status()));
+ EXPECT_TRUE(IsFailedPrecondition(mm_a.MutableSubspan().status()));
+ EXPECT_TRUE(IsFailedPrecondition(mm_a.Invalidate()));
+ EXPECT_TRUE(IsFailedPrecondition(mm_a.Flush()));
+ mm_a.reset();
+}
+
+TEST(MemoryMappingTest, MoveHandle) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_a,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+
+ // Should be able to move the handle around without having any calls.
+ auto mm_b = std::move(mm_a);
+ mm_a = std::move(mm_b);
+ mm_b = std::move(mm_a);
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_b.reset();
+}
+
+TEST(MemoryMappingTest, ReadOnlyAccess) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kRead, BufferUsage::kAll, 128);
+
+ // Should succeed to map for reading.
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_r,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+
+ // Non-mutable access is fine.
+ EXPECT_EQ(kValidPtr, mm_r.data());
+ ASSERT_OK_AND_ASSIGN(auto span, mm_r.Subspan());
+ (void)span;
+
+ // Read-only mappings should not be able to get mutable access.
+ EXPECT_EQ(nullptr, mm_r.mutable_data());
+ EXPECT_TRUE(IsPermissionDenied(mm_r.MutableSubspan().status()));
+
+ // Read-only mappings should not be able to call Flush.
+ EXPECT_TRUE(IsPermissionDenied(mm_r.Flush()));
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_r.reset();
+
+ // Should fail to map for writing.
+ EXPECT_TRUE(IsPermissionDenied(
+ buffer->MapMemory<uint8_t>(MemoryAccess::kWrite).status()));
+}
+
+TEST(MemoryMappingTest, ReadWriteAccess) {
+ auto buffer = std::make_shared<MockBuffer>(
+ nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kRead | MemoryAccess::kWrite, BufferUsage::kAll, 128);
+
+ // Should succeed to map for reading and/or writing.
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead | MemoryAccess::kWrite,
+ 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(
+ auto mm_rw,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead | MemoryAccess::kWrite));
+
+ // Everything valid.
+ EXPECT_EQ(kValidPtr, mm_rw.data());
+ ASSERT_OK_AND_ASSIGN(auto span, mm_rw.Subspan());
+ EXPECT_EQ(kValidPtr, mm_rw.mutable_data());
+ ASSERT_OK_AND_ASSIGN(span, mm_rw.MutableSubspan());
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_rw.reset();
+
+ // Should fail to map for discard.
+ EXPECT_TRUE(IsPermissionDenied(
+ buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status()));
+}
+
+TEST(MemoryMappingTest, WriteOnlyAccess) {
+ auto buffer = std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kWrite,
+ BufferUsage::kAll, 128);
+
+ // Should succeed to map for writing.
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kWrite, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_w,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
+
+ // Mutable access is valid.
+ EXPECT_EQ(kValidPtr, mm_w.mutable_data());
+ ASSERT_OK_AND_ASSIGN(auto span, mm_w.MutableSubspan());
+ (void)span;
+
+ // Write-only mappings should not be able to get non-mutable access.
+ EXPECT_EQ(nullptr, mm_w.data());
+ EXPECT_TRUE(IsPermissionDenied(mm_w.Subspan().status()));
+
+ // Write-only mappings should not be able to call Invalidate.
+ EXPECT_TRUE(IsPermissionDenied(mm_w.Invalidate()));
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_w.reset();
+
+ // Should fail to map for reading.
+ EXPECT_TRUE(IsPermissionDenied(
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead).status()));
+
+ // Should fail to map for discard.
+ EXPECT_TRUE(IsPermissionDenied(
+ buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status()));
+}
+
+TEST(MemoryMappingTest, WriteDiscardAccess) {
+ auto buffer = std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kDiscardWrite,
+ BufferUsage::kAll, 128);
+
+ // Should succeed to map for writing with discard.
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kDiscardWrite, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_dw,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite));
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_dw.reset();
+
+ // Should also be ok to map for just writing.
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kWrite, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_w,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_w.reset();
+
+ // Should fail to map for reading.
+ EXPECT_TRUE(IsPermissionDenied(
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead).status()));
+}
+
+TEST(MemoryMappingTest, Subspan) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_r,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+
+ // Request some valid ranges and ensure the byte offsets are correct.
+ ASSERT_OK_AND_ASSIGN(auto ss, mm_r.Subspan());
+ EXPECT_EQ(kValidPtr, ss.data());
+ EXPECT_EQ(128, ss.size());
+ ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, 2));
+ EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data());
+ EXPECT_EQ(2, ss.size());
+ ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, kWholeBuffer));
+ EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data());
+ EXPECT_EQ(28, ss.size());
+
+ // Zero length ranges are fine.
+ ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(0, 0));
+ EXPECT_EQ(kValidPtr, ss.data());
+ EXPECT_TRUE(ss.empty());
+ ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, 0));
+ EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data());
+ EXPECT_TRUE(ss.empty());
+ ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, kWholeBuffer));
+ EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data());
+ EXPECT_TRUE(ss.empty());
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_r.reset();
+}
+
+TEST(MemoryMappingTest, SubspanOutOfRange) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_r,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+
+ // Try some invalid ranges that would overrun the span.
+ EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 0).status()));
+ EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 2).status()));
+ EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, kWholeBuffer).status()));
+ EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(100, 1234).status()));
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_r.reset();
+}
+
+TEST(MemoryMappingTest, MutableSubspan) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kWrite, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_w,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
+
+ // Request some valid ranges and ensure the byte offsets are correct.
+ ASSERT_OK_AND_ASSIGN(auto ss, mm_w.MutableSubspan());
+ EXPECT_EQ(kValidPtr, ss.data());
+ EXPECT_EQ(128, ss.size());
+ ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, 2));
+ EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data());
+ EXPECT_EQ(2, ss.size());
+ ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, kWholeBuffer));
+ EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data());
+ EXPECT_EQ(28, ss.size());
+
+ // Zero length ranges are fine.
+ ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(0, 0));
+ EXPECT_EQ(kValidPtr, ss.data());
+ EXPECT_TRUE(ss.empty());
+ ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, 0));
+ EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data());
+ EXPECT_TRUE(ss.empty());
+ ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, kWholeBuffer));
+ EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data());
+ EXPECT_TRUE(ss.empty());
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_w.reset();
+}
+
+TEST(MemoryMappingTest, MutableSubspanOutOfRange) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kWrite, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_w,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
+
+ // Try some invalid ranges that would overrun the span.
+ EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 0).status()));
+ EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 2).status()));
+ EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, kWholeBuffer).status()));
+ EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(100, 1234).status()));
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_w.reset();
+}
+
+TEST(MemoryMappingTest, ElementOperator) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_r,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+
+ // Just verify we are getting the expected pointer back.
+ EXPECT_EQ(kValidPtr, &mm_r[0]);
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_r.reset();
+}
+
+TEST(MemoryMappingTest, Invalidate) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_r,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+
+ // Invalidate a few ways.
+ EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(0, 128))
+ .WillOnce(Return(OkStatus()));
+ EXPECT_OK(mm_r.Invalidate());
+ EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 2))
+ .WillOnce(Return(OkStatus()));
+ EXPECT_OK(mm_r.Invalidate(100, 2));
+ EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 28))
+ .WillOnce(Return(OkStatus()));
+ EXPECT_OK(mm_r.Invalidate(100, kWholeBuffer));
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_r.reset();
+}
+
+TEST(MemoryMappingTest, InvalidateOutOfRange) {
+ auto buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_r,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+
+ // Try to invalidate invalid ranges.
+ EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 0)));
+ EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 12345)));
+ EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, kWholeBuffer)));
+ EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1, 1234)));
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_r.reset();
+}
+
+TEST(MemoryMappingTest, InvalidateBadMode) {
+ // Invalidate is not required on coherent memory.
+ auto coherent_buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*coherent_buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kRead, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(
+ auto mm_r, coherent_buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+ EXPECT_TRUE(IsPermissionDenied(mm_r.Invalidate()));
+ EXPECT_CALL(*coherent_buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_r.reset();
+}
+
+TEST(MemoryMappingTest, Flush) {
+ auto buffer = std::make_shared<MockBuffer>(
+ nullptr, MemoryType::kHostVisible | MemoryType::kHostCached,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kWrite, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_w,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
+
+ // Flush a few ways.
+ EXPECT_CALL(*buffer, FlushMappedMemoryImpl(0, 128))
+ .WillOnce(Return(OkStatus()));
+ EXPECT_OK(mm_w.Flush());
+ EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 2))
+ .WillOnce(Return(OkStatus()));
+ EXPECT_OK(mm_w.Flush(100, 2));
+ EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 28))
+ .WillOnce(Return(OkStatus()));
+ EXPECT_OK(mm_w.Flush(100, kWholeBuffer));
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_w.reset();
+}
+
+TEST(MemoryMappingTest, FlushOutOfRange) {
+ auto buffer = std::make_shared<MockBuffer>(
+ nullptr, MemoryType::kHostVisible | MemoryType::kHostCached,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kWrite, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(auto mm_w,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
+
+ // Try to flush invalid ranges.
+ EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 0)));
+ EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 12345)));
+ EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, kWholeBuffer)));
+ EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1, 1234)));
+
+ EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_w.reset();
+}
+
+TEST(MemoryMappingTest, FlushBadMode) {
+ // Flush is not required on uncached memory.
+ auto uncached_buffer =
+ std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible,
+ MemoryAccess::kAll, BufferUsage::kAll, 128);
+ EXPECT_CALL(*uncached_buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
+ MemoryAccess::kWrite, 0, 128, _))
+ .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
+ ASSERT_OK_AND_ASSIGN(
+ auto mm_w, uncached_buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
+ EXPECT_TRUE(IsPermissionDenied(mm_w.Flush()));
+ EXPECT_CALL(*uncached_buffer, UnmapMemoryImpl(0, 128, kValidPtr))
+ .WillOnce(Return(OkStatus()));
+ mm_w.reset();
+}
+
+} // namespace
+} // namespace hal
+} // namespace iree
diff --git a/hal/buffer_test.cc b/hal/buffer_test.cc
new file mode 100644
index 0000000..31f0448
--- /dev/null
+++ b/hal/buffer_test.cc
@@ -0,0 +1,1000 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for the shared buffer functionality and host heap buffers.
+// This does not test device-specific buffer implementations; see the device
+// code for associated tests.
+
+#include "hal/buffer.h"
+
+#include <vector>
+
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "hal/heap_buffer.h"
+
+namespace iree {
+namespace hal {
+namespace {
+
+using ::testing::_;
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Not;
+
+TEST(BufferTest, Allocate) {
+ auto buffer =
+ HeapBuffer::Allocate(BufferUsage::kTransfer | BufferUsage::kMapping, 14);
+ EXPECT_NE(nullptr, buffer->allocator());
+ EXPECT_EQ(MemoryAccess::kAll, buffer->allowed_access());
+ EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type());
+ EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage());
+
+ // We don't currently do any padding on the host.
+ // Other implementations may differ.
+ EXPECT_GE(14, buffer->allocation_size());
+ EXPECT_EQ(0, buffer->byte_offset());
+ EXPECT_EQ(14, buffer->byte_length());
+
+ // Data should be zeroed by default.
+ std::vector<uint8_t> zero_data(buffer->allocation_size());
+ std::vector<uint8_t> actual_data(buffer->allocation_size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(zero_data));
+}
+
+TEST(BufferTest, AllocateZeroLength) {
+ auto buffer =
+ HeapBuffer::Allocate(BufferUsage::kTransfer | BufferUsage::kMapping, 0);
+ EXPECT_NE(nullptr, buffer->allocator());
+ EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type());
+ EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage());
+ EXPECT_EQ(0, buffer->allocation_size());
+}
+
+TEST(BufferTest, AllocateCopy) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+ EXPECT_NE(nullptr, buffer->allocator());
+ EXPECT_GE(src_data.size(), buffer->allocation_size());
+
+ // Data should have been copied.
+ std::vector<uint8_t> actual_data(src_data.size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(src_data));
+
+ // Modify the source data and ensure it is not reflected in the buffer.
+ src_data[0] = 0x88;
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Not(Eq(src_data)));
+}
+
+TEST(BufferTest, AllocateCopyZeroLength) {
+ std::vector<uint8_t> src_data;
+ auto buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+ EXPECT_NE(nullptr, buffer->allocator());
+ EXPECT_EQ(0, buffer->allocation_size());
+}
+
+TEST(BufferTest, AllocateCopyTyped) {
+ std::vector<int32_t> src_data = {0, 1, 2, 3};
+ auto buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ absl::MakeConstSpan(src_data));
+ EXPECT_NE(nullptr, buffer->allocator());
+ EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type());
+ EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage());
+ EXPECT_GE(src_data.size() * sizeof(int32_t), buffer->allocation_size());
+
+ // Data should have been copied.
+ std::vector<int32_t> actual_data(src_data.size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(),
+ actual_data.size() * sizeof(int32_t)));
+ EXPECT_THAT(actual_data, Eq(src_data));
+}
+
+TEST(BufferTest, WrapConstant) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto buffer = HeapBuffer::Wrap(MemoryType::kHostLocal,
+ BufferUsage::kTransfer | BufferUsage::kMapping,
+ absl::MakeConstSpan(src_data));
+ EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type());
+ EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage());
+ EXPECT_EQ(src_data.size(), buffer->allocation_size());
+
+ // src_data and buffer should match after the wrapping.
+ std::vector<uint8_t> actual_data(src_data.size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(src_data));
+
+ // Modify the source data directly.
+ src_data[0] = 123;
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(src_data));
+
+ // Attempts to modify the buffer should fail.
+ std::vector<uint8_t> new_data = {3, 2, 1, 0};
+ EXPECT_TRUE(IsPermissionDenied(
+ buffer->WriteData(0, new_data.data(), new_data.size())));
+}
+
+TEST(BufferTest, WrapMutable) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto buffer = HeapBuffer::WrapMutable(
+ MemoryType::kHostLocal, MemoryAccess::kAll,
+ BufferUsage::kTransfer | BufferUsage::kMapping, absl::MakeSpan(src_data));
+ EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type());
+ EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage());
+ EXPECT_EQ(src_data.size(), buffer->allocation_size());
+
+ // src_data and buffer should match after the wrapping.
+ std::vector<uint8_t> actual_data(src_data.size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(src_data));
+
+ // Modify the source data directly.
+ src_data[0] = 123;
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(src_data));
+
+ // Modify the source data via the Buffer and ensure reflected in src_data.
+ std::vector<uint8_t> new_data = {3, 2, 1, 0};
+ EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size()));
+ EXPECT_THAT(src_data, Eq(new_data));
+}
+
+TEST(BufferTest, WrapExternal) {
+ // This is not fully supported yet, but does let us verify that the validation
+ // of memory types is working.
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto buffer = HeapBuffer::Wrap(MemoryType::kDeviceLocal, BufferUsage::kAll,
+ absl::MakeConstSpan(src_data));
+ EXPECT_EQ(MemoryType::kDeviceLocal, buffer->memory_type());
+
+ // Should fail (for now) as the buffer is not host visible.
+ EXPECT_TRUE(IsPermissionDenied(buffer->Fill8(0, kWholeBuffer, 0x99u)));
+}
+
+TEST(BufferTest, DoesOverlap) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto parent_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+
+ // A buffer should overlap with itself.
+ EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1,
+ parent_buffer.get(), 1, 1));
+ EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1,
+ parent_buffer.get(), 0, 1));
+
+ // Zero length buffers never overlap.
+ EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 1, 1,
+ parent_buffer.get(), 1, 0));
+
+ // Subspans should offset within their allocation.
+ ASSERT_OK_AND_ASSIGN(auto subspan_buffer_0,
+ Buffer::Subspan(parent_buffer, 1, 2));
+ ASSERT_OK_AND_ASSIGN(auto subspan_buffer_1,
+ Buffer::Subspan(parent_buffer, 2, 2));
+ EXPECT_FALSE(Buffer::DoesOverlap(subspan_buffer_0.get(), 0, 1,
+ subspan_buffer_1.get(), 0, 1));
+ EXPECT_TRUE(Buffer::DoesOverlap(subspan_buffer_0.get(), 1, 1,
+ subspan_buffer_1.get(), 0, 1));
+
+ // Mixing subspans and normal buffers.
+ EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1,
+ subspan_buffer_0.get(), 0, 1));
+ EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 1, 2,
+ subspan_buffer_0.get(), 1, 1));
+
+ // Independent buffers should not be able to overlap.
+ auto other_buffer = HeapBuffer::Allocate(BufferUsage::kAll, 128);
+ EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, kWholeBuffer,
+ other_buffer.get(), 0, kWholeBuffer));
+}
+
+TEST(BufferTest, Subspan) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto parent_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(parent_buffer);
+
+ // Create a subspan of the buffer.
+ ASSERT_OK_AND_ASSIGN(auto subspan_buffer,
+ Buffer::Subspan(parent_buffer, 1, 2));
+ ASSERT_TRUE(subspan_buffer);
+ EXPECT_EQ(1, subspan_buffer->byte_offset());
+ EXPECT_EQ(2, subspan_buffer->byte_length());
+
+ // Modifications to either buffer should appear in the other.
+ EXPECT_OK(subspan_buffer->Fill8(1, kWholeBuffer, 0xFFu));
+ std::vector<uint8_t> actual_data(src_data.size());
+ EXPECT_OK(parent_buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xFF, 3));
+
+ // Subspans should be able to create subspans.
+ // NOTE: offset is from the original buffer.
+ ASSERT_OK_AND_ASSIGN(auto subsubspan_buffer,
+ Buffer::Subspan(subspan_buffer, 1, 1));
+ ASSERT_TRUE(subsubspan_buffer);
+ EXPECT_EQ(2, subsubspan_buffer->byte_offset());
+ EXPECT_EQ(1, subsubspan_buffer->byte_length());
+
+ // Zero length subspans are fine.
+ ASSERT_OK_AND_ASSIGN(auto zero_subspan_buffer,
+ Buffer::Subspan(parent_buffer, 0, 0));
+ ASSERT_TRUE(zero_subspan_buffer);
+ EXPECT_EQ(0, zero_subspan_buffer->byte_offset());
+ EXPECT_EQ(0, zero_subspan_buffer->byte_length());
+
+ // Subspan with kWholeBuffer should get the remaining size (or zero).
+ ASSERT_OK_AND_ASSIGN(auto whole_subspan_buffer,
+ Buffer::Subspan(parent_buffer, 1, kWholeBuffer));
+ ASSERT_TRUE(whole_subspan_buffer);
+ EXPECT_EQ(1, whole_subspan_buffer->byte_offset());
+ EXPECT_EQ(3, whole_subspan_buffer->byte_length());
+
+ // Zero length subspans are fine.
+ ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, 0));
+ ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, kWholeBuffer));
+}
+
+TEST(BufferTest, SubspanIdentity) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto parent_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+
+ // Asking for a subspan of the entire buffer should return the same buffer.
+ // Mostly an optimization.
+ EXPECT_EQ(parent_buffer.get(),
+ Buffer::Subspan(parent_buffer, 0, kWholeBuffer).ValueOrDie().get());
+ EXPECT_EQ(parent_buffer.get(),
+ Buffer::Subspan(parent_buffer, 0, 4).ValueOrDie().get());
+}
+
+TEST(BufferTest, SubspanOutOfRange) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto parent_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(parent_buffer);
+
+ // Create a subspan of the buffer.
+ ASSERT_OK_AND_ASSIGN(auto subspan_buffer,
+ Buffer::Subspan(parent_buffer, 1, 2));
+ ASSERT_TRUE(subspan_buffer);
+ EXPECT_EQ(1, subspan_buffer->byte_offset());
+ EXPECT_EQ(2, subspan_buffer->byte_length());
+
+ // Try to make subspans from invalid ranges.
+ EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 5, 0).status()));
+ EXPECT_TRUE(
+ IsOutOfRange(Buffer::Subspan(parent_buffer, 5, kWholeBuffer).status()));
+ EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 4, 1).status()));
+ EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 0, 123).status()));
+ EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 1, 2).status()));
+ EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 0, 44).status()));
+}
+
+TEST(BufferTest, Fill8) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5);
+ ASSERT_TRUE(buffer);
+
+ // Data should be zeroed by default.
+ std::vector<uint8_t> actual_data(buffer->allocation_size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0));
+
+ // Fill with a sentinel.
+ EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u));
+
+ // Verify data.
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33));
+
+ // Zero fills are fine.
+ EXPECT_OK(buffer->Fill8(0, 0, 0x44u));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33));
+
+ // Fill the remaining parts of the buffer by using kWholeBuffer.
+ EXPECT_OK(buffer->Fill8(2, kWholeBuffer, 0x55u));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x55, 0x55, 0x55));
+
+ // Fill a small region of the buffer.
+ EXPECT_OK(buffer->Fill8(1, 1, 0x66u));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0x33, 0x66, 0x55, 0x55, 0x55));
+
+ // Whole buffer helper.
+ EXPECT_OK(buffer->Fill8(0x99u));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0x99, 0x99, 0x99, 0x99, 0x99));
+}
+
+TEST(BufferTest, Fill8OutOfRange) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5);
+ ASSERT_TRUE(buffer);
+
+ // Fill with a sentinel.
+ EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u));
+
+ // Try to fill with invalid ranges.
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u)));
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 444, 0x44u)));
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 1, 0x44u)));
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u)));
+
+ // Ensure nothing happened with the bad ranges.
+ std::vector<uint8_t> actual_data(buffer->allocation_size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33));
+}
+
+TEST(BufferTest, Fill8BadMode) {
+ // Fail to fill buffers not supporting mapping.
+ auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
+ EXPECT_TRUE(
+ IsPermissionDenied(nonmapping_buffer->Fill8(0, kWholeBuffer, 0x99u)));
+
+ // Fail to fill constant buffers.
+ std::vector<uint8_t> const_data = {1, 2, 3};
+ auto constant_buffer =
+ HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping,
+ absl::MakeConstSpan(const_data));
+ EXPECT_TRUE(
+ IsPermissionDenied(constant_buffer->Fill8(0, kWholeBuffer, 0x99u)));
+}
+
+TEST(BufferTest, Fill8Subspan) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5);
+ ASSERT_TRUE(buffer);
+
+ // Test on subspan.
+ std::vector<uint8_t> actual_data(buffer->allocation_size());
+ ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 3));
+ EXPECT_OK(subspan_buffer->Fill8(2, kWholeBuffer, 0xDDu));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0xDD, 0));
+}
+
+TEST(BufferTest, Fill16) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
+ ASSERT_TRUE(buffer);
+
+ // Data should be zeroed by default.
+ std::vector<uint8_t> actual_data(buffer->allocation_size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0));
+
+ // Fill with a sentinel.
+ EXPECT_OK(buffer->Fill16(0, 4, 0x1122u));
+
+ // Verify data.
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0));
+
+ // Zero fills are fine.
+ EXPECT_OK(buffer->Fill16(0, 0, 0x5566u));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0));
+
+ // Fill the remaining parts of the buffer by using kWholeBuffer.
+ auto aligned_buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 8);
+ EXPECT_OK(aligned_buffer->Fill16(4, kWholeBuffer, 0x5566u));
+ std::vector<uint8_t> aligned_actual_data(aligned_buffer->allocation_size());
+ EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(),
+ aligned_actual_data.size()));
+ EXPECT_THAT(aligned_actual_data,
+ ElementsAre(0, 0, 0, 0, 0x66, 0x55, 0x66, 0x55));
+
+ // Whole buffer helper.
+ EXPECT_OK(aligned_buffer->Fill16(0x5566u));
+ EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(),
+ aligned_actual_data.size()));
+ EXPECT_THAT(aligned_actual_data,
+ ElementsAre(0x66, 0x55, 0x66, 0x55, 0x66, 0x55, 0x66, 0x55));
+}
+
+TEST(BufferTest, Fill16OutOfRange) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
+ ASSERT_TRUE(buffer);
+
+ // Try to fill with invalid ranges.
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u)));
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 444, 0x5566u)));
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 4, 0x5566u)));
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u)));
+}
+
+TEST(BufferTest, Fill16Unaligned) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
+ ASSERT_TRUE(buffer);
+
+ // Try to fill with unaligned ranges.
+ EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(1, 4, 0x5566u)));
+ EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(0, 5, 0x5566u)));
+}
+
+TEST(BufferTest, Fill16BadMode) {
+ // Fail to fill buffers not supporting mapping.
+ auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
+ EXPECT_TRUE(
+ IsPermissionDenied(nonmapping_buffer->Fill16(0, kWholeBuffer, 0x99AAu)));
+
+ // Fail to fill constant buffers.
+ std::vector<uint8_t> const_data = {1, 2, 3};
+ auto constant_buffer =
+ HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping,
+ absl::MakeConstSpan(const_data));
+ EXPECT_TRUE(
+ IsPermissionDenied(constant_buffer->Fill16(0, kWholeBuffer, 0x99AAu)));
+}
+
+TEST(BufferTest, Fill16Subspan) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
+ ASSERT_TRUE(buffer);
+
+ // Fill with a sentinel.
+ EXPECT_OK(buffer->Fill16(0, 4, 0x1122u));
+
+ // Test on subspan.
+ std::vector<uint8_t> actual_data(buffer->allocation_size());
+ ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 2, 4));
+ EXPECT_OK(subspan_buffer->Fill16(2, kWholeBuffer, 0xAABBu));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data,
+ ElementsAre(0x22, 0x11, 0x22, 0x11, 0xBB, 0xAA, 0, 0, 0));
+}
+
+TEST(BufferTest, Fill32) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
+ ASSERT_TRUE(buffer);
+
+ // Data should be zeroed by default.
+ std::vector<uint8_t> actual_data(buffer->allocation_size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0));
+
+ // Fill with a sentinel.
+ EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u));
+
+ // Verify data.
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data,
+ ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0));
+
+ // Zero fills are fine.
+ EXPECT_OK(buffer->Fill32(0, 0, 0x55667788u));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data,
+ ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0));
+
+ // Fill the remaining parts of the buffer by using kWholeBuffer.
+ auto aligned_buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 8);
+ EXPECT_OK(aligned_buffer->Fill32(4, kWholeBuffer, 0x55667788u));
+ std::vector<uint8_t> aligned_actual_data(aligned_buffer->allocation_size());
+ EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(),
+ aligned_actual_data.size()));
+ EXPECT_THAT(aligned_actual_data,
+ ElementsAre(0, 0, 0, 0, 0x88, 0x77, 0x66, 0x55));
+
+ // Whole buffer helper.
+ EXPECT_OK(aligned_buffer->Fill32(0x55667788u));
+ EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(),
+ aligned_actual_data.size()));
+ EXPECT_THAT(aligned_actual_data,
+ ElementsAre(0x88, 0x77, 0x66, 0x55, 0x88, 0x77, 0x66, 0x55));
+}
+
+TEST(BufferTest, Fill32OutOfRange) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
+ ASSERT_TRUE(buffer);
+
+ // Try to fill with invalid ranges.
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u)));
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 444, 0x55667788u)));
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 4, 0x55667788u)));
+ EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u)));
+}
+
+TEST(BufferTest, Fill32Unaligned) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
+ ASSERT_TRUE(buffer);
+
+ // Try to fill with unaligned ranges.
+ EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(1, 4, 0x55667788u)));
+ EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(0, 5, 0x55667788u)));
+}
+
+TEST(BufferTest, Fill32BadMode) {
+ // Fail to fill buffers not supporting mapping.
+ auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
+ EXPECT_TRUE(IsPermissionDenied(
+ nonmapping_buffer->Fill32(0, kWholeBuffer, 0x99AABBCCu)));
+
+ // Fail to fill constant buffers.
+ std::vector<uint8_t> const_data = {1, 2, 3};
+ auto constant_buffer =
+ HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping,
+ absl::MakeConstSpan(const_data));
+ EXPECT_TRUE(IsPermissionDenied(
+ constant_buffer->Fill32(0, kWholeBuffer, 0x99AABBCCu)));
+}
+
+TEST(BufferTest, Fill32Subspan) {
+ auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
+ ASSERT_TRUE(buffer);
+
+ // Fill with a sentinel.
+ EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u));
+
+ // Test on subspan.
+ std::vector<uint8_t> actual_data(buffer->allocation_size());
+ ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 4, 4));
+ EXPECT_OK(subspan_buffer->Fill32(0, kWholeBuffer, 0xAABBCCDDu));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data,
+ ElementsAre(0x44, 0x33, 0x22, 0x11, 0xDD, 0xCC, 0xBB, 0xAA, 0));
+}
+
+TEST(BufferTest, ReadData) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // Read the data back.
+ std::vector<uint8_t> actual_data(src_data.size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(src_data));
+
+ // Reading zero bytes is valid.
+ std::vector<uint8_t> zero_data(0);
+ EXPECT_OK(buffer->ReadData(1, zero_data.data(), 0));
+
+ // Read a portion of the data.
+ std::vector<uint8_t> partial_data(2);
+ EXPECT_OK(buffer->ReadData(1, partial_data.data(), 2));
+ EXPECT_THAT(partial_data, ElementsAre(1, 2));
+}
+
+TEST(BufferTest, ReadDataOutOfRange) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // Try to read out of range.
+ std::vector<uint8_t> partial_data(2);
+ EXPECT_TRUE(IsOutOfRange(buffer->ReadData(0, partial_data.data(), 444)));
+ EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 444)));
+ EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 1)));
+ EXPECT_TRUE(IsInvalidArgument(
+ buffer->ReadData(0, partial_data.data(), kWholeBuffer)));
+}
+
+TEST(BufferTest, ReadDataBadMode) {
+ // Fail to read buffers not supporting mapping.
+ std::vector<uint8_t> actual_data(1);
+ auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
+ EXPECT_TRUE(IsPermissionDenied(
+ nonmapping_buffer->ReadData(0, actual_data.data(), 1)));
+}
+
+TEST(BufferTest, ReadDataSubspan) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // Test on subspan.
+ std::vector<uint8_t> subspan_data(1);
+ ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2));
+ EXPECT_OK(subspan_buffer->ReadData(1, subspan_data.data(), 1));
+ EXPECT_THAT(subspan_data, ElementsAre(2));
+}
+
+TEST(BufferTest, WriteData) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // Read the data back - should still match.
+ std::vector<uint8_t> actual_data(src_data.size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(src_data));
+
+ // Write over the entire buffer.
+ std::vector<uint8_t> new_data = {10, 20, 30, 40};
+ EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size()));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(new_data));
+
+ // Writing zero bytes is valid.
+ std::vector<uint8_t> zero_data;
+ EXPECT_OK(buffer->WriteData(0, zero_data.data(), 0));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(new_data));
+
+ // Write over a portion of the buffer.
+ std::vector<uint8_t> partial_data = {99};
+ EXPECT_OK(buffer->WriteData(1, partial_data.data(), partial_data.size()));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(10, 99, 30, 40));
+}
+
+TEST(BufferTest, WriteDataOutOfRange) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // Try to write out of range.
+ std::vector<uint8_t> partial_data = {99};
+ EXPECT_TRUE(IsOutOfRange(buffer->WriteData(0, partial_data.data(), 444)));
+ EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 444)));
+ EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 1)));
+ EXPECT_TRUE(IsInvalidArgument(
+ buffer->WriteData(0, partial_data.data(), kWholeBuffer)));
+}
+
+TEST(BufferTest, WriteDataBadMode) {
+ std::vector<uint8_t> actual_data(4);
+
+ // Fail to write buffers not supporting mapping.
+ auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
+ EXPECT_TRUE(IsPermissionDenied(
+ nonmapping_buffer->WriteData(0, actual_data.data(), 1)));
+
+ // Fail to write to constant buffers.
+ std::vector<uint8_t> const_data = {1, 2, 3};
+ auto constant_buffer =
+ HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kTransfer,
+ absl::MakeConstSpan(const_data));
+ EXPECT_TRUE(
+ IsPermissionDenied(constant_buffer->WriteData(0, actual_data.data(), 2)));
+}
+
+TEST(BufferTest, WriteDataSubspan) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // Test on subspan.
+ std::vector<uint8_t> subspan_data = {0xAA};
+ ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2));
+ EXPECT_OK(subspan_buffer->WriteData(1, subspan_data.data(), 1));
+ std::vector<uint8_t> actual_data(src_data.size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xAA, 3));
+}
+
+TEST(BufferTest, CopyData) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto src_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(src_buffer);
+ std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4};
+ auto dst_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
+ dst_data.data(), dst_data.size());
+ ASSERT_TRUE(dst_buffer);
+
+ // Copy of length 0 should not change the dest buffer.
+ EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 0, 0));
+ std::vector<uint8_t> actual_data(dst_data.size());
+ EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, Eq(dst_data));
+
+ // Copy a subrange of the buffer.
+ EXPECT_OK(dst_buffer->CopyData(1, src_buffer.get(), 2, 2));
+ EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 3, 4));
+
+ // Copy the entire buffer using kWholeBuffer. This will adjust sizes
+ // to ensure that the min buffer is taken. We test both src and dst buffer
+ // offset/length calculations (note that some may end up as 0 copies).
+ EXPECT_OK(dst_buffer->CopyData(3, src_buffer.get(), 0, kWholeBuffer));
+ EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 0, 1));
+ EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 2, kWholeBuffer));
+ EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(2, 3, 3, 0, 1));
+ EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 3, kWholeBuffer));
+ EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 1));
+ EXPECT_OK(dst_buffer->CopyData(4, src_buffer.get(), 0, kWholeBuffer));
+ EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 0));
+}
+
+TEST(BufferTest, CopyDataOutOfRange) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto src_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(src_buffer);
+ std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4};
+ auto dst_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
+ dst_data.data(), dst_data.size());
+ ASSERT_TRUE(dst_buffer);
+
+ // Try to copy out of range of source and dest.
+ EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 0, 1)));
+ EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(4, src_buffer.get(), 0, 4)));
+ EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 1)));
+ EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 0, 123)));
+ EXPECT_TRUE(
+ IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 123, 123)));
+ EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 0)));
+}
+
+TEST(BufferTest, CopyDataOverlapping) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto src_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(src_buffer);
+ std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4};
+ auto dst_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
+ dst_data.data(), dst_data.size());
+ ASSERT_TRUE(dst_buffer);
+
+ // Test overlap. Non-overlapping regions should be fine, otherwise fail.
+ std::vector<uint8_t> actual_data(dst_data.size());
+ EXPECT_OK(dst_buffer->CopyData(0, dst_buffer.get(), 4, 1));
+ EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(4, 1, 2, 3, 4));
+ EXPECT_TRUE(
+ IsInvalidArgument(dst_buffer->CopyData(2, dst_buffer.get(), 0, 3)));
+ EXPECT_TRUE(
+ IsInvalidArgument(dst_buffer->CopyData(0, dst_buffer.get(), 0, 3)));
+}
+
+TEST(BufferTest, CopyDataBadMode) {
+ // Both source and target buffers must support mapping.
+ auto nonmapping_src_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
+ auto nonmapping_dst_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
+ EXPECT_TRUE(IsPermissionDenied(nonmapping_dst_buffer->CopyData(
+ 0, nonmapping_src_buffer.get(), 0, kWholeBuffer)));
+ EXPECT_TRUE(IsPermissionDenied(nonmapping_src_buffer->CopyData(
+ 0, nonmapping_dst_buffer.get(), 0, kWholeBuffer)));
+
+ // Fail to copy into to constant buffers.
+ std::vector<uint8_t> const_data = {1, 2, 3};
+ auto constant_buffer =
+ HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kTransfer,
+ absl::MakeConstSpan(const_data));
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto src_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
+ src_data.data(), src_data.size());
+ EXPECT_TRUE(IsPermissionDenied(
+ constant_buffer->CopyData(0, src_buffer.get(), 0, kWholeBuffer)));
+}
+
+TEST(BufferTest, CopyDataSubspan) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ auto src_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(src_buffer);
+ std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4};
+ auto dst_buffer =
+ HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
+ dst_data.data(), dst_data.size());
+ ASSERT_TRUE(dst_buffer);
+
+ // Test on subspan.
+ std::vector<uint8_t> actual_data(dst_data.size());
+ ASSERT_OK_AND_ASSIGN(auto subspan_src_buffer,
+ Buffer::Subspan(src_buffer, 1, 3));
+ ASSERT_OK_AND_ASSIGN(auto subspan_dst_buffer,
+ Buffer::Subspan(dst_buffer, 2, 3));
+ EXPECT_OK(subspan_dst_buffer->CopyData(1, subspan_src_buffer.get(), 1, 2));
+ EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 1, 2, 2, 3));
+}
+
+// NOTE: more tests related specifically to MappedMemory are in
+// buffer_mapping_test.cc. This tests the MapMemory operation and enough to
+// ensure the memory was mapped to the correct range and the HostBuffer and
+// SubspanBuffer work as intended for basic usage.
+TEST(BufferTest, MapMemory) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
+ auto buffer = HeapBuffer::AllocateCopy(
+ BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // 0-length mappings are valid.
+ ASSERT_OK_AND_ASSIGN(auto mapping,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 0, 0));
+ EXPECT_TRUE(mapping.empty());
+ EXPECT_EQ(0, mapping.size());
+ EXPECT_EQ(0, mapping.byte_length());
+ EXPECT_NE(nullptr, mapping.data());
+ ASSERT_OK_AND_ASSIGN(auto span, mapping.Subspan());
+ EXPECT_TRUE(span.empty());
+ mapping.reset();
+
+ // Map the whole buffer for reading.
+ ASSERT_OK_AND_ASSIGN(mapping, buffer->MapMemory<uint8_t>(MemoryAccess::kRead,
+ 0, kWholeBuffer));
+ EXPECT_EQ(src_data.size(), mapping.size());
+ ASSERT_OK_AND_ASSIGN(span, mapping.Subspan());
+ EXPECT_THAT(span, ElementsAre(0, 1, 2, 3, 4, 5, 6));
+ mapping.reset();
+
+ // Map a portion of the buffer for reading.
+ ASSERT_OK_AND_ASSIGN(mapping,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 1, 2));
+ EXPECT_EQ(2, mapping.size());
+ ASSERT_OK_AND_ASSIGN(span, mapping.Subspan());
+ EXPECT_THAT(span, ElementsAre(1, 2));
+ mapping.reset();
+}
+
+TEST(BufferTest, MapMemoryNonByte) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
+ auto buffer = HeapBuffer::AllocateCopy(
+ BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // Map the buffer as non-byte values.
+ // Note that we'll round down to the number of valid elements at the
+ // alignment.
+ ASSERT_OK_AND_ASSIGN(auto mapping16,
+ buffer->MapMemory<uint16_t>(MemoryAccess::kRead));
+ EXPECT_EQ(3, mapping16.size());
+ EXPECT_LE(6, mapping16.byte_length());
+ ASSERT_OK_AND_ASSIGN(auto span16, mapping16.Subspan());
+ EXPECT_THAT(span16, ElementsAre(0x0100, 0x0302, 0x0504));
+ mapping16.reset();
+}
+
+TEST(BufferTest, MapMemoryOutOfRange) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
+ auto buffer = HeapBuffer::AllocateCopy(
+ BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // Test invalid mapping ranges.
+ EXPECT_TRUE(IsOutOfRange(
+ buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 0, 123).status()));
+ EXPECT_TRUE(IsOutOfRange(
+ buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 5, 1231).status()));
+ EXPECT_TRUE(IsOutOfRange(
+ buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 6, kWholeBuffer)
+ .status()));
+ EXPECT_TRUE(IsOutOfRange(
+ buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 1236, 1).status()));
+}
+
+TEST(BufferTest, MapMemoryBadMode) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
+ auto read_buffer = HeapBuffer::AllocateCopy(
+ BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(read_buffer);
+
+ // Test mapping the read-only buffer for writing.
+ EXPECT_TRUE(IsPermissionDenied(
+ read_buffer->MapMemory<uint8_t>(MemoryAccess::kWrite).status()));
+ EXPECT_TRUE(IsPermissionDenied(
+ read_buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status()));
+ EXPECT_TRUE(IsPermissionDenied(
+ read_buffer
+ ->MapMemory<uint8_t>(MemoryAccess::kRead | MemoryAccess::kDiscard)
+ .status()));
+ EXPECT_TRUE(IsInvalidArgument(
+ read_buffer->MapMemory<uint8_t>(MemoryAccess::kNone).status()));
+}
+
+TEST(BufferTest, MapMemoryWrite) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
+ auto buffer = HeapBuffer::AllocateCopy(
+ BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // Map and modify the data. We should see it when we read back.
+ ASSERT_OK_AND_ASSIGN(auto mapping,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kWrite, 1, 2));
+ auto mutable_data = mapping.mutable_data();
+ mutable_data[0] = 0xAA;
+ mutable_data[1] = 0xBB;
+ mapping.reset();
+ std::vector<uint8_t> actual_data(src_data.size());
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 0xAA, 0xBB, 3, 4, 5, 6));
+}
+
+TEST(BufferTest, MapMemoryDiscard) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
+ auto buffer = HeapBuffer::AllocateCopy(
+ BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(buffer);
+
+ // Map for discard. Note that we can't really rely on the value of the data
+ // so we just trust that it's been discarded. It's a hint, anyway. We can be
+ // sure that the data we didn't want to discard is the same though.
+ std::vector<uint8_t> actual_data(src_data.size());
+ ASSERT_OK_AND_ASSIGN(auto mapping, buffer->MapMemory<uint8_t>(
+ MemoryAccess::kDiscardWrite, 1, 2));
+ EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, _, _, 3, 4, 5, 6));
+ mapping.reset();
+}
+
+TEST(BufferTest, MapMemorySubspan) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
+ auto parent_buffer = HeapBuffer::AllocateCopy(
+ BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll,
+ src_data.data(), src_data.size());
+ ASSERT_TRUE(parent_buffer);
+ ASSERT_OK_AND_ASSIGN(auto subspan_buffer,
+ Buffer::Subspan(parent_buffer, 1, 3));
+ ASSERT_OK_AND_ASSIGN(auto mapping, subspan_buffer->MapMemory<uint8_t>(
+ MemoryAccess::kDiscardWrite, 1, 2));
+ auto* mutable_data = mapping.mutable_data();
+ mutable_data[0] = 0xCC;
+ mutable_data[1] = 0xDD;
+ mapping.reset();
+
+ std::vector<uint8_t> actual_data(src_data.size());
+ EXPECT_OK(parent_buffer->ReadData(0, actual_data.data(), actual_data.size()));
+ EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xCC, 0xDD, 4, 5, 6));
+
+ // Just here to make coverage happy; they are currently no-ops on the host.
+ // buffer_mapping_test.cc contains tests that ensure they are called
+ // correctly.
+ std::vector<uint8_t> external_data = {0, 1, 2, 3, 4};
+ auto external_buffer = HeapBuffer::WrapMutable(
+ MemoryType::kHostVisible | MemoryType::kHostCached, MemoryAccess::kAll,
+ BufferUsage::kAll, absl::MakeSpan(external_data));
+ ASSERT_OK_AND_ASSIGN(auto external_subspan_buffer,
+ Buffer::Subspan(external_buffer, 0, 1));
+ ASSERT_OK_AND_ASSIGN(
+ mapping, external_subspan_buffer->MapMemory<uint8_t>(MemoryAccess::kAll));
+ EXPECT_OK(mapping.Invalidate());
+ EXPECT_OK(mapping.Flush());
+}
+
+} // namespace
+} // namespace hal
+} // namespace iree
diff --git a/hal/buffer_view.cc b/hal/buffer_view.cc
new file mode 100644
index 0000000..ff02041
--- /dev/null
+++ b/hal/buffer_view.cc
@@ -0,0 +1,180 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/buffer_view.h"
+
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "hal/buffer.h"
+
+namespace iree {
+namespace hal {
+
+namespace {
+// Pretty prints an array, e.g. [1, 2, 3, 4]
+inline std::string PrettyPrint(absl::Span<const int32_t> arr) {
+ return "[" + absl::StrJoin(arr, ",") + "]";
+}
+} // namespace
+
+// static
+bool BufferView::Equal(const BufferView& lhs, const BufferView& rhs) {
+ return lhs.buffer.get() == rhs.buffer.get() &&
+ lhs.element_size == rhs.element_size && lhs.shape == rhs.shape;
+}
+
+std::string BufferView::DebugStringShort() const {
+ if (element_size == 0) {
+ return "Ø";
+ }
+ return shape.empty() ? std::to_string(element_size)
+ : absl::StrCat(absl::StrJoin(shape.subspan(), "x"), "x",
+ element_size);
+}
+
+StatusOr<device_size_t> BufferView::CalculateOffset(
+ absl::Span<const int32_t> indices) const {
+ if (indices.empty()) {
+ return 0;
+ } else if (shape.empty() || indices.size() > shape.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Indices " << PrettyPrint(indices)
+ << " out of bounds of the rank of buffer_view "
+ << DebugStringShort();
+ }
+ device_size_t offset = 0;
+ for (int i = 0; i < indices.size(); ++i) {
+ if (indices[i] >= shape[i]) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Indices[" << i << "]=" << indices[i]
+ << " out of bounds of buffer_view " << DebugStringShort();
+ }
+ device_size_t axis_offset = indices[i];
+ for (int j = i + 1; j < shape.size(); ++j) {
+ axis_offset *= shape[j];
+ }
+ offset += axis_offset;
+ }
+ offset *= element_size;
+ return offset;
+}
+
+StatusOr<BufferView> BufferView::Slice(
+ absl::Span<const int32_t> start_indices,
+ absl::Span<const int32_t> lengths) const {
+ if (start_indices.size() != shape.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Slice start_indices " << PrettyPrint(start_indices)
+ << " do not match rank of buffer_view " << DebugStringShort();
+ }
+ if (start_indices.size() != lengths.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Slice start_indices " << PrettyPrint(start_indices)
+ << " and lengths " << PrettyPrint(lengths)
+ << " are not the same size";
+ }
+
+ // Buffer::Subspan only support contiguous memory. To ensure that this slice
+ // only requests such, we validate that the offset in the buffer between the
+ // start and end indices is the same as the requested size of the slice.
+ absl::InlinedVector<int32_t, 6> end_indices(lengths.size());
+ device_size_t subspan_length = element_size;
+ for (int i = 0; i < lengths.size(); ++i) {
+ subspan_length *= lengths[i];
+ end_indices[i] = start_indices[i] + lengths[i] - 1;
+ }
+
+ ASSIGN_OR_RETURN(auto start_byte_offset, CalculateOffset(start_indices));
+ // Also validates the ends are in bounds.
+ ASSIGN_OR_RETURN(auto end_byte_offset, CalculateOffset(end_indices));
+
+ auto offset_length = end_byte_offset - start_byte_offset + element_size;
+ if (subspan_length != offset_length) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Slice for non-contiguous region of memory unimplemented. "
+ "start_indices: "
+ << PrettyPrint(start_indices) << " lengths: " << PrettyPrint(lengths)
+ << " " << subspan_length << " " << offset_length << " "
+ << PrettyPrint(end_indices);
+ }
+
+ ASSIGN_OR_RETURN(auto new_buffer,
+ Buffer::Subspan(buffer, start_byte_offset, subspan_length));
+ return BufferView(std::move(new_buffer), Shape(lengths), element_size);
+}
+
+// static
+Status BufferView::Copy(BufferView* src,
+ absl::Span<const int32_t> src_start_indices,
+ BufferView* dst,
+ absl::Span<const int32_t> dst_start_indices,
+ absl::Span<const int32_t> lengths) {
+ if (src_start_indices.size() != src->shape.size() ||
+ dst_start_indices.size() != dst->shape.size() ||
+ src_start_indices.size() != lengths.size() ||
+ dst_start_indices.size() != lengths.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Src/dst shape/size mismatch: src=" << src->DebugStringShort()
+ << ", dst=" << dst->DebugStringShort()
+ << ", src_indices=" << PrettyPrint(src_start_indices)
+ << ", dst_indices=" << PrettyPrint(dst_start_indices)
+ << ", lengths=" << PrettyPrint(lengths);
+ }
+
+ // Copies only support contiguous memory. To ensure that this copy
+ // only requests such, we validate that the offset in the buffer between the
+ // start and end indices is the same as the requested size of the copy.
+ absl::InlinedVector<int32_t, 4> src_end_indices(lengths.size());
+ absl::InlinedVector<int32_t, 4> dst_end_indices(lengths.size());
+ device_size_t total_length = src->element_size;
+ for (int i = 0; i < lengths.size(); ++i) {
+ total_length *= lengths[i];
+ src_end_indices[i] = src_start_indices[i] + lengths[i] - 1;
+ dst_end_indices[i] = dst_start_indices[i] + lengths[i] - 1;
+ }
+
+ ASSIGN_OR_RETURN(auto src_start_byte_offset,
+ src->CalculateOffset(src_start_indices));
+ ASSIGN_OR_RETURN(auto src_end_byte_offset,
+ src->CalculateOffset(src_end_indices));
+ ASSIGN_OR_RETURN(auto dst_start_byte_offset,
+ dst->CalculateOffset(dst_start_indices));
+ ASSIGN_OR_RETURN(auto dst_end_byte_offset,
+ dst->CalculateOffset(dst_end_indices));
+
+ auto src_length =
+ src_end_byte_offset - src_start_byte_offset + src->element_size;
+ auto dst_length =
+ dst_end_byte_offset - dst_start_byte_offset + dst->element_size;
+ if (src_length != dst_length || src_length != total_length) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Copy for non-contiguous region of memory unimplemented: "
+ << src->DebugStringShort() << ", dst=" << dst->DebugStringShort()
+ << ", src_indices=" << PrettyPrint(src_start_indices)
+ << ", dst_indices=" << PrettyPrint(dst_start_indices)
+ << ", lengths=" << PrettyPrint(lengths);
+ }
+
+ RETURN_IF_ERROR(dst->buffer->CopyData(dst_start_byte_offset,
+ src->buffer.get(),
+ src_start_byte_offset, total_length));
+
+ return OkStatus();
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/buffer_view.h b/hal/buffer_view.h
new file mode 100644
index 0000000..ef5bc22
--- /dev/null
+++ b/hal/buffer_view.h
@@ -0,0 +1,108 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_BUFFER_VIEW_H_
+#define IREE_HAL_BUFFER_VIEW_H_
+
+#include <memory>
+#include <ostream>
+
+#include "base/shape.h"
+#include "hal/buffer.h"
+
+namespace iree {
+namespace hal {
+
+struct BufferView {
+ // Returns true if the given buffer_views are exactly equal.
+ static bool Equal(const BufferView& lhs, const BufferView& rhs);
+
+ BufferView() = default;
+ BufferView(ref_ptr<Buffer> buffer, Shape shape, int8_t element_size) noexcept
+ : buffer(std::move(buffer)), shape(shape), element_size(element_size) {}
+
+ BufferView(const BufferView& other) noexcept
+ : buffer(add_ref(other.buffer)),
+ shape(other.shape),
+ element_size(other.element_size) {}
+ BufferView& operator=(const BufferView& other) noexcept {
+ buffer = add_ref(other.buffer);
+ shape = other.shape;
+ element_size = other.element_size;
+ return *this;
+ }
+ BufferView(BufferView&& other) noexcept
+ : buffer(std::move(other.buffer)),
+ shape(other.shape),
+ element_size(other.element_size) {}
+ BufferView& operator=(BufferView&& other) noexcept {
+ buffer = std::move(other.buffer);
+ shape = other.shape;
+ element_size = other.element_size;
+ return *this;
+ }
+
+ // Returns a string useful for printing debug messages.
+ std::string DebugStringShort() const;
+
+ // Total length of the valid view range in bytes.
+ device_size_t byte_length() const {
+ return shape.element_count() * element_size;
+ }
+
+ // TODO(b/134586626): remove this when byte ranges are encoded in IR.
+ // Calculates a byte offset into the buffer_view at the given dimension
+ // indices.
+ StatusOr<device_size_t> CalculateOffset(
+ absl::Span<const int32_t> indices) const;
+
+ // TODO(b/134586626): remove this when byte ranges are encoded in IR.
+ // Returns a view onto the given range of the buffer underlying this view. The
+ // returned view starts at the offset indicated by |start_indices| and has a
+ // shape of |lengths|.
+ // Only contiguous regions of memory are supported at the moment.
+ StatusOr<BufferView> Slice(absl::Span<const int32_t> start_indices,
+ absl::Span<const int32_t> lengths) const;
+
+ // TODO(b/134586626): remove this when byte ranges are encoded in IR.
+ static Status Copy(BufferView* src,
+ absl::Span<const int32_t> src_start_indices,
+ BufferView* dst,
+ absl::Span<const int32_t> dst_start_indices,
+ absl::Span<const int32_t> lengths);
+
+ ref_ptr<Buffer> buffer;
+ Shape shape;
+ int8_t element_size;
+ // TODO(benvanik): strides.
+};
+
+inline bool operator==(const BufferView& a, const BufferView& b) {
+ return BufferView::Equal(a, b);
+}
+
+inline bool operator!=(const BufferView& a, const BufferView& b) {
+ return !(a == b);
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const BufferView& buffer_view) {
+ stream << buffer_view.DebugStringShort();
+ return stream;
+}
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_BUFFER_VIEW_H_
diff --git a/hal/buffer_view_string_util.cc b/hal/buffer_view_string_util.cc
new file mode 100644
index 0000000..130dabf
--- /dev/null
+++ b/hal/buffer_view_string_util.cc
@@ -0,0 +1,542 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/buffer_view_string_util.h"
+
+#include <functional>
+#include <sstream>
+#include <type_traits>
+
+#include "absl/strings/ascii.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/strip.h"
+#include "absl/types/optional.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "hal/heap_buffer.h"
+
+namespace iree {
+namespace hal {
+
+namespace {
+
+/* clang-format off */
+constexpr char kHexValue[256] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, // '0'..'9'
+ 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'A'..'F'
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'a'..'f'
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+};
+/* clang-format on */
+
+template <typename T>
+void HexStringToBytes(const char* from, T to, ptrdiff_t num) {
+ for (int i = 0; i < num; i++) {
+ to[i] = (kHexValue[from[i * 2] & 0xFF] << 4) +
+ (kHexValue[from[i * 2 + 1] & 0xFF]);
+ }
+}
+
+constexpr char kHexTable[513] =
+ "000102030405060708090a0b0c0d0e0f"
+ "101112131415161718191a1b1c1d1e1f"
+ "202122232425262728292a2b2c2d2e2f"
+ "303132333435363738393a3b3c3d3e3f"
+ "404142434445464748494a4b4c4d4e4f"
+ "505152535455565758595a5b5c5d5e5f"
+ "606162636465666768696a6b6c6d6e6f"
+ "707172737475767778797a7b7c7d7e7f"
+ "808182838485868788898a8b8c8d8e8f"
+ "909192939495969798999a9b9c9d9e9f"
+ "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf"
+ "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf"
+ "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf"
+ "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf"
+ "e0e1e2e3e4e5e6e7e8e9eaebecedeeef"
+ "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff";
+
+template <typename T>
+void BytesToHexString(const unsigned char* src, T dest, ptrdiff_t num) {
+ auto dest_ptr = &dest[0];
+ for (auto src_ptr = src; src_ptr != (src + num); ++src_ptr, dest_ptr += 2) {
+ const char* hex_p = &kHexTable[*src_ptr * 2];
+ std::copy(hex_p, hex_p + 2, dest_ptr);
+ }
+}
+
+// Returns true if the given type is represented as binary hex data.
+bool IsBinaryType(absl::string_view type_str) {
+ return !type_str.empty() && absl::ascii_isdigit(type_str[0]);
+}
+
+// Parses binary hex data.
+Status ParseBinaryData(absl::string_view data_str, Buffer* buffer) {
+ data_str = absl::StripAsciiWhitespace(data_str);
+ ASSIGN_OR_RETURN(auto mapping,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite));
+ auto contents = mapping.mutable_contents();
+ size_t dst_i = 0;
+ size_t src_i = 0;
+ while (src_i < data_str.size() && dst_i < contents.size()) {
+ char c = data_str[src_i];
+ if (absl::ascii_isspace(c) || c == ',') {
+ ++src_i;
+ continue;
+ }
+ if (src_i + 1 >= data_str.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Invalid input hex data (offset=" << src_i << ")";
+ }
+ HexStringToBytes(data_str.data() + src_i, contents.data() + dst_i, 1);
+ src_i += 2;
+ ++dst_i;
+ }
+ if (dst_i < contents.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Too few elements to fill type; expected " << contents.size()
+ << " but only read " << dst_i;
+ } else if (data_str.size() - src_i > 0) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Input data string contains more elements than the underlying "
+ "buffer ("
+ << contents.size() << ")";
+ }
+ return OkStatus();
+}
+
+// Prints binary hex data.
+Status PrintBinaryData(int element_size, Buffer* buffer, size_t max_entries,
+ std::ostream* stream) {
+ max_entries *= element_size; // Counting bytes, but treat them as elements.
+ ASSIGN_OR_RETURN(auto mapping,
+ buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+ auto contents = mapping.contents();
+ char hex_buffer[8 * 2];
+ for (size_t i = 0; i < std::min(max_entries, mapping.size());
+ i += element_size) {
+ if (i > 0) *stream << " ";
+ BytesToHexString(contents.data() + i, hex_buffer, element_size);
+ *stream << hex_buffer;
+ }
+ if (mapping.size() > max_entries) *stream << "...";
+ return OkStatus();
+}
+
+template <typename ElementType, typename Enabled = void>
+struct SimpleStrToValue {
+ absl::optional<ElementType> operator()(absl::string_view text) const = delete;
+};
+
+template <typename IntegerType>
+struct SimpleStrToValue<
+ IntegerType,
+ typename std::enable_if<(sizeof(IntegerType) < 4), void>::type> {
+ absl::optional<IntegerType> operator()(absl::string_view text) const {
+ int32_t value;
+ return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value}
+ : absl::nullopt;
+ }
+};
+
+template <typename IntegerType>
+struct SimpleStrToValue<
+ IntegerType,
+ typename std::enable_if<(sizeof(IntegerType) >= 4), void>::type> {
+ absl::optional<IntegerType> operator()(absl::string_view text) const {
+ IntegerType value;
+ return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value}
+ : absl::nullopt;
+ }
+};
+
+template <>
+struct SimpleStrToValue<float, void> {
+ absl::optional<float> operator()(absl::string_view text) const {
+ float value;
+ return absl::SimpleAtof(text, &value) ? absl::optional<float>{value}
+ : absl::nullopt;
+ }
+};
+
+template <>
+struct SimpleStrToValue<double, void> {
+ absl::optional<double> operator()(absl::string_view text) const {
+ double value;
+ return absl::SimpleAtod(text, &value) ? absl::optional<double>{value}
+ : absl::nullopt;
+ }
+};
+
+template <typename T>
+Status ParseNumericalDataElement(absl::string_view data_str, size_t token_start,
+ size_t token_end, absl::Span<T> contents,
+ int dst_i) {
+ if (dst_i >= contents.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Input data string contains more elements than the underlying "
+ "buffer ("
+ << contents.size() << ")";
+ }
+ auto element_str = data_str.substr(token_start, token_end - token_start + 1);
+ auto element = SimpleStrToValue<T>()(element_str);
+ if (!element.has_value()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Unable to parse element " << dst_i << " = '" << element_str
+ << "'";
+ }
+ contents[dst_i] = element.value();
+ return OkStatus();
+}
+
+template <typename T>
+Status ParseNumericalDataAsType(absl::string_view data_str, Buffer* buffer) {
+ ASSIGN_OR_RETURN(auto mapping,
+ buffer->MapMemory<T>(MemoryAccess::kDiscardWrite));
+ auto contents = mapping.mutable_contents();
+ size_t src_i = 0;
+ size_t dst_i = 0;
+ size_t token_start = std::string::npos;
+ while (src_i < data_str.size()) {
+ char c = data_str[src_i++];
+ bool is_separator =
+ absl::ascii_isspace(c) || c == ',' || c == '[' || c == ']';
+ if (token_start == std::string::npos) {
+ if (!is_separator) {
+ token_start = src_i - 1;
+ }
+ continue;
+ } else if (token_start != std::string::npos && !is_separator) {
+ continue;
+ }
+ RETURN_IF_ERROR(ParseNumericalDataElement<T>(data_str, token_start,
+ src_i - 2, contents, dst_i++));
+ token_start = std::string::npos;
+ }
+ if (token_start != std::string::npos) {
+ RETURN_IF_ERROR(ParseNumericalDataElement<T>(
+ data_str, token_start, data_str.size() - 1, contents, dst_i++));
+ }
+ if (dst_i < contents.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Input data string contains fewer elements than the underlying "
+ "buffer (expected "
+ << contents.size() << ")";
+ }
+ return OkStatus();
+}
+
+// Parses numerical data (ints, floats, etc) in some typed form.
+Status ParseNumericalData(absl::string_view type_str,
+ absl::string_view data_str, Buffer* buffer) {
+ if (type_str == "i8") {
+ return ParseNumericalDataAsType<int8_t>(data_str, buffer);
+ } else if (type_str == "u8") {
+ return ParseNumericalDataAsType<uint8_t>(data_str, buffer);
+ } else if (type_str == "i16") {
+ return ParseNumericalDataAsType<int16_t>(data_str, buffer);
+ } else if (type_str == "u16") {
+ return ParseNumericalDataAsType<uint16_t>(data_str, buffer);
+ } else if (type_str == "i32") {
+ return ParseNumericalDataAsType<int32_t>(data_str, buffer);
+ } else if (type_str == "u32") {
+ return ParseNumericalDataAsType<uint32_t>(data_str, buffer);
+ } else if (type_str == "i64") {
+ return ParseNumericalDataAsType<int64_t>(data_str, buffer);
+ } else if (type_str == "u64") {
+ return ParseNumericalDataAsType<uint64_t>(data_str, buffer);
+ } else if (type_str == "f32") {
+ return ParseNumericalDataAsType<float>(data_str, buffer);
+ } else if (type_str == "f64") {
+ return ParseNumericalDataAsType<double>(data_str, buffer);
+ } else {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Unsupported type: " << type_str;
+ }
+}
+
+template <typename T>
+void PrintElementList(const Shape& shape, absl::Span<const T> data,
+ size_t* max_entries, std::ostream* stream) {
+ if (shape.empty()) {
+ // Scalar value.
+ PrintElementList({1}, data, max_entries, stream);
+ return;
+ } else if (shape.size() == 1) {
+ // Leaf dimension; output data.
+ size_t max_count = std::min(*max_entries, static_cast<size_t>(shape[0]));
+ *stream << absl::StrJoin(data.subspan(0, max_count), " ");
+ if (max_count < shape[0]) {
+ *stream << "...";
+ }
+ *max_entries -= max_count;
+ } else {
+ // Nested; recurse into next dimension.
+ Shape nested_shape = Shape(shape.subspan(1));
+ size_t length = nested_shape.element_count();
+ size_t offset = 0;
+ for (int i = 0; i < shape[0]; ++i) {
+ *stream << "[";
+ PrintElementList<T>(nested_shape, data.subspan(offset, length),
+ max_entries, stream);
+ offset += length;
+ *stream << "]";
+ }
+ }
+}
+
+template <typename T>
+Status PrintNumericalDataAsType(const Shape& shape, Buffer* buffer,
+ size_t max_entries, std::ostream* stream) {
+ ASSIGN_OR_RETURN(auto mapping, buffer->MapMemory<T>(MemoryAccess::kRead));
+ PrintElementList(shape, mapping.contents(), &max_entries, stream);
+ return OkStatus();
+}
+
+// Prints numerical data (ints, floats, etc) from some typed form.
+Status PrintNumericalData(const Shape& shape, absl::string_view type_str,
+ Buffer* buffer, size_t max_entries,
+ std::ostream* stream) {
+ if (type_str == "i8") {
+ return PrintNumericalDataAsType<int8_t>(shape, buffer, max_entries, stream);
+ } else if (type_str == "u8") {
+ return PrintNumericalDataAsType<uint8_t>(shape, buffer, max_entries,
+ stream);
+ } else if (type_str == "i16") {
+ return PrintNumericalDataAsType<int16_t>(shape, buffer, max_entries,
+ stream);
+ } else if (type_str == "u16") {
+ return PrintNumericalDataAsType<uint16_t>(shape, buffer, max_entries,
+ stream);
+ } else if (type_str == "i32") {
+ return PrintNumericalDataAsType<int32_t>(shape, buffer, max_entries,
+ stream);
+ } else if (type_str == "u32") {
+ return PrintNumericalDataAsType<uint32_t>(shape, buffer, max_entries,
+ stream);
+ } else if (type_str == "i64") {
+ return PrintNumericalDataAsType<int64_t>(shape, buffer, max_entries,
+ stream);
+ } else if (type_str == "u64") {
+ return PrintNumericalDataAsType<uint64_t>(shape, buffer, max_entries,
+ stream);
+ } else if (type_str == "f32") {
+ return PrintNumericalDataAsType<float>(shape, buffer, max_entries, stream);
+ } else if (type_str == "f64") {
+ return PrintNumericalDataAsType<double>(shape, buffer, max_entries, stream);
+ } else {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Unsupported type: " << type_str;
+ }
+}
+
+} // namespace
+
+StatusOr<int> GetTypeElementSize(absl::string_view type_str) {
+ if (type_str.empty()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "Type is empty";
+ } else if (IsBinaryType(type_str)) {
+ // If the first character is a digit then we are dealign with binary data.
+ // The type is just the number of bytes per element.
+ int element_size = 0;
+ if (!absl::SimpleAtoi(type_str, &element_size)) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Unable to parse element size type '" << type_str << "'";
+ }
+ return element_size;
+ }
+ // We know that our types are single characters followed by bit counts.
+ // If we start to support other types we may need to do something more clever.
+ int bit_count = 0;
+ if (!absl::SimpleAtoi(type_str.substr(1), &bit_count)) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Unable to parse type bit count from '" << type_str
+ << "'; expecting something like 'i32'";
+ }
+ return bit_count / 8;
+}
+
+StatusOr<Shape> ParseShape(absl::string_view shape_str) {
+ std::vector<int> dims;
+ for (auto dim_str : absl::StrSplit(shape_str, 'x', absl::SkipWhitespace())) {
+ int dim_value = 0;
+ if (!absl::SimpleAtoi(dim_str, &dim_value)) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Invalid shape dimension '" << dim_str
+ << "' while parsing shape '" << shape_str << "'";
+ }
+ dims.push_back(dim_value);
+ }
+ return Shape{dims};
+}
+
+StatusOr<BufferView> ParseBufferViewFromString(
+ absl::string_view buffer_view_str, hal::Allocator* allocator) {
+ // Strip whitespace that may come along (linefeeds/etc).
+ buffer_view_str = absl::StripAsciiWhitespace(buffer_view_str);
+ if (buffer_view_str.empty()) {
+ // Empty lines denote empty buffer_views.
+ return BufferView{};
+ }
+
+ // Split into the components we can work with: shape, type, and data.
+ absl::string_view shape_and_type_str;
+ absl::string_view data_str;
+ auto equal_index = buffer_view_str.find('=');
+ if (equal_index == std::string::npos) {
+ // Treat a lack of = as defaulting the data to zeros.
+ shape_and_type_str = buffer_view_str;
+ } else {
+ shape_and_type_str = buffer_view_str.substr(0, equal_index);
+ data_str = buffer_view_str.substr(equal_index + 1);
+ }
+ absl::string_view shape_str;
+ absl::string_view type_str;
+ auto last_x_index = shape_and_type_str.rfind('x');
+ if (last_x_index == std::string::npos) {
+ // Scalar.
+ type_str = shape_and_type_str;
+ } else {
+ // Has a shape.
+ shape_str = shape_and_type_str.substr(0, last_x_index);
+ type_str = shape_and_type_str.substr(last_x_index + 1);
+ }
+
+ // Populate BufferView metadata required for allocation.
+ BufferView result;
+ ASSIGN_OR_RETURN(result.element_size, GetTypeElementSize(type_str));
+ ASSIGN_OR_RETURN(result.shape, ParseShape(shape_str));
+
+ // Allocate the host buffer.
+ size_t allocation_size = result.shape.element_count() * result.element_size;
+ if (allocator) {
+ ASSIGN_OR_RETURN(
+ result.buffer,
+ allocator->Allocate(MemoryType::kHostLocal | MemoryType::kDeviceVisible,
+ BufferUsage::kAll | BufferUsage::kConstant,
+ allocation_size));
+ } else {
+ result.buffer = HeapBuffer::Allocate(
+ MemoryType::kHostLocal, BufferUsage::kAll | BufferUsage::kConstant,
+ allocation_size);
+ }
+
+ if (!data_str.empty()) {
+ // Parse the data from the string right into the buffer.
+ if (IsBinaryType(type_str)) {
+ // Parse as binary hex.
+ RETURN_IF_ERROR(ParseBinaryData(data_str, result.buffer.get()));
+ } else {
+ // Parse as some nicely formatted type.
+ RETURN_IF_ERROR(
+ ParseNumericalData(type_str, data_str, result.buffer.get()));
+ }
+ }
+
+ return result;
+}
+
+StatusOr<BufferViewPrintMode> ParseBufferViewPrintMode(absl::string_view str) {
+ char str_char = str.empty() ? '?' : str[0];
+ switch (str_char) {
+ case 'b':
+ return BufferViewPrintMode::kBinary;
+ case 'i':
+ return BufferViewPrintMode::kSignedInteger;
+ case 'u':
+ return BufferViewPrintMode::kUnsignedInteger;
+ case 'f':
+ return BufferViewPrintMode::kFloatingPoint;
+ default:
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Unsupported output type '" << str << "'";
+ }
+}
+
+StatusOr<std::string> PrintBufferViewToString(const BufferView& buffer_view,
+ BufferViewPrintMode print_mode,
+ size_t max_entries) {
+ std::string result;
+ RETURN_IF_ERROR(
+ PrintBufferViewToString(buffer_view, print_mode, max_entries, &result));
+ return result;
+}
+
+Status PrintBufferViewToString(const BufferView& buffer_view,
+ BufferViewPrintMode print_mode,
+ size_t max_entries, std::string* out_result) {
+ std::ostringstream stream;
+ RETURN_IF_ERROR(
+ PrintBufferViewToStream(buffer_view, print_mode, max_entries, &stream));
+ *out_result = stream.str();
+ return OkStatus();
+}
+
+Status PrintBufferViewToStream(const BufferView& buffer_view,
+ BufferViewPrintMode print_mode,
+ size_t max_entries, std::ostream* stream) {
+ if (!buffer_view.buffer) {
+ // No buffer means the buffer_view is empty. We use the empty string to
+ // denote this (as we have no useful information).
+ return OkStatus();
+ }
+
+ // Pick a type based on the element size and the printing mode.
+ std::string type_str;
+ switch (print_mode) {
+ case BufferViewPrintMode::kBinary:
+ type_str = std::to_string(buffer_view.element_size);
+ break;
+ case BufferViewPrintMode::kSignedInteger:
+ absl::StrAppend(&type_str, "i", buffer_view.element_size * 8);
+ break;
+ case BufferViewPrintMode::kUnsignedInteger:
+ absl::StrAppend(&type_str, "u", buffer_view.element_size * 8);
+ break;
+ case BufferViewPrintMode::kFloatingPoint:
+ absl::StrAppend(&type_str, "f", buffer_view.element_size * 8);
+ break;
+ }
+
+ // [shape]x[type]= prefix (taking into account scalar values).
+ *stream << absl::StrJoin(buffer_view.shape.begin(), buffer_view.shape.end(),
+ "x");
+ if (!buffer_view.shape.empty()) *stream << "x";
+ *stream << type_str;
+ *stream << "=";
+
+ if (IsBinaryType(type_str)) {
+ return PrintBinaryData(buffer_view.element_size, buffer_view.buffer.get(),
+ max_entries, stream);
+ } else {
+ return PrintNumericalData(buffer_view.shape, type_str,
+ buffer_view.buffer.get(), max_entries, stream);
+ }
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/buffer_view_string_util.h b/hal/buffer_view_string_util.h
new file mode 100644
index 0000000..0864ddb
--- /dev/null
+++ b/hal/buffer_view_string_util.h
@@ -0,0 +1,95 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Utilities for working with BufferView data, mostly useful for testing.
+// These functions allow for conversion between types, parsing and printing, and
+// basic comparisons.
+//
+// The canonical BufferView string format is:
+// [shape]x[type]=value,value,...
+// For example:
+// 2x2xi32=0,1,2,3
+// Characters like [] are optional and will be ignored during parsing:
+// 2x2xi32=[[0 1][2 3]]
+//
+// The type may be one of the following:
+// * 1/2/4/8 = 1/2/4/8 byte elements in binary hex format.
+// * i8/u8 = signed/unsigned 8-bit integers.
+// * i16/u16 = signed/unsigned 16-bit integers.
+// * i32/u32 = signed/unsigned 32-bit integers.
+// * i64/u64 = signed/unsigned 64-bit integers.
+// * f32 = 32-bit floating-point number.
+// * f64 = 64-bit floating-point number.
+
+#ifndef IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_
+#define IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_
+
+#include <ostream>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "base/status.h"
+#include "hal/allocator.h"
+#include "hal/buffer_view.h"
+
+namespace iree {
+namespace hal {
+
+// Returns the size, in bytes, of the given type.
+StatusOr<int> GetTypeElementSize(absl::string_view type_str);
+
+// Returns a Shape parsed from the given NxMx... string.
+StatusOr<Shape> ParseShape(absl::string_view shape_str);
+
+// Parses a BufferView encoded in a string.
+// If an |allocator| is provided the buffer will be allocated as host-local and
+// device-visible. Otherwise, buffers will be host-local.
+// The format accepted matches that produced by PrintBufferViewToString.
+StatusOr<BufferView> ParseBufferViewFromString(
+ absl::string_view buffer_view_str, hal::Allocator* allocator = nullptr);
+
+// Defines how the elements within a BufferView are interpreted during printing.
+enum class BufferViewPrintMode {
+ // Interpret the data as if it were serialized bytes.
+ // In this mode no conversion is performed and the bytes in memory are printed
+ // as hex in groupings based on the element size. Shortened to 'b'.
+ kBinary,
+ // Interpret elements as signed integers; shortened to 'i'.
+ kSignedInteger,
+ // Interpret elements as unsigned integers; shortened to 'u'.
+ kUnsignedInteger,
+ // Interpret elements as floating-point values; shortened to 'f'.
+ kFloatingPoint,
+};
+
+// Returns the BufferViewPrintMode based on the shortened char in |str|.
+StatusOr<BufferViewPrintMode> ParseBufferViewPrintMode(absl::string_view str);
+
+// Prints a BufferView to a string encoded in the canonical format.
+StatusOr<std::string> PrintBufferViewToString(const BufferView& buffer_view,
+ BufferViewPrintMode print_mode,
+ size_t max_entries);
+Status PrintBufferViewToString(const BufferView& buffer_view,
+ BufferViewPrintMode print_mode,
+ size_t max_entries, std::string* out_result);
+
+// Prints a BufferView to a string stream encoded in the canonical format.
+Status PrintBufferViewToStream(const BufferView& buffer_view,
+ BufferViewPrintMode print_mode,
+ size_t max_entries, std::ostream* stream);
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_
diff --git a/hal/buffer_view_string_util_test.cc b/hal/buffer_view_string_util_test.cc
new file mode 100644
index 0000000..0cf39ba
--- /dev/null
+++ b/hal/buffer_view_string_util_test.cc
@@ -0,0 +1,186 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/buffer_view_string_util.h"
+
+#include "base/status.h"
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace hal {
+namespace {
+
+using ::testing::ElementsAre;
+
+template <typename T>
+StatusOr<std::vector<T>> ReadBuffer(const ref_ptr<Buffer>& buffer) {
+ std::vector<T> result;
+ result.resize(buffer->byte_length() / sizeof(T));
+ RETURN_IF_ERROR(
+ buffer->ReadData(0, result.data(), result.size() * sizeof(T)));
+ return result;
+}
+
+TEST(BufferViewUtilTest, GetTypeElementSize) {
+ EXPECT_EQ(1, GetTypeElementSize("1").ValueOrDie());
+ EXPECT_EQ(7, GetTypeElementSize("7").ValueOrDie());
+ EXPECT_EQ(4, GetTypeElementSize("i32").ValueOrDie());
+ EXPECT_EQ(8, GetTypeElementSize("f64").ValueOrDie());
+
+ EXPECT_FALSE(GetTypeElementSize("").ok());
+ EXPECT_FALSE(GetTypeElementSize(" ").ok());
+ EXPECT_FALSE(GetTypeElementSize("a").ok());
+ EXPECT_FALSE(GetTypeElementSize("ib").ok());
+ EXPECT_FALSE(GetTypeElementSize("i").ok());
+ EXPECT_FALSE(GetTypeElementSize("i543ff").ok());
+}
+
+TEST(BufferViewUtilTest, ParseShape) {
+ EXPECT_EQ((Shape{}), ParseShape("").ValueOrDie());
+ EXPECT_EQ((Shape{1}), ParseShape("1").ValueOrDie());
+ EXPECT_EQ((Shape{1, 2}), ParseShape("1x2").ValueOrDie());
+ EXPECT_EQ((Shape{1, 2}), ParseShape(" 1 x 2 ").ValueOrDie());
+
+ EXPECT_FALSE(ParseShape("abc").ok());
+ EXPECT_FALSE(ParseShape("1xf").ok());
+ EXPECT_FALSE(ParseShape("1xff23").ok());
+}
+
+TEST(BufferViewUtilTest, ParseBufferViewFromStringEmpty) {
+ // Empty string = empty buffer_view.
+ ASSERT_OK_AND_ASSIGN(auto m0, ParseBufferViewFromString(""));
+ EXPECT_EQ(nullptr, m0.buffer.get());
+ EXPECT_EQ(Shape{}, m0.shape);
+ EXPECT_EQ(0, m0.element_size);
+
+ // No = means no data.
+ ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString("4x2xf32"));
+ EXPECT_EQ(4 * 2 * 4, m1.buffer->allocation_size());
+ EXPECT_EQ(Shape({4, 2}), m1.shape);
+ EXPECT_EQ(4, m1.element_size);
+ EXPECT_THAT(ReadBuffer<float>(m1.buffer).ValueOrDie(),
+ ElementsAre(0, 0, 0, 0, 0, 0, 0, 0));
+
+ // No data after = means no data.
+ ASSERT_OK_AND_ASSIGN(auto m2, ParseBufferViewFromString("4x2xf32="));
+ EXPECT_EQ(4 * 2 * 4, m2.buffer->allocation_size());
+ EXPECT_EQ(Shape({4, 2}), m2.shape);
+ EXPECT_EQ(4, m2.element_size);
+ EXPECT_THAT(ReadBuffer<float>(m2.buffer).ValueOrDie(),
+ ElementsAre(0, 0, 0, 0, 0, 0, 0, 0));
+}
+
+TEST(BufferViewUtilTest, ParseBufferViewFromStringBinary) {
+ ASSERT_OK_AND_ASSIGN(auto m0, ParseBufferViewFromString("4x1=00 01 02 03"));
+ EXPECT_EQ(Shape({4}), m0.shape);
+ EXPECT_EQ(1, m0.element_size);
+ EXPECT_THAT(ReadBuffer<uint8_t>(m0.buffer).ValueOrDie(),
+ ElementsAre(0, 1, 2, 3));
+
+ // Whitespace shouldn't matter.
+ ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString("4x1=00,010203"));
+ EXPECT_EQ(Shape({4}), m1.shape);
+ EXPECT_EQ(1, m1.element_size);
+ EXPECT_THAT(ReadBuffer<uint8_t>(m1.buffer).ValueOrDie(),
+ ElementsAre(0, 1, 2, 3));
+
+ // Should fail on malformed hex bytes.
+ EXPECT_FALSE(ParseBufferViewFromString("4x1=1").ok());
+ EXPECT_FALSE(ParseBufferViewFromString("4x1=00003").ok());
+ EXPECT_FALSE(ParseBufferViewFromString("4x1=%0123%\1").ok());
+ EXPECT_FALSE(ParseBufferViewFromString("4x1=00010203040506").ok());
+}
+
+TEST(BufferViewUtilTest, ParseBufferViewFromStringAllowBrackets) {
+ ASSERT_OK_AND_ASSIGN(auto m0,
+ ParseBufferViewFromString("4xi16=[[0][ 1 ][2]][3]"));
+ EXPECT_EQ(Shape({4}), m0.shape);
+ EXPECT_EQ(2, m0.element_size);
+ EXPECT_THAT(ReadBuffer<int16_t>(m0.buffer).ValueOrDie(),
+ ElementsAre(0, 1, 2, 3));
+}
+
+TEST(BufferViewUtilTest, ParseBufferViewFromStringInteger) {
+ // Signed int16.
+ ASSERT_OK_AND_ASSIGN(auto m0,
+ ParseBufferViewFromString("4xi16=0 12345 65535 -2"));
+ EXPECT_EQ(Shape({4}), m0.shape);
+ EXPECT_EQ(2, m0.element_size);
+ EXPECT_THAT(ReadBuffer<int16_t>(m0.buffer).ValueOrDie(),
+ ElementsAre(0, 12345, -1, -2));
+
+ // Unsigned int16.
+ ASSERT_OK_AND_ASSIGN(auto m1,
+ ParseBufferViewFromString("4xu16=0 12345 65535 -2"));
+ EXPECT_EQ(Shape({4}), m1.shape);
+ EXPECT_EQ(2, m1.element_size);
+ EXPECT_THAT(ReadBuffer<uint16_t>(m1.buffer).ValueOrDie(),
+ ElementsAre(0, 12345, 65535, 65534));
+
+ // Mixing separator types is ok.
+ ASSERT_OK_AND_ASSIGN(auto m2,
+ ParseBufferViewFromString("4xu16=0, 12345, 65535, -2"));
+ EXPECT_EQ(Shape({4}), m2.shape);
+ EXPECT_EQ(2, m2.element_size);
+ EXPECT_THAT(ReadBuffer<uint16_t>(m2.buffer).ValueOrDie(),
+ ElementsAre(0, 12345, 65535, 65534));
+
+ // Should fail on malformed integers bytes and out of bounds values.
+ EXPECT_FALSE(ParseBufferViewFromString("4xi32=asodfj").ok());
+ EXPECT_FALSE(ParseBufferViewFromString("4xi32=0 1 2 3 4").ok());
+}
+
+TEST(BufferViewUtilTest, ParseBufferViewFromStringFloat) {
+ // Float.
+ ASSERT_OK_AND_ASSIGN(auto m0,
+ ParseBufferViewFromString("4xf32=0 1.0 1234 -2.0e-5"));
+ EXPECT_EQ(Shape({4}), m0.shape);
+ EXPECT_EQ(4, m0.element_size);
+ EXPECT_THAT(ReadBuffer<float>(m0.buffer).ValueOrDie(),
+ ElementsAre(0.0f, 1.0f, 1234.0f, -2.0e-5f));
+
+ // Double.
+ ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString(
+ "4xf64=0 1.0 123456789012345 -2.0e-5"));
+ EXPECT_EQ(Shape({4}), m1.shape);
+ EXPECT_EQ(8, m1.element_size);
+ EXPECT_THAT(ReadBuffer<double>(m1.buffer).ValueOrDie(),
+ ElementsAre(0.0, 1.0, 123456789012345.0, -2.0e-5));
+
+ // Should fail on malformed floats and out of bounds values.
+ EXPECT_FALSE(ParseBufferViewFromString("4xf32=asodfj").ok());
+ EXPECT_FALSE(ParseBufferViewFromString("4xf32=0").ok());
+ EXPECT_FALSE(ParseBufferViewFromString("4xf32=0 1 2 3 4").ok());
+}
+
+TEST(BufferViewUtilTest, ParseBufferViewPrintMode) {
+ EXPECT_EQ(BufferViewPrintMode::kBinary,
+ ParseBufferViewPrintMode("b").ValueOrDie());
+ EXPECT_EQ(BufferViewPrintMode::kSignedInteger,
+ ParseBufferViewPrintMode("i").ValueOrDie());
+ EXPECT_EQ(BufferViewPrintMode::kUnsignedInteger,
+ ParseBufferViewPrintMode("u").ValueOrDie());
+ EXPECT_EQ(BufferViewPrintMode::kFloatingPoint,
+ ParseBufferViewPrintMode("f").ValueOrDie());
+
+ EXPECT_FALSE(ParseBufferViewPrintMode("").ok());
+ EXPECT_FALSE(ParseBufferViewPrintMode("s").ok());
+ EXPECT_FALSE(ParseBufferViewPrintMode("asdfasdf").ok());
+}
+
+} // namespace
+} // namespace hal
+} // namespace iree
diff --git a/hal/buffer_view_test.cc b/hal/buffer_view_test.cc
new file mode 100644
index 0000000..efc9699
--- /dev/null
+++ b/hal/buffer_view_test.cc
@@ -0,0 +1,285 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/buffer_view.h"
+
+#include <numeric>
+#include <vector>
+
+#include "base/status.h"
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "hal/buffer.h"
+#include "hal/heap_buffer.h"
+
+namespace iree {
+namespace hal {
+namespace {
+
+template <typename T>
+BufferView MakeView(const std::vector<T> src_data, Shape shape) {
+ auto parent_buffer = HeapBuffer::AllocateCopy(
+ BufferUsage::kTransfer | BufferUsage::kMapping, absl::MakeSpan(src_data));
+
+ return BufferView(std::move(parent_buffer), shape, sizeof(T));
+}
+
+template <typename T>
+std::vector<T> ReadData(BufferView view) {
+ std::vector<T> data(view.shape.element_count());
+ EXPECT_OK(view.buffer->ReadData(0, data.data(), data.size() * sizeof(T)));
+ return data;
+}
+
+TEST(BufferViewTest, SliceWholeBuffer) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ Shape shape = {2, 2};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {0, 0};
+ std::vector<int32_t> lengths = {2, 2};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_TRUE(BufferView::Equal(parent_view, slice))
+ << "original parent_view " << parent_view.DebugStringShort()
+ << " and whole slice " << slice.DebugStringShort() << " are not equal";
+}
+
+TEST(BufferViewTest, SliceSingleRow) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ Shape shape = {2, 2};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 0};
+ std::vector<int32_t> lengths = {1, 2};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({2, 3}));
+}
+
+TEST(BufferViewTest, SliceRowStart) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7};
+ Shape shape = {2, 4};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 0};
+ std::vector<int32_t> lengths = {1, 3};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({4, 5, 6}));
+}
+
+TEST(BufferViewTest, SliceRowEnd) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7};
+ Shape shape = {2, 4};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 1};
+ std::vector<int32_t> lengths = {1, 3};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({5, 6, 7}));
+}
+
+TEST(BufferViewTest, SliceRowMiddle) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7};
+ Shape shape = {2, 4};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 1};
+ std::vector<int32_t> lengths = {1, 2};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({5, 6}));
+}
+
+TEST(BufferViewTest, SliceMultiRow) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ Shape shape = {3, 3};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 0};
+ std::vector<int32_t> lengths = {2, 3};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({3, 4, 5, 6, 7, 8}));
+}
+
+TEST(BufferViewTest, SliceHighRank) {
+ std::vector<uint8_t> src_data(81);
+ std::iota(src_data.begin(), src_data.end(), 0);
+ Shape shape = {3, 3, 3, 3};
+
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 2, 2, 1};
+ std::vector<int32_t> lengths = {1, 1, 1, 2};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({52, 53}));
+}
+
+TEST(BufferViewTest, SliceModifySlice) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ Shape shape = {2, 2};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 0};
+ std::vector<int32_t> lengths = {1, 2};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_OK(slice.buffer->Fill8(0, kWholeBuffer, 0xFFu));
+
+ auto parent_data = ReadData<uint8_t>(parent_view);
+ EXPECT_EQ(parent_data, std::vector<uint8_t>({0, 1, 0xFFu, 0xFFu}));
+}
+
+TEST(BufferViewTest, SliceModifyParent) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ Shape shape = {2, 2};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 0};
+ std::vector<int32_t> lengths = {1, 2};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_OK(parent_view.buffer->Fill8(0, kWholeBuffer, 0xFFu));
+
+ EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({0xFFu, 0xFFu}));
+}
+
+TEST(BufferViewTest, SliceMultiByteElementWholeBuffer) {
+ const std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3};
+
+ Shape shape = {2, 2};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {0, 0};
+ std::vector<int32_t> lengths = {2, 2};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_TRUE(BufferView::Equal(parent_view, slice))
+ << "original parent_view " << parent_view.DebugStringShort()
+ << " and whole slice " << slice.DebugStringShort() << " are not equal";
+}
+
+TEST(BufferViewTest, SliceShapeAndElementSize) {
+ std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3};
+ Shape shape = {2, 2};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 0};
+ std::vector<int32_t> lengths = {1, 2};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+ EXPECT_EQ(slice.shape, Shape(lengths));
+ EXPECT_EQ(slice.element_size, 4);
+}
+
+TEST(BufferViewTest, SliceMultiByteElement) {
+ std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3};
+ Shape shape = {2, 2};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 0};
+ std::vector<int32_t> lengths = {1, 2};
+ ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
+
+ EXPECT_EQ(ReadData<int32_t>(slice), std::vector<int32_t>({2, 3}));
+}
+
+TEST(BufferViewTest, SliceIndexBadRank) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ Shape shape = {2, 2};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {0};
+ std::vector<int32_t> lengths = {2};
+ EXPECT_TRUE(
+ IsInvalidArgument(parent_view.Slice(start_indices, lengths).status()));
+}
+
+TEST(BufferViewTest, SliceIndexLengthMismatch) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ Shape shape = {2, 2};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {0, 0};
+ std::vector<int32_t> lengths = {2};
+ EXPECT_TRUE(
+ IsInvalidArgument(parent_view.Slice(start_indices, lengths).status()));
+}
+
+TEST(BufferViewTest, SliceIndicesOutOfBounds) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ Shape shape = {2, 2};
+
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {0, 3};
+ std::vector<int32_t> lengths = {1, 1};
+ EXPECT_TRUE(
+ IsInvalidArgument(parent_view.Slice(start_indices, lengths).status()));
+}
+
+TEST(BufferViewTest, SliceLengthsOutOfBounds) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3};
+ Shape shape = {2, 2};
+
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {0, 0};
+ std::vector<int32_t> lengths = {1, 3};
+ EXPECT_TRUE(
+ IsInvalidArgument(parent_view.Slice(start_indices, lengths).status()));
+}
+
+TEST(BufferViewTest, SliceNonContiguous) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ Shape shape = {3, 3};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 1};
+ std::vector<int32_t> lengths = {2, 2};
+ EXPECT_TRUE(
+ IsUnimplemented(parent_view.Slice(start_indices, lengths).status()));
+}
+
+TEST(BufferViewTest, SliceNonContiguousMultiRowLeft) {
+ std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ Shape shape = {3, 3};
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 0};
+ std::vector<int32_t> lengths = {2, 1};
+ EXPECT_TRUE(
+ IsUnimplemented(parent_view.Slice(start_indices, lengths).status()));
+}
+
+TEST(BufferViewTest, SliceHighRankNonContiguous) {
+ std::vector<uint8_t> src_data(81);
+ std::iota(src_data.begin(), src_data.end(), 0);
+ Shape shape = {3, 3, 3, 3};
+
+ auto parent_view = MakeView(src_data, shape);
+
+ std::vector<int32_t> start_indices = {1, 0, 2, 1};
+ std::vector<int32_t> lengths = {1, 2, 1, 2};
+ EXPECT_TRUE(
+ IsUnimplemented(parent_view.Slice(start_indices, lengths).status()));
+}
+
+} // namespace
+} // namespace hal
+} // namespace iree
diff --git a/hal/command_buffer.cc b/hal/command_buffer.cc
new file mode 100644
index 0000000..d634c75
--- /dev/null
+++ b/hal/command_buffer.cc
@@ -0,0 +1,29 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/command_buffer.h"
+
+namespace iree {
+namespace hal {
+
+std::string CommandCategoryString(CommandCategoryBitfield categories) {
+ return FormatBitfieldValue(categories,
+ {
+ {CommandCategory::kTransfer, "kTransfer"},
+ {CommandCategory::kDispatch, "kDispatch"},
+ });
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/command_buffer.h b/hal/command_buffer.h
new file mode 100644
index 0000000..df9b40f
--- /dev/null
+++ b/hal/command_buffer.h
@@ -0,0 +1,383 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_COMMAND_BUFFER_H_
+#define IREE_HAL_COMMAND_BUFFER_H_
+
+#include <cstdint>
+
+#include "base/bitfield.h"
+#include "base/shape.h"
+#include "base/status.h"
+#include "hal/allocator.h"
+#include "hal/buffer.h"
+#include "hal/buffer_view.h"
+#include "hal/event.h"
+#include "hal/executable.h"
+#include "hal/resource.h"
+
+namespace iree {
+namespace hal {
+
+// A bitfield specifying the mode of operation for a command buffer.
+enum class CommandBufferMode : uint32_t {
+ // Command buffer will be submitted once and never used again.
+ // This may enable in-place patching of command buffers that reduce overhead
+ // when it's known that command buffers will not be reused.
+ kOneShot = 1 << 0,
+};
+IREE_BITFIELD(CommandBufferMode);
+using CommandBufferModeBitfield = CommandBufferMode;
+std::string CommandBufferModeString(CommandBufferModeBitfield mode);
+
+// A bitfield specifying the category of commands in a command queue.
+enum class CommandCategory : uint32_t {
+ // Command is considered a transfer operation (memcpy, etc).
+ kTransfer = 1 << 0,
+ // Command is considered a dispatch operation (dispatch/execute).
+ kDispatch = 1 << 1,
+};
+IREE_BITFIELD(CommandCategory);
+using CommandCategoryBitfield = CommandCategory;
+std::string CommandCategoryString(CommandCategoryBitfield categories);
+
+// Bitfield specifying which execution stage a brarrier should start/end at.
+//
+// Maps to VkPipelineStageFlagBits.
+enum class ExecutionStage : uint32_t {
+ // Top of the pipeline when commands are initially issued by the device.
+ kCommandIssue = 1 << 0,
+ // Stage of the pipeline when dispatch parameter data is consumed.
+ kCommandProcess = 1 << 1,
+ // Stage where dispatch commands execute.
+ kDispatch = 1 << 2,
+ // Stage where transfer (copy/clear/fill/etc) commands execute.
+ kTransfer = 1 << 3,
+ // Final stage in the pipeline when commands are retired on the device.
+ kCommandRetire = 1 << 4,
+ // Pseudo-stage for read/writes by the host. Not executed on device.
+ kHost = 1 << 5,
+};
+IREE_BITFIELD(ExecutionStage);
+using ExecutionStageBitfield = ExecutionStage;
+
+// Bitfield specifying which scopes will access memory and how.
+//
+// Maps to VkAccessFlagBits.
+enum class AccessScope : uint32_t {
+ // Read access to indirect command data as part of an indirect dispatch.
+ kIndirectCommandRead = 1 << 0,
+ // Constant uniform buffer reads by the device.
+ kConstantRead = 1 << 1,
+ // Storage buffer reads by dispatch commands.
+ kDispatchRead = 1 << 2,
+ // Storage buffer writes by dispatch commands.
+ kDispatchWrite = 1 << 3,
+ // Source of a transfer operation.
+ kTransferRead = 1 << 4,
+ // Target of a transfer operation.
+ kTransferWrite = 1 << 5,
+ // Read operation by the host through mapped memory.
+ kHostRead = 1 << 6,
+ // Write operation by the host through mapped memory.
+ kHostWrite = 1 << 7,
+ // External/non-specific read.
+ kMemoryRead = 1 << 8,
+ // External/non-specific write.
+ kMemoryWrite = 1 << 9,
+};
+IREE_BITFIELD(AccessScope);
+using AccessScopeBitfield = AccessScope;
+
+// Defines a global memory barrier.
+// These are cheaper to encode than buffer-specific barriers but may cause
+// stalls and bubbles in device pipelines if applied too broadly. Prefer them
+// over equivalently large sets of buffer-specific barriers (such as when
+// completely changing execution contexts).
+//
+// Maps to VkMemoryBarrier.
+struct MemoryBarrier {
+ // All access scopes prior-to the barrier (inclusive).
+ AccessScopeBitfield source_scope;
+ // All access scopes following the barrier (inclusive).
+ AccessScopeBitfield target_scope;
+};
+
+// Defines a memory barrier that applies to a range of a specific buffer.
+// Use of these (vs. global memory barriers) provides fine-grained execution
+// ordering to device command processors and allows for more aggressive
+// reordering.
+//
+// Maps to VkBufferMemoryBarrier.
+struct BufferBarrier {
+ // All access scopes prior-to the barrier (inclusive).
+ AccessScopeBitfield source_scope;
+ // All access scopes following the barrier (inclusive).
+ AccessScopeBitfield target_scope;
+ // Buffer the barrier is restricted to.
+ // The barrier will apply to the entire physical device allocation.
+ Buffer* buffer = nullptr;
+ // Relative offset/length within |buffer| (which may itself be mapped into the
+ // device allocation at an offset).
+ device_size_t offset = 0;
+ device_size_t length = kWholeBuffer;
+};
+
+// Represents a binding to a buffer with a set of attributes.
+// This may be used by drivers to validate alignment.
+struct BufferBinding {
+ // Access rights of the buffer contents by the executable.
+ MemoryAccessBitfield access = MemoryAccess::kAll;
+
+ // The buffer this binding references.
+ // The buffer is not retained by the binding and must be kept alive externally
+ // for the duration it is in use by the queue.
+ Buffer* buffer = nullptr;
+
+ // Shape of the buffer contents.
+ Shape shape;
+
+ // Size of each element within the buffer, in bytes.
+ int8_t element_size = 0;
+
+ BufferBinding() = default;
+ BufferBinding(MemoryAccessBitfield access, Buffer* buffer)
+ : access(access), buffer(buffer) {}
+ BufferBinding(MemoryAccessBitfield access, Buffer* buffer, Shape shape,
+ int8_t element_size)
+ : access(access),
+ buffer(buffer),
+ shape(shape),
+ element_size(element_size) {}
+ BufferBinding(MemoryAccessBitfield access, const BufferView& buffer_view)
+ : access(access),
+ buffer(buffer_view.buffer.get()),
+ shape(buffer_view.shape),
+ element_size(buffer_view.element_size) {}
+};
+
+// Wraps parameters for a Dispatch request.
+struct DispatchRequest {
+ // Executable prepared for use on the device.
+ // The executable must remain alive until all in-flight dispatch requests
+ // that use it have completed.
+ Executable* executable = nullptr;
+
+ // Executable entry point ordinal.
+ int entry_point = 0;
+
+ // TODO(benvanik): predication.
+
+ // Static workload parameters defining the X, Y, and Z workgroup counts.
+ std::array<int32_t, 3> workload;
+
+ // An optional buffer containing the dynamic workload to dispatch.
+ // The contents need not be available at the time of recording but must be
+ // made visible prior to execution of the dispatch command.
+ //
+ // Buffer contents are expected to be 3 int32 values defining the X, Y, and Z
+ // workgroup counts.
+ //
+ // The buffer must have been allocated with BufferUsage::kDispatch and be
+ // of MemoryType::kDeviceVisible.
+ Buffer* workload_buffer = nullptr;
+
+ // A list of buffers that contain the execution inputs/outputs.
+ // Order is dependent on executable arg layout.
+ //
+ // Buffers must have been allocated with BufferUsage::kDispatch and be
+ // of MemoryType::kDeviceVisible.
+ absl::Span<const BufferBinding> bindings;
+
+ // TODO(benvanik): push-constant equivalent (uniforms, etc).
+};
+
+// Asynchronous command buffer recording interface.
+// Commands are recorded by the implementation for later submission to command
+// queues.
+//
+// Buffers and synchronization objects referenced must remain valid and not be
+// modified or read while there are commands in-flight. The usual flow is to
+// populate input buffers, Dispatch using those buffers, wait on a Fence until
+// the buffers are guaranteed to no longer be in use, and then reuse or release
+// the buffers.
+//
+// Errors that can be recognized when operations are enqueued will be returned
+// immediately, such as invalid argument errors. Errors that can only be
+// determined at execution time will be returned on fences. Once a failure
+// occurs the device queue will enter an error state that invalidates all
+// operations on the device queue (as ordering is not strict and any may still
+// be in-flight). In this case the user of the device queue should treat all
+// in-flight operations as cancelled and fully reset themselves. Other device
+// queues that may be waiting on events from the device queue will also enter
+// error states. Only once a user has acknowledged and cleared the error state
+// with a Reset the queue will become usable, and otherwise all operations will
+// return errors.
+//
+// Command buffers are thread-compatible. Use multiple command buffers if trying
+// to record commands from multiple threads. Command buffers must not be mutated
+// between when they have are submitted for execution on a queue and when the
+// fence fires indicating the completion of their execution.
+class CommandBuffer : public Resource {
+ public:
+ virtual CommandBuffer* impl() { return this; }
+
+ // Device allocator that commands encoded into the buffer share compatibility
+ // with.
+ Allocator* allocator() const { return allocator_; }
+
+ // Command buffer operation mode.
+ CommandBufferModeBitfield mode() const { return mode_; }
+
+ // Command categories that may be recorded into the buffer.
+ CommandCategoryBitfield command_categories() const {
+ return command_categories_;
+ }
+
+ // True if the command buffer is between a Begin/End recording block.
+ virtual bool is_recording() const = 0;
+
+ // Resets and begins recording into the command buffer, clearing all
+ // previously recorded contents.
+ // The command buffer must not be in-flight.
+ virtual Status Begin() = 0;
+
+ // Ends recording into the command buffer.
+ // This must be called prior to submitting the command buffer for execution.
+ virtual Status End() = 0;
+
+ // TODO(benvanik): annotations for debugging and tracing:
+ // enter/exit
+ // stack frame manipulation
+ // explicit timers? or profiling buffer?
+
+ // TODO(b/138719910): cross-queue and external acquire/release.
+ // virtual Status AcquireBuffer() = 0;
+ // virtual Status ReleaseBuffer() = 0;
+
+ // Defines a memory dependency between commands recorded before and after the
+ // barrier. One or more memory or buffer barriers can be specified to indicate
+ // between which stages or buffers the dependencies exist.
+ virtual Status ExecutionBarrier(
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) = 0;
+
+ // Sets an event to the signaled state.
+ // |source_stage_mask| specifies when the event is signaled.
+ //
+ // Events are only valid within a single command buffer. Events can only be
+ // used on non-transfer queues.
+ virtual Status SignalEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) = 0;
+
+ // Resets an event to the non-signaled state.
+ // |source_stage_mask| specifies when the event is unsignaled.
+ //
+ // Events are only valid within a single command buffer. Events can only be
+ // used on non-transfer queues.
+ virtual Status ResetEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) = 0;
+
+ // Waits for one or more events to be signaled and defines a memory dependency
+ // between the synchronization scope of the signal operations and the commands
+ // following the wait.
+ //
+ // |source_stage_mask| must include ExecutionStage::kHost for Event::Signal to
+ // be visibile.
+ //
+ // Events are only valid within a single command buffer. Events remain
+ // signaled even after waiting and must be reset to be reused. Events can only
+ // be used on non-transfer queues.
+ virtual Status WaitEvents(
+ absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) = 0;
+
+ // Fills the target buffer with the given repeating value.
+ // Expects that value_length is one of 1, 2, or 4 and that the offset and
+ // length are aligned to the natural alignment of the value.
+ // The target buffer must be compatible with the devices owned by this
+ // device queue and be allocated with BufferUsage::kTransfer.
+ virtual Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length, const void* pattern,
+ size_t pattern_length) = 0;
+
+ // Hints to the device queue that the given buffer will not be used again.
+ // After encoding a discard the buffer contents will be considered undefined.
+ // This is because the discard may be used to elide write backs to host memory
+ // or aggressively reuse the allocation for other purposes.
+ //
+ // For buffers allocated with MemoryType::kTransient this may allow
+ // the device queue to reclaim the memory used by the buffer earlier than
+ // otherwise possible.
+ virtual Status DiscardBuffer(Buffer* buffer) = 0;
+
+ // Updates a range of the given target buffer from the source host memory.
+ // The source host memory is copied immediately into the command buffer and
+ // occupies command buffer space. It is strongly recommended that large buffer
+ // updates are performed via CopyBuffer where there is the possibility of a
+ // zero-copy path.
+ // The |source_buffer| may be releaed by the caller immediately after this
+ // call returns.
+ // The |target_buffer| must be compatible with the devices owned by this
+ // device queue and be allocated with BufferUsage::kTransfer.
+ virtual Status UpdateBuffer(const void* source_buffer,
+ device_size_t source_offset,
+ Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length) = 0;
+
+ // Copies a range of one buffer to another.
+ // Both buffers must be compatible with the devices owned by this device
+ // queue and be allocated with BufferUsage::kTransfer. Though the source and
+ // target buffer may be the same the ranges must not overlap (as with memcpy).
+ //
+ // This can be used to perform device->host, host->device, and device->device
+ // copies.
+ virtual Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) = 0;
+
+ // Dispatches an execution request.
+ // The request may execute overlapped with any other transfer operation or
+ // dispatch made within the same barrier-defined sequence.
+ //
+ // The executable specified must be registered for use with the device driver
+ // owning this queue. It must not be unregistered until all requests that use
+ // it have completed.
+ //
+ // Fails if the queue does not support dispatch operations (as indicated by
+ // can_dispatch).
+ virtual Status Dispatch(const DispatchRequest& dispatch_request) = 0;
+
+ protected:
+ CommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories)
+ : allocator_(allocator),
+ mode_(mode),
+ command_categories_(command_categories) {}
+
+ private:
+ Allocator* const allocator_;
+ const CommandBufferModeBitfield mode_;
+ const CommandCategoryBitfield command_categories_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_COMMAND_BUFFER_H_
diff --git a/hal/command_buffer_validation.cc b/hal/command_buffer_validation.cc
new file mode 100644
index 0000000..9d86b82
--- /dev/null
+++ b/hal/command_buffer_validation.cc
@@ -0,0 +1,403 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/command_buffer_validation.h"
+
+#include "base/logging.h"
+#include "base/status.h"
+
+namespace iree {
+namespace hal {
+
+namespace {
+
+// Command buffer validation shim.
+// Wraps an existing command buffer to provide in-depth validation during
+// recording. This should be enabled whenever the command buffer is being driven
+// by unsafe code or when early and readable diagnostics are needed.
+class ValidatingCommandBuffer : public CommandBuffer {
+ public:
+ explicit ValidatingCommandBuffer(ref_ptr<CommandBuffer> impl);
+ ~ValidatingCommandBuffer() override;
+
+ CommandBuffer* impl() override { return impl_.get(); }
+
+ bool is_recording() const override;
+
+ Status Begin() override;
+ Status End() override;
+
+ Status ExecutionBarrier(
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) override;
+ Status SignalEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) override;
+ Status ResetEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) override;
+ Status WaitEvents(absl::Span<Event*> events,
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) override;
+ Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length, const void* pattern,
+ size_t pattern_length) override;
+ Status DiscardBuffer(Buffer* buffer) override;
+ Status UpdateBuffer(const void* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) override;
+ Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) override;
+ Status Dispatch(const DispatchRequest& dispatch_request) override;
+
+ private:
+ // Returns a failure if the queue does not support the given caps.
+ Status ValidateCategories(CommandCategoryBitfield required_categories) const;
+ // Returns a failure if the memory type the buffer was allocated from is not
+ // compatible with the given type.
+ Status ValidateCompatibleMemoryType(Buffer* buffer,
+ MemoryTypeBitfield memory_type) const;
+ // Returns a failure if the buffer memory type or usage disallows the given
+ // access type.
+ Status ValidateAccess(Buffer* buffer,
+ MemoryAccessBitfield memory_access) const;
+ // Returns a failure if the buffer was not allocated for the given usage.
+ Status ValidateUsage(Buffer* buffer, BufferUsageBitfield usage) const;
+ // Validates that the range provided is within the given buffer.
+ Status ValidateRange(Buffer* buffer, device_size_t byte_offset,
+ device_size_t byte_length) const;
+
+ ref_ptr<CommandBuffer> impl_;
+};
+
+ValidatingCommandBuffer::ValidatingCommandBuffer(ref_ptr<CommandBuffer> impl)
+ : CommandBuffer(impl->allocator(), impl->mode(),
+ impl->command_categories()),
+ impl_(std::move(impl)) {}
+
+ValidatingCommandBuffer::~ValidatingCommandBuffer() = default;
+
+bool ValidatingCommandBuffer::is_recording() const {
+ return impl_->is_recording();
+}
+
+Status ValidatingCommandBuffer::Begin() {
+ DVLOG(3) << "CommandBuffer::Begin()";
+ if (impl_->is_recording()) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Command buffer is already recording";
+ }
+ return impl_->Begin();
+}
+
+Status ValidatingCommandBuffer::End() {
+ DVLOG(3) << "CommandBuffer::End()";
+ if (!impl_->is_recording()) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Command buffer is not recording";
+ }
+ return impl_->End();
+}
+
+Status ValidatingCommandBuffer::ValidateCategories(
+ CommandCategoryBitfield required_categories) const {
+ if (!AllBitsSet(command_categories(), required_categories)) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Operation requires categories "
+ << CommandCategoryString(required_categories)
+ << " but buffer only supports "
+ << CommandCategoryString(command_categories());
+ }
+ return OkStatus();
+}
+
+Status ValidatingCommandBuffer::ValidateCompatibleMemoryType(
+ Buffer* buffer, MemoryTypeBitfield memory_type) const {
+ if ((buffer->memory_type() & memory_type) != memory_type) {
+ // Missing one or more bits.
+ return PermissionDeniedErrorBuilder(IREE_LOC)
+ << "Buffer memory type is not compatible with the requested "
+ "operation; buffer has "
+ << MemoryTypeString(buffer->memory_type()) << ", operation requires "
+ << MemoryTypeString(memory_type);
+ }
+ return OkStatus();
+}
+
+Status ValidatingCommandBuffer::ValidateAccess(
+ Buffer* buffer, MemoryAccessBitfield memory_access) const {
+ if ((buffer->allowed_access() & memory_access) != memory_access) {
+ // Bits must match exactly.
+ return PermissionDeniedErrorBuilder(IREE_LOC)
+ << "The buffer does not support the requested access type; buffer "
+ "allows "
+ << MemoryAccessString(buffer->allowed_access())
+ << ", operation requires " << MemoryAccessString(memory_access);
+ }
+ return OkStatus();
+}
+
+// Returns a failure if the buffer was not allocated for the given usage.
+Status ValidatingCommandBuffer::ValidateUsage(Buffer* buffer,
+ BufferUsageBitfield usage) const {
+ if (!allocator()->CanUseBuffer(buffer, usage)) {
+ // Buffer cannot be used on the queue for the given usage.
+ return PermissionDeniedErrorBuilder(IREE_LOC)
+ << "Requested usage of " << buffer->DebugString()
+ << " is not supported for the buffer on this queue; "
+ "buffer allows "
+ << BufferUsageString(buffer->usage()) << ", queue requires "
+ << BufferUsageString(usage);
+ }
+
+ if ((buffer->usage() & usage) != usage) {
+ // Missing one or more bits.
+ return PermissionDeniedErrorBuilder(IREE_LOC)
+ << "Requested usage was not specified when the buffer was "
+ "allocated; buffer allows "
+ << BufferUsageString(buffer->usage()) << ", operation requires "
+ << BufferUsageString(usage);
+ }
+
+ return OkStatus();
+}
+
+// Validates that the range provided is within the given buffer.
+Status ValidatingCommandBuffer::ValidateRange(Buffer* buffer,
+ device_size_t byte_offset,
+ device_size_t byte_length) const {
+ // Check if the start of the range runs off the end of the buffer.
+ if (byte_offset > buffer->byte_length()) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Attempted to access an address off the end of the valid buffer "
+ "range (offset="
+ << byte_offset << ", length=" << byte_length
+ << ", buffer byte_length=" << buffer->byte_length() << ")";
+ }
+
+ if (byte_length == 0) {
+ // Fine to have a zero length.
+ return OkStatus();
+ }
+
+ // Check if the end runs over the allocation.
+ device_size_t end = byte_offset + byte_length;
+ if (end > buffer->byte_length()) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Attempted to access an address outside of the valid buffer "
+ "range (offset="
+ << byte_offset << ", length=" << byte_length
+ << ", end(inc)=" << (end - 1)
+ << ", buffer byte_length=" << buffer->byte_length() << ")";
+ }
+
+ return OkStatus();
+}
+
+Status ValidatingCommandBuffer::ExecutionBarrier(
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) {
+ DVLOG(3) << "CommandBuffer::ExecutionBarrier(...)";
+
+ // TODO(benvanik): additional synchronization validation.
+ RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer |
+ CommandCategory::kDispatch));
+
+ return impl_->ExecutionBarrier(source_stage_mask, target_stage_mask,
+ memory_barriers, buffer_barriers);
+}
+
+Status ValidatingCommandBuffer::SignalEvent(
+ Event* event, ExecutionStageBitfield source_stage_mask) {
+ DVLOG(3) << "CommandBuffer::SignalEvent(...)";
+
+ // TODO(benvanik): additional synchronization validation.
+ RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
+
+ return impl_->SignalEvent(event, source_stage_mask);
+}
+
+Status ValidatingCommandBuffer::ResetEvent(
+ Event* event, ExecutionStageBitfield source_stage_mask) {
+ DVLOG(3) << "CommandBuffer::ResetEvent(...)";
+
+ // TODO(benvanik): additional synchronization validation.
+ RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
+
+ return impl_->ResetEvent(event, source_stage_mask);
+}
+
+Status ValidatingCommandBuffer::WaitEvents(
+ absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) {
+ DVLOG(3) << "CommandBuffer::WaitEvents(...)";
+
+ // TODO(benvanik): additional synchronization validation.
+ RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
+
+ return impl_->WaitEvents(events, source_stage_mask, target_stage_mask,
+ memory_barriers, buffer_barriers);
+}
+
+Status ValidatingCommandBuffer::FillBuffer(Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length,
+ const void* pattern,
+ size_t pattern_length) {
+ DVLOG(3) << "CommandBuffer::FillBuffer(" << target_buffer->DebugString()
+ << ", " << target_offset << ", " << length << ", ??, "
+ << pattern_length << ")";
+
+ RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
+ RETURN_IF_ERROR(
+ ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible));
+ RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite));
+ RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer));
+ RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length));
+
+ // Ensure the value length is supported.
+ if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Fill value length is not one of the supported values "
+ "(pattern_length="
+ << pattern_length << ")";
+ }
+
+ // Ensure the offset and length have an alignment matching the value length.
+ if ((target_offset % pattern_length) != 0 || (length % pattern_length) != 0) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Fill offset and/or length do not match the natural alignment of "
+ "the fill value (target_offset="
+ << target_offset << ", length=" << length
+ << ", pattern_length=" << pattern_length << ")";
+ }
+
+ return impl_->FillBuffer(target_buffer, target_offset, length, pattern,
+ pattern_length);
+}
+
+Status ValidatingCommandBuffer::DiscardBuffer(Buffer* buffer) {
+ DVLOG(3) << "CommandBuffer::DiscardBuffer(" << buffer->DebugString() << ")";
+
+ RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
+ RETURN_IF_ERROR(
+ ValidateCompatibleMemoryType(buffer, MemoryType::kDeviceVisible));
+ RETURN_IF_ERROR(ValidateUsage(buffer, BufferUsage::kNone));
+
+ return impl_->DiscardBuffer(buffer);
+}
+
+Status ValidatingCommandBuffer::UpdateBuffer(const void* source_buffer,
+ device_size_t source_offset,
+ Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length) {
+ DVLOG(3) << "CommandBuffer::UpdateBuffer(" << source_buffer << ", "
+ << source_offset << ", " << target_buffer->DebugString() << ", "
+ << target_offset << ", " << length << ")";
+
+ RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
+ RETURN_IF_ERROR(
+ ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible));
+ RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite));
+ RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer));
+ RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length));
+
+ return impl_->UpdateBuffer(source_buffer, source_offset, target_buffer,
+ target_offset, length);
+}
+
+Status ValidatingCommandBuffer::CopyBuffer(Buffer* source_buffer,
+ device_size_t source_offset,
+ Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length) {
+ DVLOG(3) << "CommandBuffer::CopyBuffer(" << source_buffer->DebugString()
+ << ", " << source_offset << ", " << target_buffer->DebugString()
+ << ", " << target_offset << ", " << length << ")";
+
+ RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
+
+ // At least source or destination must be device-visible to enable
+ // host->device, device->host, and device->device.
+ // TODO(b/117338171): host->host copies.
+ if (!AnyBitSet(source_buffer->memory_type() & MemoryType::kDeviceVisible) &&
+ !AnyBitSet(target_buffer->memory_type() & MemoryType::kDeviceVisible)) {
+ return PermissionDeniedErrorBuilder(IREE_LOC)
+ << "At least one buffer must be device-visible for a copy; "
+ "source_buffer="
+ << MemoryTypeString(source_buffer->memory_type())
+ << ", target_buffer="
+ << MemoryTypeString(target_buffer->memory_type());
+ }
+
+ RETURN_IF_ERROR(ValidateAccess(source_buffer, MemoryAccess::kRead));
+ RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite));
+ RETURN_IF_ERROR(ValidateUsage(source_buffer, BufferUsage::kTransfer));
+ RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer));
+ RETURN_IF_ERROR(ValidateRange(source_buffer, source_offset, length));
+ RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length));
+
+ // Check for overlap - just like memcpy we don't handle that.
+ if (Buffer::TestOverlap(source_buffer, source_offset, length, target_buffer,
+ target_offset,
+ length) != Buffer::Overlap::kDisjoint) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Source and target ranges overlap within the same buffer";
+ }
+
+ return impl_->CopyBuffer(source_buffer, source_offset, target_buffer,
+ target_offset, length);
+}
+
+Status ValidatingCommandBuffer::Dispatch(
+ const DispatchRequest& dispatch_request) {
+ DVLOG(3) << "CommandBuffer::Dispatch(?)";
+
+ RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
+
+ // Validate all buffers referenced have compatible memory types, access
+ // rights, and usage.
+ for (const auto& binding : dispatch_request.bindings) {
+ RETURN_IF_ERROR(ValidateCompatibleMemoryType(binding.buffer,
+ MemoryType::kDeviceVisible))
+ << "input buffer: " << MemoryAccessString(binding.access) << " "
+ << binding.buffer->DebugStringShort();
+ RETURN_IF_ERROR(ValidateAccess(binding.buffer, binding.access));
+ RETURN_IF_ERROR(ValidateUsage(binding.buffer, BufferUsage::kDispatch));
+ // TODO(benvanik): validate it matches the executable expectations.
+ // TODO(benvanik): validate buffer contains enough data for shape+size.
+ }
+
+ // TODO(benvanik): validate no aliasing?
+
+ return impl_->Dispatch(dispatch_request);
+}
+
+} // namespace
+
+ref_ptr<CommandBuffer> WrapCommandBufferWithValidation(
+ ref_ptr<CommandBuffer> impl) {
+ return make_ref<ValidatingCommandBuffer>(std::move(impl));
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/command_buffer_validation.h b/hal/command_buffer_validation.h
new file mode 100644
index 0000000..f60d465
--- /dev/null
+++ b/hal/command_buffer_validation.h
@@ -0,0 +1,32 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_COMMAND_BUFFER_VALIDATION_H_
+#define IREE_HAL_COMMAND_BUFFER_VALIDATION_H_
+
+#include "hal/command_buffer.h"
+
+namespace iree {
+namespace hal {
+
+// Wraps an existing command buffer to provide in-depth validation during
+// recording. This should be enabled whenever the command buffer is being driven
+// by unsafe code or when early and readable diagnostics are needed.
+ref_ptr<CommandBuffer> WrapCommandBufferWithValidation(
+ ref_ptr<CommandBuffer> impl);
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_COMMAND_BUFFER_VALIDATION_H_
diff --git a/hal/command_queue.h b/hal/command_queue.h
new file mode 100644
index 0000000..e4f5239
--- /dev/null
+++ b/hal/command_queue.h
@@ -0,0 +1,119 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_COMMAND_QUEUE_H_
+#define IREE_HAL_COMMAND_QUEUE_H_
+
+#include <cstdint>
+#include <string>
+
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "base/bitfield.h"
+#include "base/status.h"
+#include "base/time.h"
+#include "hal/command_buffer.h"
+#include "hal/fence.h"
+#include "hal/semaphore.h"
+
+namespace iree {
+namespace hal {
+
+// A batch of command buffers with synchronization information for submission.
+struct SubmissionBatch {
+ // Semaphores that must be signaled prior to the execution of any command
+ // buffer in this submission. For TimelineSemaphores the specified payload
+ // must be reached or exceeded.
+ absl::Span<const SemaphoreValue> wait_semaphores;
+
+ // Command buffers that will execute in this batch.
+ // The command buffers will begin execution in order but may complete out of
+ // order.
+ absl::Span<CommandBuffer* const> command_buffers;
+
+ // Semaphores to signal after execution of all command buffers complete.
+ // TimelineSemaphores will be set to the maximum of the specified payload or
+ // their current payload.
+ absl::Span<const SemaphoreValue> signal_semaphores;
+};
+
+// Asynchronous command execution queue.
+//
+// CommandQueues may capture device status at Fence barriers, including
+// information about device state such as thermal throttling. This information
+// is a snapshot of the state at the time the fence was signaled and not
+// necessarily live at the time of the application query.
+//
+// Command queues are thread-safe and submissions may occur from multiple
+// threads.
+class CommandQueue {
+ public:
+ virtual ~CommandQueue() = default;
+
+ // Name of the queue used for logging purposes.
+ // Try to keep at 4 characters total for prettier logging.
+ const std::string& name() const { return name_; }
+
+ // Capabilities of the command queue.
+ CommandCategoryBitfield supported_categories() const {
+ return supported_categories_;
+ }
+
+ // Whether this queue may be used for transfer commands.
+ bool can_transfer() const {
+ return AllBitsSet(supported_categories_, CommandCategory::kTransfer);
+ }
+
+ // Whether this queue may be used for dispatch commands.
+ bool can_dispatch() const {
+ return AllBitsSet(supported_categories_, CommandCategory::kDispatch);
+ }
+
+ // Submits one or more command batches for execution on the queue.
+ // Dependencies between |batches| on BinarySemaphores must be sorted in order
+ // such that all semaphores are signaled prior to any waits on them.
+ // Dependencies between TimelineSemaphores may occur in any order.
+ //
+ // The provided |fence| will be signaled when all |batches| have retired.
+ virtual Status Submit(absl::Span<const SubmissionBatch> batches,
+ FenceValue fence) = 0;
+ inline Status Submit(const SubmissionBatch& batch, FenceValue fence) {
+ return Submit(absl::MakeConstSpan(&batch, 1), std::move(fence));
+ }
+
+ // Blocks until all outstanding requests have been completed.
+ // This is equivalent to having waited on all outstanding fences.
+ // Implicitly calls Flush to ensure delayed requests are scheduled.
+ //
+ // If the command queue has encountered an error during submission at any
+ // point it will be returned here (repeatedly).
+ virtual Status WaitIdle(absl::Time deadline) = 0;
+ inline Status WaitIdle(absl::Duration timeout) {
+ return WaitIdle(RelativeTimeoutToDeadline(timeout));
+ }
+ inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); }
+
+ protected:
+ CommandQueue(std::string name, CommandCategoryBitfield supported_categories)
+ : name_(std::move(name)), supported_categories_(supported_categories) {}
+
+ const std::string name_;
+ const CommandCategoryBitfield supported_categories_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_COMMAND_QUEUE_H_
diff --git a/hal/dawn/BUILD b/hal/dawn/BUILD
new file mode 100644
index 0000000..0f5c39a
--- /dev/null
+++ b/hal/dawn/BUILD
@@ -0,0 +1,72 @@
+# HAL implementation using Dawn and SPIR-V executables.
+# https://dawn.googlesource.com/dawn
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "dawn_device",
+ srcs = ["dawn_device.cc"],
+ hdrs = ["dawn_device.h"],
+ deps = [
+ "///base:memory",
+ "///base:status",
+ "///base:tracing",
+ "///hal:command_queue",
+ "///hal:device",
+ "///hal:executable_cache",
+ "///hal:fence",
+ "///hal/host:host_local_allocator",
+ "//third_party/dawn:dawn_headers",
+ "//third_party/dawn:dawn_native",
+ "//third_party/dawn:dawn_static_proc",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "dawn_driver",
+ srcs = ["dawn_driver.cc"],
+ hdrs = ["dawn_driver.h"],
+ deps = [
+ ":dawn_device",
+ "///base:status",
+ "///base:tracing",
+ "///hal:device_info",
+ "///hal:driver",
+ "//third_party/dawn:dawn_headers",
+ "//third_party/dawn:dawn_native",
+ "//third_party/dawn:dawn_static_proc",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+# TODO(scotttodd): Use SwiftShader to test Vulkan backend
+cc_test(
+ name = "dawn_driver_test",
+ srcs = ["dawn_driver_test.cc"],
+ deps = [
+ ":dawn_driver",
+ "///base:status_matchers",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "dawn_driver_module",
+ srcs = ["dawn_driver_module.cc"],
+ deps = [
+ ":dawn_driver",
+ "///base:init",
+ "///base:status",
+ "///base:tracing",
+ "///hal:driver_registry",
+ "@com_google_absl//absl/flags:flag",
+ ],
+ alwayslink = 1,
+)
diff --git a/hal/dawn/dawn_device.cc b/hal/dawn/dawn_device.cc
new file mode 100644
index 0000000..c8a3618
--- /dev/null
+++ b/hal/dawn/dawn_device.cc
@@ -0,0 +1,139 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/dawn/dawn_device.h"
+
+#include "absl/memory/memory.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/command_queue.h"
+#include "hal/executable_cache.h"
+#include "hal/fence.h"
+
+namespace iree {
+namespace hal {
+namespace dawn {
+
+namespace {
+
+// ExecutableCache implementation that compiles but does nothing.
+// This will be replaced with something functional soon.
+class NoopExecutableCache final : public ExecutableCache {
+ public:
+ explicit NoopExecutableCache() {}
+ ~NoopExecutableCache() override = default;
+
+ bool CanPrepareFormat(ExecutableFormat format) const override {
+ return false;
+ }
+
+ StatusOr<ref_ptr<Executable>> PrepareExecutable(
+ ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override {
+ return UnimplementedErrorBuilder(IREE_LOC) << "PrepareExecutable NYI";
+ }
+};
+
+} // namespace
+
+DawnDevice::DawnDevice(const DeviceInfo& device_info,
+ ::dawn::Device backend_device)
+ : Device(device_info), backend_device_(backend_device) {
+ IREE_TRACE_SCOPE0("DawnDevice::ctor");
+
+ // TODO(scotttodd): construct command queues, perform other initialization
+
+ // Log some basic device info.
+ std::string backend_type_str;
+ auto* adapter =
+ static_cast<dawn_native::Adapter*>(device_info.driver_handle());
+ switch (adapter->GetBackendType()) {
+ case dawn_native::BackendType::D3D12:
+ backend_type_str = "D3D12";
+ break;
+ case dawn_native::BackendType::Metal:
+ backend_type_str = "Metal";
+ break;
+ case dawn_native::BackendType::Null:
+ backend_type_str = "Null";
+ break;
+ case dawn_native::BackendType::OpenGL:
+ backend_type_str = "OpenGL";
+ break;
+ case dawn_native::BackendType::Vulkan:
+ backend_type_str = "Vulkan";
+ break;
+ }
+ LOG(INFO) << "Created DawnDevice '" << device_info.name() << "' ("
+ << backend_type_str << ")";
+}
+
+DawnDevice::~DawnDevice() = default;
+
+std::shared_ptr<ExecutableCache> DawnDevice::CreateExecutableCache() {
+ return std::make_shared<NoopExecutableCache>();
+}
+
+StatusOr<ref_ptr<CommandBuffer>> DawnDevice::CreateCommandBuffer(
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories) {
+ return UnimplementedErrorBuilder(IREE_LOC) << "CreateCommandBuffer NYI";
+}
+
+StatusOr<ref_ptr<Event>> DawnDevice::CreateEvent() {
+ return UnimplementedErrorBuilder(IREE_LOC) << "CreateEvent NYI";
+}
+
+StatusOr<ref_ptr<BinarySemaphore>> DawnDevice::CreateBinarySemaphore(
+ bool initial_value) {
+ IREE_TRACE_SCOPE0("DawnDevice::CreateBinarySemaphore");
+
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Binary semaphores not yet implemented";
+}
+
+StatusOr<ref_ptr<TimelineSemaphore>> DawnDevice::CreateTimelineSemaphore(
+ uint64_t initial_value) {
+ IREE_TRACE_SCOPE0("DawnDevice::CreateTimelineSemaphore");
+
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Timeline semaphores not yet implemented";
+}
+
+StatusOr<ref_ptr<Fence>> DawnDevice::CreateFence(uint64_t initial_value) {
+ IREE_TRACE_SCOPE0("DawnDevice::CreateFence");
+
+ return UnimplementedErrorBuilder(IREE_LOC) << "CreateFence NYI";
+}
+
+Status DawnDevice::WaitAllFences(absl::Span<const FenceValue> fences,
+ absl::Time deadline) {
+ IREE_TRACE_SCOPE0("DawnDevice::WaitAllFences");
+
+ return UnimplementedErrorBuilder(IREE_LOC) << "WaitAllFences NYI";
+}
+
+StatusOr<int> DawnDevice::WaitAnyFence(absl::Span<const FenceValue> fences,
+ absl::Time deadline) {
+ IREE_TRACE_SCOPE0("DawnDevice::WaitAnyFence");
+
+ return UnimplementedErrorBuilder(IREE_LOC) << "WaitAnyFence NYI";
+}
+
+Status DawnDevice::WaitIdle(absl::Time deadline) {
+ return UnimplementedErrorBuilder(IREE_LOC) << "WaitIdle";
+}
+
+} // namespace dawn
+} // namespace hal
+} // namespace iree
diff --git a/hal/dawn/dawn_device.h b/hal/dawn/dawn_device.h
new file mode 100644
index 0000000..2c6b4d6
--- /dev/null
+++ b/hal/dawn/dawn_device.h
@@ -0,0 +1,78 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DAWN_DAWN_DEVICE_H_
+#define IREE_HAL_DAWN_DAWN_DEVICE_H_
+
+#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
+#include "base/memory.h"
+#include "hal/device.h"
+#include "hal/host/host_local_allocator.h"
+#include "third_party/dawn/src/include/dawn/dawncpp.h"
+#include "third_party/dawn/src/include/dawn_native/DawnNative.h"
+
+namespace iree {
+namespace hal {
+namespace dawn {
+
+class DawnDevice final : public Device {
+ public:
+ explicit DawnDevice(const DeviceInfo& device_info,
+ ::dawn::Device backend_device);
+ ~DawnDevice() override;
+
+ Allocator* allocator() const override { return &allocator_; }
+
+ absl::Span<CommandQueue*> dispatch_queues() const override {
+ return RawPtrSpan(absl::MakeSpan(command_queues_));
+ }
+
+ absl::Span<CommandQueue*> transfer_queues() const override {
+ return RawPtrSpan(absl::MakeSpan(command_queues_));
+ }
+
+ std::shared_ptr<ExecutableCache> CreateExecutableCache() override;
+
+ StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer(
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories) override;
+
+ StatusOr<ref_ptr<Event>> CreateEvent() override;
+
+ StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore(
+ bool initial_value) override;
+ StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore(
+ uint64_t initial_value) override;
+
+ StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override;
+ Status WaitAllFences(absl::Span<const FenceValue> fences,
+ absl::Time deadline) override;
+ StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences,
+ absl::Time deadline) override;
+
+ Status WaitIdle(absl::Time deadline) override;
+
+ private:
+ mutable HostLocalAllocator allocator_;
+ mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 1> command_queues_;
+
+ ::dawn::Device backend_device_;
+};
+
+} // namespace dawn
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_DAWN_DAWN_DEVICE_H_
diff --git a/hal/dawn/dawn_driver.cc b/hal/dawn/dawn_driver.cc
new file mode 100644
index 0000000..565e98a
--- /dev/null
+++ b/hal/dawn/dawn_driver.cc
@@ -0,0 +1,120 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/dawn/dawn_driver.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/dawn/dawn_device.h"
+#include "hal/device_info.h"
+
+namespace iree {
+namespace hal {
+namespace dawn {
+
+namespace {
+
+// Populates device information from the given dawn_native::Adapter.
+StatusOr<DeviceInfo> PopulateDeviceInfo(dawn_native::Adapter* adapter) {
+ // TODO(scotttodd): Query these for each backend or implement?
+ DeviceFeatureBitfield supported_features = DeviceFeature::kNone;
+ // supported_features |= DeviceFeature::kDebugging;
+ // supported_features |= DeviceFeature::kCoverage;
+ // supported_features |= DeviceFeature::kProfiling;
+
+ // TODO(scotttodd): more clever/sanitized device naming.
+ std::string device_name = absl::StrCat("dawn-", adapter->GetPCIInfo().name);
+
+ return DeviceInfo(device_name, supported_features,
+ reinterpret_cast<void*>(adapter));
+}
+
+} // namespace
+
+DawnDriver::DawnDriver() : Driver("dawn") {
+ dawn_instance_ = absl::make_unique<dawn_native::Instance>();
+}
+
+DawnDriver::~DawnDriver() = default;
+
+StatusOr<std::vector<DeviceInfo>> DawnDriver::EnumerateAvailableDevices() {
+ IREE_TRACE_SCOPE0("DawnDriver::EnumerateAvailableDevices");
+
+ if (dawn_backend_adapters_.empty()) {
+ // Discover adapters (i.e. devices and their associated backend APIs).
+ // Retain the list of adapters so pointers are valid for the lifetime of
+ // this object.
+ dawn_instance_->DiscoverDefaultAdapters();
+ dawn_backend_adapters_ = dawn_instance_->GetAdapters();
+ } else {
+ // Assume that the list of adapters does not change. This is not guaranteed
+ // to be true, but we also don't want to invalidate pointers by requesting
+ // a new list each time. If the list of available devices would change,
+ // tearing down and creating a new DawnDriver may be your best option.
+ }
+
+ // Convert to our HAL structure.
+ std::vector<DeviceInfo> device_infos;
+ device_infos.reserve(dawn_backend_adapters_.size());
+ for (auto& adapter : dawn_backend_adapters_) {
+ // TODO(scotttodd): if we fail should we just ignore the device in the list?
+ ASSIGN_OR_RETURN(auto device_info, PopulateDeviceInfo(&adapter));
+ device_infos.push_back(std::move(device_info));
+ }
+ return device_infos;
+}
+
+StatusOr<std::shared_ptr<Device>> DawnDriver::CreateDefaultDevice() {
+ IREE_TRACE_SCOPE0("DawnDriver::CreateDefaultDevice");
+
+ // Query available devices.
+ ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices());
+ if (available_devices.empty()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "No devices are available";
+ }
+
+ // Create the first non-null device, if any.
+ for (const auto& device : available_devices) {
+ auto* adapter = static_cast<dawn_native::Adapter*>(device.driver_handle());
+ if (adapter->GetBackendType() != dawn_native::BackendType::Null) {
+ return CreateDevice(device);
+ }
+ }
+
+ // Otherwise create the first null device.
+ return CreateDevice(available_devices.front());
+}
+
+StatusOr<std::shared_ptr<Device>> DawnDriver::CreateDevice(
+ const DeviceInfo& device_info) {
+ IREE_TRACE_SCOPE0("DawnDriver::CreateDevice");
+
+ auto* adapter =
+ static_cast<dawn_native::Adapter*>(device_info.driver_handle());
+ ::DawnDevice c_backend_device = adapter->CreateDevice();
+ if (!c_backend_device) {
+ return InternalErrorBuilder(IREE_LOC) << "Failed to create a Dawn device";
+ }
+ DawnProcTable backend_procs = dawn_native::GetProcs();
+ dawnSetProcs(&backend_procs);
+ ::dawn::Device backend_device = ::dawn::Device::Acquire(c_backend_device);
+
+ return std::make_shared<DawnDevice>(device_info, backend_device);
+}
+
+} // namespace dawn
+} // namespace hal
+} // namespace iree
diff --git a/hal/dawn/dawn_driver.h b/hal/dawn/dawn_driver.h
new file mode 100644
index 0000000..03a2cdb
--- /dev/null
+++ b/hal/dawn/dawn_driver.h
@@ -0,0 +1,50 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DAWN_DAWN_DRIVER_H_
+#define IREE_HAL_DAWN_DAWN_DRIVER_H_
+
+#include <memory>
+#include <vector>
+
+#include "hal/driver.h"
+#include "third_party/dawn/src/include/dawn/dawncpp.h"
+#include "third_party/dawn/src/include/dawn_native/DawnNative.h"
+
+namespace iree {
+namespace hal {
+namespace dawn {
+
+class DawnDriver final : public Driver {
+ public:
+ DawnDriver();
+ ~DawnDriver() override;
+
+ StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override;
+
+ StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override;
+
+ StatusOr<std::shared_ptr<Device>> CreateDevice(
+ const DeviceInfo& device_info) override;
+
+ private:
+ std::unique_ptr<dawn_native::Instance> dawn_instance_;
+ std::vector<dawn_native::Adapter> dawn_backend_adapters_;
+};
+
+} // namespace dawn
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_DAWN_DAWN_DRIVER_H_
diff --git a/hal/dawn/dawn_driver_module.cc b/hal/dawn/dawn_driver_module.cc
new file mode 100644
index 0000000..317d9fe
--- /dev/null
+++ b/hal/dawn/dawn_driver_module.cc
@@ -0,0 +1,41 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+
+#include "base/init.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/dawn/dawn_driver.h"
+#include "hal/driver_registry.h"
+
+namespace iree {
+namespace hal {
+namespace dawn {
+namespace {
+
+StatusOr<std::shared_ptr<Driver>> CreateDawnDriver() {
+ return std::make_shared<DawnDriver>();
+}
+
+} // namespace
+} // namespace dawn
+} // namespace hal
+} // namespace iree
+
+IREE_REGISTER_MODULE_INITIALIZER(iree_hal_dawn_driver, {
+ QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
+ "dawn", ::iree::hal::dawn::CreateDawnDriver));
+});
+IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_dawn_driver);
diff --git a/hal/dawn/dawn_driver_test.cc b/hal/dawn/dawn_driver_test.cc
new file mode 100644
index 0000000..7e196f1
--- /dev/null
+++ b/hal/dawn/dawn_driver_test.cc
@@ -0,0 +1,45 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/dawn/dawn_driver.h"
+
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace hal {
+namespace dawn {
+namespace {
+
+TEST(DawnDriverTest, CreateDefaultDevice) {
+ DawnDriver dawn_driver;
+ ASSERT_OK_AND_ASSIGN(auto default_device, dawn_driver.CreateDefaultDevice());
+}
+
+TEST(DawnDriverTest, EnumerateDevicesAndCreate) {
+ DawnDriver dawn_driver;
+
+ ASSERT_OK_AND_ASSIGN(auto available_devices,
+ dawn_driver.EnumerateAvailableDevices());
+ ASSERT_GT(available_devices.size(), 0);
+
+ ASSERT_OK_AND_ASSIGN(auto first_device,
+ dawn_driver.CreateDevice(available_devices[0]));
+}
+
+} // namespace
+} // namespace dawn
+} // namespace hal
+} // namespace iree
diff --git a/hal/deferred_buffer.cc b/hal/deferred_buffer.cc
new file mode 100644
index 0000000..5b3e82d
--- /dev/null
+++ b/hal/deferred_buffer.cc
@@ -0,0 +1,162 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/deferred_buffer.h"
+
+#include "base/status.h"
+
+namespace iree {
+namespace hal {
+
+DeferredBuffer::DeferredBuffer(Allocator* allocator,
+ MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield usage,
+ device_size_t byte_length)
+ : Buffer(allocator, memory_type, allowed_access, usage, 0, 0, byte_length) {
+}
+
+DeferredBuffer::~DeferredBuffer() = default;
+
+Status DeferredBuffer::GrowByteLength(device_size_t new_byte_length) {
+ if (parent_buffer_) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Attempting to set min allocation size while bound to an "
+ "allocation";
+ }
+ if (byte_length_ != kWholeBuffer && new_byte_length < byte_length_) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Attempting to shrink a buffer to " << new_byte_length
+ << " when it has a minimum size of " << byte_length_;
+ }
+ byte_length_ = new_byte_length;
+ return OkStatus();
+}
+
+Status DeferredBuffer::BindAllocation(ref_ptr<Buffer> allocated_buffer,
+ device_size_t byte_offset,
+ device_size_t byte_length) {
+ // We can only be bound to allocations that are compatible with our specified
+ // allocator and usage.
+ if (!allocator_->CanUseBuffer(allocated_buffer.get(), usage())) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Allocation is not compatible with the allocator specified for "
+ "the deferred buffer";
+ }
+
+ // Calculate the range in the allocated_buffer that we are interested in.
+ RETURN_IF_ERROR(Buffer::CalculateRange(0, allocated_buffer->byte_length(),
+ byte_offset, byte_length, &byte_offset,
+ &byte_length));
+
+ // Verify that we have enough bytes for what we've promised.
+ if (byte_length < byte_length_) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Allocation range is too small; min_allocation_size="
+ << byte_length_ << " but the range of " << byte_offset << "-"
+ << (byte_offset + byte_length - 1) << " (" << byte_length
+ << "b) is too small";
+ }
+
+ allocated_buffer_ = allocated_buffer.get();
+ parent_buffer_ = std::move(allocated_buffer);
+ byte_offset_ = byte_offset;
+ return OkStatus();
+}
+
+void DeferredBuffer::ResetAllocation() {
+ allocated_buffer_ = this;
+ parent_buffer_.reset();
+ byte_offset_ = 0;
+}
+
+StatusOr<Buffer*> DeferredBuffer::ResolveAllocation() const {
+ // If you get errors here then someone allocated the buffer with
+ // MemoryType::kTransient and you are trying to use it outside of the time
+ // it is actually allocated (such as during CommandBuffer evaluation). If
+ // you need to use the buffer in non-transient ways then allocate the buffer
+ // without the MemoryType::kTransient flag.
+ if (!parent_buffer_) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Attempting to use a transient buffer prior to allocation: "
+ << DebugString();
+ }
+ return parent_buffer_.get();
+}
+
+Status DeferredBuffer::FillImpl(device_size_t byte_offset,
+ device_size_t byte_length, const void* pattern,
+ device_size_t pattern_length) {
+ ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
+ return allocated_buffer->FillImpl(byte_offset, byte_length, pattern,
+ pattern_length);
+}
+
+Status DeferredBuffer::ReadDataImpl(device_size_t source_offset, void* data,
+ device_size_t data_length) {
+ ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
+ return allocated_buffer->ReadDataImpl(source_offset, data, data_length);
+}
+
+Status DeferredBuffer::WriteDataImpl(device_size_t target_offset,
+ const void* data,
+ device_size_t data_length) {
+ ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
+ return allocated_buffer->WriteDataImpl(target_offset, data, data_length);
+}
+
+Status DeferredBuffer::CopyDataImpl(device_size_t target_offset,
+ Buffer* source_buffer,
+ device_size_t source_offset,
+ device_size_t data_length) {
+ ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
+ return allocated_buffer->CopyDataImpl(target_offset, source_buffer,
+ source_offset, data_length);
+}
+
+Status DeferredBuffer::MapMemoryImpl(MappingMode mapping_mode,
+ MemoryAccessBitfield memory_access,
+ device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void** out_data) {
+ ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
+ return allocated_buffer->MapMemoryImpl(mapping_mode, memory_access,
+ local_byte_offset, local_byte_length,
+ out_data);
+}
+
+Status DeferredBuffer::UnmapMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void* data) {
+ ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
+ return allocated_buffer->UnmapMemoryImpl(local_byte_offset, local_byte_length,
+ data);
+}
+
+Status DeferredBuffer::InvalidateMappedMemoryImpl(
+ device_size_t local_byte_offset, device_size_t local_byte_length) {
+ ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
+ return allocated_buffer->InvalidateMappedMemoryImpl(local_byte_offset,
+ local_byte_length);
+}
+
+Status DeferredBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) {
+ ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
+ return allocated_buffer->FlushMappedMemoryImpl(local_byte_offset,
+ local_byte_length);
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/deferred_buffer.h b/hal/deferred_buffer.h
new file mode 100644
index 0000000..ab03c0d
--- /dev/null
+++ b/hal/deferred_buffer.h
@@ -0,0 +1,106 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DEFERRED_BUFFER_H_
+#define IREE_HAL_DEFERRED_BUFFER_H_
+
+#include <cstddef>
+#include <memory>
+#include <utility>
+
+#include "base/status.h"
+#include "hal/allocator.h"
+#include "hal/buffer.h"
+
+namespace iree {
+namespace hal {
+
+// A Buffer that can have its underlying allocation changed at runtime.
+// Unbound buffers act as a way to logically group dependent ranges of memory
+// without needing to have allocated that memory yet.
+//
+// Usage:
+// // Setup two spans referencing ranges of a deferred buffer.
+// auto deferred_buffer = std::make_shared<DeferredBuffer>(..., 200);
+// ASSIGN_OR_RETURN(auto span0, Buffer::Subspan(deferred_buffer, 0, 100));
+// ASSIGN_OR_RETURN(auto span1, Buffer::Subspan(deferred_buffer, 100, 100));
+//
+// // Attempting to access |deferred_buffer| or |span0| or |span1| will fail.
+// // ERROR: span0->Fill(false);
+//
+// // Now allocate a real buffer to serve as storage for the data.
+// ASSIGN_OR_RETURN(auto allocated_buffer, Buffer::Allocate(..., 200));
+// RETURN_IF_ERROR(deferred_buffer->BindAllocation(
+// allocated_buffer, 0, kWholeBuffer));
+//
+// // And now we can use the spans.
+// RETURN_IF_ERROR(span0->Fill(false));
+//
+// // If at some point we want to detach the buffer from the allocation (so we
+// // can use a different allocation, reuse the memory, etc).
+// deferred_buffer->ResetAllocation();
+//
+// Thread-compatible. Attempting to rebind the allocation while other threads
+// are using the buffer will lead to undefined behavior.
+class DeferredBuffer : public Buffer {
+ public:
+ DeferredBuffer(Allocator* allocator, MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
+ device_size_t byte_length);
+ ~DeferredBuffer() override;
+
+ // Grows the minimum allocation size of the buffer to |new_byte_length|.
+ // Attempting to bind an allocation less than this size will fail. This must
+ // only be called when the buffer is not bound to an allocation.
+ Status GrowByteLength(device_size_t new_byte_length);
+
+ // Binds or rebinds the deferred buffer to an allocated buffer.
+ Status BindAllocation(ref_ptr<Buffer> allocated_buffer,
+ device_size_t byte_offset, device_size_t byte_length);
+
+ // Resets the deferred buffer to have no binding.
+ void ResetAllocation();
+
+ private:
+ // Resolves the allocated buffer that this subspan references into.
+ // This will fail if the buffer has not yet been bound to an allocation or
+ // the allocated buffer has not been committed.
+ StatusOr<Buffer*> ResolveAllocation() const;
+
+ Status FillImpl(device_size_t byte_offset, device_size_t byte_length,
+ const void* pattern, device_size_t pattern_length) override;
+ Status ReadDataImpl(device_size_t source_offset, void* data,
+ device_size_t data_length) override;
+ Status WriteDataImpl(device_size_t target_offset, const void* data,
+ device_size_t data_length) override;
+ Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer,
+ device_size_t source_offset,
+ device_size_t data_length) override;
+ Status MapMemoryImpl(MappingMode mapping_mode,
+ MemoryAccessBitfield memory_access,
+ device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void** out_data) override;
+ Status UnmapMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length, void* data) override;
+ Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) override;
+ Status FlushMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) override;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_DEFERRED_BUFFER_H_
diff --git a/hal/deferred_buffer_test.cc b/hal/deferred_buffer_test.cc
new file mode 100644
index 0000000..6a8f051
--- /dev/null
+++ b/hal/deferred_buffer_test.cc
@@ -0,0 +1,174 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/deferred_buffer.h"
+
+#include "absl/memory/memory.h"
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "hal/heap_buffer.h"
+#include "hal/testing/mock_allocator.h"
+
+namespace iree {
+namespace hal {
+namespace {
+
+using ::iree::hal::testing::MockAllocator;
+using ::testing::_;
+using ::testing::Return;
+
+// Tests properties of unbound buffers.
+TEST(DeferredBufferTest, Unbound) {
+ MockAllocator allocator;
+ auto deferred_buffer = absl::make_unique<DeferredBuffer>(
+ &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
+ 100);
+ EXPECT_EQ(&allocator, deferred_buffer->allocator());
+ EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer());
+ EXPECT_EQ(0, deferred_buffer->allocation_size());
+ EXPECT_EQ(0, deferred_buffer->byte_offset());
+ EXPECT_EQ(100, deferred_buffer->byte_length());
+}
+
+// Tests that binding verifies allocators are compatible.
+TEST(DeferredBufferTest, AllocatorCheck) {
+ MockAllocator allocator;
+ auto deferred_buffer = absl::make_unique<DeferredBuffer>(
+ &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
+ 100);
+ auto real_buffer =
+ HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256);
+ EXPECT_CALL(
+ allocator,
+ CanUseBufferLike(real_buffer->allocator(), real_buffer->memory_type(),
+ real_buffer->usage(), BufferUsage::kAll))
+ .WillOnce(Return(false));
+ EXPECT_TRUE(IsInvalidArgument(
+ deferred_buffer->BindAllocation(std::move(real_buffer), 0, 100)));
+}
+
+// Tests that binding verifies allocation sizes.
+TEST(DeferredBufferTest, SizeCheck) {
+ MockAllocator allocator;
+ auto deferred_buffer = absl::make_unique<DeferredBuffer>(
+ &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
+ 100);
+ auto real_buffer =
+ HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256);
+ EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _))
+ .WillRepeatedly(Return(true));
+
+ EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 10, 100));
+ EXPECT_EQ(256, deferred_buffer->allocation_size());
+ EXPECT_EQ(10, deferred_buffer->byte_offset());
+ EXPECT_EQ(100, deferred_buffer->byte_length());
+ EXPECT_OK(
+ deferred_buffer->BindAllocation(add_ref(real_buffer), 10, kWholeBuffer));
+ EXPECT_EQ(256, deferred_buffer->allocation_size());
+ EXPECT_EQ(10, deferred_buffer->byte_offset());
+ EXPECT_EQ(100, deferred_buffer->byte_length());
+
+ EXPECT_TRUE(IsOutOfRange(
+ deferred_buffer->BindAllocation(add_ref(real_buffer), 200, 100)));
+ EXPECT_TRUE(IsOutOfRange(deferred_buffer->BindAllocation(add_ref(real_buffer),
+ 200, kWholeBuffer)));
+ EXPECT_TRUE(IsOutOfRange(
+ deferred_buffer->BindAllocation(add_ref(real_buffer), 10, 10)));
+}
+
+// Tests resizing buffers after they have been allocated.
+TEST(DeferredBufferTest, Resizing) {
+ MockAllocator allocator;
+ auto deferred_buffer = absl::make_unique<DeferredBuffer>(
+ &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
+ 100);
+ auto real_buffer =
+ HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256);
+ EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _))
+ .WillRepeatedly(Return(true));
+
+ // Grow.
+ EXPECT_EQ(100, deferred_buffer->byte_length());
+ EXPECT_OK(deferred_buffer->GrowByteLength(150));
+ EXPECT_EQ(150, deferred_buffer->byte_length());
+
+ // Shrinking should fail.
+ EXPECT_TRUE(IsInvalidArgument(deferred_buffer->GrowByteLength(5)));
+
+ // Growing should fail if bound.
+ EXPECT_OK(deferred_buffer->BindAllocation(std::move(real_buffer), 0, 150));
+ EXPECT_TRUE(IsFailedPrecondition(deferred_buffer->GrowByteLength(100)));
+}
+
+// Tests binding and rebinding behavior.
+TEST(DeferredBufferTest, Rebinding) {
+ MockAllocator allocator;
+ auto deferred_buffer = absl::make_unique<DeferredBuffer>(
+ &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
+ 100);
+ auto real_buffer =
+ HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256);
+ EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _))
+ .WillRepeatedly(Return(true));
+
+ // Safe to reset when not bound.
+ deferred_buffer->ResetAllocation();
+ EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer());
+ EXPECT_EQ(0, deferred_buffer->allocation_size());
+
+ EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 0, 100));
+ EXPECT_EQ(real_buffer.get(), deferred_buffer->allocated_buffer());
+ EXPECT_EQ(256, deferred_buffer->allocation_size());
+ deferred_buffer->ResetAllocation();
+ EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer());
+ EXPECT_EQ(0, deferred_buffer->allocation_size());
+ EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 0, 100));
+ EXPECT_EQ(real_buffer.get(), deferred_buffer->allocated_buffer());
+ EXPECT_EQ(256, deferred_buffer->allocation_size());
+}
+
+// Tests normal usage of bound buffers.
+TEST(DeferredBufferTest, BoundUsage) {
+ MockAllocator allocator;
+ auto deferred_buffer = absl::make_unique<DeferredBuffer>(
+ &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
+ 100);
+ auto real_buffer =
+ HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256);
+ EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _))
+ .WillRepeatedly(Return(true));
+ EXPECT_OK(deferred_buffer->BindAllocation(std::move(real_buffer), 0, 100));
+
+ EXPECT_FALSE(deferred_buffer->DebugString().empty());
+ EXPECT_FALSE(deferred_buffer->DebugStringShort().empty());
+
+ EXPECT_OK(deferred_buffer->Fill8(0, 10, 0xFF));
+}
+
+// Tests that unbound buffers fail to perform any buffer actions.
+TEST(DeferredBufferTest, UnboundUsage) {
+ MockAllocator allocator;
+ auto deferred_buffer = absl::make_unique<DeferredBuffer>(
+ &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
+ 100);
+ EXPECT_FALSE(deferred_buffer->DebugString().empty());
+ EXPECT_FALSE(deferred_buffer->DebugStringShort().empty());
+
+ EXPECT_TRUE(IsFailedPrecondition(deferred_buffer->Fill8(0, 10, 0xFF)));
+}
+
+} // namespace
+} // namespace hal
+} // namespace iree
diff --git a/hal/device.h b/hal/device.h
new file mode 100644
index 0000000..69586ca
--- /dev/null
+++ b/hal/device.h
@@ -0,0 +1,165 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DEVICE_H_
+#define IREE_HAL_DEVICE_H_
+
+#include <memory>
+
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "base/status.h"
+#include "base/time.h"
+#include "hal/allocator.h"
+#include "hal/buffer.h"
+#include "hal/command_queue.h"
+#include "hal/device_info.h"
+#include "hal/event.h"
+#include "hal/executable_cache.h"
+#include "hal/semaphore.h"
+
+namespace iree {
+namespace hal {
+
+class Device {
+ public:
+ virtual ~Device() = default;
+
+ // Information about device capabilities.
+ const DeviceInfo& info() const { return device_info_; }
+
+ // TODO(benvanik): status (thermal, power mode, etc).
+
+ // TODO(benvanik): throttling adjustment/power profile.
+
+ // TODO(benvanik): control (suspend/resume, delay, etc).
+
+ // An allocator providing buffers usable by the device.
+ // This allocator may be shared with other devices in the same family.
+ virtual Allocator* allocator() const = 0;
+
+ // Returns a list of all general-purpose dispatch queues provided by the
+ // device. In general these map 1:1 with independent execution contexts,
+ // though some devices may hide that and expose only a single queue that is
+ // scheduled internally.
+ virtual absl::Span<CommandQueue*> dispatch_queues() const = 0;
+
+ // Returns a list of transfer queues provided by the device. These queues may
+ // perform transfer operations asynchronously with respect to execution on the
+ // dispatch queues. For large sequences of transfer operations always prefer
+ // using one of these queues.
+ // Note that if the device does not support a dedicated transfer queue this
+ // list may be the same as (or a subset of) dispatch_queues.
+ virtual absl::Span<CommandQueue*> transfer_queues() const = 0;
+
+ // TODO(b/137153339): accept initial cache data.
+ // Creates a device-specific cache for executables prepared for dispatch.
+ // The cache manages executable compilation, caching (on disk or in memory),
+ // and lifetime. Users can decide to use one or more caches to allow differing
+ // lifetimes (such as unloading modules), persistent on disk caching of only
+ // specific hot executables, etc.
+ //
+ // Returns a thread-safe cache that must remain alive until all executables
+ // using the cache are no longer in-flight.
+ virtual std::shared_ptr<ExecutableCache> CreateExecutableCache() = 0;
+
+ // Creates a command buffer for recording commands to submit to queues owned
+ // by this device. The command buffer may come from a pool but will be reset
+ // prior to being returned to the caller.
+ virtual StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer(
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories) = 0;
+
+ // Creates an event for recording into command buffers.
+ // The returned event object is only usable with this device and events must
+ // only be used to synchronize within the same queue.
+ virtual StatusOr<ref_ptr<Event>> CreateEvent() = 0;
+
+ // Creates a binary semaphore that can be used with command queues owned by
+ // this device. To use the semaphores with other devices or instances they
+ // must first be exported.
+ virtual StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore(
+ bool initial_value) = 0;
+
+ // Creates a timeline semaphore that can be used with command queues owned by
+ // this device. To use the semaphores with other devices or instances they
+ // must first be exported.
+ virtual StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore(
+ uint64_t initial_value) = 0;
+
+ // Creates a fence that can be used with command queues owned by this device.
+ // To use the fences with other devices or instances they must first be
+ // exported.
+ virtual StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) = 0;
+
+ // TODO(benvanik): import/export semaphore utilities.
+ // TODO(benvanik): import/export fence utilities.
+ // TODO(benvanik): fences to wait handles.
+
+ // Blocks the caller until all passed |fences| reach or exceed the specified
+ // payload values or the |deadline| elapses. All |fences| must be created from
+ // this device (or be imported into it).
+ //
+ // Returns success if the wait is successful and all fences have been
+ // signaled.
+ //
+ // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all fences
+ // having been signaled. Note that a subset of the |fences| may have been
+ // signaled and each can be queried to see which ones.
+ virtual Status WaitAllFences(absl::Span<const FenceValue> fences,
+ absl::Time deadline) = 0;
+ inline Status WaitAllFences(absl::Span<const FenceValue> fences,
+ absl::Duration timeout) {
+ return WaitAllFences(fences, RelativeTimeoutToDeadline(timeout));
+ }
+
+ // Blocks the caller until at least one of the |fences| reaches or exceeds the
+ // specified payload value or the |deadline| elapses. All |fences| must be
+ // created from this device (or be imported into it).
+ //
+ // Returns an arbitrary index into |fences| of a fence that was signaled. Note
+ // that more than one fence may have been signaled and all of the other
+ // |fences| should be queried or waited on again until waits for them
+ // succeed.
+ //
+ // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any fences
+ // having been signaled.
+ virtual StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences,
+ absl::Time deadline) = 0;
+ inline StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences,
+ absl::Duration timeout) {
+ return WaitAnyFence(fences, RelativeTimeoutToDeadline(timeout));
+ }
+
+ // Blocks until all outstanding requests on all queues have been
+ // completed. This is equivalent to having waited on all outstanding
+ // fences.
+ virtual Status WaitIdle(absl::Time deadline) = 0;
+ inline Status WaitIdle(absl::Duration timeout) {
+ return WaitIdle(RelativeTimeoutToDeadline(timeout));
+ }
+ inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); }
+
+ protected:
+ explicit Device(DeviceInfo device_info)
+ : device_info_(std::move(device_info)) {}
+
+ private:
+ const DeviceInfo device_info_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_DEVICE_H_
diff --git a/hal/device_info.h b/hal/device_info.h
new file mode 100644
index 0000000..c7bbe67
--- /dev/null
+++ b/hal/device_info.h
@@ -0,0 +1,90 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DEVICE_INFO_H_
+#define IREE_HAL_DEVICE_INFO_H_
+
+#include <cstdint>
+#include <string>
+#include <utility>
+
+#include "base/bitfield.h"
+
+namespace iree {
+namespace hal {
+
+// Describes features supported by the device.
+// These flags indicate the availability of features that may be enabled at the
+// request of the calling application. Note that certain features may disable
+// runtime optimizations or require compilation flags to ensure the required
+// metadata is present in executables.
+enum class DeviceFeature : uint32_t {
+ kNone = 0,
+
+ // Device supports executable debugging.
+ // When present executables *may* be compiled with
+ // ExecutableCachingMode::kEnableDebugging and will have usable debugging
+ // related methods. Note that if the input executables do not have embedded
+ // debugging information they still may not be able to perform disassembly or
+ // fine-grained breakpoint insertion.
+ kDebugging = 1 << 0,
+
+ // Device supports executable coverage information.
+ // When present executables *may* be compiled with
+ // ExecutableCachingMode::kEnableCoverage and will produce coverage buffers
+ // during dispatch. Note that input executables must have partial embedded
+ // debug information to allow mapping back to source offsets.
+ kCoverage = 1 << 1,
+
+ // Device supports executable and command queue profiling.
+ // When present executables *may* be compiled with
+ // ExecutableCachingMode::kEnableProfiling and will produce profiling buffers
+ // during dispatch. Note that input executables must have partial embedded
+ // debug information to allow mapping back to source offsets.
+ kProfiling = 1 << 2,
+};
+IREE_BITFIELD(DeviceFeature);
+using DeviceFeatureBitfield = DeviceFeature;
+
+// TODO(benvanik): device info (caps, physical mappings, etc).
+class DeviceInfo {
+ public:
+ DeviceInfo(std::string name, DeviceFeatureBitfield supported_features,
+ void* driver_handle = nullptr)
+ : name_(std::move(name)),
+ supported_features_(supported_features),
+ driver_handle_(driver_handle) {}
+
+ const std::string& name() const { return name_; }
+
+ // Features supported by the device.
+ DeviceFeatureBitfield supported_features() const {
+ return supported_features_;
+ }
+
+ // Opaque handle used by drivers to correlate this device with their internal
+ // listing. This handle will not be valid across driver instances or outside
+ // of the current process.
+ void* driver_handle() const { return driver_handle_; }
+
+ private:
+ const std::string name_;
+ const DeviceFeatureBitfield supported_features_;
+ void* driver_handle_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_DEVICE_INFO_H_
diff --git a/hal/device_manager.cc b/hal/device_manager.cc
new file mode 100644
index 0000000..4d1feeb
--- /dev/null
+++ b/hal/device_manager.cc
@@ -0,0 +1,201 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/device_manager.h"
+
+#include <algorithm>
+
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/heap_buffer.h"
+
+namespace iree {
+namespace hal {
+
+DeviceManager::DeviceManager() = default;
+
+DeviceManager::~DeviceManager() {
+ IREE_TRACE_SCOPE0("DeviceManager::dtor");
+ WaitIdle().IgnoreError();
+}
+
+Status DeviceManager::RegisterDevice(std::shared_ptr<Device> device) {
+ IREE_TRACE_SCOPE0("DeviceManager::RegisterDevice");
+ absl::MutexLock lock(&device_mutex_);
+ if (std::find(devices_.begin(), devices_.end(), device) != devices_.end()) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Device already registered";
+ }
+ devices_.push_back(std::move(device));
+ return OkStatus();
+}
+
+Status DeviceManager::UnregisterDevice(Device* device) {
+ IREE_TRACE_SCOPE0("DeviceManager::UnregisterDevice");
+ absl::MutexLock lock(&device_mutex_);
+ auto it = std::find_if(devices_.begin(), devices_.end(),
+ [device](const std::shared_ptr<Device>& other_device) {
+ return device == other_device.get();
+ });
+ if (it == devices_.end()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "Device not registered";
+ }
+ devices_.erase(it);
+ return OkStatus();
+}
+
+StatusOr<DevicePlacement> DeviceManager::ResolvePlacement(
+ const PlacementSpec& placement_spec) const {
+ IREE_TRACE_SCOPE0("DeviceManager::ResolvePlacement");
+ absl::MutexLock lock(&device_mutex_);
+ if (devices_.empty()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "No devices registered";
+ }
+
+ // TODO(benvanik): multiple devices and placement.
+ QCHECK_EQ(devices_.size(), 1)
+ << "Multiple devices not yet supported (need placement)";
+ DevicePlacement device_placement;
+ device_placement.device = devices_.front();
+
+ return device_placement;
+}
+
+StatusOr<Allocator*> DeviceManager::FindCompatibleAllocator(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ absl::Span<const DevicePlacement> device_placements) const {
+ IREE_TRACE_SCOPE0("DeviceManager::FindCompatibleAllocator");
+ if (device_placements.empty()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "No placements provided";
+ }
+
+ // Find the first allocator. As we only return an allocator if all placements
+ // are compatible we'll compare allocator[0] against allocator[1,N].
+ Allocator* some_allocator = nullptr;
+ for (const auto& device_placement : device_placements) {
+ auto* allocator = device_placement.device->allocator();
+ if (!some_allocator) {
+ some_allocator = allocator;
+ continue;
+ }
+ // NOTE: as there can be asymmetry between usage restrictions (A can use B
+ // but B cannot use A) we have to compare both directions.
+ if (!some_allocator->CanUseBufferLike(allocator, memory_type, buffer_usage,
+ buffer_usage) ||
+ !allocator->CanUseBufferLike(some_allocator, memory_type, buffer_usage,
+ buffer_usage)) {
+ // Allocators are not compatible.
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "No single allocator found that is compatible with all "
+ "placements";
+ }
+ }
+ return some_allocator;
+}
+
+StatusOr<ref_ptr<Buffer>> DeviceManager::TryAllocateDeviceVisibleBuffer(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ device_size_t allocation_size,
+ absl::Span<const DevicePlacement> device_placements) {
+ IREE_TRACE_SCOPE("DeviceManager::TryAllocateDeviceVisibleBuffer:size", int)
+ (static_cast<int>(allocation_size));
+ if (!AnyBitSet(memory_type & MemoryType::kHostLocal)) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Host-local buffers require the kHostLocal bit: "
+ << MemoryTypeString(memory_type);
+ }
+
+ // Strip kDeviceVisible as we conditionally add it based on support.
+ memory_type &= ~MemoryType::kDeviceVisible;
+
+ // Find an allocator that works for device-visible buffers.
+ // If this fails we'll fall back to allocation a non-device-visible buffer.
+ auto allocator_or =
+ FindCompatibleAllocator(memory_type | MemoryType::kDeviceVisible,
+ buffer_usage, device_placements);
+ if (allocator_or.ok()) {
+ return allocator_or.ValueOrDie()->Allocate(
+ memory_type | MemoryType::kDeviceVisible, buffer_usage,
+ allocation_size);
+ }
+
+ // Fallback to allocating a host-local buffer.
+ return HeapBuffer::Allocate(memory_type, buffer_usage, allocation_size);
+}
+
+StatusOr<ref_ptr<Buffer>> DeviceManager::AllocateDeviceVisibleBuffer(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ device_size_t allocation_size,
+ absl::Span<const DevicePlacement> device_placements) {
+ IREE_TRACE_SCOPE("DeviceManager::AllocateDeviceVisibleBuffer:size", int)
+ (static_cast<int>(allocation_size));
+ if (!AnyBitSet(memory_type & MemoryType::kHostLocal)) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Host-local buffers require the kHostLocal bit: "
+ << MemoryTypeString(memory_type);
+ }
+
+ // Always use device-visible.
+ memory_type |= MemoryType::kDeviceVisible;
+
+ // Find an allocator that works for device-visible buffers.
+ ASSIGN_OR_RETURN(
+ auto* allocator,
+ FindCompatibleAllocator(memory_type, buffer_usage, device_placements));
+ return allocator->Allocate(memory_type, buffer_usage, allocation_size);
+}
+
+StatusOr<ref_ptr<Buffer>> DeviceManager::AllocateDeviceLocalBuffer(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ device_size_t allocation_size,
+ absl::Span<const DevicePlacement> device_placements) {
+ IREE_TRACE_SCOPE("DeviceManager::AllocateDeviceLocalBuffer:size", int)
+ (static_cast<int>(allocation_size));
+ if (!AnyBitSet(memory_type & MemoryType::kDeviceLocal)) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Device-local buffers require the kDeviceLocal bit: "
+ << MemoryTypeString(memory_type);
+ }
+
+ // Find an allocator that works for device-local buffers.
+ ASSIGN_OR_RETURN(
+ auto* allocator,
+ FindCompatibleAllocator(memory_type, buffer_usage, device_placements));
+ return allocator->Allocate(memory_type, buffer_usage, allocation_size);
+}
+
+Status DeviceManager::Submit(Device* device, CommandQueue* command_queue,
+ absl::Span<const SubmissionBatch> batches,
+ absl::Time deadline, FenceValue fence) {
+ IREE_TRACE_SCOPE0("DeviceManager::Submit");
+ return command_queue->Submit(batches, fence);
+}
+
+Status DeviceManager::Flush() {
+ IREE_TRACE_SCOPE0("DeviceManager::Flush");
+ return OkStatus();
+}
+
+Status DeviceManager::WaitIdle(absl::Time deadline) {
+ IREE_TRACE_SCOPE0("DeviceManager::WaitIdle");
+ absl::MutexLock lock(&device_mutex_);
+ for (const auto& device : devices_) {
+ RETURN_IF_ERROR(device->WaitIdle(deadline));
+ }
+ return OkStatus();
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/device_manager.h b/hal/device_manager.h
new file mode 100644
index 0000000..07c22e0
--- /dev/null
+++ b/hal/device_manager.h
@@ -0,0 +1,209 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DEVICE_MANAGER_H_
+#define IREE_HAL_DEVICE_MANAGER_H_
+
+#include <vector>
+
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "base/time.h"
+#include "hal/allocator.h"
+#include "hal/buffer.h"
+#include "hal/command_queue.h"
+#include "hal/device.h"
+#include "hal/device_placement.h"
+#include "hal/executable_format.h"
+#include "hal/fence.h"
+
+namespace iree {
+namespace hal {
+
+// Specifies how devices should be resolved to DevicePlacements.
+// Most fields are optional and when not included will be ignored.
+struct PlacementSpec {
+ // TODO(benvanik): other requirements (features/caps, power, etc).
+
+ // A list of executable formats that the placement should support.
+ // If more than one format is provided any device satisfying at least one
+ // will be considered for placement. The formats can be sorted in descending
+ // priority order to prefer the first available format in the case of ties.
+ absl::Span<const ExecutableFormat> available_formats;
+};
+
+// Manages device lifetime and placement resolution.
+// Optionally the DeviceManager may be used for automatic device selection for
+// allocations or batched submissions, however this is not required if specific
+// devices and scheduling behavior are known to the caller.
+//
+// Thread-safe. Note that callers must ensure that unregistered devices are kept
+// alive for as long as any commands are in-flight that may be using them.
+class DeviceManager final {
+ public:
+ DeviceManager();
+ ~DeviceManager();
+
+ // Registers a device with the manager.
+ // The device will be used to resolve placements. Any placements resolved
+ // prior to the addition of the device will need to be refreshed by the caller
+ // if they want to make use of the new device.
+ Status RegisterDevice(std::shared_ptr<Device> device);
+
+ // Unregisters a device with the manager.
+ // Placements that resolved to the device prior to unregistering will remain
+ // valid for that device. Callers will need to refresh the placements to
+ // ensure the device stops being used.
+ Status UnregisterDevice(Device* device);
+
+ // TODO(benvanik): dispatch info + requirements + etc -> DevicePlacement.
+
+ // Resolves a placement spec to a device placement based on the registered
+ // devices.
+ // If the placement is not fully specified the device and queue may be chosen
+ // at random. See PlacementSpec for more information about resolution and
+ // ranking.
+ StatusOr<DevicePlacement> ResolvePlacement(
+ const PlacementSpec& placement_spec) const;
+
+ // Finds an allocator that can allocate buffers of the given |memory_type| and
+ // |buffer_usage| such that the buffers can be used interchangebly.
+ // Fails if there is no Allocator that can satisfy that requirement.
+ StatusOr<Allocator*> FindCompatibleAllocator(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ absl::Span<const DevicePlacement> device_placements) const;
+
+ // Tries to allocate a host-local buffer that _may_ be optimal for use with
+ // the given |device_placements| and _may_ be device-visible. The buffer can
+ // be used for staging uploads to device-local buffers and is useful for times
+ // when the buffer will be used more on the host than the device. If a buffer
+ // never needs to be used with a device prefer instead
+ // Allocator::host_local()::Allocate.
+ //
+ // Returns a buffer even if it's not possible to satisfy the requested
+ // |buffer_usage| for the |device_placements| at the cost of a run-time
+ // performance hit.
+ StatusOr<ref_ptr<Buffer>> TryAllocateDeviceVisibleBuffer(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ device_size_t allocation_size,
+ absl::Span<const DevicePlacement> device_placements);
+ StatusOr<ref_ptr<Buffer>> TryAllocateDeviceVisibleBuffer(
+ BufferUsageBitfield buffer_usage, device_size_t allocation_size,
+ absl::Span<const DevicePlacement> device_placements) {
+ return TryAllocateDeviceVisibleBuffer(
+ MemoryType::kHostLocal | MemoryType::kDeviceVisible, buffer_usage,
+ allocation_size, device_placements);
+ }
+
+ // Allocates a host-local buffer that is optimal for use on the host but is
+ // usable by the given |device_placements| (at a possible performance
+ // penalty). The buffer can be used for staging uploads to device-local
+ // buffers and is useful for times when the buffer will be used more on the
+ // host than the device. If a buffer never needs to be used with a device
+ // prefer instead HeapBuffer::Allocate.
+ //
+ // Fails if it is not possible to allocate and satisfy all |device_placements|
+ // for the requested |buffer_usage|.
+ StatusOr<ref_ptr<Buffer>> AllocateDeviceVisibleBuffer(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ device_size_t allocation_size,
+ absl::Span<const DevicePlacement> device_placements);
+ StatusOr<ref_ptr<Buffer>> AllocateDeviceVisibleBuffer(
+ BufferUsageBitfield buffer_usage, device_size_t allocation_size,
+ absl::Span<const DevicePlacement> device_placements) {
+ return AllocateDeviceVisibleBuffer(
+ MemoryType::kHostLocal | MemoryType::kDeviceVisible, buffer_usage,
+ allocation_size, device_placements);
+ }
+
+ // Allocates a device-local buffer that is optimal for use with the given
+ // |device_placements|. The buffer will not be host-visible and can only be
+ // used from compatible device queues.
+ //
+ // Fails if it is not possible to allocate and satisfy all |device_placements|
+ // for the requested |buffer_usage|.
+ StatusOr<ref_ptr<Buffer>> AllocateDeviceLocalBuffer(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ device_size_t allocation_size,
+ absl::Span<const DevicePlacement> device_placements);
+ StatusOr<ref_ptr<Buffer>> AllocateDeviceLocalBuffer(
+ BufferUsageBitfield buffer_usage, device_size_t allocation_size,
+ absl::Span<const DevicePlacement> device_placements) {
+ return AllocateDeviceLocalBuffer(MemoryType::kDeviceLocal, buffer_usage,
+ allocation_size, device_placements);
+ }
+
+ // Enqueues a submission against the given target |device| |command_queue|.
+ // The provided |deadline| is used to determine how long the submission can
+ // stay waiting in the queue prior to flushing, with absl::InfinitePast
+ // indicating immediate submission and absl::InfiniteFuture indicating that
+ // Flush must be called.
+ //
+ // If a |fence| is provided it will be signaled when the submission has
+ // completed and otherwise the caller must use WaitIdle to ensure completion.
+ // If a sequence of submissions are performed then the semaphore relationships
+ // can be used to elide waits. Submit(A)+Submit(B, fence) where there is a
+ // dependency from A->B is safe.
+ //
+ // All provided resources must remain alive until the provided |fence|
+ // resolves or Scheduler::WaitIdle succeeds.
+ //
+ // Submissions may be made from any thread. Behavior is undefined
+ // if a thread is performing a WaitIdle while another thread submits work.
+ Status Submit(Device* device, CommandQueue* command_queue,
+ absl::Span<const SubmissionBatch> batches, absl::Time deadline,
+ FenceValue fence = {});
+ Status Submit(Device* device, CommandQueue* command_queue,
+ absl::Span<const SubmissionBatch> batches,
+ absl::Duration timeout, FenceValue fence = {}) {
+ return Submit(device, command_queue, batches,
+ RelativeTimeoutToDeadline(timeout), fence);
+ }
+ Status Submit(Device* device, CommandQueue* command_queue,
+ absl::Span<const SubmissionBatch> batches,
+ FenceValue fence = {}) {
+ return Submit(device, command_queue, batches, absl::InfinitePast(), fence);
+ }
+
+ // Flushes any requests that are pending in the scheduler and ensures they
+ // begin executing ASAP regardless of policy.
+ //
+ // If any used device has encountered an error during submission at any
+ // point it will be returned here (repeatedly).
+ Status Flush();
+
+ // Blocks until all outstanding requests have been completed.
+ // This is equivalent to having waited on all outstanding fences.
+ // Implicitly calls Flush to ensure delayed requests are scheduled.
+ // Work submitted from other threads during a wait may not be included in the
+ // wait set.
+ //
+ // If any used device has encountered an error during submission at any
+ // point it will be returned here (repeatedly).
+ Status WaitIdle(absl::Time deadline);
+ inline Status WaitIdle(absl::Duration timeout) {
+ return WaitIdle(RelativeTimeoutToDeadline(timeout));
+ }
+ inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); }
+
+ private:
+ mutable absl::Mutex device_mutex_;
+ std::vector<std::shared_ptr<Device>> devices_ ABSL_GUARDED_BY(device_mutex_);
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_DEVICE_MANAGER_H_
diff --git a/iree/hal/device_placement.h b/hal/device_placement.h
similarity index 100%
rename from iree/hal/device_placement.h
rename to hal/device_placement.h
diff --git a/hal/driver.h b/hal/driver.h
new file mode 100644
index 0000000..bc0dbc1
--- /dev/null
+++ b/hal/driver.h
@@ -0,0 +1,61 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DRIVER_H_
+#define IREE_HAL_DRIVER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "base/status.h"
+#include "hal/device.h"
+#include "hal/device_info.h"
+
+namespace iree {
+namespace hal {
+
+class Driver {
+ public:
+ virtual ~Driver() = default;
+
+ // Driver name used during registration.
+ const std::string& name() const { return name_; }
+
+ // TODO(benvanik): info/query (version number, etc).
+
+ // Enumerates devices available for creation from the driver.
+ // This may fail if the driver is in an invalid state but otherwise will
+ // return an empty list if no devices are available.
+ virtual StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() = 0;
+
+ // Creates the driver-defined 'default' device.
+ // This may simply be the first device enumerated.
+ virtual StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() = 0;
+
+ // Creates a device as queried with the given |device_info|.
+ virtual StatusOr<std::shared_ptr<Device>> CreateDevice(
+ const DeviceInfo& device_info) = 0;
+
+ protected:
+ explicit Driver(std::string name) : name_(std::move(name)) {}
+
+ private:
+ const std::string name_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_DRIVER_H_
diff --git a/hal/driver_registry.cc b/hal/driver_registry.cc
new file mode 100644
index 0000000..cb49a0a
--- /dev/null
+++ b/hal/driver_registry.cc
@@ -0,0 +1,87 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/driver_registry.h"
+
+#include "base/status.h"
+
+namespace iree {
+namespace hal {
+
+// static
+DriverRegistry* DriverRegistry::shared_registry() {
+ static auto* singleton = new DriverRegistry();
+ return singleton;
+}
+
+DriverRegistry::DriverRegistry() = default;
+
+DriverRegistry::~DriverRegistry() = default;
+
+Status DriverRegistry::Register(std::string driver_name, FactoryFn factory_fn) {
+ absl::MutexLock lock(&mutex_);
+ for (const auto& pair : driver_factory_fns_) {
+ if (pair.first == driver_name) {
+ return AlreadyExistsErrorBuilder(IREE_LOC)
+ << "Driver already registered: " << driver_name;
+ }
+ }
+ driver_factory_fns_.emplace_back(driver_name, std::move(factory_fn));
+ return OkStatus();
+}
+
+bool DriverRegistry::HasDriver(absl::string_view driver_name) const {
+ absl::MutexLock lock(&mutex_);
+ for (const auto& pair : driver_factory_fns_) {
+ if (pair.first == driver_name) {
+ return true;
+ }
+ }
+ return false;
+}
+
+std::vector<std::string> DriverRegistry::EnumerateAvailableDrivers() const {
+ absl::MutexLock lock(&mutex_);
+ std::vector<std::string> driver_names;
+ driver_names.reserve(driver_factory_fns_.size());
+ for (const auto& pair : driver_factory_fns_) {
+ driver_names.push_back(pair.first);
+ }
+ return driver_names;
+}
+
+StatusOr<std::shared_ptr<Driver>> DriverRegistry::Create(
+ absl::string_view driver_name) const {
+ FactoryFn factory_fn;
+ {
+ absl::MutexLock lock(&mutex_);
+ for (const auto& pair : driver_factory_fns_) {
+ if (pair.first == driver_name) {
+ factory_fn = pair.second;
+ break;
+ }
+ }
+ if (!factory_fn) {
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Driver " << driver_name << " not found";
+ }
+ }
+ return factory_fn();
+}
+
+} // namespace hal
+} // namespace iree
+
+IREE_REGISTER_MODULE_INITIALIZER(
+ iree_hal, ::iree::hal::DriverRegistry::shared_registry());
diff --git a/hal/driver_registry.h b/hal/driver_registry.h
new file mode 100644
index 0000000..6cca616
--- /dev/null
+++ b/hal/driver_registry.h
@@ -0,0 +1,83 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DRIVER_REGISTRY_H_
+#define IREE_HAL_DRIVER_REGISTRY_H_
+
+#include <memory>
+#include <vector>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#include "base/init.h"
+#include "base/status.h"
+#include "hal/driver.h"
+
+namespace iree {
+namespace hal {
+
+// Driver registry and factory.
+// Factory functions for available drivers are registered with a given name and
+// can be invoked with a call to Create. The configuration of the drivers is
+// generally contained within the factory function and consumers of the drivers
+// don't need to fiddle with things.
+//
+// This is used for dynamic *safe* link-time driver module registration.
+// Roughly: driver_registry provides the shared registry and a way to create
+// drivers and *_driver_module.cc files register drivers when linked in.
+// Remember to alwayslink=1 on cc_libraries providing modules.
+//
+// If link-time driver registration is not desired (or possible) it's also
+// possible to explicitly register drivers via this registry. This is useful
+// when programmatically enabling drivers.
+//
+// Thread-safe.
+class DriverRegistry final {
+ public:
+ using FactoryFn = std::function<StatusOr<std::shared_ptr<Driver>>()>;
+
+ // The shared driver registry singleton that modules use when linked in.
+ static DriverRegistry* shared_registry();
+
+ DriverRegistry();
+ ~DriverRegistry();
+
+ // Registers a driver and its factory function.
+ // The function will be called to create a new driver whenever it is requested
+ // via Create.
+ Status Register(std::string driver_name, FactoryFn factory_fn);
+
+ // Returns true if there is a driver registered with the given name.
+ bool HasDriver(absl::string_view driver_name) const;
+
+ // Returns a list of registered drivers.
+ std::vector<std::string> EnumerateAvailableDrivers() const;
+
+ // TODO(benvanik): flags for enabling debug validation/control/etc.
+ // Creates a driver by name.
+ StatusOr<std::shared_ptr<Driver>> Create(absl::string_view driver_name) const;
+
+ private:
+ mutable absl::Mutex mutex_;
+ std::vector<std::pair<std::string, FactoryFn>> driver_factory_fns_
+ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace hal
+} // namespace iree
+
+IREE_DECLARE_MODULE_INITIALIZER(iree_hal);
+IREE_REQUIRE_MODULE_LINKED(iree_hal);
+
+#endif // IREE_HAL_DRIVER_REGISTRY_H_
diff --git a/hal/event.h b/hal/event.h
new file mode 100644
index 0000000..53b981d
--- /dev/null
+++ b/hal/event.h
@@ -0,0 +1,35 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_EVENT_H_
+#define IREE_HAL_EVENT_H_
+
+#include "hal/resource.h"
+
+namespace iree {
+namespace hal {
+
+// Events are used for defining synchronization scopes within CommandBuffers.
+// An event only exists within a single CommandBuffer and must not be used
+// across CommandBuffers from the same device or others.
+//
+// See CommandBuffer::SignalEvent and CommandBuffer::WaitEvents for more info.
+class Event : public Resource {
+ public:
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_EVENT_H_
diff --git a/hal/executable.h b/hal/executable.h
new file mode 100644
index 0000000..7a39b10
--- /dev/null
+++ b/hal/executable.h
@@ -0,0 +1,57 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_EXECUTABLE_H_
+#define IREE_HAL_EXECUTABLE_H_
+
+#include "hal/resource.h"
+
+namespace iree {
+namespace hal {
+
+class Executable : public Resource {
+ public:
+ ~Executable() override = default;
+
+ // True if the executable was prepared with debugging enabled and the device
+ // and input data support debugging (symbols present, etc).
+ virtual bool supports_debugging() const = 0;
+
+ // TODO(benvanik): disassembly methods.
+
+ // TODO(benvanik): relative offset calculation:
+ // - step once
+ // - step over
+ // - step out
+
+ // TODO(benvanik): create executable split on breakpoint.
+ // Executable should return when the breakpoint is hit without any future
+ // modifications to output buffers. If the breakpoint is not hit the
+ // executable should run to completion as normal.
+
+ // TODO(benvanik): retrieve coverage info.
+ // Returns a buffer containing offset -> coverage metrics. Note that depending
+ // on the device this may only contain a single coverage metric for the entire
+ // executable or some subset of the available offsets.
+
+ // TODO(benvanik): retrieve profiling info.
+
+ protected:
+ Executable() = default;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_EXECUTABLE_H_
diff --git a/hal/executable_cache.cc b/hal/executable_cache.cc
new file mode 100644
index 0000000..05617cd
--- /dev/null
+++ b/hal/executable_cache.cc
@@ -0,0 +1,25 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/executable_cache.h"
+
+namespace iree {
+namespace hal {
+
+ExecutableCache::ExecutableCache() = default;
+
+ExecutableCache::~ExecutableCache() = default;
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/executable_cache.h b/hal/executable_cache.h
new file mode 100644
index 0000000..cd3eaa3
--- /dev/null
+++ b/hal/executable_cache.h
@@ -0,0 +1,126 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_EXECUTABLE_CACHE_H_
+#define IREE_HAL_EXECUTABLE_CACHE_H_
+
+#include "base/bitfield.h"
+#include "base/ref_ptr.h"
+#include "base/status.h"
+#include "hal/executable.h"
+#include "hal/executable_format.h"
+#include "hal/executable_spec.h"
+
+namespace iree {
+namespace hal {
+
+// Defines how the executable cache performs preparation.
+enum class ExecutableCachingMode : uint32_t {
+ // Allows the cache to reference the provided executable_data after it has
+ // prepared the executable. Callers must ensure the data remains valid for the
+ // lifetime of the cache. If memory mapping constant executable data from
+ // disk this can be used to avoid copies.
+ kAliasProvidedData = 1 << 0,
+
+ // Allows the prepared executable to be cached persistently (on disk/etc).
+ // Enable for any executable that is likely to be used in future runs.
+ // Note that not all caches support persistent serialization and this is just
+ // a hint.
+ kAllowPersistentCaching = 1 << 1,
+
+ // Allows the cache to optimize the executable as much as it can.
+ // This may cause preparation to take significantly longer while (hopefully)
+ // improving runtime performance. Avoid for one-shot executables.
+ kAllowOptimization = 1 << 2,
+
+ // Enables Executable debugging methods if supported by the device and
+ // executable. This may disable certain optimizations or retain additional
+ // data to allow disassembly, stepping, etc.
+ //
+ // Device must support the DeviceFeature::kDebugging feature and executables
+ // must support the ExecutableFeature::kDebugging feature.
+ kEnableDebugging = 1 << 3,
+
+ // Enables Executable coverage if supported by the device and executable.
+ // Depending on the optimization mode this may produce partial coverage
+ // results (for example, when certain source operations were optimized away).
+ //
+ // Device must support the DeviceFeature::kCoverage feature and executables
+ // must support the ExecutableFeature::kCoverage feature.
+ kEnableCoverage = 1 << 4,
+
+ // Enables Executable profiling if supported by the device and executable.
+ // Depending on the optimization mode this may produce partial profiling
+ // results. Profiling attribution (whether to the entire executable or
+ // specific operations) depends on the implementation.
+ //
+ // Device must support the DeviceFeature::kProfiling feature and executables
+ // must support the ExecutableFeature::kProfiling feature.
+ kEnableProfiling = 1 << 5,
+
+ // Default caching mode.
+ kDefault = kAllowPersistentCaching | kAllowOptimization,
+};
+IREE_BITFIELD(ExecutableCachingMode);
+using ExecutableCachingModeBitfield = ExecutableCachingMode;
+
+// A cache of prepared executables for a particular device.
+// Caches may be shared across multiple devices from the same driver or specific
+// to individual devices. Caches may persist prepared executables across process
+// launches or reprepare them each run. Callers should assume that the cache is
+// a no-op and the returned Executables only live for as long as the cache does.
+//
+// The term 'cache' here is rather optimistic - it's perfectly acceptable for
+// implementations to not cache at all and return new Executables for each
+// PrepareExecutable called (even for the same executable). Callers should
+// expect such behavior and try to retain the results of the PrepareExecutable
+// calls to reduce overhead in re-preparing executables.
+//
+// Thread-safe - multiple threads may prepare executables (including the *same*
+// executable) simultaneously.
+class ExecutableCache {
+ public:
+ virtual ~ExecutableCache();
+
+ // TODO(benvanik): status/queries (size, etc).
+
+ // TODO(b/137153339): serialization/deserialization.
+
+ // Returns true if the executable cache can prepare the given executable input
+ // format. Perparation may still fail if the particular version or features
+ // required by the executable are not supported.
+ virtual bool CanPrepareFormat(ExecutableFormat format) const = 0;
+
+ // Prepares an executable for use.
+ // The provided |spec| and |executable_data| will be used to either lookup a
+ // previously prepared executable in the cache or prepare a new one.
+ //
+ // Depending on the driver preparation may take a non-trivial amount of time
+ // (such as when JITing/etc). As the cache is internally synchronized callers
+ // can issue preparation requests from multiple threads - even for the same
+ // executables - and calls will block until preparation completes.
+ //
+ // When preparing a large number of executables it's recommended to use the
+ // PrepareExecutables method to batch and wait on the results.
+ virtual StatusOr<ref_ptr<Executable>> PrepareExecutable(
+ ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) = 0;
+
+ protected:
+ ExecutableCache();
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_EXECUTABLE_CACHE_H_
diff --git a/iree/hal/executable_format.h b/hal/executable_format.h
similarity index 100%
rename from iree/hal/executable_format.h
rename to hal/executable_format.h
diff --git a/hal/executable_spec.h b/hal/executable_spec.h
new file mode 100644
index 0000000..f871819
--- /dev/null
+++ b/hal/executable_spec.h
@@ -0,0 +1,44 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_EXECUTABLE_SPEC_H_
+#define IREE_HAL_EXECUTABLE_SPEC_H_
+
+#include "absl/types/span.h"
+#include "hal/executable_format.h"
+
+namespace iree {
+namespace hal {
+
+// Defines an executable specification used by a cache to prepare an executable.
+struct ExecutableSpec {
+ // TODO(benvanik): pre-populated hash_code/key to avoid calculation.
+
+ // Format of the executable input data.
+ ExecutableFormat format = kExecutableFormatUnspecified;
+
+ // A reference to the executable data as input to the cache.
+ // If ExecutableCachingMode::kAliasProvidedData is set then this reference
+ // may be retained by the cache and the backing buffer must be kept valid for
+ // the lifetime of the cache.
+ absl::Span<const uint8_t> executable_data;
+
+ // TODO(benvanik): add specialization info (constants/defines).
+ // TODO(benvanik): add compiler flags? could treat as opaque.
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_EXECUTABLE_SPEC_H_
diff --git a/hal/fence.h b/hal/fence.h
new file mode 100644
index 0000000..40071c5
--- /dev/null
+++ b/hal/fence.h
@@ -0,0 +1,72 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_FENCE_H_
+#define IREE_HAL_FENCE_H_
+
+#include <cstdint>
+
+#include "base/status.h"
+#include "hal/resource.h"
+
+namespace iree {
+namespace hal {
+
+// Synchronization mechanism for device->host notification.
+// Fences behave like timeline semaphores and contain a monotonically increasing
+// uint64_t payload. They may be waited on any number of times - even if they
+// have already been signaled.
+//
+// A fence is updated to its new value after all prior commands have completed
+// but the delay between completion and the host being woken varies. Some
+// implementations may coalesce fences to avoid spurious waking while others
+// will immediately synchronize with the host.
+//
+// The primary use of fences is for resource lifetime management: all resources
+// used by a set of submission batches must be considered live until the fence
+// attached to the submission has signaled.
+//
+// Fences may be set to a permanently failed state by implementations when
+// errors occur during asynchronous execution. Users are expected to propagate
+// the failures and possibly reset the entire device that produced the error.
+//
+// For more information on fences see the following docs describing how
+// timelines are generally used (specifically in the device->host case):
+// https://www.youtube.com/watch?v=SpE--Rf516Y
+// https://www.khronos.org/assets/uploads/developers/library/2018-xdc/Vulkan-Timeline-Semaphores-Part-1_Sep18.pdf
+// https://docs.microsoft.com/en-us/windows/win32/direct3d12/user-mode-heap-synchronization
+class Fence : public Resource {
+ public:
+ // Returns a permanent failure status if the fence is indicating an
+ // asynchronous failure.
+ //
+ // Returns the status at the time the method is called without blocking and as
+ // such is only valid after a fence has been signaled. The same failure status
+ // will be returned regardless of when in the timeline the error occurred.
+ virtual Status status() const = 0;
+
+ // Queries the current payload of the fence. As the payload is monotonically
+ // increasing it is guaranteed that the value is at least equal to the
+ // previous result of a QueryValue call and coherent with any waits for a
+ // specified value via Device::WaitAllFences.
+ virtual StatusOr<uint64_t> QueryValue() = 0;
+};
+
+// A reference to a fence and associated payload value.
+using FenceValue = std::pair<Fence*, uint64_t>;
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_FENCE_H_
diff --git a/hal/heap_buffer.cc b/hal/heap_buffer.cc
new file mode 100644
index 0000000..5ea4326
--- /dev/null
+++ b/hal/heap_buffer.cc
@@ -0,0 +1,190 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/heap_buffer.h"
+
+#include <cstdint>
+#include <cstdlib>
+#include <string>
+#include <utility>
+
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/allocator.h"
+#include "hal/host/host_buffer.h"
+
+namespace iree {
+namespace hal {
+
+namespace {
+
+// An allocator that allocates or wraps host-only buffers.
+// The resulting buffers are not usable by most devices without a copy and
+// using a device allocator is strongly preferred.
+class HeapAllocator : public Allocator {
+ public:
+ // Returns a singleton heap allocator that can provide buffers that have
+ // MemoryType::kHostLocal and are allocated with malloc/free.
+ // These buffers will not be usable by devices directly and may incur
+ // additional copies.
+ static Allocator* std_heap();
+
+ // TODO(benvanik): specify custom allocator (not malloc/free).
+ HeapAllocator();
+ ~HeapAllocator() override;
+
+ bool CanUseBufferLike(Allocator* source_allocator,
+ MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ BufferUsageBitfield intended_usage) const override;
+
+ bool CanAllocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) const override;
+
+ StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) override;
+
+ StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield buffer_usage,
+ void* data,
+ size_t data_length) override;
+};
+
+// static
+Allocator* HeapAllocator::std_heap() {
+ static Allocator* std_heap_allocator = new HeapAllocator();
+ return std_heap_allocator;
+}
+
+HeapAllocator::HeapAllocator() = default;
+
+HeapAllocator::~HeapAllocator() = default;
+
+bool HeapAllocator::CanUseBufferLike(Allocator* source_allocator,
+ MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ BufferUsageBitfield intended_usage) const {
+ // The host can use anything with kHostVisible.
+ if (!AnyBitSet(memory_type & MemoryType::kHostVisible)) {
+ return false;
+ }
+
+ // Host currently uses mapping to copy buffers, which is done a lot.
+ if (!AnyBitSet(buffer_usage & BufferUsage::kMapping)) {
+ return false;
+ }
+
+ return true;
+}
+
+bool HeapAllocator::CanAllocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) const {
+ // This host only allocator cannot serve device visible allocation as we
+ // can't know which devices these buffers will be used with.
+ return (memory_type & MemoryType::kHostLocal) == MemoryType::kHostLocal &&
+ !AnyBitSet(memory_type & MemoryType::kDeviceLocal) &&
+ !AnyBitSet(memory_type & MemoryType::kDeviceVisible);
+}
+
+StatusOr<ref_ptr<Buffer>> HeapAllocator::Allocate(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ size_t allocation_size) {
+ IREE_TRACE_SCOPE0("HeapAllocator::Allocate");
+
+ if (!CanAllocate(memory_type, buffer_usage, allocation_size)) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Allocation not supported; memory_type="
+ << MemoryTypeString(memory_type)
+ << ", buffer_usage=" << BufferUsageString(buffer_usage)
+ << ", allocation_size=" << allocation_size;
+ }
+
+ void* malloced_data = std::calloc(1, allocation_size);
+ if (!malloced_data) {
+ return ResourceExhaustedErrorBuilder(IREE_LOC)
+ << "Failed to malloc " << allocation_size << " bytes";
+ }
+
+ auto buffer =
+ make_ref<HostBuffer>(this, memory_type, MemoryAccess::kAll, buffer_usage,
+ allocation_size, malloced_data, true);
+ return buffer;
+}
+
+StatusOr<ref_ptr<Buffer>> HeapAllocator::WrapMutable(
+ MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield buffer_usage, void* data, size_t data_length) {
+ auto buffer = make_ref<HostBuffer>(this, memory_type, allowed_access,
+ buffer_usage, data_length, data, false);
+ return buffer;
+}
+
+} // namespace
+
+// static
+ref_ptr<Buffer> HeapBuffer::Allocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield usage,
+ size_t allocation_size) {
+ auto buffer_or =
+ HeapAllocator::std_heap()->Allocate(memory_type, usage, allocation_size);
+ return std::move(buffer_or.ValueOrDie());
+}
+
+// static
+ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage,
+ const void* data, size_t data_length) {
+ return AllocateCopy(usage, MemoryAccess::kAll, data, data_length);
+}
+
+// static
+ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage,
+ MemoryAccessBitfield allowed_access,
+ const void* data, size_t data_length) {
+ IREE_TRACE_SCOPE0("HeapBuffer::AllocateCopy");
+ // Ensure we can map so that we can copy into it.
+ usage |= BufferUsage::kMapping;
+ auto buffer_or = HeapAllocator::std_heap()->Allocate(MemoryType::kHostLocal,
+ usage, data_length);
+ auto buffer = std::move(buffer_or.ValueOrDie());
+ buffer->WriteData(0, data, data_length).IgnoreError();
+ buffer->set_allowed_access(allowed_access);
+ return buffer;
+}
+
+// static
+ref_ptr<Buffer> HeapBuffer::Wrap(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield usage, const void* data,
+ size_t data_length) {
+ auto buffer_or =
+ HeapAllocator::std_heap()->Wrap(memory_type, usage, data, data_length);
+ return std::move(buffer_or.ValueOrDie());
+}
+
+// static
+ref_ptr<Buffer> HeapBuffer::WrapMutable(MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield usage, void* data,
+ size_t data_length) {
+ auto buffer_or = HeapAllocator::std_heap()->WrapMutable(
+ memory_type, allowed_access, usage, data, data_length);
+ return std::move(buffer_or.ValueOrDie());
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/heap_buffer.h b/hal/heap_buffer.h
new file mode 100644
index 0000000..ceaa10f
--- /dev/null
+++ b/hal/heap_buffer.h
@@ -0,0 +1,117 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_HEAP_BUFFER_H_
+#define IREE_HAL_HEAP_BUFFER_H_
+
+#include <memory>
+
+#include "base/status.h"
+#include "hal/buffer.h"
+
+namespace iree {
+namespace hal {
+
+// Factory for buffers that are allocated from the host heap (malloc/free).
+// These buffers cannot be used by devices and will incur copies/transfers when
+// used. Prefer device-specific allocators instead.
+class HeapBuffer {
+ public:
+ // Allocates a zeroed host heap buffer of the given size.
+ // Returns a buffer allocated with malloc and have MemoryType::kHostLocal
+ // and will not be usable by devices without copies.
+ static ref_ptr<Buffer> Allocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield usage,
+ size_t allocation_size);
+ static ref_ptr<Buffer> Allocate(BufferUsageBitfield usage,
+ size_t allocation_size) {
+ return Allocate(MemoryType::kHostLocal, usage, allocation_size);
+ }
+
+ // Allocates a host heap buffer with a copy of the given data.
+ // Returns a buffer allocated with malloc and have MemoryType::kHostLocal
+ // and will not be usable by devices without copies.
+ static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage,
+ const void* data, size_t data_length);
+ static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage,
+ MemoryAccessBitfield allowed_access,
+ const void* data, size_t data_length);
+ template <typename T>
+ static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage,
+ absl::Span<const T> data);
+ template <typename T>
+ static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage,
+ MemoryAccessBitfield allowed_access,
+ absl::Span<const T> data);
+
+ // Wraps an existing host heap allocation in a buffer.
+ // Ownership of the host allocation remains with the caller and the memory
+ // must remain valid for so long as the Buffer may be in use.
+ // Will have MemoryType::kHostLocal in most cases and may not be usable
+ // by the device.
+ static ref_ptr<Buffer> Wrap(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield usage, const void* data,
+ size_t data_length);
+ static ref_ptr<Buffer> WrapMutable(MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield usage, void* data,
+ size_t data_length);
+ template <typename T>
+ static ref_ptr<Buffer> Wrap(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield usage,
+ absl::Span<const T> data);
+ template <typename T>
+ static ref_ptr<Buffer> WrapMutable(MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield usage,
+ absl::Span<T> data);
+};
+
+// Inline functions and template definitions follow:
+
+template <typename T>
+ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage,
+ absl::Span<const T> data) {
+ return HeapBuffer::AllocateCopy(usage, MemoryAccess::kAll, data);
+}
+
+template <typename T>
+ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage,
+ MemoryAccessBitfield allowed_access,
+ absl::Span<const T> data) {
+ return HeapBuffer::AllocateCopy(usage, allowed_access, data.data(),
+ data.size() * sizeof(T));
+}
+
+template <typename T>
+ref_ptr<Buffer> HeapBuffer::Wrap(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield usage,
+ absl::Span<const T> data) {
+ return HeapBuffer::Wrap(memory_type, usage, data.data(),
+ data.size() * sizeof(T));
+}
+
+template <typename T>
+ref_ptr<Buffer> HeapBuffer::WrapMutable(MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield usage,
+ absl::Span<T> data) {
+ return HeapBuffer::WrapMutable(memory_type, allowed_access, usage,
+ data.data(), data.size() * sizeof(T));
+}
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_HEAP_BUFFER_H_
diff --git a/hal/host/BUILD b/hal/host/BUILD
new file mode 100644
index 0000000..76f055c
--- /dev/null
+++ b/hal/host/BUILD
@@ -0,0 +1,155 @@
+# Default implementations for HAL types that use the host resources.
+# These are generally just wrappers around host heap memory and host threads.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "async_command_queue",
+ srcs = ["async_command_queue.cc"],
+ hdrs = ["async_command_queue.h"],
+ deps = [
+ ":host_submission_queue",
+ "///base:status",
+ "///base:tracing",
+ "///hal:command_queue",
+ "///hal:fence",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_test(
+ name = "async_command_queue_test",
+ srcs = ["async_command_queue_test.cc"],
+ deps = [
+ ":async_command_queue",
+ ":host_submission_queue",
+ "///base:status",
+ "///base:status_matchers",
+ "///base:time",
+ "///hal:command_queue",
+ "///hal/testing:mock_command_buffer",
+ "///hal/testing:mock_command_queue",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "host_buffer",
+ srcs = ["host_buffer.cc"],
+ hdrs = ["host_buffer.h"],
+ deps = [
+ "///base:logging",
+ "///base:source_location",
+ "///base:status",
+ "///hal:buffer",
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
+cc_library(
+ name = "host_event",
+ srcs = ["host_event.cc"],
+ hdrs = ["host_event.h"],
+ deps = [
+ "///hal:event",
+ ],
+)
+
+cc_library(
+ name = "host_fence",
+ srcs = ["host_fence.cc"],
+ hdrs = ["host_fence.h"],
+ deps = [
+ "///base:status",
+ "///base:tracing",
+ "///hal:fence",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_test(
+ name = "host_fence_test",
+ srcs = ["host_fence_test.cc"],
+ deps = [
+ ":host_fence",
+ "///base:status",
+ "///base:status_matchers",
+ "@com_google_absl//absl/time",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "host_local_allocator",
+ srcs = ["host_local_allocator.cc"],
+ hdrs = ["host_local_allocator.h"],
+ deps = [
+ ":host_buffer",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal:allocator",
+ "///hal:buffer",
+ ],
+)
+
+cc_library(
+ name = "host_local_command_processor",
+ srcs = ["host_local_command_processor.cc"],
+ hdrs = ["host_local_command_processor.h"],
+ deps = [
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal:command_buffer",
+ ],
+)
+
+cc_library(
+ name = "host_submission_queue",
+ srcs = ["host_submission_queue.cc"],
+ hdrs = ["host_submission_queue.h"],
+ deps = [
+ ":host_fence",
+ "///base:intrusive_list",
+ "///base:status",
+ "///base:tracing",
+ "///hal:command_queue",
+ "///hal:fence",
+ "///hal:semaphore",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_test(
+ name = "host_submission_queue_test",
+ srcs = ["host_submission_queue_test.cc"],
+ deps = [
+ ":host_submission_queue",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "inproc_command_buffer",
+ srcs = ["inproc_command_buffer.cc"],
+ hdrs = ["inproc_command_buffer.h"],
+ deps = [
+ "///base:arena",
+ "///base:intrusive_list",
+ "///base:status",
+ "///base:tracing",
+ "///hal:command_buffer",
+ ],
+)
diff --git a/iree/hal/host/CMakeLists.txt b/hal/host/CMakeLists.txt
similarity index 100%
rename from iree/hal/host/CMakeLists.txt
rename to hal/host/CMakeLists.txt
diff --git a/hal/host/async_command_queue.cc b/hal/host/async_command_queue.cc
new file mode 100644
index 0000000..0d3bc2c
--- /dev/null
+++ b/hal/host/async_command_queue.cc
@@ -0,0 +1,127 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/async_command_queue.h"
+
+#include "absl/base/thread_annotations.h"
+#include "base/status.h"
+#include "base/tracing.h"
+
+namespace iree {
+namespace hal {
+
+AsyncCommandQueue::AsyncCommandQueue(std::unique_ptr<CommandQueue> target_queue)
+ : CommandQueue(target_queue->name(), target_queue->supported_categories()),
+ target_queue_(std::move(target_queue)) {
+ IREE_TRACE_SCOPE0("AsyncCommandQueue::ctor");
+ thread_ = std::thread([this]() { ThreadMain(); });
+}
+
+AsyncCommandQueue::~AsyncCommandQueue() {
+ IREE_TRACE_SCOPE0("AsyncCommandQueue::dtor");
+ {
+ // Signal to thread that we want to stop. Note that the thread may have
+ // already been stopped and that's ok (as we'll Join right away).
+ // The thread will finish processing any queued submissions.
+ absl::MutexLock lock(&submission_mutex_);
+ submission_queue_.SignalShutdown();
+ }
+ thread_.join();
+
+ // Ensure we shut down OK.
+ {
+ absl::MutexLock lock(&submission_mutex_);
+ CHECK(submission_queue_.empty())
+ << "Dirty shutdown of async queue (unexpected thread exit?)";
+ }
+}
+
+void AsyncCommandQueue::ThreadMain() {
+ // TODO(benvanik): make this safer (may die if trace is flushed late).
+ IREE_TRACE_THREAD_ENABLE(target_queue_->name().c_str());
+
+ bool is_exiting = false;
+ while (!is_exiting) {
+ // Block until we are either requested to exit or there are pending
+ // submissions.
+ submission_mutex_.Lock();
+ submission_mutex_.Await(absl::Condition(
+ +[](HostSubmissionQueue* queue) {
+ return queue->has_shutdown() || !queue->empty();
+ },
+ &submission_queue_));
+ if (!submission_queue_.empty()) {
+ // Run all ready submissions (this may be called many times).
+ submission_mutex_.AssertHeld();
+ submission_queue_
+ .ProcessBatches(
+ [this](absl::Span<CommandBuffer* const> command_buffers)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(submission_mutex_) {
+ // Release the lock while we perform the processing so that
+ // other threads can submit more work.
+ submission_mutex_.AssertHeld();
+ submission_mutex_.Unlock();
+
+ // Relay the command buffers to the target queue.
+ // Since we are taking care of all synchronization they
+ // don't need any waiters or fences.
+ auto status = target_queue_->Submit(
+ {{}, command_buffers, {}}, {nullptr, 0u});
+
+ // Take back the lock so we can manipulate the queue safely.
+ submission_mutex_.Lock();
+ submission_mutex_.AssertHeld();
+
+ return status;
+ })
+ .IgnoreError();
+ submission_mutex_.AssertHeld();
+ }
+ if (submission_queue_.has_shutdown()) {
+ // Exit when there are no more submissions to process and an exit was
+ // requested (or we errored out).
+ is_exiting = true;
+ }
+ submission_mutex_.Unlock();
+ }
+}
+
+Status AsyncCommandQueue::Submit(absl::Span<const SubmissionBatch> batches,
+ FenceValue fence) {
+ IREE_TRACE_SCOPE0("AsyncCommandQueue::Submit");
+ absl::MutexLock lock(&submission_mutex_);
+ return submission_queue_.Enqueue(batches, fence);
+}
+
+Status AsyncCommandQueue::WaitIdle(absl::Time deadline) {
+ IREE_TRACE_SCOPE0("AsyncCommandQueue::WaitIdle");
+
+ // Wait until the deadline, the thread exits, or there are no more pending
+ // submissions.
+ absl::MutexLock lock(&submission_mutex_);
+ if (!submission_mutex_.AwaitWithDeadline(
+ absl::Condition(
+ +[](HostSubmissionQueue* queue) {
+ return queue->empty() || !queue->permanent_error().ok();
+ },
+ &submission_queue_),
+ deadline)) {
+ return DeadlineExceededErrorBuilder(IREE_LOC)
+ << "Deadline exceeded waiting for submission thread to go idle";
+ }
+ return submission_queue_.permanent_error();
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/async_command_queue.h b/hal/host/async_command_queue.h
new file mode 100644
index 0000000..54224db
--- /dev/null
+++ b/hal/host/async_command_queue.h
@@ -0,0 +1,71 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_
+#define IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_
+
+#include <memory>
+#include <thread> // NOLINT
+
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#include "hal/command_queue.h"
+#include "hal/fence.h"
+#include "hal/host/host_submission_queue.h"
+
+namespace iree {
+namespace hal {
+
+// Asynchronous command queue wrapper.
+// This creates a single thread to perform all CommandQueue operations. Any
+// submitted CommandBuffer is dispatched in FIFO order on the queue thread
+// against the provided |target_queue|.
+//
+// Target queues will receive submissions containing only command buffers as
+// all semaphore synchronization is handled by the wrapper. Fences will also be
+// omitted and code should safely handle nullptr.
+//
+// AsyncCommandQueue (as with CommandQueue) is thread-safe. Multiple threads
+// may submit command buffers concurrently, though the order of execution in
+// such a case depends entirely on the synchronization primitives provided.
+class AsyncCommandQueue final : public CommandQueue {
+ public:
+ explicit AsyncCommandQueue(std::unique_ptr<CommandQueue> target_queue);
+ ~AsyncCommandQueue() override;
+
+ Status Submit(absl::Span<const SubmissionBatch> batches,
+ FenceValue fence) override;
+
+ Status WaitIdle(absl::Time deadline) override;
+
+ private:
+ // Thread entry point for the async worker thread.
+ // Waits for submissions to be queued up and processes them eagerly.
+ void ThreadMain();
+
+ // CommandQueue that the async queue relays submissions into.
+ std::unique_ptr<CommandQueue> target_queue_;
+
+ // Thread that runs the ThreadMain() function and processes submissions.
+ std::thread thread_;
+
+ // Queue that manages submission ordering.
+ mutable absl::Mutex submission_mutex_;
+ HostSubmissionQueue submission_queue_ ABSL_GUARDED_BY(submission_mutex_);
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_
diff --git a/hal/host/async_command_queue_test.cc b/hal/host/async_command_queue_test.cc
new file mode 100644
index 0000000..995ef42
--- /dev/null
+++ b/hal/host/async_command_queue_test.cc
@@ -0,0 +1,232 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/async_command_queue.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "absl/memory/memory.h"
+#include "absl/time/time.h"
+#include "base/status.h"
+#include "base/status_matchers.h"
+#include "base/time.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "hal/command_queue.h"
+#include "hal/host/host_submission_queue.h"
+#include "hal/testing/mock_command_buffer.h"
+#include "hal/testing/mock_command_queue.h"
+
+namespace iree {
+namespace hal {
+namespace {
+
+using ::testing::_;
+
+using testing::MockCommandBuffer;
+using testing::MockCommandQueue;
+
+struct AsyncCommandQueueTest : public ::testing::Test {
+ MockCommandQueue* mock_target_queue;
+ std::unique_ptr<CommandQueue> command_queue;
+
+ void SetUp() override {
+ auto mock_queue = absl::make_unique<MockCommandQueue>(
+ "mock", CommandCategory::kTransfer | CommandCategory::kDispatch);
+ mock_target_queue = mock_queue.get();
+ command_queue = absl::make_unique<AsyncCommandQueue>(std::move(mock_queue));
+ }
+
+ void TearDown() override {
+ command_queue.reset();
+ mock_target_queue = nullptr;
+ }
+};
+
+// Tests that submitting a command buffer and immediately waiting will not
+// deadlock.
+TEST_F(AsyncCommandQueueTest, BlockingSubmit) {
+ ::testing::InSequence sequence;
+
+ auto cmd_buffer = make_ref<MockCommandBuffer>(
+ nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
+
+ EXPECT_CALL(*mock_target_queue, Submit(_, _))
+ .WillOnce(
+ [&](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
+ CHECK_EQ(1, batches.size());
+ CHECK_EQ(1, batches[0].command_buffers.size());
+ CHECK_EQ(cmd_buffer.get(), batches[0].command_buffers[0]);
+ CHECK_EQ(nullptr, fence.first);
+ return OkStatus();
+ });
+ HostFence fence(0u);
+ ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u}));
+ ASSERT_OK(HostFence::WaitForFences({{&fence, 1u}}, /*wait_all=*/true,
+ absl::InfiniteFuture()));
+}
+
+// Tests that failure is propagated along the fence from the target queue.
+TEST_F(AsyncCommandQueueTest, PropagateSubmitFailure) {
+ ::testing::InSequence sequence;
+
+ auto cmd_buffer = make_ref<MockCommandBuffer>(
+ nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
+
+ EXPECT_CALL(*mock_target_queue, Submit(_, _))
+ .WillOnce(
+ [](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
+ return DataLossErrorBuilder(IREE_LOC);
+ });
+ HostFence fence(0u);
+ ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u}));
+ EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences(
+ {{&fence, 1u}}, /*wait_all=*/true, absl::InfiniteFuture())));
+}
+
+// Tests that waiting for idle is a no-op when nothing is queued.
+TEST_F(AsyncCommandQueueTest, WaitIdleWhileIdle) {
+ ASSERT_OK(command_queue->WaitIdle());
+}
+
+// Tests that waiting for idle will block when work is pending/in-flight.
+TEST_F(AsyncCommandQueueTest, WaitIdleWithPending) {
+ ::testing::InSequence sequence;
+
+ auto cmd_buffer = make_ref<MockCommandBuffer>(
+ nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
+
+ EXPECT_CALL(*mock_target_queue, Submit(_, _))
+ .WillOnce(
+ [](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
+ Sleep(absl::Milliseconds(100));
+ return OkStatus();
+ });
+ HostFence fence(0u);
+ ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u}));
+
+ // This should block for a sec or two.
+ ASSERT_OK(command_queue->WaitIdle());
+
+ // Should have already expired.
+ ASSERT_OK_AND_ASSIGN(uint64_t value, fence.QueryValue());
+ ASSERT_EQ(1u, value);
+}
+
+// Tests that waiting for idle with multiple pending submissions will wait until
+// all of them complete while still allowing incremental progress.
+TEST_F(AsyncCommandQueueTest, WaitIdleAndProgress) {
+ ::testing::InSequence sequence;
+
+ EXPECT_CALL(*mock_target_queue, Submit(_, _))
+ .WillRepeatedly(
+ [](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
+ Sleep(absl::Milliseconds(100));
+ return OkStatus();
+ });
+
+ auto cmd_buffer_0 = make_ref<MockCommandBuffer>(
+ nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
+ auto cmd_buffer_1 = make_ref<MockCommandBuffer>(
+ nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
+
+ HostFence fence_0(0u);
+ ASSERT_OK(
+ command_queue->Submit({{}, {cmd_buffer_0.get()}, {}}, {&fence_0, 1u}));
+ HostFence fence_1(0u);
+ ASSERT_OK(
+ command_queue->Submit({{}, {cmd_buffer_1.get()}, {}}, {&fence_1, 1u}));
+
+ // This should block for a sec or two.
+ ASSERT_OK(command_queue->WaitIdle());
+
+ // Both should have already expired.
+ ASSERT_OK_AND_ASSIGN(uint64_t value_0, fence_0.QueryValue());
+ ASSERT_EQ(1u, value_0);
+ ASSERT_OK_AND_ASSIGN(uint64_t value_1, fence_1.QueryValue());
+ ASSERT_EQ(1u, value_1);
+}
+
+// Tests that failures are sticky.
+TEST_F(AsyncCommandQueueTest, StickyFailures) {
+ ::testing::InSequence sequence;
+
+ // Fail.
+ EXPECT_CALL(*mock_target_queue, Submit(_, _))
+ .WillOnce(
+ [](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
+ Sleep(absl::Milliseconds(100));
+ return DataLossErrorBuilder(IREE_LOC);
+ });
+ auto cmd_buffer_0 = make_ref<MockCommandBuffer>(
+ nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
+ HostFence fence_0(0u);
+ ASSERT_OK(
+ command_queue->Submit({{}, {cmd_buffer_0.get()}, {}}, {&fence_0, 1u}));
+ EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences(
+ {{&fence_0, 1u}}, /*wait_all=*/true, absl::InfiniteFuture())));
+
+ // Future flushes/waits/etc should also fail.
+ EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle()));
+
+ // Future submits should fail asynchronously.
+ auto cmd_buffer_1 = make_ref<MockCommandBuffer>(
+ nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
+ HostFence fence_1(0u);
+ EXPECT_TRUE(IsDataLoss(
+ command_queue->Submit({{}, {cmd_buffer_1.get()}, {}}, {&fence_1, 1u})));
+}
+
+// Tests that a failure with two submissions pending causes the second to
+// bail as well.
+TEST_F(AsyncCommandQueueTest, FailuresCascadeAcrossSubmits) {
+ ::testing::InSequence sequence;
+
+ // Fail.
+ EXPECT_CALL(*mock_target_queue, Submit(_, _))
+ .WillOnce(
+ [](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
+ Sleep(absl::Milliseconds(100));
+ return DataLossErrorBuilder(IREE_LOC);
+ });
+
+ auto cmd_buffer_0 = make_ref<MockCommandBuffer>(
+ nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
+ auto cmd_buffer_1 = make_ref<MockCommandBuffer>(
+ nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
+
+ HostBinarySemaphore semaphore_0_1(false);
+ HostFence fence_0(0u);
+ ASSERT_OK(command_queue->Submit({{}, {cmd_buffer_0.get()}, {&semaphore_0_1}},
+ {&fence_0, 1u}));
+ HostFence fence_1(0u);
+ ASSERT_OK(command_queue->Submit({{&semaphore_0_1}, {cmd_buffer_1.get()}, {}},
+ {&fence_1, 1u}));
+
+ EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle()));
+
+ EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences(
+ {{&fence_0, 1u}}, /*wait_all=*/true, absl::InfiniteFuture())));
+ EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences(
+ {{&fence_1, 1u}}, /*wait_all=*/true, absl::InfiniteFuture())));
+
+ // Future flushes/waits/etc should also fail.
+ EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle()));
+}
+
+} // namespace
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/host_buffer.cc b/hal/host/host_buffer.cc
new file mode 100644
index 0000000..fe4bce4
--- /dev/null
+++ b/hal/host/host_buffer.cc
@@ -0,0 +1,148 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/host_buffer.h"
+
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+
+#include "base/logging.h"
+#include "base/source_location.h"
+#include "base/status.h"
+
+namespace iree {
+namespace hal {
+
+class Allocator;
+
+HostBuffer::HostBuffer(Allocator* allocator, MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield usage, device_size_t allocation_size,
+ void* data, bool owns_data)
+ : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, 0,
+ allocation_size),
+ data_(data),
+ owns_data_(owns_data) {}
+
+HostBuffer::~HostBuffer() {
+ if (owns_data_ && data_) {
+ std::free(data_);
+ data_ = nullptr;
+ }
+}
+
+Status HostBuffer::FillImpl(device_size_t byte_offset,
+ device_size_t byte_length, const void* pattern,
+ device_size_t pattern_length) {
+ auto data_ptr = data_;
+ switch (pattern_length) {
+ case 1: {
+ uint8_t* data = static_cast<uint8_t*>(data_ptr);
+ uint8_t value_bits = *static_cast<const uint8_t*>(pattern);
+ std::fill_n(data + byte_offset, byte_length, value_bits);
+ break;
+ }
+ case 2: {
+ uint16_t* data = static_cast<uint16_t*>(data_ptr);
+ uint16_t value_bits = *static_cast<const uint16_t*>(pattern);
+ std::fill_n(data + byte_offset / sizeof(uint16_t),
+ byte_length / sizeof(uint16_t), value_bits);
+ break;
+ }
+ case 4: {
+ uint32_t* data = static_cast<uint32_t*>(data_ptr);
+ uint32_t value_bits = *static_cast<const uint32_t*>(pattern);
+ std::fill_n(data + byte_offset / sizeof(uint32_t),
+ byte_length / sizeof(uint32_t), value_bits);
+ break;
+ }
+ default:
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Unsupported scalar data size: " << pattern_length;
+ }
+ return OkStatus();
+}
+
+Status HostBuffer::ReadDataImpl(device_size_t source_offset, void* data,
+ device_size_t data_length) {
+ auto data_ptr = static_cast<uint8_t*>(data_);
+ std::memcpy(data, data_ptr + source_offset, data_length);
+ return OkStatus();
+}
+
+Status HostBuffer::WriteDataImpl(device_size_t target_offset, const void* data,
+ device_size_t data_length) {
+ auto data_ptr = static_cast<uint8_t*>(data_);
+ std::memcpy(data_ptr + target_offset, data, data_length);
+ return OkStatus();
+}
+
+Status HostBuffer::CopyDataImpl(device_size_t target_offset,
+ Buffer* source_buffer,
+ device_size_t source_offset,
+ device_size_t data_length) {
+ // This is pretty terrible. Let's not do this.
+ // TODO(benvanik): a way for allocators to indicate transfer compat.
+ ASSIGN_OR_RETURN(auto source_data,
+ source_buffer->MapMemory<uint8_t>(
+ MemoryAccess::kRead, source_offset, data_length));
+ CHECK_EQ(data_length, source_data.size());
+ auto data_ptr = static_cast<uint8_t*>(data_);
+ std::memcpy(data_ptr + target_offset, source_data.data(), data_length);
+ return OkStatus();
+}
+
+Status HostBuffer::MapMemoryImpl(MappingMode mapping_mode,
+ MemoryAccessBitfield memory_access,
+ device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void** out_data) {
+ auto data_ptr = static_cast<uint8_t*>(data_);
+ *out_data = data_ptr + local_byte_offset;
+
+ // If we mapped for discard scribble over the bytes. This is not a mandated
+ // behavior but it will make debugging issues easier. Alternatively for
+ // heap buffers we could reallocate them such that ASAN yells, but that
+ // would only work if the entire buffer was discarded.
+#ifndef NDEBUG
+ if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) {
+ std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length);
+ }
+#endif // !NDEBUG
+
+ return OkStatus();
+}
+
+Status HostBuffer::UnmapMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void* data) {
+ // No-op? We still want error checking to make finding misuse easier.
+ return OkStatus();
+}
+
+Status HostBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) {
+ // No-op? We still want error checking to make finding misuse easier.
+ return OkStatus();
+}
+
+Status HostBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) {
+ // No-op? We still want error checking to make finding misuse easier.
+ return OkStatus();
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/host_buffer.h b/hal/host/host_buffer.h
new file mode 100644
index 0000000..2d20ea7
--- /dev/null
+++ b/hal/host/host_buffer.h
@@ -0,0 +1,67 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_HOST_BUFFER_H_
+#define IREE_HAL_HOST_BUFFER_H_
+
+#include <cstdint>
+
+#include "base/status.h"
+#include "hal/buffer.h"
+
+namespace iree {
+namespace hal {
+
+// A buffer type that operates on host pointers.
+// This can be used by Allocator implementations when they support operating
+// on host memory (or mapping their memory to host memory).
+class HostBuffer : public Buffer {
+ public:
+ HostBuffer(Allocator* allocator, MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
+ device_size_t allocation_size, void* data, bool owns_data);
+
+ ~HostBuffer() override;
+
+ protected:
+ Status FillImpl(device_size_t byte_offset, device_size_t byte_length,
+ const void* pattern, device_size_t pattern_length) override;
+ Status ReadDataImpl(device_size_t source_offset, void* data,
+ device_size_t data_length) override;
+ Status WriteDataImpl(device_size_t target_offset, const void* data,
+ device_size_t data_length) override;
+ Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer,
+ device_size_t source_offset,
+ device_size_t data_length) override;
+ Status MapMemoryImpl(MappingMode mapping_mode,
+ MemoryAccessBitfield memory_access,
+ device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void** out_data) override;
+ Status UnmapMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length, void* data) override;
+ Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) override;
+ Status FlushMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) override;
+
+ private:
+ void* data_ = nullptr;
+ bool owns_data_ = false;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_HOST_BUFFER_H_
diff --git a/hal/host/host_event.cc b/hal/host/host_event.cc
new file mode 100644
index 0000000..1b7abac
--- /dev/null
+++ b/hal/host/host_event.cc
@@ -0,0 +1,25 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/host_event.h"
+
+namespace iree {
+namespace hal {
+
+HostEvent::HostEvent() = default;
+
+HostEvent::~HostEvent() = default;
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/host_event.h b/hal/host/host_event.h
new file mode 100644
index 0000000..cfdbe09
--- /dev/null
+++ b/hal/host/host_event.h
@@ -0,0 +1,32 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_HOST_HOST_EVENT_H_
+#define IREE_HAL_HOST_HOST_EVENT_H_
+
+#include "hal/event.h"
+
+namespace iree {
+namespace hal {
+
+class HostEvent final : public Event {
+ public:
+ HostEvent();
+ ~HostEvent() override;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_HOST_HOST_EVENT_H_
diff --git a/hal/host/host_fence.cc b/hal/host/host_fence.cc
new file mode 100644
index 0000000..a983af9
--- /dev/null
+++ b/hal/host/host_fence.cc
@@ -0,0 +1,110 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/host_fence.h"
+
+#include <atomic>
+#include <cstdint>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/synchronization/mutex.h"
+#include "base/status.h"
+#include "base/tracing.h"
+
+namespace iree {
+namespace hal {
+
+HostFence::HostFence(uint64_t initial_value) : value_(initial_value) {}
+
+HostFence::~HostFence() = default;
+
+Status HostFence::status() const {
+ absl::MutexLock lock(&mutex_);
+ return status_;
+}
+
+StatusOr<uint64_t> HostFence::QueryValue() {
+ return value_.load(std::memory_order_acquire);
+}
+
+Status HostFence::Signal(uint64_t value) {
+ absl::MutexLock lock(&mutex_);
+ if (!status_.ok()) {
+ return status_;
+ }
+ if (value_.exchange(value) >= value) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Fence values must be monotonically increasing";
+ }
+ return OkStatus();
+}
+
+Status HostFence::Fail(Status status) {
+ absl::MutexLock lock(&mutex_);
+ status_ = status;
+ value_.store(UINT64_MAX, std::memory_order_release);
+ return OkStatus();
+}
+
+// static
+Status HostFence::WaitForFences(absl::Span<const FenceValue> fences,
+ bool wait_all, absl::Time deadline) {
+ IREE_TRACE_SCOPE0("HostFence::WaitForFences");
+
+ // Some of the fences may already be signaled; we only need to wait for those
+ // that are not yet at the expected value.
+ using HostFenceValue = std::pair<HostFence*, uint64_t>;
+ absl::InlinedVector<HostFenceValue, 4> waitable_fences;
+ waitable_fences.reserve(fences.size());
+ for (auto& fence_value : fences) {
+ auto* fence = reinterpret_cast<HostFence*>(fence_value.first);
+ ASSIGN_OR_RETURN(uint64_t current_value, fence->QueryValue());
+ if (current_value == UINT64_MAX) {
+ // Fence has failed. Return the error.
+ return fence->status();
+ } else if (current_value < fence_value.second) {
+ // Fence has not yet hit the required value; wait for it.
+ waitable_fences.push_back({fence, fence_value.second});
+ }
+ }
+
+ // TODO(benvanik): maybe sort fences by value in case we are waiting on
+ // multiple values from the same fence.
+
+ // Loop over the fences and wait for them to complete.
+ // TODO(b/140026716): add WaitHandle support for !wait_all (wait any).
+ for (auto& fence_value : waitable_fences) {
+ auto* fence = fence_value.first;
+ absl::MutexLock lock(&fence->mutex_);
+ if (!fence->mutex_.AwaitWithDeadline(
+ absl::Condition(
+ +[](HostFenceValue* fence_value) {
+ return fence_value->first->value_.load(
+ std::memory_order_acquire) >= fence_value->second;
+ },
+ &fence_value),
+ deadline)) {
+ return DeadlineExceededErrorBuilder(IREE_LOC)
+ << "Deadline exceeded waiting for fences";
+ }
+ if (!fence->status_.ok()) {
+ return fence->status_;
+ }
+ }
+
+ return OkStatus();
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/host_fence.h b/hal/host/host_fence.h
new file mode 100644
index 0000000..16464db
--- /dev/null
+++ b/hal/host/host_fence.h
@@ -0,0 +1,64 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_HOST_HOST_FENCE_H_
+#define IREE_HAL_HOST_HOST_FENCE_H_
+
+#include <atomic>
+#include <cstdint>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "hal/fence.h"
+
+namespace iree {
+namespace hal {
+
+// TODO(b/140026716): add WaitHandle support for better multi-wait.
+// Simple host-only fence semaphore implemented with a mutex.
+//
+// Thread-safe (as instances may be imported and used by others).
+class HostFence final : public Fence {
+ public:
+ // Waits for one or more (or all) fences to reach or exceed the given values.
+ static Status WaitForFences(absl::Span<const FenceValue> fences,
+ bool wait_all, absl::Time deadline);
+
+ explicit HostFence(uint64_t initial_value);
+ ~HostFence() override;
+
+ Status status() const override;
+ StatusOr<uint64_t> QueryValue() override;
+
+ Status Signal(uint64_t value);
+ Status Fail(Status status);
+
+ private:
+ // The mutex is not required to query the value; this lets us quickly check if
+ // a required value has been exceeded. The mutex is only used to update and
+ // notify waiters.
+ std::atomic<uint64_t> value_{0};
+
+ // We have a full mutex here so that we can perform condvar waits on value
+ // changes.
+ mutable absl::Mutex mutex_;
+ Status status_ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_HOST_HOST_FENCE_H_
diff --git a/hal/host/host_fence_test.cc b/hal/host/host_fence_test.cc
new file mode 100644
index 0000000..c0d12a5
--- /dev/null
+++ b/hal/host/host_fence_test.cc
@@ -0,0 +1,148 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/host_fence.h"
+
+#include <cstdint>
+#include <thread> // NOLINT
+
+#include "absl/time/time.h"
+#include "base/status.h"
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace hal {
+namespace {
+
+// Tests that a fence that is unused properly cleans itself up.
+TEST(HostFenceTest, NoOp) {
+ HostFence fence(123u);
+ EXPECT_TRUE(fence.status().ok());
+ ASSERT_OK_AND_ASSIGN(uint64_t value, fence.QueryValue());
+ EXPECT_EQ(123u, value);
+}
+
+// Tests that a fence will accept new values as it is signaled.
+TEST(HostFenceTest, NormalSignaling) {
+ HostFence fence(2u);
+ EXPECT_EQ(2u, fence.QueryValue().ValueOrDie());
+ EXPECT_OK(fence.Signal(3u));
+ EXPECT_EQ(3u, fence.QueryValue().ValueOrDie());
+ EXPECT_OK(fence.Signal(40u));
+ EXPECT_EQ(40u, fence.QueryValue().ValueOrDie());
+}
+
+// Tests that a fence will fail to set non-increasing values.
+TEST(HostFenceTest, RequireIncreasingValues) {
+ HostFence fence(2u);
+ EXPECT_EQ(2u, fence.QueryValue().ValueOrDie());
+ // Same value.
+ EXPECT_TRUE(IsInvalidArgument(fence.Signal(2u)));
+ // Decreasing.
+ EXPECT_TRUE(IsInvalidArgument(fence.Signal(1u)));
+}
+
+// Tests that a fence that has failed will remain in a failed state.
+TEST(HostFenceTest, StickyFailure) {
+ HostFence fence(2u);
+ // Signal to 3.
+ EXPECT_OK(fence.Signal(3u));
+ EXPECT_TRUE(fence.status().ok());
+ EXPECT_EQ(3u, fence.QueryValue().ValueOrDie());
+
+ // Fail now.
+ EXPECT_OK(fence.Fail(UnknownErrorBuilder(IREE_LOC)));
+ EXPECT_TRUE(IsUnknown(fence.status()));
+ EXPECT_EQ(UINT64_MAX, fence.QueryValue().ValueOrDie());
+
+ // Unable to signal again (it'll return the sticky failure).
+ EXPECT_TRUE(IsUnknown(fence.Signal(4u)));
+ EXPECT_TRUE(IsUnknown(fence.status()));
+ EXPECT_EQ(UINT64_MAX, fence.QueryValue().ValueOrDie());
+}
+
+// Tests waiting on no fences.
+TEST(HostFenceTest, EmptyWait) {
+ EXPECT_OK(
+ HostFence::WaitForFences({}, /*wait_all=*/true, absl::InfiniteFuture()));
+}
+
+// Tests waiting on a fence that has already been signaled.
+TEST(HostFenceTest, WaitAlreadySignaled) {
+ HostFence fence(2u);
+ // Test both previous and current values.
+ EXPECT_OK(HostFence::WaitForFences({{&fence, 1u}}, /*wait_all=*/true,
+ absl::InfiniteFuture()));
+ EXPECT_OK(HostFence::WaitForFences({{&fence, 2u}}, /*wait_all=*/true,
+ absl::InfiniteFuture()));
+}
+
+// Tests waiting on a fence that has not been signaled.
+TEST(HostFenceTest, WaitUnsignaled) {
+ HostFence fence(2u);
+ // NOTE: we don't actually block here because otherwise we'd lock up.
+ EXPECT_TRUE(IsDeadlineExceeded(HostFence::WaitForFences(
+ {{&fence, 3u}}, /*wait_all=*/true, absl::InfinitePast())));
+}
+
+// Tests waiting on a failed fence (it should return the error on the fence).
+TEST(HostFenceTest, WaitAlreadyFailed) {
+ HostFence fence(2u);
+ EXPECT_OK(fence.Fail(UnknownErrorBuilder(IREE_LOC)));
+ EXPECT_TRUE(IsUnknown(HostFence::WaitForFences(
+ {{&fence, 2u}}, /*wait_all=*/true, absl::InfinitePast())));
+}
+
+// Tests threading behavior by ping-ponging between the test main thread and
+// a little thread.
+TEST(HostFenceTest, PingPong) {
+ HostFence a2b(0u);
+ HostFence b2a(0u);
+ std::thread thread([&]() {
+ // Should advance right past this because the value is already set.
+ ASSERT_OK(HostFence::WaitForFences({{&a2b, 0u}}, /*wait_all=*/true,
+ absl::InfiniteFuture()));
+ ASSERT_OK(b2a.Signal(1u));
+ // Jump ahead.
+ ASSERT_OK(HostFence::WaitForFences({{&a2b, 4u}}, /*wait_all=*/true,
+ absl::InfiniteFuture()));
+ });
+ ASSERT_OK(HostFence::WaitForFences({{&b2a, 1u}}, /*wait_all=*/true,
+ absl::InfiniteFuture()));
+ ASSERT_OK(a2b.Signal(4u));
+ thread.join();
+}
+
+// Tests that failure still wakes waiters and propagates the error.
+TEST(HostFenceTest, FailNotifies) {
+ HostFence a2b(0u);
+ HostFence b2a(0u);
+ bool got_failure = false;
+ std::thread thread([&]() {
+ ASSERT_OK(b2a.Signal(1u));
+ got_failure = IsUnknown(HostFence::WaitForFences(
+ {{&a2b, 1u}}, /*wait_all=*/true, absl::InfiniteFuture()));
+ });
+ ASSERT_OK(HostFence::WaitForFences({{&b2a, 1u}}, /*wait_all=*/true,
+ absl::InfiniteFuture()));
+ ASSERT_OK(a2b.Fail(UnknownErrorBuilder(IREE_LOC)));
+ thread.join();
+ ASSERT_TRUE(got_failure);
+}
+
+} // namespace
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/host_local_allocator.cc b/hal/host/host_local_allocator.cc
new file mode 100644
index 0000000..6b60769
--- /dev/null
+++ b/hal/host/host_local_allocator.cc
@@ -0,0 +1,111 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/host_local_allocator.h"
+
+#include <cstdlib>
+#include <string>
+#include <utility>
+
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/host/host_buffer.h"
+
+namespace iree {
+namespace hal {
+
+HostLocalAllocator::HostLocalAllocator() = default;
+
+HostLocalAllocator::~HostLocalAllocator() = default;
+
+bool HostLocalAllocator::CanUseBufferLike(
+ Allocator* source_allocator, MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ BufferUsageBitfield intended_usage) const {
+ // Must always have visibility to the device, which ensures we can test
+ // against the host but have things work on devices with separate address
+ // spaces.
+ if (!AnyBitSet(memory_type & MemoryType::kDeviceVisible)) {
+ return false;
+ }
+
+ // kHostVisible is required for mapping.
+ if (AnyBitSet(intended_usage & BufferUsage::kMapping) &&
+ !AnyBitSet(memory_type & MemoryType::kHostVisible)) {
+ return false;
+ }
+
+ // Dispatch needs to be specified if we intend to dispatch.
+ if (AnyBitSet(intended_usage & BufferUsage::kDispatch) &&
+ !AnyBitSet(buffer_usage & BufferUsage::kDispatch)) {
+ return false;
+ }
+
+ return true;
+}
+
+bool HostLocalAllocator::CanAllocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) const {
+ // Host allows everything, pretty much, so long as it is device-visible (as
+ // the host is the device here).
+ return AnyBitSet(memory_type & MemoryType::kDeviceVisible);
+}
+
+Status HostLocalAllocator::MakeCompatible(
+ MemoryTypeBitfield* memory_type, BufferUsageBitfield* buffer_usage) const {
+ // Always ensure we are host-visible.
+ *memory_type |= MemoryType::kHostVisible;
+
+ // Host currently uses mapping to copy buffers, which is done a lot.
+ // We could probably remove this restriction somehow.
+ *buffer_usage |= BufferUsage::kMapping;
+
+ // TODO(b/111372612): tensorflow needs transfer too, but shouldn't.
+ *buffer_usage |= BufferUsage::kTransfer;
+
+ return OkStatus();
+}
+
+StatusOr<ref_ptr<Buffer>> HostLocalAllocator::Allocate(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ size_t allocation_size) {
+ IREE_TRACE_SCOPE0("HostLocalAllocator::Allocate");
+
+ if (!CanAllocate(memory_type, buffer_usage, allocation_size)) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Allocation not supported; memory_type="
+ << MemoryTypeString(memory_type)
+ << ", buffer_usage=" << BufferUsageString(buffer_usage)
+ << ", allocation_size=" << allocation_size;
+ }
+
+ // Make compatible with our requirements.
+ RETURN_IF_ERROR(MakeCompatible(&memory_type, &buffer_usage));
+
+ void* malloced_data = std::calloc(1, allocation_size);
+ if (!malloced_data) {
+ return ResourceExhaustedErrorBuilder(IREE_LOC)
+ << "Failed to malloc " << allocation_size << " bytes";
+ }
+
+ auto buffer =
+ make_ref<HostBuffer>(this, memory_type, MemoryAccess::kAll, buffer_usage,
+ allocation_size, malloced_data, true);
+ return buffer;
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/host_local_allocator.h b/hal/host/host_local_allocator.h
new file mode 100644
index 0000000..fd38910
--- /dev/null
+++ b/hal/host/host_local_allocator.h
@@ -0,0 +1,60 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_HOST_LOCAL_ALLOCATOR_H_
+#define IREE_HAL_HOST_LOCAL_ALLOCATOR_H_
+
+#include <cstddef>
+#include <memory>
+
+#include "base/status.h"
+#include "hal/allocator.h"
+#include "hal/buffer.h"
+
+namespace iree {
+namespace hal {
+
+// An allocator implementation that allocates buffers from host memory.
+// This can be used for drivers that do not have a memory space of their own.
+//
+// Buffers allocated will have be MemoryType::kHostLocal | kDeviceVisible as
+// the 'device' in the case of a host-local queue *is* the host. To keep code
+// written initially for a host-local queue working when other queues are used
+// the allocator only works with buffers that are kDeviceVisible.
+class HostLocalAllocator : public Allocator {
+ public:
+ HostLocalAllocator();
+ ~HostLocalAllocator() override;
+
+ bool CanUseBufferLike(Allocator* source_allocator,
+ MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ BufferUsageBitfield intended_usage) const override;
+
+ bool CanAllocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) const override;
+
+ Status MakeCompatible(MemoryTypeBitfield* memory_type,
+ BufferUsageBitfield* buffer_usage) const override;
+
+ StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) override;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_HOST_LOCAL_ALLOCATOR_H_
diff --git a/hal/host/host_local_command_processor.cc b/hal/host/host_local_command_processor.cc
new file mode 100644
index 0000000..72a3c88
--- /dev/null
+++ b/hal/host/host_local_command_processor.cc
@@ -0,0 +1,120 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/host_local_command_processor.h"
+
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+
+namespace iree {
+namespace hal {
+
+HostLocalCommandProcessor::HostLocalCommandProcessor(
+ Allocator* allocator, CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories)
+ : CommandBuffer(allocator, mode, command_categories) {}
+
+HostLocalCommandProcessor::~HostLocalCommandProcessor() = default;
+
+Status HostLocalCommandProcessor::Begin() {
+ IREE_TRACE_SCOPE0("HostLocalCommandProcessor::Begin");
+ is_recording_ = true;
+ return OkStatus();
+}
+
+Status HostLocalCommandProcessor::End() {
+ IREE_TRACE_SCOPE0("HostLocalCommandProcessor::End");
+ is_recording_ = false;
+ return OkStatus();
+}
+
+Status HostLocalCommandProcessor::ExecutionBarrier(
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) {
+ IREE_TRACE_SCOPE0("HostLocalCommandProcessor::ExecutionBarrier");
+ // No-op.
+ return OkStatus();
+}
+
+Status HostLocalCommandProcessor::SignalEvent(
+ Event* event, ExecutionStageBitfield source_stage_mask) {
+ IREE_TRACE_SCOPE0("HostLocalCommandProcessor::SignalEvent");
+ // No-op.
+ return OkStatus();
+}
+
+Status HostLocalCommandProcessor::ResetEvent(
+ Event* event, ExecutionStageBitfield source_stage_mask) {
+ IREE_TRACE_SCOPE0("HostLocalCommandProcessor::ResetEvent");
+ // No-op.
+ return OkStatus();
+}
+
+Status HostLocalCommandProcessor::WaitEvents(
+ absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) {
+ IREE_TRACE_SCOPE0("HostLocalCommandProcessor::WaitEvents");
+ // No-op.
+ return OkStatus();
+}
+
+Status HostLocalCommandProcessor::FillBuffer(Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length,
+ const void* pattern,
+ size_t pattern_length) {
+ IREE_TRACE_SCOPE0("HostLocalCommandProcessor::FillBuffer");
+ return target_buffer->Fill(target_offset, length, pattern, pattern_length);
+}
+
+Status HostLocalCommandProcessor::DiscardBuffer(Buffer* buffer) {
+ IREE_TRACE_SCOPE0("HostLocalCommandProcessor::DiscardBuffer");
+ // No-op as we don't support lazily allocated buffers.
+ return OkStatus();
+}
+
+Status HostLocalCommandProcessor::UpdateBuffer(const void* source_buffer,
+ device_size_t source_offset,
+ Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length) {
+ IREE_TRACE_SCOPE0("HostLocalCommandProcessor::UpdateBuffer");
+ return target_buffer->WriteData(
+ target_offset, static_cast<const uint8_t*>(source_buffer) + source_offset,
+ length);
+}
+
+Status HostLocalCommandProcessor::CopyBuffer(Buffer* source_buffer,
+ device_size_t source_offset,
+ Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length) {
+ IREE_TRACE_SCOPE0("HostLocalCommandProcessor::CopyBuffer");
+ return target_buffer->CopyData(target_offset, source_buffer, source_offset,
+ length);
+}
+
+Status HostLocalCommandProcessor::Dispatch(
+ const DispatchRequest& dispatch_request) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Command processor does not support dispatch operations";
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/host_local_command_processor.h b/hal/host/host_local_command_processor.h
new file mode 100644
index 0000000..e18d2a8
--- /dev/null
+++ b/hal/host/host_local_command_processor.h
@@ -0,0 +1,85 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_
+#define IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_
+
+#include "hal/command_buffer.h"
+
+namespace iree {
+namespace hal {
+
+// Host-local command processor for dispatching transfer operations against
+// buffers allocated from the HostLocalAllocator.
+// This assumes that all buffers are host-visible (if not local) and that all
+// buffers can be mapped for access.
+//
+// Subclasses may implement Dispatch, otherwise the default implementation just
+// returns failure.
+//
+// Thread-compatible (as with CommandBuffer itself).
+class HostLocalCommandProcessor : public CommandBuffer {
+ public:
+ HostLocalCommandProcessor(Allocator* allocator,
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories);
+ ~HostLocalCommandProcessor() override;
+
+ bool is_recording() const override { return is_recording_; }
+
+ Status Begin() override;
+ Status End() override;
+
+ Status ExecutionBarrier(
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) override;
+
+ Status SignalEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) override;
+
+ Status ResetEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) override;
+
+ Status WaitEvents(absl::Span<Event*> events,
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) override;
+
+ Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length, const void* pattern,
+ size_t pattern_length) override;
+
+ Status DiscardBuffer(Buffer* buffer) override;
+
+ Status UpdateBuffer(const void* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) override;
+
+ Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) override;
+
+ Status Dispatch(const DispatchRequest& dispatch_request) override;
+
+ private:
+ bool is_recording_ = false;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_
diff --git a/hal/host/host_submission_queue.cc b/hal/host/host_submission_queue.cc
new file mode 100644
index 0000000..aa672b0
--- /dev/null
+++ b/hal/host/host_submission_queue.cc
@@ -0,0 +1,295 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/host_submission_queue.h"
+
+#include <atomic>
+#include <cstdint>
+
+#include "absl/synchronization/mutex.h"
+#include "base/status.h"
+#include "base/tracing.h"
+
+namespace iree {
+namespace hal {
+
+HostBinarySemaphore::HostBinarySemaphore(bool initial_value) {
+ State state = {0};
+ state.signaled = initial_value ? 1 : 0;
+ state_ = state;
+}
+
+bool HostBinarySemaphore::is_signaled() const {
+ return state_.load(std::memory_order_acquire).signaled == 1;
+}
+
+Status HostBinarySemaphore::BeginSignaling() {
+ State old_state = state_.load(std::memory_order_acquire);
+ if (old_state.signal_pending != 0) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "A signal operation on a binary semaphore is already pending";
+ }
+ State new_state = old_state;
+ new_state.signal_pending = 1;
+ state_.compare_exchange_strong(old_state, new_state);
+ return OkStatus();
+}
+
+Status HostBinarySemaphore::EndSignaling() {
+ State old_state = state_.load(std::memory_order_acquire);
+ DCHECK_EQ(old_state.signal_pending, 1)
+ << "A signal operation on a binary semaphore was not pending";
+ if (old_state.signaled != 0) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "A binary semaphore cannot be signaled multiple times";
+ }
+ State new_state = old_state;
+ new_state.signal_pending = 0;
+ new_state.signaled = 1;
+ state_.compare_exchange_strong(old_state, new_state);
+ return OkStatus();
+}
+
+Status HostBinarySemaphore::BeginWaiting() {
+ State old_state = state_.load(std::memory_order_acquire);
+ if (old_state.wait_pending != 0) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "A wait operation on a binary semaphore is already pending";
+ }
+ State new_state = old_state;
+ new_state.wait_pending = 1;
+ state_.compare_exchange_strong(old_state, new_state);
+ return OkStatus();
+}
+
+Status HostBinarySemaphore::EndWaiting() {
+ State old_state = state_.load(std::memory_order_acquire);
+ DCHECK_EQ(old_state.wait_pending, 1)
+ << "A wait operation on a binary semaphore was not pending";
+ if (old_state.signaled != 1) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "A binary semaphore cannot be reset multiple times";
+ }
+ State new_state = old_state;
+ new_state.wait_pending = 0;
+ new_state.signaled = 0;
+ state_.compare_exchange_strong(old_state, new_state);
+ return OkStatus();
+}
+
+HostSubmissionQueue::HostSubmissionQueue() = default;
+
+HostSubmissionQueue::~HostSubmissionQueue() = default;
+
+bool HostSubmissionQueue::IsBatchReady(const PendingBatch& batch) const {
+ for (auto& wait_point : batch.wait_semaphores) {
+ if (wait_point.index() == 0) {
+ auto* binary_semaphore =
+ reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(wait_point));
+ if (!binary_semaphore->is_signaled()) {
+ return false;
+ }
+ } else {
+ // TODO(b/140141417): implement timeline semaphores.
+ return false;
+ }
+ }
+ return true;
+}
+
+Status HostSubmissionQueue::Enqueue(absl::Span<const SubmissionBatch> batches,
+ FenceValue fence) {
+ IREE_TRACE_SCOPE0("HostSubmissionQueue::Enqueue");
+
+ if (has_shutdown_) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Cannot enqueue new submissions; queue is exiting";
+ } else if (!permanent_error_.ok()) {
+ return permanent_error_;
+ }
+
+ // Verify waiting/signaling behavior on semaphores and prepare them all.
+ // We need to track this to ensure that we are modeling the Vulkan behavior
+ // and are consistent across HAL implementations.
+ for (auto& batch : batches) {
+ for (auto& semaphore_value : batch.wait_semaphores) {
+ if (semaphore_value.index() == 0) {
+ auto* binary_semaphore = reinterpret_cast<HostBinarySemaphore*>(
+ absl::get<0>(semaphore_value));
+ RETURN_IF_ERROR(binary_semaphore->BeginWaiting());
+ } else {
+ // TODO(b/140141417): implement timeline semaphores.
+ return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI";
+ }
+ }
+ for (auto& semaphore_value : batch.signal_semaphores) {
+ if (semaphore_value.index() == 0) {
+ auto* binary_semaphore = reinterpret_cast<HostBinarySemaphore*>(
+ absl::get<0>(semaphore_value));
+ RETURN_IF_ERROR(binary_semaphore->BeginSignaling());
+ } else {
+ // TODO(b/140141417): implement timeline semaphores.
+ return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI";
+ }
+ }
+ }
+
+ // Add to list - order does not matter as Process evaluates semaphores.
+ auto submission = absl::make_unique<Submission>();
+ submission->fence = std::move(fence);
+ submission->pending_batches.resize(batches.size());
+ for (int i = 0; i < batches.size(); ++i) {
+ submission->pending_batches[i] = PendingBatch{
+ {batches[i].wait_semaphores.begin(), batches[i].wait_semaphores.end()},
+ {batches[i].command_buffers.begin(), batches[i].command_buffers.end()},
+ {batches[i].signal_semaphores.begin(),
+ batches[i].signal_semaphores.end()},
+ };
+ }
+ list_.push_back(std::move(submission));
+
+ return OkStatus();
+}
+
+Status HostSubmissionQueue::ProcessBatches(ExecuteFn execute_fn) {
+ IREE_TRACE_SCOPE0("HostSubmissionQueue::ProcessBatches");
+
+ if (!permanent_error_.ok()) {
+ // Sticky failure state.
+ return permanent_error_;
+ }
+
+ // Repeated try to run things until we quiesce or are blocked.
+ while (permanent_error_.ok() && !list_.empty()) {
+ // NOTE: to support re-entrancy where |execute_fn| may modify the submission
+ // list we need to always start from the beginning. If we wanted we could
+ // track a list of ready submissions however that's a lot of bookkeeping and
+ // the list is usually short.
+ bool restart_iteration = false;
+ for (auto* submission : list_) {
+ for (int i = 0; i < submission->pending_batches.size(); ++i) {
+ auto& batch = submission->pending_batches[i];
+ if (!IsBatchReady(batch)) {
+ // Try the next batch in the submission until we find one that is
+ // ready. If none are ready we'll return to the caller.
+ continue;
+ }
+
+ // Batch can run! Process now and remove it from the list so we don't
+ // try to run it again.
+ auto batch_status = ProcessBatch(batch, execute_fn);
+ submission->pending_batches.erase(submission->pending_batches.begin() +
+ i);
+ if (batch_status.ok()) {
+ // Batch succeeded. Since we want to preserve submission order we'll
+ // break out of the loop and try from the first submission again.
+ if (submission->pending_batches.empty()) {
+ // All work for this submission completed successfully. Signal the
+ // fence and remove the submission from the list.
+ RETURN_IF_ERROR(CompleteSubmission(submission, OkStatus()));
+ list_.take(submission).reset();
+ }
+ } else {
+ // Batch failed; set the permanent error flag and abort so we don't
+ // try to process anything else.
+ permanent_error_ = batch_status;
+ RETURN_IF_ERROR(CompleteSubmission(submission, batch_status));
+ list_.take(submission).reset();
+ }
+ restart_iteration = true;
+ break;
+ }
+ if (restart_iteration) break;
+ }
+ }
+
+ if (!permanent_error_.ok()) {
+ // If the sticky error got set while processing we need to abort all
+ // remaining submissions (simulating a device loss).
+ FailAllPending(permanent_error_);
+ return permanent_error_;
+ }
+
+ return OkStatus();
+}
+
+Status HostSubmissionQueue::ProcessBatch(const PendingBatch& batch,
+ const ExecuteFn& execute_fn) {
+ IREE_TRACE_SCOPE0("HostSubmissionQueue::ProcessBatch");
+
+ // Complete the waits on all semaphores and reset them.
+ for (auto& semaphore_value : batch.wait_semaphores) {
+ if (semaphore_value.index() == 0) {
+ auto* binary_semaphore =
+ reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(semaphore_value));
+ RETURN_IF_ERROR(binary_semaphore->EndWaiting());
+ } else {
+ // TODO(b/140141417): implement timeline semaphores.
+ return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI";
+ }
+ }
+
+ // Let the caller handle execution of the command buffers.
+ RETURN_IF_ERROR(execute_fn(batch.command_buffers));
+
+ // Signal all semaphores to allow them to unblock waiters.
+ for (auto& semaphore_value : batch.signal_semaphores) {
+ if (semaphore_value.index() == 0) {
+ auto* binary_semaphore =
+ reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(semaphore_value));
+ RETURN_IF_ERROR(binary_semaphore->EndSignaling());
+ } else {
+ // TODO(b/140141417): implement timeline semaphores.
+ return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI";
+ }
+ }
+
+ return OkStatus();
+}
+
+Status HostSubmissionQueue::CompleteSubmission(Submission* submission,
+ Status status) {
+ IREE_TRACE_SCOPE0("HostSubmissionQueue::CompleteSubmission");
+
+ // It's safe to drop any remaining batches - their semaphores will never be
+ // signaled but that's fine as we should be the only thing relying on them.
+ submission->pending_batches.clear();
+
+ // Signal the fence.
+ auto* fence = static_cast<HostFence*>(submission->fence.first);
+ if (status.ok()) {
+ RETURN_IF_ERROR(fence->Signal(submission->fence.second));
+ } else {
+ RETURN_IF_ERROR(fence->Fail(std::move(status)));
+ }
+
+ return OkStatus();
+}
+
+void HostSubmissionQueue::FailAllPending(Status status) {
+ IREE_TRACE_SCOPE0("HostSubmissionQueue::FailAllPending");
+ while (!list_.empty()) {
+ auto submission = list_.take(list_.front());
+ CompleteSubmission(submission.get(), status).IgnoreError();
+ submission.reset();
+ }
+}
+
+void HostSubmissionQueue::SignalShutdown() {
+ IREE_TRACE_SCOPE0("HostSubmissionQueue::SignalShutdown");
+ has_shutdown_ = true;
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/host_submission_queue.h b/hal/host/host_submission_queue.h
new file mode 100644
index 0000000..1ec141f
--- /dev/null
+++ b/hal/host/host_submission_queue.h
@@ -0,0 +1,163 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_
+#define IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_
+
+#include "absl/base/thread_annotations.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/synchronization/mutex.h"
+#include "base/intrusive_list.h"
+#include "base/status.h"
+#include "hal/command_queue.h"
+#include "hal/host/host_fence.h"
+#include "hal/semaphore.h"
+
+namespace iree {
+namespace hal {
+
+class HostSubmissionQueue;
+
+// Simple host-only binary semaphore implemented with a mutex.
+// To match the expected HAL behavior (mostly dictated by Vulkan) we can only
+// have a single waiter and waits can only occur once a signal has been
+// enqueued.
+//
+// Thread-safe (as instances may be imported and used by others).
+class HostBinarySemaphore final : public BinarySemaphore {
+ public:
+ explicit HostBinarySemaphore(bool initial_value);
+
+ // Returns true if the semaphore has been signaled.
+ bool is_signaled() const;
+
+ private:
+ friend class HostSubmissionQueue;
+
+ // Begins a signal operation and ensures no other signal operation is pending.
+ Status BeginSignaling();
+ // Ends a signal operation by setting the semaphore to the signaled state.
+ Status EndSignaling();
+
+ // Begins a wait operation and ensures no other wait operation is pending.
+ Status BeginWaiting();
+ // Ends a wait operation by resetting the semaphore to the unsignaled state.
+ Status EndWaiting();
+
+ // A single 32-bit int for lock-free semaphore behavior. We need to do this
+ // extra tracking so that we get consistent behavior across HAL
+ // implementations that have strict semaphore semantics.
+ struct State {
+ uint32_t signal_pending : 1;
+ uint32_t wait_pending : 1;
+ uint32_t signaled : 1;
+ };
+ std::atomic<State> state_{{0, 0, 0}};
+};
+
+// Simple host-only timeline semaphore implemented with a mutex.
+//
+// Thread-safe (as instances may be imported and used by others).
+class HostTimelineSemaphore final : public TimelineSemaphore {
+ public:
+ // TODO(b/140141417): implement timeline semaphores.
+};
+
+// A queue managing CommandQueue submissions that uses host-local
+// synchronization primitives. Evaluates submission order by respecting the
+// wait and signal semaphores defined per batch and notifies fences upon
+// submission completion.
+//
+// Note that it's possible for HAL users to deadlock themselves; we don't try to
+// avoid that as in device backends it may not be possible and we want to have
+// some kind of warning in the host implementation that TSAN can catch.
+//
+// Thread-compatible. Const methods may be called from any thread.
+class HostSubmissionQueue {
+ public:
+ using ExecuteFn =
+ std::function<Status(absl::Span<CommandBuffer* const> command_buffers)>;
+
+ HostSubmissionQueue();
+ ~HostSubmissionQueue();
+
+ // Returns true if the queue is currently empty.
+ bool empty() const { return list_.empty(); }
+ // Returns true if SignalShutdown has been called.
+ bool has_shutdown() const { return has_shutdown_; }
+ // The sticky error status, if an error has occurred.
+ Status permanent_error() const { return permanent_error_; }
+
+ // Enqueues a new submission.
+ // No work will be performed until Process is called.
+ Status Enqueue(absl::Span<const SubmissionBatch> batches, FenceValue fence);
+
+ // Processes all ready batches using the provided |execute_fn|.
+ // The function may be called several times if new batches become ready due to
+ // prior batches in the sequence completing during processing.
+ //
+ // Returns any errors returned by |execute_fn| (which will be the same as
+ // permanent_error()). When an error occurs all in-flight submissions are
+ // aborted, the permanent_error() is set, and the queue is shutdown.
+ Status ProcessBatches(ExecuteFn execute_fn);
+
+ // Marks the queue as having shutdown. All pending submissions will be allowed
+ // to complete but future enqueues will fail.
+ void SignalShutdown();
+
+ private:
+ // A submitted command buffer batch and its synchronization information.
+ struct PendingBatch {
+ absl::InlinedVector<SemaphoreValue, 4> wait_semaphores;
+ absl::InlinedVector<CommandBuffer*, 4> command_buffers;
+ absl::InlinedVector<SemaphoreValue, 4> signal_semaphores;
+ };
+ struct Submission : public IntrusiveLinkBase<void> {
+ absl::InlinedVector<PendingBatch, 4> pending_batches;
+ FenceValue fence;
+ };
+
+ // Returns true if all wait semaphores in the |batch| are signaled.
+ bool IsBatchReady(const PendingBatch& batch) const;
+
+ // Processes a batch by resetting semaphores, dispatching the command buffers
+ // to the specified |execute_fn|, and signaling semaphores.
+ //
+ // Preconditions: IsBatchReady(batch) == true
+ Status ProcessBatch(const PendingBatch& batch, const ExecuteFn& execute_fn);
+
+ // Completes a submission by signaling the fence with the given |status|.
+ Status CompleteSubmission(Submission* submission, Status status);
+
+ // Fails all pending submissions with the given status.
+ // Errors that occur during this process are silently ignored.
+ void FailAllPending(Status status);
+
+ // True to exit the thread after all submissions complete.
+ bool has_shutdown_ = false;
+
+ // A sticky error that is set on the first failed submit. All future
+ // submissions will be skipped except for fences, which will receive this
+ // error.
+ Status permanent_error_;
+
+ // Pending submissions in submission order.
+ // Note that we may evaluate batches within the list out of order.
+ IntrusiveList<std::unique_ptr<Submission>> list_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_
diff --git a/hal/host/host_submission_queue_test.cc b/hal/host/host_submission_queue_test.cc
new file mode 100644
index 0000000..ed71de2
--- /dev/null
+++ b/hal/host/host_submission_queue_test.cc
@@ -0,0 +1,30 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/host_submission_queue.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace hal {
+namespace {
+
+TEST(HostSubmissionQueueTest, TBD) {
+ // TODO(benvanik): test!
+}
+
+} // namespace
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/inproc_command_buffer.cc b/hal/host/inproc_command_buffer.cc
new file mode 100644
index 0000000..4115825
--- /dev/null
+++ b/hal/host/inproc_command_buffer.cc
@@ -0,0 +1,264 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/host/inproc_command_buffer.h"
+
+#include "base/tracing.h"
+
+namespace iree {
+namespace hal {
+
+InProcCommandBuffer::InProcCommandBuffer(
+ Allocator* allocator, CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories)
+ : CommandBuffer(allocator, mode, command_categories) {}
+
+InProcCommandBuffer::~InProcCommandBuffer() { Reset(); }
+
+Status InProcCommandBuffer::Begin() {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::Begin");
+ is_recording_ = true;
+ Reset();
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::End() {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::End");
+ is_recording_ = false;
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::ExecutionBarrier(
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::ExecutionBarrier");
+ ASSIGN_OR_RETURN(auto* cmd, AppendCmd<ExecutionBarrierCmd>());
+ cmd->source_stage_mask = source_stage_mask;
+ cmd->target_stage_mask = target_stage_mask;
+ cmd->memory_barriers = AppendStructSpan(memory_barriers);
+ cmd->buffer_barriers = AppendStructSpan(buffer_barriers);
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::SignalEvent(
+ Event* event, ExecutionStageBitfield source_stage_mask) {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::SignalEvent");
+ ASSIGN_OR_RETURN(auto* cmd, AppendCmd<SignalEventCmd>());
+ cmd->event = event;
+ cmd->source_stage_mask = source_stage_mask;
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::ResetEvent(
+ Event* event, ExecutionStageBitfield source_stage_mask) {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::ResetEvent");
+ ASSIGN_OR_RETURN(auto* cmd, AppendCmd<ResetEventCmd>());
+ cmd->event = event;
+ cmd->source_stage_mask = source_stage_mask;
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::WaitEvents(
+ absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::WaitEvents");
+ ASSIGN_OR_RETURN(auto* cmd, AppendCmd<WaitEventsCmd>());
+ cmd->events = AppendStructSpan(events);
+ cmd->source_stage_mask = source_stage_mask;
+ cmd->target_stage_mask = target_stage_mask;
+ cmd->memory_barriers = AppendStructSpan(memory_barriers);
+ cmd->buffer_barriers = AppendStructSpan(buffer_barriers);
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::FillBuffer(Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length,
+ const void* pattern,
+ size_t pattern_length) {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::FillBuffer");
+ ASSIGN_OR_RETURN(auto* cmd, AppendCmd<FillBufferCmd>());
+ cmd->target_buffer = target_buffer;
+ cmd->target_offset = target_offset;
+ cmd->length = length;
+ std::memcpy(cmd->pattern, pattern, pattern_length);
+ cmd->pattern_length = pattern_length;
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::DiscardBuffer(Buffer* buffer) {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::DiscardBuffer");
+ ASSIGN_OR_RETURN(auto* cmd, AppendCmd<DiscardBufferCmd>());
+ cmd->buffer = buffer;
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::UpdateBuffer(const void* source_buffer,
+ device_size_t source_offset,
+ Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length) {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::UpdateBuffer");
+ ASSIGN_OR_RETURN(auto* cmd, AppendCmd<UpdateBufferCmd>());
+ cmd->source_buffer = AppendCmdData(source_buffer, source_offset, length);
+ cmd->target_buffer = target_buffer;
+ cmd->target_offset = target_offset;
+ cmd->length = length;
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::CopyBuffer(Buffer* source_buffer,
+ device_size_t source_offset,
+ Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length) {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::CopyBuffer");
+ ASSIGN_OR_RETURN(auto* cmd, AppendCmd<CopyBufferCmd>());
+ cmd->source_buffer = source_buffer;
+ cmd->source_offset = source_offset;
+ cmd->target_buffer = target_buffer;
+ cmd->target_offset = target_offset;
+ cmd->length = length;
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::Dispatch(const DispatchRequest& dispatch_request) {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::Dispatch");
+ ASSIGN_OR_RETURN(auto* cmd, AppendCmd<DispatchCmd>());
+ cmd->request.executable = dispatch_request.executable;
+ cmd->request.entry_point = dispatch_request.entry_point;
+ cmd->request.workload = dispatch_request.workload;
+ cmd->request.workload_buffer = dispatch_request.workload_buffer;
+ cmd->request.bindings = AppendStructSpan(dispatch_request.bindings);
+ return OkStatus();
+}
+
+void InProcCommandBuffer::Reset() {
+ auto* cmd_list = ¤t_cmd_list_;
+ cmd_list->head = cmd_list->tail = nullptr;
+ cmd_list->arena.Reset();
+}
+
+InProcCommandBuffer::CmdHeader* InProcCommandBuffer::AppendCmdHeader(
+ CmdType type, size_t cmd_size) {
+ auto* cmd_list = ¤t_cmd_list_;
+ auto* cmd_header = reinterpret_cast<CmdHeader*>(
+ cmd_list->arena.AllocateBytes(sizeof(CmdHeader) + cmd_size));
+ cmd_header->next = nullptr;
+ cmd_header->type = type;
+ if (!cmd_list->head) {
+ cmd_list->head = cmd_header;
+ } else if (cmd_list->tail) {
+ cmd_list->tail->next = cmd_header;
+ }
+ cmd_list->tail = cmd_header;
+ return cmd_header;
+}
+
+void* InProcCommandBuffer::AppendCmdData(const void* source_buffer,
+ device_size_t source_offset,
+ device_size_t source_length) {
+ auto* cmd_list = ¤t_cmd_list_;
+
+ uint8_t* allocated_bytes = cmd_list->arena.AllocateBytes(source_length);
+ std::memcpy(allocated_bytes,
+ static_cast<const uint8_t*>(source_buffer) + source_offset,
+ source_length);
+ return allocated_bytes;
+}
+
+Status InProcCommandBuffer::Process(CommandBuffer* command_processor) const {
+ IREE_TRACE_SCOPE0("InProcCommandBuffer::Process");
+
+ RETURN_IF_ERROR(command_processor->Begin());
+
+ // Process each command in the order they were recorded.
+ auto* cmd_list = ¤t_cmd_list_;
+ for (CmdHeader* cmd_header = cmd_list->head; cmd_header != nullptr;
+ cmd_header = cmd_header->next) {
+ auto command_status = ProcessCmd(cmd_header, command_processor);
+ if (!command_status.ok()) {
+ LOG(ERROR) << "DeviceQueue failure while executing command; permanently "
+ "failing all future commands: "
+ << command_status;
+ }
+ }
+
+ RETURN_IF_ERROR(command_processor->End());
+
+ return OkStatus();
+}
+
+Status InProcCommandBuffer::ProcessCmd(CmdHeader* cmd_header,
+ CommandBuffer* command_processor) const {
+ switch (cmd_header->type) {
+ case CmdType::kExecutionBarrier: {
+ auto* cmd = reinterpret_cast<ExecutionBarrierCmd*>(cmd_header + 1);
+ return command_processor->ExecutionBarrier(
+ cmd->source_stage_mask, cmd->target_stage_mask, cmd->memory_barriers,
+ cmd->buffer_barriers);
+ }
+ case CmdType::kSignalEvent: {
+ auto* cmd = reinterpret_cast<SignalEventCmd*>(cmd_header + 1);
+ return command_processor->SignalEvent(cmd->event, cmd->source_stage_mask);
+ }
+ case CmdType::kResetEvent: {
+ auto* cmd = reinterpret_cast<ResetEventCmd*>(cmd_header + 1);
+ return command_processor->ResetEvent(cmd->event, cmd->source_stage_mask);
+ }
+ case CmdType::kWaitEvents: {
+ auto* cmd = reinterpret_cast<WaitEventsCmd*>(cmd_header + 1);
+ return command_processor->WaitEvents(
+ cmd->events, cmd->source_stage_mask, cmd->target_stage_mask,
+ cmd->memory_barriers, cmd->buffer_barriers);
+ }
+ case CmdType::kFillBuffer: {
+ auto* cmd = reinterpret_cast<FillBufferCmd*>(cmd_header + 1);
+ return command_processor->FillBuffer(cmd->target_buffer,
+ cmd->target_offset, cmd->length,
+ cmd->pattern, cmd->pattern_length);
+ }
+ case CmdType::kDiscardBuffer: {
+ auto* cmd = reinterpret_cast<DiscardBufferCmd*>(cmd_header + 1);
+ return command_processor->DiscardBuffer(cmd->buffer);
+ }
+ case CmdType::kUpdateBuffer: {
+ auto* cmd = reinterpret_cast<UpdateBufferCmd*>(cmd_header + 1);
+ return command_processor->UpdateBuffer(cmd->source_buffer, 0,
+ cmd->target_buffer,
+ cmd->target_offset, cmd->length);
+ }
+ case CmdType::kCopyBuffer: {
+ auto* cmd = reinterpret_cast<CopyBufferCmd*>(cmd_header + 1);
+ return command_processor->CopyBuffer(
+ cmd->source_buffer, cmd->source_offset, cmd->target_buffer,
+ cmd->target_offset, cmd->length);
+ }
+ case CmdType::kDispatch: {
+ auto* cmd = reinterpret_cast<DispatchCmd*>(cmd_header + 1);
+ return command_processor->Dispatch(cmd->request);
+ }
+ default:
+ return DataLossErrorBuilder(IREE_LOC)
+ << "Unrecognized command type "
+ << static_cast<int>(cmd_header->type) << "; corrupt buffer?";
+ }
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/host/inproc_command_buffer.h b/hal/host/inproc_command_buffer.h
new file mode 100644
index 0000000..a3bca41
--- /dev/null
+++ b/hal/host/inproc_command_buffer.h
@@ -0,0 +1,241 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_
+#define IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_
+
+#include "base/arena.h"
+#include "base/intrusive_list.h"
+#include "base/status.h"
+#include "hal/command_buffer.h"
+
+namespace iree {
+namespace hal {
+
+// In-process command buffer with support for recording and playback.
+// Commands are recorded into heap-allocated arenas with pointers to used
+// resources (Buffer*, etc). To replay a command buffer against a real
+// implementation use Process to call each command method as it was originally
+// recorded.
+//
+// Thread-compatible (as with CommandBuffer itself).
+class InProcCommandBuffer final : public CommandBuffer {
+ public:
+ InProcCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories);
+ ~InProcCommandBuffer() override;
+
+ bool is_recording() const override { return is_recording_; }
+
+ Status Begin() override;
+ Status End() override;
+
+ Status ExecutionBarrier(
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) override;
+
+ Status SignalEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) override;
+
+ Status ResetEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) override;
+
+ Status WaitEvents(absl::Span<Event*> events,
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) override;
+
+ Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length, const void* pattern,
+ size_t pattern_length) override;
+
+ Status DiscardBuffer(Buffer* buffer) override;
+
+ Status UpdateBuffer(const void* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) override;
+
+ Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) override;
+
+ Status Dispatch(const DispatchRequest& dispatch_request) override;
+
+ // Processes all commands in the buffer using the given |command_processor|.
+ // The commands are issued in the order they were recorded.
+ Status Process(CommandBuffer* command_processor) const;
+
+ private:
+ // Type of Cmd, used by CmdHeader to identify the command payload.
+ enum class CmdType {
+ kExecutionBarrier,
+ kSignalEvent,
+ kResetEvent,
+ kWaitEvents,
+ kFillBuffer,
+ kDiscardBuffer,
+ kUpdateBuffer,
+ kCopyBuffer,
+ kDispatch,
+ };
+
+ // Prefix for commands encoded into the CmdList.
+ // This is used to identify the type of a command as well as connect commands
+ // in the list sequence. Command data immediately follows the header in
+ // memory.
+ struct CmdHeader {
+ // Optional next command in the list.
+ CmdHeader* next;
+ // Type of the command.
+ CmdType type;
+ };
+
+ // A lightweight linked list of commands and an arena that stores them.
+ // CmdLists are designed to be reused so that the arena allocations are
+ // amortized across multiple uses.
+ //
+ // Note that this and the CmdHeader/Cmd types include raw pointers and as
+ // such are *not* portable across processes. It'd be possible, though, to
+ // extend this for cross-process use if a shared-memory Buffer was also
+ // implemented. For YAGNI we avoid that here.
+ struct CmdList : public IntrusiveLinkBase<void> {
+ static constexpr size_t kArenaBlockSize = 64 * 1024;
+
+ Arena arena{kArenaBlockSize};
+ CmdHeader* head = nullptr;
+ CmdHeader* tail = nullptr;
+ };
+
+ // Defines an execution barrier.
+ struct ExecutionBarrierCmd {
+ static constexpr CmdType kType = CmdType::kExecutionBarrier;
+ ExecutionStageBitfield source_stage_mask;
+ ExecutionStageBitfield target_stage_mask;
+ absl::Span<const MemoryBarrier> memory_barriers;
+ absl::Span<const BufferBarrier> buffer_barriers;
+ };
+
+ // Signals an event.
+ struct SignalEventCmd {
+ static constexpr CmdType kType = CmdType::kSignalEvent;
+ Event* event;
+ ExecutionStageBitfield source_stage_mask;
+ };
+
+ // Resets an event.
+ struct ResetEventCmd {
+ static constexpr CmdType kType = CmdType::kResetEvent;
+ Event* event;
+ ExecutionStageBitfield source_stage_mask;
+ };
+
+ // Waits for one or more events.
+ struct WaitEventsCmd {
+ static constexpr CmdType kType = CmdType::kWaitEvents;
+ absl::Span<Event*> events;
+ ExecutionStageBitfield source_stage_mask;
+ ExecutionStageBitfield target_stage_mask;
+ absl::Span<const MemoryBarrier> memory_barriers;
+ absl::Span<const BufferBarrier> buffer_barriers;
+ };
+
+ // Fills the target buffer with the given repeating value.
+ struct FillBufferCmd {
+ static constexpr CmdType kType = CmdType::kFillBuffer;
+ Buffer* target_buffer;
+ device_size_t target_offset;
+ device_size_t length;
+ uint8_t pattern[4];
+ size_t pattern_length;
+ };
+
+ // Hints to the device queue that the given buffer will not be used again.
+ struct DiscardBufferCmd {
+ static constexpr CmdType kType = CmdType::kDiscardBuffer;
+ Buffer* buffer;
+ };
+
+ // Writes a range of the given target buffer from the embedded memory.
+ // The source buffer contents immediately follow the command in the arena.
+ struct UpdateBufferCmd {
+ static constexpr CmdType kType = CmdType::kUpdateBuffer;
+ const void* source_buffer;
+ Buffer* target_buffer;
+ device_size_t target_offset;
+ device_size_t length;
+ };
+
+ // Copies a range of one buffer to another.
+ struct CopyBufferCmd {
+ static constexpr CmdType kType = CmdType::kCopyBuffer;
+ Buffer* source_buffer;
+ device_size_t source_offset;
+ Buffer* target_buffer;
+ device_size_t target_offset;
+ device_size_t length;
+ };
+
+ // Dispatches an execution request.
+ struct DispatchCmd {
+ static constexpr CmdType kType = CmdType::kDispatch;
+ DispatchRequest request;
+ };
+
+ // Resets the command list.
+ void Reset();
+
+ // Allocates a command and appends it to the current command list.
+ // The caller must populate the fields in the returned pointer.
+ template <typename T>
+ StatusOr<T*> AppendCmd() {
+ return reinterpret_cast<T*>(AppendCmdHeader(T::kType, sizeof(T)) + 1);
+ }
+
+ // Appends a command with the given |type| and payload |cmd_size| prefixed
+ // with a CmdHeader. Returns a pointer to the CmdHeader that is followed
+ // immediately by |cmd_size| zero bytes.
+ CmdHeader* AppendCmdHeader(CmdType type, size_t cmd_size);
+
+ // Appends a byte buffer to the command buffer and returns a pointer to the
+ // copied data within the command buffer arena.
+ void* AppendCmdData(const void* source_buffer, device_size_t source_offset,
+ device_size_t source_length);
+
+ // Appends a span of POD structs to the current CmdList and returns a span
+ // pointing into the CmdList arena.
+ template <typename T>
+ absl::Span<T> AppendStructSpan(absl::Span<T> value) {
+ static_assert(std::is_standard_layout<T>::value,
+ "Struct must be a POD type");
+ void* data_ptr = AppendCmdData(value.data(), 0, value.size() * sizeof(T));
+ return absl::MakeSpan(static_cast<T*>(data_ptr), value.size());
+ }
+
+ // Processes a single command.
+ Status ProcessCmd(CmdHeader* cmd_header,
+ CommandBuffer* command_processor) const;
+
+ bool is_recording_ = false;
+
+ // NOTE: not synchronized. Expected to be used from a single thread.
+ CmdList current_cmd_list_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_
diff --git a/hal/interpreter/BUILD b/hal/interpreter/BUILD
new file mode 100644
index 0000000..5984600
--- /dev/null
+++ b/hal/interpreter/BUILD
@@ -0,0 +1,190 @@
+# HAL implementation running on the CPU using the IREE bytecode.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "bytecode_cache",
+ srcs = ["bytecode_cache.cc"],
+ hdrs = ["bytecode_cache.h"],
+ deps = [
+ ":bytecode_executable",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal:allocator",
+ "///hal:executable",
+ "///hal:executable_cache",
+ "///hal:executable_format",
+ "///rt",
+ ],
+)
+
+cc_library(
+ name = "bytecode_dispatch",
+ srcs = [
+ "bytecode_dispatch.cc",
+ "bytecode_dispatch_conversion.h",
+ "bytecode_dispatch_util.cc",
+ "bytecode_dispatch_util.h",
+ ],
+ hdrs = ["bytecode_dispatch.h"],
+ deps = [
+ ":bytecode_kernels",
+ "///base:logging",
+ "///base:memory",
+ "///base:status",
+ "///hal:allocator",
+ "///hal:buffer_view",
+ "///hal:heap_buffer",
+ "///rt",
+ "///schemas/bytecode:interpreter_bytecode_v0",
+ "///vm:bytecode_module",
+ "///vm:bytecode_reader",
+ "///vm:bytecode_tables_interpreter",
+ "///vm:bytecode_util",
+ "///vm:opcode_info",
+ "///vm:type",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "bytecode_executable",
+ srcs = ["bytecode_executable.cc"],
+ hdrs = ["bytecode_executable.h"],
+ deps = [
+ ":interpreter_module",
+ "///base:status",
+ "///hal:allocator",
+ "///hal:executable",
+ "///hal:executable_spec",
+ "///rt",
+ "///vm:bytecode_tables_interpreter",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "bytecode_kernels",
+ hdrs = ["bytecode_kernels.h"],
+ textual_hdrs = [
+ # TODO(benvanik): SIMD variants.
+ "bytecode_kernels_generic.h",
+ "bytecode_kernels_ruy.h",
+ ],
+ deps = [
+ "///base:shape",
+ "///base:status",
+ "///hal:buffer_view",
+ "@com_google_absl//absl/algorithm",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
+ "@org_tensorflow//tensorflow/lite/experimental/ruy",
+ "@org_tensorflow//tensorflow/lite/experimental/ruy:context",
+ ],
+)
+
+cc_test(
+ name = "bytecode_kernels_test",
+ srcs = ["bytecode_kernels_test.cc"],
+ deps = [
+ ":bytecode_kernels",
+ "///base:memory",
+ "///base:status_matchers",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "interpreter_command_processor",
+ srcs = ["interpreter_command_processor.cc"],
+ hdrs = ["interpreter_command_processor.h"],
+ deps = [
+ ":bytecode_executable",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal:buffer_view",
+ "///hal/host:host_local_command_processor",
+ "///rt",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "interpreter_device",
+ srcs = ["interpreter_device.cc"],
+ hdrs = ["interpreter_device.h"],
+ deps = [
+ ":bytecode_cache",
+ ":bytecode_kernels",
+ ":interpreter_command_processor",
+ "///base:memory",
+ "///base:status",
+ "///base:tracing",
+ "///hal:command_buffer_validation",
+ "///hal:command_queue",
+ "///hal:device",
+ "///hal:fence",
+ "///hal/host:async_command_queue",
+ "///hal/host:host_event",
+ "///hal/host:host_local_allocator",
+ "///hal/host:host_submission_queue",
+ "///hal/host:inproc_command_buffer",
+ "///rt",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "interpreter_driver",
+ srcs = ["interpreter_driver.cc"],
+ hdrs = ["interpreter_driver.h"],
+ deps = [
+ ":interpreter_device",
+ "///hal:device_info",
+ "///hal:driver",
+ ],
+)
+
+cc_library(
+ name = "interpreter_driver_module",
+ srcs = ["interpreter_driver_module.cc"],
+ deps = [
+ ":interpreter_driver",
+ "///base:init",
+ "///base:status",
+ "///hal:driver_registry",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "interpreter_module",
+ srcs = ["interpreter_module.cc"],
+ hdrs = ["interpreter_module.h"],
+ deps = [
+ ":bytecode_dispatch",
+ ":bytecode_kernels",
+ "///base:flatbuffer_util",
+ "///base:status",
+ "///base:tracing",
+ "///hal:allocator",
+ "///hal:buffer_view",
+ "///rt",
+ "///vm:bytecode_module",
+ "///vm:bytecode_tables_interpreter",
+ "@com_google_absl//absl/types:span",
+ ],
+)
diff --git a/iree/hal/interpreter/CMakeLists.txt b/hal/interpreter/CMakeLists.txt
similarity index 100%
rename from iree/hal/interpreter/CMakeLists.txt
rename to hal/interpreter/CMakeLists.txt
diff --git a/hal/interpreter/bytecode_cache.cc b/hal/interpreter/bytecode_cache.cc
new file mode 100644
index 0000000..5eee0d8
--- /dev/null
+++ b/hal/interpreter/bytecode_cache.cc
@@ -0,0 +1,55 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/interpreter/bytecode_cache.h"
+
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/executable_format.h"
+#include "hal/interpreter/bytecode_executable.h"
+
+namespace iree {
+namespace hal {
+
+BytecodeCache::BytecodeCache(ref_ptr<rt::Instance> instance,
+ hal::Allocator* allocator)
+ : instance_(std::move(instance)), allocator_(allocator) {}
+
+BytecodeCache::~BytecodeCache() = default;
+
+bool BytecodeCache::CanPrepareFormat(ExecutableFormat format) const {
+ return format == kExecutableFormatIreeBytecode;
+}
+
+StatusOr<ref_ptr<Executable>> BytecodeCache::PrepareExecutable(
+ ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) {
+ IREE_TRACE_SCOPE0("BytecodeCache::PrepareExecutable");
+ if (!CanPrepareFormat(spec.format)) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unsupported format: " << spec.format;
+ }
+
+ // Wrap the data (or copy it).
+ bool allow_aliasing_data =
+ AllBitsSet(mode, ExecutableCachingMode::kAliasProvidedData);
+ ASSIGN_OR_RETURN(auto executable,
+ BytecodeExecutable::Load(add_ref(instance_), allocator_,
+ spec, !allow_aliasing_data));
+
+ return executable;
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/interpreter/bytecode_cache.h b/hal/interpreter/bytecode_cache.h
new file mode 100644
index 0000000..7f56e3e
--- /dev/null
+++ b/hal/interpreter/bytecode_cache.h
@@ -0,0 +1,44 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_
+#define IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_
+
+#include "hal/allocator.h"
+#include "hal/executable.h"
+#include "hal/executable_cache.h"
+#include "rt/instance.h"
+
+namespace iree {
+namespace hal {
+
+class BytecodeCache final : public ExecutableCache {
+ public:
+ BytecodeCache(ref_ptr<rt::Instance> instance, hal::Allocator* allocator);
+ ~BytecodeCache() override;
+
+ bool CanPrepareFormat(ExecutableFormat format) const override;
+
+ StatusOr<ref_ptr<Executable>> PrepareExecutable(
+ ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override;
+
+ private:
+ ref_ptr<rt::Instance> instance_;
+ hal::Allocator* allocator_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_
diff --git a/hal/interpreter/bytecode_dispatch.cc b/hal/interpreter/bytecode_dispatch.cc
new file mode 100644
index 0000000..8dd1892
--- /dev/null
+++ b/hal/interpreter/bytecode_dispatch.cc
@@ -0,0 +1,850 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Implements a full bytecode dispatch system.
+// Currently this is verbose and object oriented, but future revisions
+// (once we have interesting benchmarks) will likely simplify and inline
+// a lot of the checks to make things faster. Consider this to be as
+// experimental an implementation as the entire rest of the project :)
+
+#include "hal/interpreter/bytecode_dispatch.h"
+
+#include <algorithm>
+
+#include "absl/base/attributes.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
+#include "base/logging.h"
+#include "base/memory.h"
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "hal/heap_buffer.h"
+#include "hal/interpreter/bytecode_dispatch_conversion.h"
+#include "hal/interpreter/bytecode_dispatch_util.h"
+#include "hal/interpreter/bytecode_kernels.h"
+#include "rt/function.h"
+#include "schemas/bytecode/interpreter_bytecode_v0.h"
+#include "vm/bytecode_module.h"
+#include "vm/bytecode_reader.h"
+#include "vm/bytecode_tables_interpreter.h"
+#include "vm/bytecode_util.h"
+#include "vm/opcode_info.h"
+
+namespace iree {
+namespace hal {
+
+namespace {
+
+using ::iree::rt::Stack;
+using ::iree::rt::StackFrame;
+using ::iree::vm::BytecodeReader;
+
+} // namespace
+
+Status Dispatch(hal::Allocator* allocator,
+ kernels::RuntimeState* kernel_runtime_state, Stack* stack,
+ StackFrame* entry_stack_frame,
+ absl::Span<BufferView> entry_results) {
+ // Dispatch table mapping 1:1 with bytecode ops.
+ // Each entry is a label within this function that can be used for computed
+ // goto. You can find more information on computed goto here:
+ // https://eli.thegreenplace.net/2012/07/12/computed-goto-for-efficient-dispatch-tables
+ //
+ // Note that we ensure the table is 256 elements long exactly to make sure
+ // that unused opcodes are handled gracefully.
+ static const void* kDispatchTable[256] = {
+#define DECLARE_DISPATCH(ordinal, name, ...) &&_dispatch_##name,
+#define DECLARE_DISPATCH_RESERVED(ordinal, name, ...) &&_dispatch_unhandled,
+ IREE_INTERPRETER_OPCODE_LIST(DECLARE_DISPATCH, DECLARE_DISPATCH_RESERVED)
+#undef DECLARE_DISPATCH
+#undef DECLARE_DISPATCH_RESERVED
+ };
+
+ // Primary dispatch state. This is our 'native stack frame' and really just
+ // enough to make dereferencing common addresses (like the current offset)
+ // faster. You can think of this like CPU state (like PC).
+ //
+ // We hope that LLVM decides to keep these in registers (as they are touched
+ // for every instruction executed). The stack_frame will change as we call
+ // into different functions.
+ BytecodeReader reader(stack);
+ RETURN_IF_ERROR(reader.SwitchStackFrame(entry_stack_frame));
+
+#define DISPATCH_NEXT() \
+ { \
+ uint8_t opcode = *reader.AdvanceOffset().ValueOrDie(); \
+ DVLOG(1) \
+ << "Interpreter dispatching op code: " \
+ << GetOpcodeInfo(vm::interpreter_opcode_table(), opcode).mnemonic; \
+ goto* kDispatchTable[opcode]; \
+ }
+
+#define DISPATCH_CORE_OPCODE(opcode, body) \
+ _dispatch_##opcode : {body} DISPATCH_NEXT()
+#if defined(IREE_SUPPORT_F32) || defined(IREE_SUPPORT_F64)
+#define DISPATCH_FLOAT_OPCODE(opcode, body) \
+ _dispatch_##opcode : {body} DISPATCH_NEXT()
+#else
+#define DISPATCH_FLOAT_OPCODE(...)
+#endif // IREE_SUPPORT_F32 || IREE_SUPPORT_F64
+
+ DISPATCH_NEXT();
+
+ DISPATCH_CORE_OPCODE(kConstant, {
+ ASSIGN_OR_RETURN(auto value, reader.ReadConstant());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ *dst_local = std::move(value);
+ });
+
+ DISPATCH_CORE_OPCODE(kCall, {
+ auto* old_stack_frame = stack->current_frame();
+ ASSIGN_OR_RETURN(const auto& target_function, reader.ReadFunction());
+ // TODO(benvanik): rework register storage interface.
+ ASSIGN_OR_RETURN(
+ const auto* function_def,
+ static_cast<const vm::BytecodeModule*>(target_function.module())
+ ->GetFunctionDef(target_function.linkage(),
+ target_function.ordinal()));
+ ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function));
+ new_stack_frame->mutable_registers()->buffer_views.resize(
+ function_def->bytecode()->local_count());
+ RETURN_IF_ERROR(
+ reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame));
+ DVLOG(1) << "Call; stack now: " << stack->DebugString();
+ });
+
+ DISPATCH_CORE_OPCODE(kCallImport, {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Non-module imports not supported";
+ });
+
+ DISPATCH_CORE_OPCODE(kCallIndirect, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented call_indirect";
+ });
+
+ DISPATCH_CORE_OPCODE(kReturn, {
+ auto* old_stack_frame = stack->current_frame();
+ auto* new_stack_frame = stack->caller_frame();
+ if (old_stack_frame == entry_stack_frame) {
+ // Returning from entry function. Marshal results from the return stmt.
+ ASSIGN_OR_RETURN(int32_t src_count, reader.ReadCount());
+ for (int i = 0; i < src_count; ++i) {
+ ASSIGN_OR_RETURN(
+ auto* src_local,
+ reader.ReadLocal(old_stack_frame->mutable_registers()));
+ entry_results[i] = std::move(*src_local);
+ }
+ DVLOG(1) << "Returning to entry";
+ return OkStatus();
+ } else if (!new_stack_frame) {
+ return FailedPreconditionErrorBuilder(IREE_LOC) << "Stack underflow";
+ }
+ RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame,
+ new_stack_frame));
+ RETURN_IF_ERROR(stack->PopFrame());
+ DVLOG(1) << "Return; stack now: " << stack->DebugString();
+ });
+
+ DISPATCH_CORE_OPCODE(kBranch, {
+ ASSIGN_OR_RETURN(int32_t offset, reader.ReadBlockOffset());
+ RETURN_IF_ERROR(reader.CopySlots());
+ RETURN_IF_ERROR(reader.BranchToOffset(offset));
+ });
+
+ DISPATCH_CORE_OPCODE(kCondBranch, {
+ // Evaluate condition first so we can do the copies as we read them for
+ // which side of the branch we take.
+ ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal());
+ bool cond_value = BufferViewIsTrue(*cond_local);
+ ASSIGN_OR_RETURN(int32_t true_offset, reader.ReadBlockOffset());
+ if (cond_value) {
+ RETURN_IF_ERROR(reader.CopySlots());
+ RETURN_IF_ERROR(reader.BranchToOffset(true_offset));
+ } else {
+ ASSIGN_OR_RETURN(int32_t true_op_count, reader.ReadCount());
+ RETURN_IF_ERROR(reader.SkipLocals(2 * true_op_count));
+ ASSIGN_OR_RETURN(int32_t false_offset, reader.ReadBlockOffset());
+ RETURN_IF_ERROR(reader.CopySlots());
+ RETURN_IF_ERROR(reader.BranchToOffset(false_offset));
+ }
+ });
+
+ DISPATCH_CORE_OPCODE(kCmpI, {
+ ASSIGN_OR_RETURN(uint8_t predicate, reader.ReadUint8_t());
+ ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+
+ switch (static_cast<CmpIPredicate>(predicate)) {
+ case CmpIPredicate::kEq:
+ RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareEQ>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpIPredicate::kNe:
+ RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareNE>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpIPredicate::kSlt:
+ RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareLT>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpIPredicate::kSle:
+ RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareLE>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpIPredicate::kSgt:
+ RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareGT>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpIPredicate::kSge:
+ RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareGE>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpIPredicate::kUlt:
+ RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareLT>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpIPredicate::kUle:
+ RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareLE>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpIPredicate::kUgt:
+ RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareGT>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpIPredicate::kUge:
+ RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareGE>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ }
+ });
+
+ DISPATCH_FLOAT_OPCODE(kCmpF, {
+ ASSIGN_OR_RETURN(uint8_t p, reader.ReadUint8_t());
+ ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+
+ auto predicate = static_cast<CmpFPredicate>(p);
+ switch (predicate) {
+ case CmpFPredicate::kOeq:
+ RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareEQ>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpFPredicate::kUne:
+ RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareNE>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpFPredicate::kOlt:
+ RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareLT>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpFPredicate::kOle:
+ RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareLE>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpFPredicate::kOgt:
+ RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareGT>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpFPredicate::kOge:
+ RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareGE>(
+ lhs_local, rhs_local, dst_local));
+ break;
+ case CmpFPredicate::kFalse:
+ case CmpFPredicate::kOne:
+ case CmpFPredicate::kOrd:
+ case CmpFPredicate::kUeq:
+ case CmpFPredicate::kUgt:
+ case CmpFPredicate::kUge:
+ case CmpFPredicate::kUlt:
+ case CmpFPredicate::kUle:
+ case CmpFPredicate::kUno:
+ case CmpFPredicate::kTrue:
+ // TODO(b/132183250) support these if we ever need them.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unsupported comparison predicate value "
+ << static_cast<int>(p) << " ("
+ << vm::PredicateToString(predicate) << ")";
+ }
+ });
+
+ DISPATCH_CORE_OPCODE(kAllocStatic, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_static";
+ });
+
+ DISPATCH_CORE_OPCODE(kAllocStack, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_stack";
+ });
+
+ DISPATCH_CORE_OPCODE(kAllocStackInit, {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented alloc_stack_init";
+ });
+
+ DISPATCH_CORE_OPCODE(kAllocHeap, {
+ ASSIGN_OR_RETURN(auto heap_type, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto type, reader.ReadType());
+ size_t element_size = type.element_size();
+
+ // TODO(benvanik): more efficient reading and storage.
+ size_t element_count = 0;
+ ASSIGN_OR_RETURN(auto shape, reader.ReadShapePieces(&element_count));
+ size_t allocation_size = element_size * element_count;
+
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ dst_local->element_size = element_size;
+ dst_local->shape = shape;
+
+ // TODO(benvanik): properly allocate with attributes from op.
+ CHECK_EQ(heap_type, 0);
+ ASSIGN_OR_RETURN(
+ dst_local->buffer,
+ allocator->Allocate(MemoryType::kHostLocal | MemoryType::kDeviceVisible,
+ BufferUsage::kAll, allocation_size));
+ });
+
+ DISPATCH_CORE_OPCODE(kDiscard, {
+ // NOTE: if we were an encoder we would actually discard the buffer.
+ ASSIGN_OR_RETURN(auto* local, reader.ReadLocal());
+ *local = {};
+ });
+
+ DISPATCH_CORE_OPCODE(kRank, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ int32_t rank = src_local->shape.size();
+ RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &rank, sizeof(int32_t)));
+ });
+
+ DISPATCH_CORE_OPCODE(kDim, {
+ ASSIGN_OR_RETURN(int32_t axis, reader.ReadUint8_t());
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(int32_t dim, src_local->shape.ResolveAxis(axis));
+ RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &dim, sizeof(int32_t)));
+ });
+
+ DISPATCH_CORE_OPCODE(kShape, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ RETURN_IF_ERROR(dst_local->buffer->WriteData(
+ 0, src_local->shape.subspan().data(),
+ src_local->shape.subspan().size() * sizeof(int32_t)));
+ });
+
+ DISPATCH_CORE_OPCODE(kLength, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ int32_t length = src_local->shape.element_count();
+ RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &length, sizeof(int32_t)));
+ });
+
+ DISPATCH_CORE_OPCODE(kDynamicSlice, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto indices, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(*dst_local, src_local->Slice(indices, lengths));
+ });
+
+ DISPATCH_CORE_OPCODE(kStaticSlice, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto indices, reader.ReadIndexList());
+ ASSIGN_OR_RETURN(auto lengths, reader.ReadIndexList());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(*dst_local, src_local->Slice(indices, lengths));
+ });
+
+ DISPATCH_CORE_OPCODE(kDynamicCopy, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto src_indices, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dst_indices, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>());
+ RETURN_IF_ERROR(
+ ApplyCopy(src_local, src_indices, dst_local, dst_indices, lengths));
+ });
+
+ DISPATCH_CORE_OPCODE(kStaticCopy, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto src_indices, reader.ReadIndexList());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dst_indices, reader.ReadIndexList());
+ ASSIGN_OR_RETURN(auto lengths, reader.ReadIndexList());
+ RETURN_IF_ERROR(
+ ApplyCopy(src_local, src_indices, dst_local, dst_indices, lengths));
+ });
+
+ DISPATCH_CORE_OPCODE(kClone, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ dst_local->element_size = src_local->element_size;
+ dst_local->shape = src_local->shape;
+ dst_local->buffer = HeapBuffer::Allocate(src_local->buffer->usage(),
+ src_local->buffer->byte_length());
+ RETURN_IF_ERROR(dst_local->buffer->CopyData(0, src_local->buffer.get()));
+ });
+
+ DISPATCH_CORE_OPCODE(kSplit, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented split";
+ });
+
+ DISPATCH_CORE_OPCODE(kAssign, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ *dst_local = *src_local;
+ });
+
+ DISPATCH_CORE_OPCODE(kCondAssign, {
+ ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ *dst_local = BufferViewIsTrue(*cond_local) ? *lhs_local : *rhs_local;
+ });
+
+ DISPATCH_CORE_OPCODE(kReshape, {
+ // TODO(benvanik): more logic required if strides differ.
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ Shape new_shape = Shape{shape_data};
+ if (src_local->shape.element_count() != new_shape.element_count()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "New element count " << new_shape.element_count()
+ << " != source element count " << src_local->shape.element_count();
+ }
+ dst_local->shape = new_shape;
+ dst_local->buffer = add_ref(src_local->buffer);
+ dst_local->element_size = src_local->element_size;
+ });
+
+ DISPATCH_CORE_OPCODE(kSelect, {
+ ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto cond_buffer, cond_local->buffer->MapMemory<uint8_t>(
+ MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto lhs_buffer, lhs_local->buffer->MapMemory<uint8_t>(
+ MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto rhs_buffer, rhs_local->buffer->MapMemory<uint8_t>(
+ MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<uint8_t>(
+ MemoryAccess::kDiscardWrite));
+ if (cond_local->element_size != 1) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "Select cond must be i8";
+ } else if (lhs_buffer.size() != rhs_buffer.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "LHS " << lhs_buffer.size() << "b != RHS " << rhs_buffer.size()
+ << "b; both arguments must match";
+ } else if (lhs_buffer.size() != dst_buffer.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Dest " << dst_buffer.size() << "b != LHS/RHS "
+ << lhs_buffer.size() << "b; dest must match inputs";
+ }
+ switch (lhs_local->element_size) {
+ case 1:
+ RETURN_IF_ERROR(kernels::Select::Execute<uint8_t>(
+ cond_buffer.contents(), lhs_buffer.contents(),
+ rhs_buffer.contents(), dst_buffer.mutable_contents()));
+ break;
+ case 2:
+ RETURN_IF_ERROR(kernels::Select::Execute<uint16_t>(
+ cond_buffer.contents(),
+ ReinterpretSpan<uint16_t>(lhs_buffer.contents()),
+ ReinterpretSpan<uint16_t>(rhs_buffer.contents()),
+ ReinterpretSpan<uint16_t>(dst_buffer.mutable_contents())));
+ break;
+ case 4:
+ RETURN_IF_ERROR(kernels::Select::Execute<uint32_t>(
+ cond_buffer.contents(),
+ ReinterpretSpan<uint32_t>(lhs_buffer.contents()),
+ ReinterpretSpan<uint32_t>(rhs_buffer.contents()),
+ ReinterpretSpan<uint32_t>(dst_buffer.mutable_contents())));
+ break;
+ case 8:
+ RETURN_IF_ERROR(kernels::Select::Execute<uint64_t>(
+ cond_buffer.contents(),
+ ReinterpretSpan<uint64_t>(lhs_buffer.contents()),
+ ReinterpretSpan<uint64_t>(rhs_buffer.contents()),
+ ReinterpretSpan<uint64_t>(dst_buffer.mutable_contents())));
+ break;
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << lhs_local->element_size;
+ }
+ });
+
+ DISPATCH_CORE_OPCODE(kTranspose, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto perm_data, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Transpose>(
+ src_local, dst_local, src_local->shape,
+ absl::MakeConstSpan(perm_data)));
+ });
+
+ DISPATCH_CORE_OPCODE(kReverse, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto perm_data, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ RETURN_IF_ERROR(
+ ApplyUnaryOpIU<kernels::Reverse>(src_local, dst_local, src_local->shape,
+ absl::MakeConstSpan(perm_data)));
+ });
+
+ DISPATCH_CORE_OPCODE(kPad, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* padding_value, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto edge_padding_low, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto edge_padding_high,
+ reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto interior_padding, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+
+ RETURN_IF_ERROR(ApplyBinaryOpIU<kernels::Pad>(
+ src_local, padding_value, dst_local, src_local->shape, dst_local->shape,
+ absl::MakeConstSpan(edge_padding_low),
+ absl::MakeConstSpan(edge_padding_high),
+ absl::MakeConstSpan(interior_padding)));
+ });
+
+ DISPATCH_CORE_OPCODE(kBroadcast, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ dst_local->shape = Shape{shape_data};
+ RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Broadcast>(src_local, dst_local));
+ });
+
+ DISPATCH_CORE_OPCODE(kTile, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ dst_local->shape = Shape{shape_data};
+ RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Tile>(
+ src_local, dst_local, src_local->shape, dst_local->shape));
+ });
+
+ DISPATCH_CORE_OPCODE(kNot, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpIU<kernels::Not>(&reader));
+ });
+ DISPATCH_CORE_OPCODE(kAnd, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::And>(&reader));
+ });
+ DISPATCH_CORE_OPCODE(kOr, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Or>(&reader));
+ });
+ DISPATCH_CORE_OPCODE(kXor, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Xor>(&reader));
+ });
+ DISPATCH_CORE_OPCODE(kShiftLeft, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::ShiftLeft>(&reader));
+ });
+ DISPATCH_CORE_OPCODE(kShiftRightLogical, {
+ RETURN_IF_ERROR(
+ DispatchElementwiseBinaryOpIU<kernels::ShiftRight>(&reader));
+ });
+ DISPATCH_CORE_OPCODE(kShiftRightArithmetic, {
+ RETURN_IF_ERROR(
+ DispatchElementwiseBinaryOpIS<kernels::ShiftRight>(&reader));
+ });
+
+ DISPATCH_CORE_OPCODE(kAddI, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Add>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kAddF, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Add>(&reader));
+ });
+
+ DISPATCH_CORE_OPCODE(kSubI, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Sub>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kSubF, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Sub>(&reader));
+ });
+
+ DISPATCH_CORE_OPCODE(kAbsI, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpIS<kernels::Abs>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kAbsF, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Abs>(&reader));
+ });
+
+ DISPATCH_CORE_OPCODE(kMulI, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Mul>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kMulF, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Mul>(&reader));
+ });
+
+ DISPATCH_CORE_OPCODE(kDivIS, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Div>(&reader));
+ });
+ DISPATCH_CORE_OPCODE(kDivIU, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Div>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kDivF, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Div>(&reader));
+ });
+
+ DISPATCH_CORE_OPCODE(kMulAddI, {
+ RETURN_IF_ERROR(DispatchElementwiseTernaryOpIU<kernels::MulAdd>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kMulAddF, {
+ RETURN_IF_ERROR(DispatchElementwiseTernaryOpF<kernels::MulAdd>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kExpF, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Exp>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kLogF, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Log>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kRsqrtF, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Rsqrt>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kCosF, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Cos>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kSinF, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Sin>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kTanhF, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Tanh>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kAtan2F, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Atan2>(&reader));
+ });
+
+ DISPATCH_CORE_OPCODE(kMinIS, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Min>(&reader));
+ });
+ DISPATCH_CORE_OPCODE(kMinIU, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Min>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kMinF, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Min>(&reader));
+ });
+
+ DISPATCH_CORE_OPCODE(kMaxIS, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Max>(&reader));
+ });
+ DISPATCH_CORE_OPCODE(kMaxIU, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Max>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kMaxF, {
+ RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Max>(&reader));
+ });
+
+ DISPATCH_CORE_OPCODE(kClampIS, {
+ RETURN_IF_ERROR(DispatchElementwiseTernaryOpIS<kernels::Clamp>(&reader));
+ });
+ DISPATCH_CORE_OPCODE(kClampIU, {
+ RETURN_IF_ERROR(DispatchElementwiseTernaryOpIS<kernels::Clamp>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kClampF, {
+ RETURN_IF_ERROR(DispatchElementwiseTernaryOpF<kernels::Clamp>(&reader));
+ });
+
+ DISPATCH_FLOAT_OPCODE(kFloorF, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Floor>(&reader));
+ });
+ DISPATCH_FLOAT_OPCODE(kCeilF, {
+ RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Ceil>(&reader));
+ });
+
+ DISPATCH_CORE_OPCODE(kConvertSS, {
+ ASSIGN_OR_RETURN(auto src_type, reader.ReadType());
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dst_type, reader.ReadType());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ RETURN_IF_ERROR(
+ ApplyConvertSS::Apply(src_type, src_local, dst_type, dst_local));
+ });
+ DISPATCH_CORE_OPCODE(kConvertUU, {
+ ASSIGN_OR_RETURN(auto src_type, reader.ReadType());
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dst_type, reader.ReadType());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ RETURN_IF_ERROR(
+ ApplyConvertUU::Apply(src_type, src_local, dst_type, dst_local));
+ });
+ DISPATCH_CORE_OPCODE(kConvertSU, {
+ ASSIGN_OR_RETURN(auto src_type, reader.ReadType());
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dst_type, reader.ReadType());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ RETURN_IF_ERROR(
+ ApplyConvertSU::Apply(src_type, src_local, dst_type, dst_local));
+ });
+ DISPATCH_CORE_OPCODE(kConvertUS, {
+ ASSIGN_OR_RETURN(auto src_type, reader.ReadType());
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dst_type, reader.ReadType());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ RETURN_IF_ERROR(
+ ApplyConvertUS::Apply(src_type, src_local, dst_type, dst_local));
+ });
+
+ DISPATCH_CORE_OPCODE(kMatMulI, {
+ ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
+ // TODO(benvanik): add fused matmul-with-bias op in MLIR and lower to this.
+ BufferView* bias_local = nullptr;
+ ASSIGN_OR_RETURN(auto* multiplier_mantissa_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* multiplier_exponent_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ RETURN_IF_ERROR(ValidateMatMulOpI(lhs_local, rhs_local, bias_local,
+ multiplier_mantissa_local,
+ multiplier_exponent_local, dst_local));
+ auto* mat_mul_state = kernel_runtime_state->mat_mul_state.get();
+ // TODO(benvanik): define as a matrix of supported types to enable 8*8=16,
+ // accumulator options, and other precision modes.
+ switch (lhs_local->element_size) {
+ case 1:
+ RETURN_IF_ERROR(ApplyMatMulOpI<int8_t>(
+ mat_mul_state, lhs_local, rhs_local, bias_local,
+ multiplier_mantissa_local, multiplier_exponent_local, dst_local));
+ break;
+ case 2:
+ RETURN_IF_ERROR(ApplyMatMulOpI<int16_t>(
+ mat_mul_state, lhs_local, rhs_local, bias_local,
+ multiplier_mantissa_local, multiplier_exponent_local, dst_local));
+ break;
+ case 4:
+ RETURN_IF_ERROR(ApplyMatMulOpI<int32_t>(
+ mat_mul_state, lhs_local, rhs_local, bias_local,
+ multiplier_mantissa_local, multiplier_exponent_local, dst_local));
+ break;
+ case 8:
+ RETURN_IF_ERROR(ApplyMatMulOpI<int64_t>(
+ mat_mul_state, lhs_local, rhs_local, bias_local,
+ multiplier_mantissa_local, multiplier_exponent_local, dst_local));
+ break;
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << lhs_local->element_size;
+ }
+ });
+
+ DISPATCH_FLOAT_OPCODE(kMatMulF, {
+ ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
+ BufferView* bias_local = nullptr;
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ RETURN_IF_ERROR(
+ ValidateMatMulOpF(lhs_local, rhs_local, bias_local, dst_local));
+ auto* mat_mul_state = kernel_runtime_state->mat_mul_state.get();
+ switch (lhs_local->element_size) {
+ case 4:
+ RETURN_IF_ERROR(ApplyMatMulOpF<float>(
+ mat_mul_state, lhs_local, rhs_local, bias_local, dst_local));
+ break;
+ case 8:
+ RETURN_IF_ERROR(ApplyMatMulOpF<double>(
+ mat_mul_state, lhs_local, rhs_local, bias_local, dst_local));
+ break;
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << lhs_local->element_size;
+ }
+ });
+
+ DISPATCH_CORE_OPCODE(kReduceSumI, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ // TODO(scotttodd): validate
+ RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceSum>(
+ src_local, init_local, dst_local, dimension, src_local->shape,
+ dst_local->shape));
+ });
+
+ DISPATCH_FLOAT_OPCODE(kReduceSumF, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ // TODO(scotttodd): validate
+ RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceSum>(
+ src_local, init_local, dst_local, dimension, src_local->shape,
+ dst_local->shape));
+ });
+
+ DISPATCH_CORE_OPCODE(kReduceMinI, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ // TODO(scotttodd): validate
+ RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceMin>(
+ src_local, init_local, dst_local, dimension, src_local->shape,
+ dst_local->shape));
+ });
+
+ DISPATCH_FLOAT_OPCODE(kReduceMinF, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ // TODO(scotttodd): validate
+ RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceMin>(
+ src_local, init_local, dst_local, dimension, src_local->shape,
+ dst_local->shape));
+ });
+
+ DISPATCH_CORE_OPCODE(kReduceMaxI, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ // TODO(scotttodd): validate
+ RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceMax>(
+ src_local, init_local, dst_local, dimension, src_local->shape,
+ dst_local->shape));
+ });
+
+ DISPATCH_FLOAT_OPCODE(kReduceMaxF, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ // TODO(scotttodd): validate
+ RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceMax>(
+ src_local, init_local, dst_local, dimension, src_local->shape,
+ dst_local->shape));
+ });
+
+ DISPATCH_CORE_OPCODE(kTrace, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented trace";
+ });
+
+ DISPATCH_CORE_OPCODE(kBreak, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented break";
+ });
+
+ DISPATCH_CORE_OPCODE(kCondBreak, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented cond_break";
+ });
+
+_dispatch_unhandled:
+ // TODO(benvanik): better tracing.
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unknown dispatch opcode";
+} // NOLINT(readability/fn_size)
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/interpreter/bytecode_dispatch.h b/hal/interpreter/bytecode_dispatch.h
new file mode 100644
index 0000000..4a9db68
--- /dev/null
+++ b/hal/interpreter/bytecode_dispatch.h
@@ -0,0 +1,35 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_
+#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_
+
+#include "base/status.h"
+#include "hal/allocator.h"
+#include "hal/interpreter/bytecode_kernels.h"
+#include "rt/stack.h"
+#include "rt/stack_frame.h"
+
+namespace iree {
+namespace hal {
+
+Status Dispatch(hal::Allocator* allocator,
+ kernels::RuntimeState* kernel_runtime_state, rt::Stack* stack,
+ rt::StackFrame* entry_stack_frame,
+ absl::Span<BufferView> entry_results);
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_
diff --git a/hal/interpreter/bytecode_dispatch_conversion.h b/hal/interpreter/bytecode_dispatch_conversion.h
new file mode 100644
index 0000000..8cedd75
--- /dev/null
+++ b/hal/interpreter/bytecode_dispatch_conversion.h
@@ -0,0 +1,395 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Conversion helper tables.
+
+#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_
+#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_
+
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "hal/interpreter/bytecode_dispatch_util.h"
+#include "schemas/bytecode/interpreter_bytecode_v0.h"
+#include "vm/type.h"
+
+namespace iree {
+namespace hal {
+
+template <typename KERNEL, bool src_signed, bool dst_signed, typename... ARGS>
+struct ApplyConversionOp {
+ static Status Apply(const vm::Type& src_type, BufferView* src_local,
+ const vm::Type& dst_type, BufferView* dst_local,
+ ARGS... args) {
+ // Validate ranges so that we cannot go out of bounds on thunk table.
+ int src_type_index = src_type.type_index();
+ int dst_type_index = dst_type.type_index();
+ if (src_type_index < 0 || src_type_index >= kBuiltinTypeCount) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Conversion from invalid source builtin type "
+ << src_type_index;
+ } else if (dst_type_index < 0 || dst_type_index >= kBuiltinTypeCount) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Conversion to invalid dest builtin type " << dst_type_index;
+ }
+
+ // All possible combinations of conversions.
+ using KernelFn = Status (*)(BufferView * src_local, BufferView * dst_local,
+ ARGS... args);
+ KernelFn fn = nullptr;
+ if (src_signed && dst_signed) {
+ // Signed -> signed.
+ static const KernelFn
+ kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
+ // src_type = kI8:
+ /* kI8 */ Thunk<int8_t, int8_t>::Apply,
+ /* kI16 */ Thunk<int8_t, int16_t>::Apply,
+ /* kI32 */ Thunk<int8_t, int32_t>::Apply,
+ /* kI64 */ Thunk<int8_t, int64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ Thunk<int8_t, float>::Apply,
+ /* kF64 */ Thunk<int8_t, double>::Apply,
+
+ // src_type = kI16:
+ /* kI8 */ Thunk<int16_t, int8_t>::Apply,
+ /* kI16 */ Thunk<int16_t, int16_t>::Apply,
+ /* kI32 */ Thunk<int16_t, int32_t>::Apply,
+ /* kI64 */ Thunk<int16_t, int64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ Thunk<int16_t, float>::Apply,
+ /* kF64 */ Thunk<int16_t, double>::Apply,
+
+ // src_type = kI32:
+ /* kI8 */ Thunk<int32_t, int8_t>::Apply,
+ /* kI16 */ Thunk<int32_t, int16_t>::Apply,
+ /* kI32 */ Thunk<int32_t, int32_t>::Apply,
+ /* kI64 */ Thunk<int32_t, int64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ Thunk<int32_t, float>::Apply,
+ /* kF64 */ Thunk<int32_t, double>::Apply,
+
+ // src_type = kI64:
+ /* kI8 */ Thunk<int64_t, int8_t>::Apply,
+ /* kI16 */ Thunk<int64_t, int16_t>::Apply,
+ /* kI32 */ Thunk<int64_t, int32_t>::Apply,
+ /* kI64 */ Thunk<int64_t, int64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ Thunk<int64_t, float>::Apply,
+ /* kF64 */ Thunk<int64_t, double>::Apply,
+
+ // src_type = kF16:
+ /* kI8 */ nullptr,
+ /* kI16 */ nullptr,
+ /* kI32 */ nullptr,
+ /* kI64 */ nullptr,
+ /* kF16 */ Thunk<uint16_t, uint16_t>::Apply,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kF32:
+ /* kI8 */ Thunk<float, int8_t>::Apply,
+ /* kI16 */ Thunk<float, int16_t>::Apply,
+ /* kI32 */ Thunk<float, int32_t>::Apply,
+ /* kI64 */ Thunk<float, int64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ Thunk<float, float>::Apply,
+ /* kF64 */ Thunk<float, double>::Apply,
+
+ // src_type = kF64:
+ /* kI8 */ Thunk<double, int8_t>::Apply,
+ /* kI16 */ Thunk<double, int16_t>::Apply,
+ /* kI32 */ Thunk<double, int32_t>::Apply,
+ /* kI64 */ Thunk<double, int64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ Thunk<double, float>::Apply,
+ /* kF64 */ Thunk<double, double>::Apply,
+ };
+ fn =
+ kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
+ } else if (src_signed && !dst_signed) {
+ // Signed -> unsigned.
+ static const KernelFn
+ kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
+ // src_type = kI8:
+ /* kI8 */ Thunk<int8_t, uint8_t>::Apply,
+ /* kI16 */ Thunk<int8_t, uint16_t>::Apply,
+ /* kI32 */ Thunk<int8_t, uint32_t>::Apply,
+ /* kI64 */ Thunk<int8_t, uint64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kI16:
+ /* kI8 */ Thunk<int16_t, uint8_t>::Apply,
+ /* kI16 */ Thunk<int16_t, uint16_t>::Apply,
+ /* kI32 */ Thunk<int16_t, uint32_t>::Apply,
+ /* kI64 */ Thunk<int16_t, uint64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kI32:
+ /* kI8 */ Thunk<int32_t, uint8_t>::Apply,
+ /* kI16 */ Thunk<int32_t, uint16_t>::Apply,
+ /* kI32 */ Thunk<int32_t, uint32_t>::Apply,
+ /* kI64 */ Thunk<int32_t, uint64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kI64:
+ /* kI8 */ Thunk<int64_t, uint8_t>::Apply,
+ /* kI16 */ Thunk<int64_t, uint16_t>::Apply,
+ /* kI32 */ Thunk<int64_t, uint32_t>::Apply,
+ /* kI64 */ Thunk<int64_t, uint64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kF16:
+ /* kI8 */ nullptr,
+ /* kI16 */ nullptr,
+ /* kI32 */ nullptr,
+ /* kI64 */ nullptr,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kF32:
+ /* kI8 */ Thunk<float, uint8_t>::Apply,
+ /* kI16 */ Thunk<float, uint16_t>::Apply,
+ /* kI32 */ Thunk<float, uint32_t>::Apply,
+ /* kI64 */ Thunk<float, uint64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kF64:
+ /* kI8 */ Thunk<double, uint8_t>::Apply,
+ /* kI16 */ Thunk<double, uint16_t>::Apply,
+ /* kI32 */ Thunk<double, uint32_t>::Apply,
+ /* kI64 */ Thunk<double, uint64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+ };
+ fn =
+ kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
+ } else if (!src_signed && dst_signed) {
+ // Unsigned -> signed.
+ static const KernelFn
+ kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
+ // src_type = kI8:
+ /* kI8 */ Thunk<uint8_t, int8_t>::Apply,
+ /* kI16 */ Thunk<uint8_t, int16_t>::Apply,
+ /* kI32 */ Thunk<uint8_t, int32_t>::Apply,
+ /* kI64 */ Thunk<uint8_t, int64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ Thunk<uint8_t, float>::Apply,
+ /* kF64 */ Thunk<uint8_t, double>::Apply,
+
+ // src_type = kI16:
+ /* kI8 */ Thunk<uint16_t, int8_t>::Apply,
+ /* kI16 */ Thunk<uint16_t, int16_t>::Apply,
+ /* kI32 */ Thunk<uint16_t, int32_t>::Apply,
+ /* kI64 */ Thunk<uint16_t, int64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ Thunk<uint16_t, float>::Apply,
+ /* kF64 */ Thunk<uint16_t, double>::Apply,
+
+ // src_type = kI32:
+ /* kI8 */ Thunk<uint32_t, int8_t>::Apply,
+ /* kI16 */ Thunk<uint32_t, int16_t>::Apply,
+ /* kI32 */ Thunk<uint32_t, int32_t>::Apply,
+ /* kI64 */ Thunk<uint32_t, int64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ Thunk<uint32_t, float>::Apply,
+ /* kF64 */ Thunk<uint32_t, double>::Apply,
+
+ // src_type = kI64:
+ /* kI8 */ Thunk<uint64_t, int8_t>::Apply,
+ /* kI16 */ Thunk<uint64_t, int16_t>::Apply,
+ /* kI32 */ Thunk<uint64_t, int32_t>::Apply,
+ /* kI64 */ Thunk<uint64_t, int64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ Thunk<uint64_t, float>::Apply,
+ /* kF64 */ Thunk<uint64_t, double>::Apply,
+
+ // src_type = kF16:
+ /* kI8 */ nullptr,
+ /* kI16 */ nullptr,
+ /* kI32 */ nullptr,
+ /* kI64 */ nullptr,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kF32:
+ /* kI8 */ nullptr,
+ /* kI16 */ nullptr,
+ /* kI32 */ nullptr,
+ /* kI64 */ nullptr,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kF64:
+ /* kI8 */ nullptr,
+ /* kI16 */ nullptr,
+ /* kI32 */ nullptr,
+ /* kI64 */ nullptr,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+ };
+ fn =
+ kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
+ } else if (!src_signed && !dst_signed) {
+ // Unsigned -> unsigned.
+ static const KernelFn
+ kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
+ // src_type = kI8:
+ /* kI8 */ Thunk<uint8_t, uint8_t>::Apply,
+ /* kI16 */ Thunk<uint8_t, uint16_t>::Apply,
+ /* kI32 */ Thunk<uint8_t, uint32_t>::Apply,
+ /* kI64 */ Thunk<uint8_t, uint64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kI16:
+ /* kI8 */ Thunk<uint16_t, uint8_t>::Apply,
+ /* kI16 */ Thunk<uint16_t, uint16_t>::Apply,
+ /* kI32 */ Thunk<uint16_t, uint32_t>::Apply,
+ /* kI64 */ Thunk<uint16_t, uint64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kI32:
+ /* kI8 */ Thunk<uint32_t, uint8_t>::Apply,
+ /* kI16 */ Thunk<uint32_t, uint16_t>::Apply,
+ /* kI32 */ Thunk<uint32_t, uint32_t>::Apply,
+ /* kI64 */ Thunk<uint32_t, uint64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kI64:
+ /* kI8 */ Thunk<uint64_t, uint8_t>::Apply,
+ /* kI16 */ Thunk<uint64_t, uint16_t>::Apply,
+ /* kI32 */ Thunk<uint64_t, uint32_t>::Apply,
+ /* kI64 */ Thunk<uint64_t, uint64_t>::Apply,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kF16:
+ /* kI8 */ nullptr,
+ /* kI16 */ nullptr,
+ /* kI32 */ nullptr,
+ /* kI64 */ nullptr,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kF32:
+ /* kI8 */ nullptr,
+ /* kI16 */ nullptr,
+ /* kI32 */ nullptr,
+ /* kI64 */ nullptr,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+
+ // src_type = kF64:
+ /* kI8 */ nullptr,
+ /* kI16 */ nullptr,
+ /* kI32 */ nullptr,
+ /* kI64 */ nullptr,
+ /* kF16 */ nullptr,
+ /* kF32 */ nullptr,
+ /* kF64 */ nullptr,
+ };
+ fn =
+ kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
+ }
+ if (!fn) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Unsupported conversion from " << src_type_index << " to "
+ << dst_type_index;
+ }
+ return fn(src_local, dst_local, args...);
+ }
+
+ template <typename SRC, typename DST>
+ struct Thunk {
+ static Status Apply(BufferView* src_local, BufferView* dst_local,
+ ARGS... args) {
+ ASSIGN_OR_RETURN(auto src_buffer,
+ src_local->buffer->MapMemory<SRC>(MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<DST>(
+ MemoryAccess::kDiscardWrite));
+ return KERNEL::Execute(src_buffer.contents(),
+ dst_buffer.mutable_contents(), args...);
+ }
+ };
+
+// Disable F32/F64 conversions if they are not supported.
+#if !defined(IREE_SUPPORT_F32)
+ template <typename DST>
+ struct Thunk<float, DST> {
+ static Status Apply(BufferView* src_local, BufferView* dst_local,
+ ARGS... args) {
+ return UnimplementedErrorBuilder(IREE_LOC) << "F32 not supported";
+ }
+ };
+ template <typename SRC>
+ struct Thunk<SRC, float> {
+ static Status Apply(BufferView* src_local, BufferView* dst_local,
+ ARGS... args) {
+ return UnimplementedErrorBuilder(IREE_LOC) << "F32 not supported";
+ }
+ };
+#endif // !IREE_SUPPORT_F32
+#if !defined(IREE_SUPPORT_F64)
+ template <typename DST>
+ struct Thunk<double, DST> {
+ static Status Apply(BufferView* src_local, BufferView* dst_local,
+ ARGS... args) {
+ return UnimplementedErrorBuilder(IREE_LOC) << "F64 not supported";
+ }
+ };
+ template <typename SRC>
+ struct Thunk<SRC, double> {
+ static Status Apply(BufferView* src_local, BufferView* dst_local,
+ ARGS... args) {
+ return UnimplementedErrorBuilder(IREE_LOC) << "F64 not supported";
+ }
+ };
+#endif // !IREE_SUPPORT_F64
+};
+
+using ApplyConvertSS = ApplyConversionOp<kernels::Convert, /*src_signed=*/true,
+ /*dst_signed=*/true>;
+using ApplyConvertUU = ApplyConversionOp<kernels::Convert, /*src_signed=*/false,
+ /*dst_signed=*/false>;
+using ApplyConvertSU = ApplyConversionOp<kernels::Convert, /*src_signed=*/true,
+ /*dst_signed=*/false>;
+using ApplyConvertUS = ApplyConversionOp<kernels::Convert, /*src_signed=*/false,
+ /*dst_signed=*/true>;
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_
diff --git a/hal/interpreter/bytecode_dispatch_util.cc b/hal/interpreter/bytecode_dispatch_util.cc
new file mode 100644
index 0000000..988b525
--- /dev/null
+++ b/hal/interpreter/bytecode_dispatch_util.cc
@@ -0,0 +1,107 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/interpreter/bytecode_dispatch_util.h"
+
+namespace iree {
+namespace hal {
+
+bool BufferViewIsTrue(const BufferView& buffer_view) {
+ if (buffer_view.element_size == 0 || !buffer_view.buffer ||
+ buffer_view.byte_length() == 0) {
+ return false;
+ }
+ // TODO(benvanik): map more efficiently (based on element size?).
+ auto mapping =
+ buffer_view.buffer->MapMemory<uint8_t>(hal::MemoryAccess::kRead);
+ if (!mapping.ok()) {
+ return false;
+ }
+ for (uint8_t value : mapping.ValueOrDie().contents()) {
+ if (value) return true;
+ }
+ return false;
+}
+
+Status ValidateElementwiseUnaryOp(BufferView* src_local,
+ BufferView* dst_local) {
+ // TODO(benvanik): validate shapes.
+ return OkStatus();
+}
+
+Status ValidateElementwiseBinaryOp(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* dst_local) {
+ // TODO(benvanik): validate shapes.
+ return OkStatus();
+}
+
+Status ValidateElementwiseTernaryOp(BufferView* a_local, BufferView* b_local,
+ BufferView* c_local,
+ BufferView* dst_local) {
+ // TODO(benvanik): validate shapes.
+ return OkStatus();
+}
+
+Status ValidateMatMulOpI(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* bias_local,
+ BufferView* multiplier_mantissa_local,
+ BufferView* multiplier_exponent_local,
+ BufferView* dst_local) {
+ // TODO(benvanik): validate shapes.
+ return OkStatus();
+}
+
+Status ValidateMatMulOpF(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* bias_local, BufferView* dst_local) {
+ // TODO(benvanik): validate shapes.
+ return OkStatus();
+}
+
+Status ApplyCopy(BufferView* src_local, absl::Span<const int32_t> src_indices,
+ BufferView* dst_local, absl::Span<const int32_t> dst_indices,
+ absl::Span<const int32_t> lengths) {
+ ASSIGN_OR_RETURN(auto src_buffer,
+ src_local->buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
+ // TODO(benvanik): discard if overwriting the entire buffer.
+ ASSIGN_OR_RETURN(auto dst_buffer,
+ dst_local->buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
+ switch (src_local->element_size) {
+ case 1:
+ return kernels::Copy::Execute<1>(src_buffer.contents(), src_local->shape,
+ src_indices,
+ dst_buffer.mutable_contents(),
+ dst_local->shape, dst_indices, lengths);
+ case 2:
+ return kernels::Copy::Execute<2>(src_buffer.contents(), src_local->shape,
+ src_indices,
+ dst_buffer.mutable_contents(),
+ dst_local->shape, dst_indices, lengths);
+ case 4:
+ return kernels::Copy::Execute<4>(src_buffer.contents(), src_local->shape,
+ src_indices,
+ dst_buffer.mutable_contents(),
+ dst_local->shape, dst_indices, lengths);
+ case 8:
+ return kernels::Copy::Execute<8>(src_buffer.contents(), src_local->shape,
+ src_indices,
+ dst_buffer.mutable_contents(),
+ dst_local->shape, dst_indices, lengths);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << src_local->element_size;
+ }
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/interpreter/bytecode_dispatch_util.h b/hal/interpreter/bytecode_dispatch_util.h
new file mode 100644
index 0000000..65855c1
--- /dev/null
+++ b/hal/interpreter/bytecode_dispatch_util.h
@@ -0,0 +1,513 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Utilities used by the bytecode_dispatch routines to aid in working with the
+// bytecode stream and kernel dispatch.
+
+#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_
+#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_
+
+#include "absl/base/attributes.h"
+#include "absl/container/inlined_vector.h"
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "hal/heap_buffer.h"
+#include "hal/interpreter/bytecode_kernels.h"
+#include "rt/function.h"
+#include "rt/stack.h"
+#include "schemas/bytecode/interpreter_bytecode_v0.h"
+#include "vm/bytecode_reader.h"
+#include "vm/type.h"
+
+// TODO(benvanik): move to dedicated config file/build flags.
+#define IREE_SUPPORT_F32 1
+#define IREE_SUPPORT_F64 1
+
+namespace iree {
+namespace hal {
+
+// Returns true if the contents of the BufferView are bitwise non-zero.
+// Returns false if there is no buffer, the buffer is empty, or the contents are
+// bitwise zero.
+bool BufferViewIsTrue(const BufferView& buffer_view);
+
+Status ValidateElementwiseUnaryOp(BufferView* src_local, BufferView* dst_local);
+Status ValidateElementwiseBinaryOp(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* dst_local);
+Status ValidateElementwiseTernaryOp(BufferView* a_local, BufferView* b_local,
+ BufferView* c_local, BufferView* dst_local);
+Status ValidateMatMulOpI(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* bias_local,
+ BufferView* multiplier_mantissa_local,
+ BufferView* multiplier_exponent_local,
+ BufferView* dst_local);
+Status ValidateMatMulOpF(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* bias_local, BufferView* dst_local);
+
+template <typename KERNEL, typename T, typename... ARGS>
+Status ApplyUnaryOp(BufferView* src_local, BufferView* dst_local,
+ ARGS... args) {
+ // TODO(benvanik): avoid mapping by changing buffer type?
+ ASSIGN_OR_RETURN(auto src_buffer,
+ src_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>(
+ MemoryAccess::kDiscardWrite));
+ return KERNEL::Execute(src_buffer.contents(), dst_buffer.mutable_contents(),
+ args...);
+}
+
+template <typename KERNEL, typename T, typename... ARGS>
+Status ApplyBinaryOp(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* dst_local, ARGS... args) {
+ ASSIGN_OR_RETURN(auto lhs_buffer,
+ lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto rhs_buffer,
+ rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>(
+ MemoryAccess::kDiscardWrite));
+ return KERNEL::Execute(lhs_buffer.contents(), rhs_buffer.contents(),
+ dst_buffer.mutable_contents(), args...);
+}
+
+template <typename KERNEL, typename T, typename... ARGS>
+Status ApplyTernaryOp(BufferView* a_local, BufferView* b_local,
+ BufferView* c_local, BufferView* dst_local,
+ ARGS... args) {
+ ASSIGN_OR_RETURN(auto a_buffer,
+ a_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto b_buffer,
+ b_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto c_buffer,
+ c_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>(
+ MemoryAccess::kDiscardWrite));
+ return KERNEL::Execute(a_buffer.contents(), b_buffer.contents(),
+ c_buffer.contents(), dst_buffer.mutable_contents(),
+ args...);
+}
+
+template <typename KERNEL, typename T>
+Status ApplyComparisonOp(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* dst_local) {
+ ASSIGN_OR_RETURN(auto lhs_buffer,
+ lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto rhs_buffer,
+ rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<uint8_t>(
+ MemoryAccess::kDiscardWrite));
+ return KERNEL::Execute(lhs_buffer.contents(), rhs_buffer.contents(),
+ dst_buffer.mutable_contents());
+}
+
+template <typename KERNEL, typename... ARGS>
+Status ApplyUnaryOpIS(BufferView* src_local, BufferView* dst_local,
+ ARGS... args) {
+ switch (src_local->element_size) {
+ case 1:
+ return ApplyUnaryOp<KERNEL, int8_t>(src_local, dst_local, args...);
+ case 2:
+ return ApplyUnaryOp<KERNEL, int16_t>(src_local, dst_local, args...);
+ case 4:
+ return ApplyUnaryOp<KERNEL, int32_t>(src_local, dst_local, args...);
+ case 8:
+ return ApplyUnaryOp<KERNEL, int64_t>(src_local, dst_local, args...);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << src_local->element_size;
+ }
+}
+
+template <typename KERNEL, typename... ARGS>
+Status ApplyUnaryOpIU(BufferView* src_local, BufferView* dst_local,
+ ARGS... args) {
+ switch (src_local->element_size) {
+ case 1:
+ return ApplyUnaryOp<KERNEL, uint8_t>(src_local, dst_local, args...);
+ case 2:
+ return ApplyUnaryOp<KERNEL, uint16_t>(src_local, dst_local, args...);
+ case 4:
+ return ApplyUnaryOp<KERNEL, uint32_t>(src_local, dst_local, args...);
+ case 8:
+ return ApplyUnaryOp<KERNEL, uint64_t>(src_local, dst_local, args...);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << src_local->element_size;
+ }
+}
+
+template <typename KERNEL, typename... ARGS>
+Status ApplyUnaryOpF(BufferView* src_local, BufferView* dst_local,
+ ARGS... args) {
+ switch (src_local->element_size) {
+#if defined(IREE_SUPPORT_F32)
+ case 4:
+ return ApplyUnaryOp<KERNEL, float>(src_local, dst_local, args...);
+#endif // IREE_SUPPORT_F32
+#if defined(IREE_SUPPORT_F64)
+ case 8:
+ return ApplyUnaryOp<KERNEL, double>(src_local, dst_local, args...);
+#endif // IREE_SUPPORT_F64
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << src_local->element_size;
+ }
+}
+
+template <typename KERNEL, typename... ARGS>
+Status ApplyBinaryOpIS(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* dst_local, ARGS... args) {
+ switch (lhs_local->element_size) {
+ case 1:
+ return ApplyBinaryOp<KERNEL, int8_t>(lhs_local, rhs_local, dst_local,
+ args...);
+ case 2:
+ return ApplyBinaryOp<KERNEL, int16_t>(lhs_local, rhs_local, dst_local,
+ args...);
+ case 4:
+ return ApplyBinaryOp<KERNEL, int32_t>(lhs_local, rhs_local, dst_local,
+ args...);
+ case 8:
+ return ApplyBinaryOp<KERNEL, int64_t>(lhs_local, rhs_local, dst_local,
+ args...);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << lhs_local->element_size;
+ }
+}
+
+template <typename KERNEL, typename... ARGS>
+Status ApplyBinaryOpIU(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* dst_local, ARGS... args) {
+ switch (lhs_local->element_size) {
+ case 1:
+ return ApplyBinaryOp<KERNEL, uint8_t>(lhs_local, rhs_local, dst_local,
+ args...);
+ case 2:
+ return ApplyBinaryOp<KERNEL, uint16_t>(lhs_local, rhs_local, dst_local,
+ args...);
+ case 4:
+ return ApplyBinaryOp<KERNEL, uint32_t>(lhs_local, rhs_local, dst_local,
+ args...);
+ case 8:
+ return ApplyBinaryOp<KERNEL, uint64_t>(lhs_local, rhs_local, dst_local,
+ args...);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << lhs_local->element_size;
+ }
+}
+
+template <typename KERNEL, typename... ARGS>
+Status ApplyBinaryOpF(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* dst_local, ARGS... args) {
+ switch (lhs_local->element_size) {
+#if defined(IREE_SUPPORT_F32)
+ case 4:
+ return ApplyBinaryOp<KERNEL, float>(lhs_local, rhs_local, dst_local,
+ args...);
+#endif // IREE_SUPPORT_F32
+#if defined(IREE_SUPPORT_F64)
+ case 8:
+ return ApplyBinaryOp<KERNEL, double>(lhs_local, rhs_local, dst_local,
+ args...);
+#endif // IREE_SUPPORT_F64
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << lhs_local->element_size;
+ }
+}
+
+template <typename KERNEL, typename... ARGS>
+Status ApplyTernaryOpIS(BufferView* a_local, BufferView* b_local,
+ BufferView* c_local, BufferView* dst_local,
+ ARGS... args) {
+ switch (a_local->element_size) {
+ case 1:
+ return ApplyTernaryOp<KERNEL, int8_t>(a_local, b_local, c_local,
+ dst_local, args...);
+ case 2:
+ return ApplyTernaryOp<KERNEL, int16_t>(a_local, b_local, c_local,
+ dst_local, args...);
+ case 4:
+ return ApplyTernaryOp<KERNEL, int32_t>(a_local, b_local, c_local,
+ dst_local, args...);
+ case 8:
+ return ApplyTernaryOp<KERNEL, int64_t>(a_local, b_local, c_local,
+ dst_local, args...);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << a_local->element_size;
+ }
+}
+
+template <typename KERNEL, typename... ARGS>
+Status ApplyTernaryOpIU(BufferView* a_local, BufferView* b_local,
+ BufferView* c_local, BufferView* dst_local,
+ ARGS... args) {
+ switch (a_local->element_size) {
+ case 1:
+ return ApplyTernaryOp<KERNEL, uint8_t>(a_local, b_local, c_local,
+ dst_local, args...);
+ case 2:
+ return ApplyTernaryOp<KERNEL, uint16_t>(a_local, b_local, c_local,
+ dst_local, args...);
+ case 4:
+ return ApplyTernaryOp<KERNEL, uint32_t>(a_local, b_local, c_local,
+ dst_local, args...);
+ case 8:
+ return ApplyTernaryOp<KERNEL, uint64_t>(a_local, b_local, c_local,
+ dst_local, args...);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << a_local->element_size;
+ }
+}
+
+template <typename KERNEL, typename... ARGS>
+Status ApplyTernaryOpF(BufferView* a_local, BufferView* b_local,
+ BufferView* c_local, BufferView* dst_local,
+ ARGS... args) {
+ switch (a_local->element_size) {
+#if defined(IREE_SUPPORT_F32)
+ case 4:
+ return ApplyTernaryOp<KERNEL, float>(a_local, b_local, c_local, dst_local,
+ args...);
+#endif // IREE_SUPPORT_F32
+#if defined(IREE_SUPPORT_F64)
+ case 8:
+ return ApplyTernaryOp<KERNEL, double>(a_local, b_local, c_local,
+ dst_local, args...);
+#endif // IREE_SUPPORT_F64
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << a_local->element_size;
+ }
+}
+
+template <typename KERNEL>
+Status ApplyComparisonOpIS(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* dst_local) {
+ switch (lhs_local->element_size) {
+ case 1:
+ return ApplyComparisonOp<KERNEL, int8_t>(lhs_local, rhs_local, dst_local);
+ case 2:
+ return ApplyComparisonOp<KERNEL, int16_t>(lhs_local, rhs_local,
+ dst_local);
+ case 4:
+ return ApplyComparisonOp<KERNEL, int32_t>(lhs_local, rhs_local,
+ dst_local);
+ case 8:
+ return ApplyComparisonOp<KERNEL, int64_t>(lhs_local, rhs_local,
+ dst_local);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << lhs_local->element_size;
+ }
+}
+
+template <typename KERNEL>
+Status ApplyComparisonOpIU(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* dst_local) {
+ switch (lhs_local->element_size) {
+ case 1:
+ return ApplyComparisonOp<KERNEL, uint8_t>(lhs_local, rhs_local,
+ dst_local);
+ case 2:
+ return ApplyComparisonOp<KERNEL, uint16_t>(lhs_local, rhs_local,
+ dst_local);
+ case 4:
+ return ApplyComparisonOp<KERNEL, uint32_t>(lhs_local, rhs_local,
+ dst_local);
+ case 8:
+ return ApplyComparisonOp<KERNEL, uint64_t>(lhs_local, rhs_local,
+ dst_local);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << lhs_local->element_size;
+ }
+}
+
+template <typename KERNEL>
+Status ApplyComparisonOpF(BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* dst_local) {
+ switch (lhs_local->element_size) {
+ case 4:
+ return ApplyComparisonOp<KERNEL, float>(lhs_local, rhs_local, dst_local);
+ case 8:
+ return ApplyComparisonOp<KERNEL, double>(lhs_local, rhs_local, dst_local);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented element size: " << lhs_local->element_size;
+ }
+}
+
+template <typename T, typename ACC = int32_t>
+Status ApplyMatMulOpI(kernels::MatMul::RuntimeState* runtime_state,
+ BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* bias_local,
+ BufferView* multiplier_mantissa_local,
+ BufferView* multiplier_exponent_local,
+ BufferView* dst_local) {
+ kernels::MatMul::Buffers<T, ACC> buffers;
+ ASSIGN_OR_RETURN(auto lhs_buffer,
+ lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ buffers.lhs_buffer = lhs_buffer.contents();
+ buffers.lhs_shape = lhs_local->shape;
+ ASSIGN_OR_RETURN(auto rhs_buffer,
+ rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ buffers.rhs_buffer = rhs_buffer.contents();
+ buffers.rhs_shape = rhs_local->shape;
+ MappedMemory<ACC> bias_buffer;
+ if (bias_local && bias_local->buffer && !bias_local->shape.empty()) {
+ if (bias_local->element_size != sizeof(ACC)) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Only " << sizeof(ACC) << "b biases are supported right now";
+ }
+ ASSIGN_OR_RETURN(bias_buffer,
+ bias_local->buffer->MapMemory<ACC>(MemoryAccess::kRead));
+ buffers.bias_buffer = bias_buffer.contents();
+ }
+ ASSIGN_OR_RETURN(
+ auto multiplier_mantissa_buffer,
+ multiplier_mantissa_local->buffer->MapMemory<ACC>(MemoryAccess::kRead));
+ buffers.multiplier_mantissa_buffer = multiplier_mantissa_buffer.contents();
+ ASSIGN_OR_RETURN(auto multiplier_exponent_buffer,
+ multiplier_exponent_local->buffer->MapMemory<int32_t>(
+ MemoryAccess::kRead));
+ buffers.multiplier_exponent_buffer = multiplier_exponent_buffer.contents();
+ ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>(
+ MemoryAccess::kDiscardWrite));
+ buffers.dst_buffer = dst_buffer.mutable_contents();
+ buffers.dst_shape = dst_local->shape;
+ return kernels::MatMul::Execute(runtime_state, buffers);
+}
+
+template <typename T>
+Status ApplyMatMulOpF(kernels::MatMul::RuntimeState* runtime_state,
+ BufferView* lhs_local, BufferView* rhs_local,
+ BufferView* bias_local, BufferView* dst_local) {
+ kernels::MatMul::Buffers<T, T> buffers;
+ ASSIGN_OR_RETURN(auto lhs_buffer,
+ lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ buffers.lhs_buffer = lhs_buffer.contents();
+ buffers.lhs_shape = lhs_local->shape;
+ ASSIGN_OR_RETURN(auto rhs_buffer,
+ rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ buffers.rhs_buffer = rhs_buffer.contents();
+ buffers.rhs_shape = rhs_local->shape;
+ MappedMemory<T> bias_buffer;
+ if (bias_local && bias_local->buffer && !bias_local->shape.empty()) {
+ ASSIGN_OR_RETURN(bias_buffer,
+ bias_local->buffer->MapMemory<T>(MemoryAccess::kRead));
+ buffers.bias_buffer = bias_buffer.contents();
+ }
+ ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>(
+ MemoryAccess::kDiscardWrite));
+ buffers.dst_buffer = dst_buffer.mutable_contents();
+ buffers.dst_shape = dst_local->shape;
+ return kernels::MatMul::Execute(runtime_state, buffers);
+}
+
+template <typename KERNEL>
+Status DispatchElementwiseUnaryOpIS(vm::BytecodeReader* reader) {
+ ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
+ RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local));
+ return ApplyUnaryOpIS<KERNEL>(src_local, dst_local);
+}
+
+template <typename KERNEL>
+Status DispatchElementwiseUnaryOpIU(vm::BytecodeReader* reader) {
+ ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
+ RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local));
+ return ApplyUnaryOpIU<KERNEL>(src_local, dst_local);
+}
+
+template <typename KERNEL>
+Status DispatchElementwiseUnaryOpF(vm::BytecodeReader* reader) {
+ ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
+ RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local));
+ return ApplyUnaryOpF<KERNEL>(src_local, dst_local);
+}
+
+template <typename KERNEL>
+Status DispatchElementwiseBinaryOpIS(vm::BytecodeReader* reader) {
+ ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
+ RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local));
+ return ApplyBinaryOpIS<KERNEL>(lhs_local, rhs_local, dst_local);
+}
+
+template <typename KERNEL>
+Status DispatchElementwiseBinaryOpIU(vm::BytecodeReader* reader) {
+ ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
+ RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local));
+ return ApplyBinaryOpIU<KERNEL>(lhs_local, rhs_local, dst_local);
+}
+
+template <typename KERNEL>
+Status DispatchElementwiseBinaryOpF(vm::BytecodeReader* reader) {
+ ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
+ RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local));
+ return ApplyBinaryOpF<KERNEL>(lhs_local, rhs_local, dst_local);
+}
+
+template <typename KERNEL>
+Status DispatchElementwiseTernaryOpIS(vm::BytecodeReader* reader) {
+ ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
+ RETURN_IF_ERROR(
+ ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local));
+ return ApplyTernaryOpIS<KERNEL>(a_local, b_local, c_local, dst_local);
+}
+
+template <typename KERNEL>
+Status DispatchElementwiseTernaryOpIU(vm::BytecodeReader* reader) {
+ ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
+ RETURN_IF_ERROR(
+ ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local));
+ return ApplyTernaryOpIU<KERNEL>(a_local, b_local, c_local, dst_local);
+}
+
+template <typename KERNEL>
+Status DispatchElementwiseTernaryOpF(vm::BytecodeReader* reader) {
+ ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
+ RETURN_IF_ERROR(
+ ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local));
+ return ApplyTernaryOpF<KERNEL>(a_local, b_local, c_local, dst_local);
+}
+
+Status ApplyCopy(BufferView* src_local, absl::Span<const int32_t> src_indices,
+ BufferView* dst_local, absl::Span<const int32_t> dst_indices,
+ absl::Span<const int32_t> lengths);
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_
diff --git a/hal/interpreter/bytecode_executable.cc b/hal/interpreter/bytecode_executable.cc
new file mode 100644
index 0000000..ff3d6d3
--- /dev/null
+++ b/hal/interpreter/bytecode_executable.cc
@@ -0,0 +1,64 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/interpreter/bytecode_executable.h"
+
+#include <iostream>
+
+#include "hal/interpreter/interpreter_module.h"
+#include "rt/policy.h"
+
+namespace iree {
+namespace hal {
+
+// static
+StatusOr<ref_ptr<BytecodeExecutable>> BytecodeExecutable::Load(
+ ref_ptr<rt::Instance> instance, hal::Allocator* allocator,
+ ExecutableSpec spec, bool allow_aliasing_data) {
+ // Allocate the executable now.
+ // We do this here so that if we need to clone the data we are passing that
+ // to the VM loader instead of the data we may not have access to later.
+ auto executable = make_ref<BytecodeExecutable>(std::move(instance), allocator,
+ spec, allow_aliasing_data);
+
+ // Create the executable module.
+ auto module_def =
+ ::flatbuffers::GetRoot<ModuleDef>(executable->executable_data().data());
+ ASSIGN_OR_RETURN(auto module,
+ InterpreterModule::FromDef(allocator, *module_def));
+ executable->module_ = add_ref(module);
+ RETURN_IF_ERROR(executable->context()->RegisterModule(std::move(module)));
+
+ return executable;
+}
+
+BytecodeExecutable::BytecodeExecutable(ref_ptr<rt::Instance> instance,
+ hal::Allocator* allocator,
+ ExecutableSpec spec,
+ bool allow_aliasing_data)
+ : spec_(spec),
+ context_(
+ make_ref<rt::Context>(std::move(instance), make_ref<rt::Policy>())) {
+ if (!allow_aliasing_data) {
+ // Clone data.
+ cloned_executable_data_ = {spec.executable_data.begin(),
+ spec.executable_data.end()};
+ spec_.executable_data = absl::MakeConstSpan(cloned_executable_data_);
+ }
+}
+
+BytecodeExecutable::~BytecodeExecutable() = default;
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/interpreter/bytecode_executable.h b/hal/interpreter/bytecode_executable.h
new file mode 100644
index 0000000..f90e144
--- /dev/null
+++ b/hal/interpreter/bytecode_executable.h
@@ -0,0 +1,68 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_
+#define IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_
+
+#include <vector>
+
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "hal/allocator.h"
+#include "hal/executable.h"
+#include "hal/executable_spec.h"
+#include "rt/context.h"
+#include "rt/instance.h"
+#include "rt/module.h"
+
+namespace iree {
+namespace hal {
+
+class BytecodeExecutable final : public Executable {
+ public:
+ static StatusOr<ref_ptr<BytecodeExecutable>> Load(
+ ref_ptr<rt::Instance> instance, hal::Allocator* allocator,
+ ExecutableSpec spec, bool allow_aliasing_data);
+
+ BytecodeExecutable(ref_ptr<rt::Instance> instance, hal::Allocator* allocator,
+ ExecutableSpec spec, bool allow_aliasing_data);
+ ~BytecodeExecutable() override;
+
+ bool supports_debugging() const override { return false; }
+
+ // Reference to the bytecode blob contents.
+ absl::Span<const uint8_t> executable_data() const {
+ return spec_.executable_data;
+ }
+
+ // VM context with the executable registered.
+ const ref_ptr<rt::Context>& context() const { return context_; }
+
+ // VM module representing the executable.
+ // Note that there may be more than one module in the Context and only this
+ // module can be used to lookup executable exports.
+ const ref_ptr<rt::Module>& module() const { return module_; }
+
+ private:
+ ExecutableSpec spec_;
+ std::vector<uint8_t> cloned_executable_data_;
+
+ ref_ptr<rt::Context> context_;
+ ref_ptr<rt::Module> module_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_
diff --git a/hal/interpreter/bytecode_kernels.h b/hal/interpreter/bytecode_kernels.h
new file mode 100644
index 0000000..2e7e90c
--- /dev/null
+++ b/hal/interpreter/bytecode_kernels.h
@@ -0,0 +1,371 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Defines kernel functions and provides their implementation via one (or more)
+// included files.
+//
+// Kernels should do the simplest possible operation. Buffer validation is
+// handled by the dispatch logic and need not be checked. Kernels may optionally
+// accept arguments beyond just the buffers, depending on the required state
+// and attributes.
+//
+// Kernels may optionally have runtime state. This is state that is allocated
+// once for the entire Runtime (and stored on RuntimeState) and shared across
+// all fibers. This enables kernels that may require thread pools or device
+// handles to be shared while kernels that require transient storage to be safe
+// to use from multiple fibers concurrently.
+//
+// All kernels are templated to enable specialization of particular types or
+// type combinations. By default the bytecode_kernels_generic.h will provide C++
+// semantics as reference and platform-specific versions can be implemented
+// as needed.
+
+#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_
+#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_
+
+#include <cstdint>
+
+#include "absl/types/span.h"
+#include "base/shape.h"
+#include "base/status.h"
+
+namespace iree {
+namespace hal {
+namespace kernels {
+
+struct CompareEQ {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer);
+};
+struct CompareNE {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer);
+};
+struct CompareLT {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer);
+};
+struct CompareLE {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer);
+};
+struct CompareGT {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer);
+};
+struct CompareGE {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer);
+};
+
+struct Copy {
+ template <int element_size>
+ static Status Execute(absl::Span<const uint8_t> src_buffer,
+ const Shape& src_shape,
+ absl::Span<const int32_t> src_indices,
+ absl::Span<uint8_t> dst_buffer, const Shape& dst_shape,
+ absl::Span<const int32_t> dst_indices,
+ absl::Span<const int32_t> lengths);
+};
+
+struct Select {
+ template <typename T>
+ static Status Execute(absl::Span<const uint8_t> cond_buffer,
+ absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Transpose {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer, const Shape& src_shape,
+ absl::Span<const int32_t> perm);
+};
+
+struct Pad {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> padding_value,
+ absl::Span<T> dst_buffer, const Shape& src_shape,
+ const Shape& dst_shape,
+ absl::Span<const int32_t> edge_padding_low,
+ absl::Span<const int32_t> edge_padding_high,
+ absl::Span<const int32_t> interior_padding);
+};
+
+struct Reverse {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer, const Shape& src_shape,
+ absl::Span<const int32_t> dimensions);
+};
+
+struct Broadcast {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Tile {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer, const Shape& src_shape,
+ const Shape& dst_shape);
+};
+
+struct Not {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct And {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Or {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Xor {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct ShiftLeft {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct ShiftRight {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Add {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Sub {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Abs {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Mul {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Div {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+// a + (b * c)
+struct MulAdd {
+ template <typename T>
+ static Status Execute(absl::Span<const T> a_buffer,
+ absl::Span<const T> b_buffer,
+ absl::Span<const T> c_buffer, absl::Span<T> dst_buffer);
+};
+
+struct Exp {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Log {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Rsqrt {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Cos {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Sin {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Tanh {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Atan2 {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Min {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Max {
+ template <typename T>
+ static Status Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Clamp {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> min_buffer,
+ absl::Span<const T> max_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Floor {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Ceil {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer);
+};
+
+struct Convert {
+ template <typename SRC, typename DST>
+ static Status Execute(absl::Span<const SRC> src_buffer,
+ absl::Span<DST> dst_buffer);
+};
+
+struct MatMul {
+ struct RuntimeState;
+
+ static std::unique_ptr<RuntimeState> CreateRuntimeState();
+
+ template <typename T, typename ACC>
+ struct Buffers {
+ Shape lhs_shape;
+ absl::Span<const T> lhs_buffer;
+ Shape rhs_shape;
+ absl::Span<const T> rhs_buffer;
+ Shape dst_shape;
+ absl::Span<T> dst_buffer;
+
+ // Optional bias buffer.
+ absl::Span<const ACC> bias_buffer;
+
+ // Fixed-point multiplier mantissa/exponent. May be a single value (for
+ // uniform quantization) or one element per row of the destination matrix
+ // for per-channel.
+ absl::Span<const ACC> multiplier_mantissa_buffer;
+ absl::Span<const int32_t> multiplier_exponent_buffer;
+ };
+
+ template <typename T, typename ACC>
+ static Status Execute(RuntimeState* runtime_state,
+ const Buffers<T, ACC>& buffers);
+};
+
+struct RuntimeState {
+ std::unique_ptr<MatMul::RuntimeState> mat_mul_state =
+ MatMul::CreateRuntimeState();
+};
+
+struct ReduceSum {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer,
+ absl::Span<T> dst_buffer, int32_t dimension,
+ const Shape& src_shape, const Shape& dst_shape);
+};
+
+struct ReduceMin {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer,
+ absl::Span<T> dst_buffer, int32_t dimension,
+ const Shape& src_shape, const Shape& dst_shape);
+};
+
+struct ReduceMax {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer,
+ absl::Span<T> dst_buffer, int32_t dimension,
+ const Shape& src_shape, const Shape& dst_shape);
+};
+
+} // namespace kernels
+} // namespace hal
+} // namespace iree
+
+#include "hal/interpreter/bytecode_kernels_generic.h" // IWYU pragma: export
+#include "hal/interpreter/bytecode_kernels_ruy.h" // IWYU pragma: export
+
+#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_
diff --git a/hal/interpreter/bytecode_kernels_generic.h b/hal/interpreter/bytecode_kernels_generic.h
new file mode 100644
index 0000000..95d12dc
--- /dev/null
+++ b/hal/interpreter/bytecode_kernels_generic.h
@@ -0,0 +1,708 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_
+#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
+#include "base/status.h"
+
+namespace iree {
+namespace hal {
+namespace kernels {
+
+template <typename T>
+Status CompareEQ::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] == rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status CompareNE::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] != rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status CompareLT::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] < rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status CompareLE::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] <= rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status CompareGT::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] > rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status CompareGE::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<uint8_t> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] >= rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+namespace impl {
+inline absl::InlinedVector<size_t, 6> ComputeCopyStrides(const Shape& shape,
+ size_t element_size) {
+ absl::InlinedVector<size_t, 6> strides(shape.size());
+ strides.back() = element_size;
+ for (int i = shape.size() - 2; i >= 0; --i) {
+ strides[i] = strides[i + 1] * shape[i + 1];
+ }
+ return strides;
+}
+
+inline void CopyRegion(absl::Span<const uint8_t> src_buffer,
+ absl::Span<const size_t> src_strides,
+ absl::Span<const int32_t> src_indices,
+ absl::Span<uint8_t> dst_buffer,
+ absl::Span<const size_t> dst_strides,
+ absl::Span<const int32_t> dst_indices,
+ absl::Span<const int32_t> lengths) {
+ if (lengths.size() > 1) {
+ for (int i = 0; i < lengths[0]; ++i) {
+ size_t src_offset = src_strides[0] * (src_indices[0] + i);
+ size_t dst_offset = dst_strides[0] * (dst_indices[0] + i);
+ CopyRegion(src_buffer.subspan(src_offset), src_strides.subspan(1),
+ src_indices.subspan(1), dst_buffer.subspan(dst_offset),
+ dst_strides.subspan(1), dst_indices.subspan(1),
+ lengths.subspan(1));
+ }
+ } else {
+ DCHECK_EQ(dst_strides.size(), 1);
+ DCHECK_EQ(src_strides.size(), 1);
+ DCHECK_EQ(src_indices.size(), 1);
+ DCHECK_EQ(dst_indices.size(), 1);
+ DCHECK_EQ(lengths.size(), 1);
+ auto src_offset = src_indices[0] * src_strides[0];
+ auto dst_offset = dst_indices[0] * dst_strides[0];
+ auto length = dst_strides[0] * lengths[0];
+ std::memcpy(dst_buffer.data() + dst_offset, src_buffer.data() + src_offset,
+ length);
+ }
+}
+} // namespace impl
+
+// TODO(benvanik): replace with a real implementation once copy is defined.
+// TODO(gcmn): More consistent/principled handling for scalars.
+template <int element_size>
+Status Copy::Execute(absl::Span<const uint8_t> src_buffer,
+ const Shape& src_shape,
+ absl::Span<const int32_t> src_indices,
+ absl::Span<uint8_t> dst_buffer, const Shape& dst_shape,
+ absl::Span<const int32_t> dst_indices,
+ absl::Span<const int32_t> lengths) {
+ DCHECK_EQ(src_indices.size(), lengths.size());
+ DCHECK_EQ(dst_indices.size(), lengths.size());
+ DCHECK_EQ(src_shape.size(), lengths.size());
+ DCHECK_EQ(dst_shape.size(), lengths.size());
+ if (lengths.empty()) {
+ std::memcpy(dst_buffer.data(), src_buffer.data(), element_size);
+ return OkStatus();
+ }
+
+ // TODO(gcmn) Maybe we can fast-path earlier if we detect contiguous memory
+ // across multiple rows.
+ auto src_strides = impl::ComputeCopyStrides(src_shape, element_size);
+ auto dst_strides = impl::ComputeCopyStrides(dst_shape, element_size);
+ DCHECK_EQ(src_strides.size(), lengths.size());
+ DCHECK_EQ(dst_strides.size(), lengths.size());
+ impl::CopyRegion(src_buffer, src_strides, src_indices, dst_buffer,
+ dst_strides, dst_indices, lengths);
+ return OkStatus();
+}
+
+template <typename T>
+Status Select::Execute(absl::Span<const uint8_t> cond_buffer,
+ absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = cond_buffer[i] ? lhs_buffer[i] : rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Transpose::Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer, const Shape& src_shape,
+ absl::Span<const int32_t> perm) {
+ // This implementation is .... not fast.
+ int rank = src_shape.size();
+ absl::InlinedVector<int, 8> src_strides(rank);
+ absl::InlinedVector<int, 8> dst_strides(rank);
+ size_t src_stride = 1;
+ size_t dst_stride = 1;
+ for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
+ src_strides[dim_i] = src_stride;
+ dst_strides[dim_i] = dst_stride;
+ src_stride *= src_shape[dim_i];
+ dst_stride *= src_shape[perm[dim_i]];
+ }
+ for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
+ size_t src_i = 0;
+ size_t t = dst_i;
+ for (int dim_i = 0; dim_i < rank; ++dim_i) {
+ size_t ratio = t / dst_strides[dim_i];
+ t -= ratio * dst_strides[dim_i];
+ src_i += ratio * src_strides[perm[dim_i]];
+ }
+ dst_buffer[dst_i] = src_buffer[src_i];
+ }
+ return OkStatus();
+}
+
+namespace impl {
+inline void IncrementShapeIndex(absl::Span<int32_t> indices,
+ const Shape& shape) {
+ for (int i = indices.size() - 1; i >= 0; --i) {
+ if (++indices[i] < shape[i]) return;
+ indices[i] = 0;
+ }
+}
+
+inline bool IsPadding(absl::Span<const int32_t> indices, const Shape& shape,
+ absl::Span<const int32_t> edge_padding_low,
+ absl::Span<const int32_t> edge_padding_high,
+ absl::Span<const int32_t> interior_padding) {
+ for (int i = 0; i < indices.size(); ++i) {
+ auto index = indices[i];
+ if (index < edge_padding_low[i] ||
+ index >= shape[i] - edge_padding_high[i] ||
+ (index - edge_padding_low[i]) % (interior_padding[i] + 1) != 0) {
+ return true;
+ }
+ }
+
+ return false;
+}
+} // namespace impl
+
+template <typename T>
+Status Pad::Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> padding_value_buffer,
+ absl::Span<T> dst_buffer, const Shape& src_shape,
+ const Shape& dst_shape,
+ absl::Span<const int32_t> edge_padding_low,
+ absl::Span<const int32_t> edge_padding_high,
+ absl::Span<const int32_t> interior_padding) {
+ // This implementation is not at all fast, as it iterates every index in the
+ // destination buffer individually. Potential improvements:
+ // 1. Fill the dst buffer with padded value initially. Only need to iterate
+ // through source buffer and can exit early.
+ // 2. Use striding to advance through larger swaths of the buffer with a
+ // memcpy from src and filling (or skipping) padded incides. Especially
+ // useful when e.g. entire rows are padded.
+
+ // TODO(b/140836672) support negative padding
+
+ if (padding_value_buffer.size() != 1) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Padding value buffer is larger than one element.";
+ }
+ auto padding_value = padding_value_buffer.front();
+
+ absl::InlinedVector<int, 8> dst_indices(src_shape.size(), 0);
+
+ const T* src_ptr = src_buffer.begin();
+ T* dst_ptr = dst_buffer.begin();
+ while (dst_ptr != dst_buffer.end()) {
+ if (impl::IsPadding(dst_indices, dst_shape, edge_padding_low,
+ edge_padding_high, interior_padding)) {
+ *dst_ptr++ = padding_value;
+ } else {
+ DCHECK(src_ptr != src_buffer.end());
+ *dst_ptr++ = *src_ptr++;
+ }
+ impl::IncrementShapeIndex(absl::MakeSpan(dst_indices), dst_shape);
+ }
+
+ return OkStatus();
+}
+
+template <typename T>
+Status Reverse::Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer, const Shape& src_shape,
+ absl::Span<const int32_t> dimensions) {
+ // This implementation is not fast either
+ int rank = src_shape.size();
+ absl::InlinedVector<int, 8> strides(rank);
+ size_t stride = 1;
+ for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
+ strides[dim_i] = stride;
+ stride *= src_shape[dim_i];
+ }
+ absl::flat_hash_set<int32_t> dims_set(dimensions.begin(), dimensions.end());
+ for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
+ size_t src_i = 0;
+ size_t t = dst_i;
+ for (int dim_i = 0; dim_i < rank; ++dim_i) {
+ size_t ratio = t / strides[dim_i];
+ t -= ratio * strides[dim_i];
+ bool do_reverse = dims_set.contains(dim_i);
+ src_i += (do_reverse ? (src_shape[dim_i] - 1 - ratio) : ratio) *
+ strides[dim_i];
+ }
+ dst_buffer[dst_i] = src_buffer[src_i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Broadcast::Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = src_buffer[0];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Tile::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer,
+ const Shape& src_shape, const Shape& dst_shape) {
+ // This implementation is .... not fast.
+ int rank = dst_shape.size();
+ absl::InlinedVector<int, 8> src_strides(rank);
+ absl::InlinedVector<int, 8> dst_strides(rank);
+ size_t src_stride = 1;
+ size_t dst_stride = 1;
+ for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
+ src_strides[dim_i] = src_stride;
+ dst_strides[dim_i] = dst_stride;
+ src_stride *= src_shape[dim_i];
+ dst_stride *= dst_shape[dim_i];
+ }
+ for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
+ size_t src_i = 0;
+ size_t t = dst_i;
+ for (int dim_i = 0; dim_i < rank; ++dim_i) {
+ src_i += t / dst_strides[dim_i] % src_shape[dim_i] * src_strides[dim_i];
+ t %= dst_strides[dim_i];
+ }
+ dst_buffer[dst_i] = src_buffer[src_i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Not::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = ~src_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status And::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] & rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Or::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] | rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Xor::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] ^ rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status ShiftLeft::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] << rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status ShiftRight::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] >> rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Add::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] + rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Sub::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] - rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Abs::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::abs(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Mul::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] * rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Div::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = lhs_buffer[i] / rhs_buffer[i];
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status MulAdd::Execute(absl::Span<const T> a_buffer,
+ absl::Span<const T> b_buffer,
+ absl::Span<const T> c_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = a_buffer[i] + (b_buffer[i] * c_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Exp::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::exp(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Rsqrt::Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = 1.0 / std::sqrt(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Log::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::log(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Cos::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::cos(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Sin::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::sin(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Tanh::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::tanh(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Atan2::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer,
+ absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::atan2(lhs_buffer[i], rhs_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Min::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::min(lhs_buffer[i], rhs_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Max::Execute(absl::Span<const T> lhs_buffer,
+ absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::max(lhs_buffer[i], rhs_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Clamp::Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> min_buffer,
+ absl::Span<const T> max_buffer,
+ absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ T src = src_buffer[i];
+ T min = min_buffer[i];
+ T max = max_buffer[i];
+ dst_buffer[i] = src <= min ? min : src >= max ? max : src;
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Floor::Execute(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::floor(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename T>
+Status Ceil::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = std::ceil(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+template <typename SRC, typename DST>
+Status Convert::Execute(absl::Span<const SRC> src_buffer,
+ absl::Span<DST> dst_buffer) {
+ DCHECK_EQ(src_buffer.size(), dst_buffer.size());
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
+ dst_buffer[i] = static_cast<DST>(src_buffer[i]);
+ }
+ return OkStatus();
+}
+
+namespace impl {
+
+struct SumKernel {
+ template <typename T>
+ inline void operator()(T* value0, const T value1) {
+ *value0 += value1;
+ }
+};
+
+struct MinKernel {
+ template <typename T>
+ inline void operator()(T* value0, const T value1) {
+ *value0 = std::min(*value0, value1);
+ }
+};
+
+struct MaxKernel {
+ template <typename T>
+ inline void operator()(T* value0, const T value1) {
+ *value0 = std::max(*value0, value1);
+ }
+};
+
+template <typename T, typename KernelImpl>
+inline void ReduceDimension(absl::Span<const T> src_buffer,
+ absl::Span<T> dst_buffer, const Shape& src_shape,
+ absl::Span<const int32_t> reduce_dims,
+ absl::Span<const int> dst_strides, int dim,
+ absl::Span<int> src_indices, size_t flat_src_i,
+ size_t src_stride) {
+ if (dim < 0) {
+ // Base case of the recursion - figure out which elements should be acted
+ // upon and apply the reduction kernel to them.
+
+ // Derive destination indices from source indices.
+ // For example,
+ // reduce_dims: [1, 2]
+ // src_indices: [2, 1, 3, 0]
+ // ^ ^
+ // | |
+ // |----- remove these dimensions
+ // dst_indices: [2, 0]
+ //
+ // TODO(scotttodd): Clean this up somehow, share across recursion levels?
+ size_t dst_size = src_shape.size() - reduce_dims.size();
+ absl::InlinedVector<int, 8> dst_indices;
+ for (size_t i = 0; i < src_indices.size(); ++i) {
+ if (std::find(std::begin(reduce_dims), std::end(reduce_dims), i) ==
+ std::end(reduce_dims)) {
+ dst_indices.push_back(src_indices[i]);
+ }
+ }
+ // Compute the flattened index into dst_buffer at [dst_indices].
+ size_t dst_i = 0;
+ for (size_t i = 0; i < dst_indices.size(); ++i) {
+ dst_i += dst_indices[i] * dst_strides[dst_size - 1 - i];
+ }
+
+ // Flattened src and dst indices have been computed, invoke the kernel.
+ KernelImpl()(&dst_buffer[dst_i], src_buffer[flat_src_i]);
+ return;
+ }
+
+ // Iterate through the current dimension in the source shape, recursing
+ // down one dimension at a time.
+ //
+ // This touches each element in the source buffer once, tracking complete
+ // dimensions within the shaped source buffer and using them to compute
+ // the corresponding indices (shaped and flattened) within the destination
+ // buffer. Each element in the destination buffer will be touched multiple
+ // times.
+ //
+ // Note that cache coherency isn't considered here, and some computations
+ // are redundant, so this could be optimized substantially.
+ for (size_t dim_i = 0; dim_i < src_shape[dim]; ++dim_i) {
+ src_indices[dim] = dim_i;
+
+ // Recurse down to the next dimension (e.g. 2 -> 1 -> 0 -> base case)
+ // * Add the current stride to flat_src_i
+ // * Multiply src_stride by this dimension's shape
+ ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape,
+ reduce_dims, dst_strides, dim - 1,
+ src_indices, flat_src_i + dim_i * src_stride,
+ src_stride * src_shape[dim]);
+ }
+}
+
+template <typename T, typename KernelImpl>
+Status GenericReduce(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer, absl::Span<T> dst_buffer,
+ int32_t dimension, const Shape& src_shape,
+ const Shape& dst_shape) {
+ // Initialize using init_buffer, which is expected to be a scalar.
+ std::fill_n(dst_buffer.data(), dst_buffer.size(), init_buffer[0]);
+
+ // Precompute destination strides.
+ int dst_rank = dst_shape.size();
+ absl::InlinedVector<int, 8> dst_strides;
+ size_t dst_stride = 1;
+ for (int dim_i = dst_rank - 1; dim_i >= 0; --dim_i) {
+ dst_strides.push_back(dst_stride);
+ dst_stride *= dst_shape[dim_i];
+ }
+
+ // Call the helper (recursive) function, starting with:
+ // * source index [0, 0, ..., 0]
+ // * the innermost dimension (last in the shape)
+ // * flat_src_i of 0 (corresponds to [0, 0, ..., 0] above)
+ // * source stride 1
+ absl::InlinedVector<int, 8> src_indices(src_shape.size(), 0);
+ ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape, {dimension},
+ absl::MakeSpan(dst_strides),
+ src_shape.size() - 1,
+ absl::MakeSpan(src_indices), 0, 1);
+
+ return OkStatus();
+}
+
+} // namespace impl
+
+template <typename T>
+Status ReduceSum::Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer,
+ absl::Span<T> dst_buffer, int32_t dimension,
+ const Shape& src_shape, const Shape& dst_shape) {
+ return impl::GenericReduce<T, impl::SumKernel>(
+ src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
+}
+
+template <typename T>
+Status ReduceMin::Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer,
+ absl::Span<T> dst_buffer, int32_t dimension,
+ const Shape& src_shape, const Shape& dst_shape) {
+ return impl::GenericReduce<T, impl::MinKernel>(
+ src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
+}
+
+template <typename T>
+Status ReduceMax::Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer,
+ absl::Span<T> dst_buffer, int32_t dimension,
+ const Shape& src_shape, const Shape& dst_shape) {
+ return impl::GenericReduce<T, impl::MaxKernel>(
+ src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
+}
+
+} // namespace kernels
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_
diff --git a/hal/interpreter/bytecode_kernels_ruy.h b/hal/interpreter/bytecode_kernels_ruy.h
new file mode 100644
index 0000000..ae36844
--- /dev/null
+++ b/hal/interpreter/bytecode_kernels_ruy.h
@@ -0,0 +1,81 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_
+#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_
+
+#include "absl/base/thread_annotations.h"
+#include "absl/memory/memory.h"
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "tensorflow/lite/experimental/ruy/context.h"
+#include "tensorflow/lite/experimental/ruy/ruy.h"
+
+namespace iree {
+namespace hal {
+namespace kernels {
+
+// TODO(benvanik): something more clever for making this shareable.
+// Maybe a factory fn based on the impl selected?
+struct MatMul::RuntimeState {
+ // TODO(benvanik): share the thread pool but keep context per-fiber?
+ ruy::Context context;
+};
+
+inline std::unique_ptr<MatMul::RuntimeState> MatMul::CreateRuntimeState() {
+ return absl::make_unique<RuntimeState>();
+}
+
+template <typename T, typename ACC>
+Status MatMul::Execute(RuntimeState* runtime_state,
+ const Buffers<T, ACC>& buffers) {
+ ruy::Matrix<T> lhs_matrix;
+ ruy::MakeSimpleLayout(buffers.lhs_shape[0], buffers.lhs_shape[1],
+ ruy::Order::kRowMajor, &lhs_matrix.layout);
+ lhs_matrix.data.set(buffers.lhs_buffer.data());
+
+ ruy::Matrix<T> rhs_matrix;
+ ruy::MakeSimpleLayout(buffers.rhs_shape[0], buffers.rhs_shape[1],
+ ruy::Order::kRowMajor, &rhs_matrix.layout);
+ rhs_matrix.data.set(buffers.rhs_buffer.data());
+
+ ruy::Matrix<T> dst_matrix;
+ ruy::MakeSimpleLayout(buffers.dst_shape[0], buffers.dst_shape[1],
+ ruy::Order::kRowMajor, &dst_matrix.layout);
+ dst_matrix.data.set(buffers.dst_buffer.data());
+
+ ruy::BasicSpec<ACC, T> spec;
+ spec.bias = buffers.bias_buffer.data();
+
+ if (buffers.multiplier_mantissa_buffer.size() == 1) {
+ spec.multiplier_fixedpoint = buffers.multiplier_mantissa_buffer[0];
+ spec.multiplier_exponent = buffers.multiplier_exponent_buffer[0];
+ } else {
+ spec.multiplier_fixedpoint_perchannel =
+ buffers.multiplier_mantissa_buffer.data();
+ spec.multiplier_exponent_perchannel =
+ buffers.multiplier_exponent_buffer.data();
+ }
+
+ ruy::Mul<ruy::kAllPaths>(lhs_matrix, rhs_matrix, spec,
+ &runtime_state->context, &dst_matrix);
+
+ return OkStatus();
+}
+
+} // namespace kernels
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_
diff --git a/hal/interpreter/bytecode_kernels_test.cc b/hal/interpreter/bytecode_kernels_test.cc
new file mode 100644
index 0000000..db99add
--- /dev/null
+++ b/hal/interpreter/bytecode_kernels_test.cc
@@ -0,0 +1,407 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/interpreter/bytecode_kernels.h"
+
+#include "base/memory.h"
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace iree {
+namespace hal {
+namespace kernels {
+
+namespace {
+
+constexpr float kEpsilon = 0.0001f;
+
+template <typename T>
+std::vector<T> MakeIota(int size) {
+ std::vector<T> v(size);
+ std::iota(v.begin(), v.end(), static_cast<T>(1));
+ return v;
+}
+
+TEST(Copy, WholeBuffer) {
+ Shape src_shape = {2, 2};
+ auto src_buffer = MakeIota<uint8_t>(4);
+ std::vector<int32_t> src_indices = {0, 0};
+ Shape dst_shape = src_shape;
+ std::vector<uint8_t> dst_buffer(dst_shape.element_count());
+ std::vector<int32_t> dst_indices = {0, 0};
+ std::vector<int32_t> lengths = {2, 2};
+ auto expected_dst = src_buffer;
+
+ EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
+ absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
+ lengths));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Copy, FirstRow) {
+ Shape src_shape = {3, 4};
+ auto src_buffer = MakeIota<uint8_t>(12);
+ std::vector<int32_t> src_indices = {0, 0};
+ Shape dst_shape = {1, 4};
+ std::vector<uint8_t> dst_buffer(dst_shape.element_count());
+ std::vector<int32_t> dst_indices = {0, 0};
+ std::vector<int32_t> lengths = {1, 4};
+ std::vector<uint8_t> expected_dst = {1, 2, 3, 4};
+
+ EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
+ absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
+ lengths));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Copy, RowPart) {
+ Shape src_shape = {3, 4};
+ auto src_buffer = MakeIota<uint8_t>(12);
+ std::vector<int32_t> src_indices = {1, 1};
+ Shape dst_shape = {1, 2};
+ std::vector<uint8_t> dst_buffer(dst_shape.element_count());
+ std::vector<int32_t> dst_indices = {0, 0};
+ std::vector<int32_t> lengths = {1, 2};
+ std::vector<uint8_t> expected_dst = {6, 7};
+
+ EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
+ absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
+ lengths));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Copy, MultiRow) {
+ Shape src_shape = {3, 4};
+ auto src_buffer = MakeIota<uint8_t>(12);
+ std::vector<int32_t> src_indices = {1, 0};
+ Shape dst_shape = {2, 4};
+ std::vector<uint8_t> dst_buffer(dst_shape.element_count());
+ std::vector<int32_t> dst_indices = {0, 0};
+ std::vector<int32_t> lengths = {2, 4};
+ std::vector<uint8_t> expected_dst = {5, 6, 7, 8, 9, 10, 11, 12};
+
+ EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
+ absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
+ lengths));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Copy, NonContiguous) {
+ Shape src_shape = {3, 4};
+ auto src_buffer = MakeIota<uint8_t>(12);
+ std::vector<int32_t> src_indices = {1, 1};
+ Shape dst_shape = {2, 2};
+ std::vector<uint8_t> dst_buffer(dst_shape.element_count());
+ std::vector<int32_t> dst_indices = {0, 0};
+ std::vector<int32_t> lengths = {2, 2};
+ std::vector<uint8_t> expected_dst = {6, 7, 10, 11};
+
+ EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
+ absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
+ lengths));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Copy, MultiByte) {
+ Shape src_shape = {3, 4};
+ auto src_vals = MakeIota<int32_t>(12);
+ auto src_buffer = ReinterpretSpan<uint8_t>(absl::MakeSpan(src_vals));
+ std::vector<int32_t> src_indices = {1, 1};
+ Shape dst_shape = {2, 2};
+ std::vector<uint8_t> dst_buffer(dst_shape.element_count() * sizeof(int32_t));
+ std::vector<int32_t> dst_indices = {0, 0};
+ std::vector<int32_t> lengths = {2, 2};
+ std::vector<int32_t> expected_dst = {6, 7, 10, 11};
+
+ EXPECT_OK(Copy::Execute<4>(src_buffer, src_shape, src_indices,
+ absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
+ lengths));
+
+ absl::Span<int32_t> dst_buffer_int32_t =
+ ReinterpretSpan<int32_t>(absl::MakeSpan(dst_buffer));
+
+ EXPECT_EQ(dst_buffer_int32_t, expected_dst);
+}
+
+TEST(Copy, NotFullDst) {
+ Shape src_shape = {3, 4};
+ auto src_buffer = MakeIota<uint8_t>(12);
+ std::vector<int32_t> src_indices = {0, 0};
+ Shape dst_shape = {4, 3};
+ std::vector<uint8_t> dst_buffer(12, 42);
+ std::vector<int32_t> dst_indices = {1, 1};
+ std::vector<int32_t> lengths = {2, 2};
+ // clang-format off
+ std::vector<uint8_t> expected_dst = {42, 42, 42,
+ 42, 1, 2,
+ 42, 5, 6,
+ 42, 42, 42};
+ // clang-format on
+
+ EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
+ absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
+ lengths));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Copy, HighRank) {
+ Shape src_shape = {3, 3, 3, 3};
+ auto src_buffer = MakeIota<uint8_t>(81);
+ std::vector<int32_t> src_indices = {1, 1, 1, 1};
+ Shape dst_shape = {2, 2, 2, 2};
+ std::vector<uint8_t> dst_buffer(dst_shape.element_count());
+ std::vector<int32_t> dst_indices = {0, 0, 0, 0};
+ std::vector<int32_t> lengths = {2, 2, 2, 2};
+ std::vector<uint8_t> expected_dst = {41, 42, 44, 45, 50, 51, 53, 54,
+ 68, 69, 71, 72, 77, 78, 80, 81};
+
+ EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
+ absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
+ lengths));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Copy, Scalar) {
+ Shape src_shape = {};
+ std::vector<uint8_t> src_buffer = {42};
+ std::vector<int32_t> src_indices = {};
+ Shape dst_shape = {};
+ std::vector<uint8_t> dst_buffer(dst_shape.element_count());
+ std::vector<int32_t> dst_indices = {};
+ std::vector<int32_t> lengths = {};
+ std::vector<uint8_t> expected_dst = {42};
+
+ EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
+ absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
+ lengths));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Copy, ScalarMultiByte) {
+ Shape src_shape = {};
+ std::vector<int32_t> src_vals = {INT32_MAX};
+ auto src_buffer = ReinterpretSpan<uint8_t>(absl::MakeSpan(src_vals));
+ std::vector<int32_t> src_indices = {};
+ Shape dst_shape = {};
+ std::vector<uint8_t> dst_buffer(sizeof(int32_t));
+ std::vector<int32_t> dst_indices = {};
+ std::vector<int32_t> lengths = {};
+ std::vector<int32_t> expected_dst = {INT32_MAX};
+
+ EXPECT_OK(Copy::Execute<4>(src_buffer, src_shape, src_indices,
+ absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
+ lengths));
+
+ absl::Span<int32_t> dst_buffer_int32_t =
+ ReinterpretSpan<int32_t>(absl::MakeSpan(dst_buffer));
+
+ EXPECT_EQ(dst_buffer_int32_t, expected_dst);
+}
+
+TEST(Pad, NoPadding) {
+ Shape src_shape = {2, 3};
+ auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
+ std::vector<uint16_t> pad_value_buffer = {0};
+ std::vector<int32_t> edge_padding_low = {0, 0};
+ std::vector<int32_t> edge_padding_high = {0, 0};
+ std::vector<int32_t> interior_padding = {0, 0};
+ Shape dst_shape = src_shape;
+ std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
+ auto expected_dst = src_buffer;
+
+ EXPECT_OK(Pad::Execute<uint16_t>(
+ src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
+ dst_shape, edge_padding_low, edge_padding_high, interior_padding));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Pad, LowHighPadding) {
+ Shape src_shape = {2, 3};
+ auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
+ std::vector<uint16_t> pad_value_buffer = {0};
+ std::vector<int32_t> edge_padding_low = {0, 1};
+ std::vector<int32_t> edge_padding_high = {1, 2};
+ std::vector<int32_t> interior_padding = {0, 0};
+ Shape dst_shape = {3, 6};
+ std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
+ // clang-format off
+ std::vector<uint16_t> expected_dst = {0, 1, 2, 3, 0, 0,
+ 0, 4, 5, 6, 0, 0,
+ 0, 0, 0, 0, 0, 0};
+ // clang-format on
+
+ EXPECT_OK(Pad::Execute<uint16_t>(
+ src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
+ dst_shape, edge_padding_low, edge_padding_high, interior_padding));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Pad, OnlyHighPadding) {
+ Shape src_shape = {2, 3};
+ auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
+ std::vector<uint16_t> pad_value_buffer = {0};
+ std::vector<int32_t> edge_padding_low = {0, 0};
+ std::vector<int32_t> edge_padding_high = {1, 3};
+ std::vector<int32_t> interior_padding = {0, 0};
+ Shape dst_shape = {3, 6};
+ std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
+ // clang-format off
+ std::vector<uint16_t> expected_dst = {1, 2, 3, 0, 0, 0,
+ 4, 5, 6, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0};
+ // clang-format on
+
+ EXPECT_OK(Pad::Execute<uint16_t>(
+ src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
+ dst_shape, edge_padding_low, edge_padding_high, interior_padding));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Pad, OnlyLowPadding) {
+ Shape src_shape = {2, 3};
+ auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
+ std::vector<uint16_t> pad_value_buffer = {0};
+ std::vector<int32_t> edge_padding_low = {1, 3};
+ std::vector<int32_t> edge_padding_high = {0, 0};
+ std::vector<int32_t> interior_padding = {0, 0};
+ Shape dst_shape = {3, 6};
+ std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
+ // clang-format off
+ std::vector<uint16_t> expected_dst = {0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 2, 3,
+ 0, 0, 0, 4, 5, 6};
+ // clang-format on
+
+ EXPECT_OK(Pad::Execute<uint16_t>(
+ src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
+ dst_shape, edge_padding_low, edge_padding_high, interior_padding));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Pad, OnlyInteriorPadding) {
+ Shape src_shape = {2, 3};
+ auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
+ std::vector<uint16_t> pad_value_buffer = {0};
+ std::vector<int32_t> edge_padding_low = {0, 0};
+ std::vector<int32_t> edge_padding_high = {0, 0};
+ std::vector<int32_t> interior_padding = {1, 1};
+ Shape dst_shape = {3, 5};
+ std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
+ // clang-format off
+ std::vector<uint16_t> expected_dst = {1, 0, 2, 0, 3,
+ 0, 0, 0, 0, 0,
+ 4, 0, 5, 0, 6};
+ // clang-format on
+
+ EXPECT_OK(Pad::Execute<uint16_t>(
+ src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
+ dst_shape, edge_padding_low, edge_padding_high, interior_padding));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Pad, AllPaddingTypes) {
+ Shape src_shape = {2, 3};
+ auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
+ std::vector<uint16_t> pad_value_buffer = {0};
+ std::vector<int32_t> edge_padding_low = {1, 1};
+ std::vector<int32_t> edge_padding_high = {1, 2};
+ std::vector<int32_t> interior_padding = {1, 1};
+ Shape dst_shape = {5, 8};
+ std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
+ // clang-format off
+ std::vector<uint16_t> expected_dst = {0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 1, 0, 2, 0, 3, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 4, 0, 5, 0, 6, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0};
+ // clang-format on
+
+ EXPECT_OK(Pad::Execute<uint16_t>(
+ src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
+ dst_shape, edge_padding_low, edge_padding_high, interior_padding));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(Pad, HighRank) {
+ Shape src_shape = {2, 2, 2, 2};
+ auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
+ std::vector<uint16_t> pad_value_buffer = {0};
+ std::vector<int32_t> edge_padding_low = {1, 0, 0, 0};
+ std::vector<int32_t> edge_padding_high = {0, 1, 0, 0};
+ std::vector<int32_t> interior_padding = {0, 0, 1, 0};
+ Shape dst_shape = {3, 3, 3, 2};
+ std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
+ // clang-format off
+ std::vector<uint16_t> expected_dst = { 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0,
+
+ 1, 2, 0, 0, 3, 4,
+ 5, 6, 0, 0, 7, 8,
+ 0, 0, 0, 0, 0, 0,
+
+ 9, 10, 0, 0, 11, 12,
+ 13, 14, 0, 0, 15, 16,
+ 0, 0, 0, 0, 0, 0};
+ // clang-format on
+
+ ASSERT_EQ(dst_buffer.size(), expected_dst.size());
+
+ EXPECT_OK(Pad::Execute<uint16_t>(
+ src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
+ dst_shape, edge_padding_low, edge_padding_high, interior_padding));
+ EXPECT_EQ(dst_buffer, expected_dst);
+}
+
+TEST(ReduceSum, Scalar) {
+ Shape src_shape = {5};
+ int32_t dimension = 0;
+ Shape dst_shape = {1};
+ std::vector<float> src_buffer = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
+ std::vector<float> init_buffer = {0.0f};
+ std::vector<float> dst_buffer(dst_shape.element_count(), 0.0f);
+ std::vector<float> expected_dst = {5.0f};
+
+ EXPECT_OK(ReduceSum::Execute<float>(src_buffer, init_buffer,
+ absl::MakeSpan(dst_buffer), dimension,
+ src_shape, dst_shape));
+
+ for (int i = 0; i < dst_buffer.size(); ++i) {
+ EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
+ }
+}
+
+TEST(ReduceMin, TwoDimensionsToOne) {
+ Shape src_shape = {3, 3};
+ int32_t dimension = 0;
+ Shape dst_shape = {3};
+ std::vector<float> src_buffer = MakeIota<float>(src_shape.element_count());
+ std::vector<float> init_buffer = {std::numeric_limits<float>::max()};
+ std::vector<float> dst_buffer(dst_shape.element_count(), 0.0f);
+ std::vector<float> expected_dst = {1.0f, 2.0f, 3.0f};
+
+ EXPECT_OK(ReduceMin::Execute<float>(src_buffer, init_buffer,
+ absl::MakeSpan(dst_buffer), dimension,
+ src_shape, dst_shape));
+
+ for (int i = 0; i < dst_buffer.size(); ++i) {
+ EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
+ }
+}
+
+} // namespace
+} // namespace kernels
+} // namespace hal
+} // namespace iree
diff --git a/hal/interpreter/interpreter_command_processor.cc b/hal/interpreter/interpreter_command_processor.cc
new file mode 100644
index 0000000..ac8b0b1
--- /dev/null
+++ b/hal/interpreter/interpreter_command_processor.cc
@@ -0,0 +1,66 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/interpreter/interpreter_command_processor.h"
+
+#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/buffer_view.h"
+#include "hal/interpreter/bytecode_executable.h"
+#include "rt/stack.h"
+
+namespace iree {
+namespace hal {
+
+InterpreterCommandProcessor::InterpreterCommandProcessor(
+ Allocator* allocator, CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories)
+ : HostLocalCommandProcessor(allocator, mode, command_categories) {}
+
+InterpreterCommandProcessor::~InterpreterCommandProcessor() = default;
+
+Status InterpreterCommandProcessor::Dispatch(
+ const DispatchRequest& dispatch_request) {
+ IREE_TRACE_SCOPE0("InterpreterCommandProcessor::Dispatch");
+
+ // Lookup the exported function.
+ auto* executable =
+ static_cast<BytecodeExecutable*>(dispatch_request.executable);
+ const auto& module = executable->module();
+ ASSIGN_OR_RETURN(auto entry_function, module->LookupFunctionByOrdinal(
+ rt::Function::Linkage::kExport,
+ dispatch_request.entry_point));
+
+ rt::Stack stack(executable->context().get());
+
+ // TODO(benvanik): avoid this by directly referencing the bindings.
+ absl::InlinedVector<BufferView, 8> arguments;
+ arguments.reserve(dispatch_request.bindings.size());
+ for (auto& binding : dispatch_request.bindings) {
+ arguments.push_back(BufferView{add_ref(binding.buffer), binding.shape,
+ binding.element_size});
+ }
+ absl::InlinedVector<BufferView, 8> results;
+
+ RETURN_IF_ERROR(executable->module()->Execute(
+ &stack, entry_function, std::move(arguments), &results));
+
+ return OkStatus();
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/interpreter/interpreter_command_processor.h b/hal/interpreter/interpreter_command_processor.h
new file mode 100644
index 0000000..0913517
--- /dev/null
+++ b/hal/interpreter/interpreter_command_processor.h
@@ -0,0 +1,36 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_
+#define IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_
+
+#include "hal/host/host_local_command_processor.h"
+
+namespace iree {
+namespace hal {
+
+class InterpreterCommandProcessor final : public HostLocalCommandProcessor {
+ public:
+ InterpreterCommandProcessor(Allocator* allocator,
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories);
+ ~InterpreterCommandProcessor() override;
+
+ Status Dispatch(const DispatchRequest& dispatch_request) override;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_
diff --git a/hal/interpreter/interpreter_device.cc b/hal/interpreter/interpreter_device.cc
new file mode 100644
index 0000000..faab7af
--- /dev/null
+++ b/hal/interpreter/interpreter_device.cc
@@ -0,0 +1,173 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/interpreter/interpreter_device.h"
+
+#include <utility>
+
+#include "absl/memory/memory.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/command_buffer_validation.h"
+#include "hal/command_queue.h"
+#include "hal/fence.h"
+#include "hal/host/async_command_queue.h"
+#include "hal/host/host_event.h"
+#include "hal/host/host_submission_queue.h"
+#include "hal/host/inproc_command_buffer.h"
+#include "hal/interpreter/bytecode_cache.h"
+#include "hal/interpreter/interpreter_command_processor.h"
+
+namespace iree {
+namespace hal {
+
+namespace {
+
+// A CommandQueue that performs no synchronization (semaphores/fences) and just
+// directly executes command buffers inline.
+//
+// This is meant to be wrapped by SyncCommandQueue or AsyncCommandQueue that
+// themselves perform the synchronization/threading/etc. As such we ignore
+// all semaphores in the provided batches under the assumption that if Submit is
+// being called then all dependencies are valid. The wrapping queue is also
+// responsible for signaling the fence as well as propagating errors in a way
+// that is dependent on how it is performing its synchronization.
+class UnsynchronizedCommandQueue final : public CommandQueue {
+ public:
+ UnsynchronizedCommandQueue(Allocator* allocator, std::string name,
+ CommandCategoryBitfield supported_categories)
+ : CommandQueue(std::move(name), supported_categories),
+ allocator_(allocator) {}
+ ~UnsynchronizedCommandQueue() override = default;
+
+ Status Submit(absl::Span<const SubmissionBatch> batches,
+ FenceValue fence) override {
+ IREE_TRACE_SCOPE0("UnsynchronizedCommandQueue::Submit");
+ DCHECK_EQ(nullptr, fence.first)
+ << "Fences must be handled by the wrapping queue";
+
+ // Process command buffers and propagate errors asynchronously through the
+ // fence. This ensures that even if we are running synchronously we still
+ // get consistent failure behavior with drivers that are purely async.
+ for (auto& batch : batches) {
+ DCHECK(batch.wait_semaphores.empty() && batch.signal_semaphores.empty())
+ << "Semaphores must be handled by the wrapping queue";
+ RETURN_IF_ERROR(ProcessCommandBuffers(batch.command_buffers));
+ }
+
+ // NOTE: fence is ignored here.
+ return OkStatus();
+ }
+
+ Status WaitIdle(absl::Time deadline) override {
+ // No-op.
+ return OkStatus();
+ }
+
+ private:
+ // Processes each command buffer in-turn with a fresh processor.
+ // This ensures we don't have any state that can carry across buffers.
+ Status ProcessCommandBuffers(
+ absl::Span<CommandBuffer* const> command_buffers) {
+ IREE_TRACE_SCOPE0("UnsynchronizedCommandQueue::ProcessCommandBuffers");
+ for (auto* command_buffer : command_buffers) {
+ auto* inproc_command_buffer =
+ static_cast<InProcCommandBuffer*>(command_buffer->impl());
+ InterpreterCommandProcessor command_processor(
+ allocator_, command_buffer->mode(), supported_categories());
+ RETURN_IF_ERROR(inproc_command_buffer->Process(&command_processor));
+ }
+ return OkStatus();
+ }
+
+ Allocator* const allocator_;
+};
+
+} // namespace
+
+InterpreterDevice::InterpreterDevice(DeviceInfo device_info)
+ : Device(std::move(device_info)), instance_(make_ref<rt::Instance>()) {
+ // We currently only expose a single command queue.
+ auto command_queue = absl::make_unique<UnsynchronizedCommandQueue>(
+ &allocator_, "cpu0",
+ CommandCategory::kTransfer | CommandCategory::kDispatch);
+
+ // TODO(benvanik): allow injection of the wrapper type to support
+ // SyncCommandQueue without always linking in both.
+ auto async_command_queue =
+ absl::make_unique<AsyncCommandQueue>(std::move(command_queue));
+ command_queues_.push_back(std::move(async_command_queue));
+}
+
+InterpreterDevice::~InterpreterDevice() = default;
+
+std::shared_ptr<ExecutableCache> InterpreterDevice::CreateExecutableCache() {
+ return std::make_shared<BytecodeCache>(add_ref(instance_), &allocator_);
+}
+
+StatusOr<ref_ptr<CommandBuffer>> InterpreterDevice::CreateCommandBuffer(
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories) {
+ // TODO(b/140026716): conditionally enable validation.
+ auto impl =
+ make_ref<InProcCommandBuffer>(&allocator_, mode, command_categories);
+ return WrapCommandBufferWithValidation(std::move(impl));
+}
+
+StatusOr<ref_ptr<Event>> InterpreterDevice::CreateEvent() {
+ return make_ref<HostEvent>();
+}
+
+StatusOr<ref_ptr<BinarySemaphore>> InterpreterDevice::CreateBinarySemaphore(
+ bool initial_value) {
+ IREE_TRACE_SCOPE0("InterpreterDevice::CreateBinarySemaphore");
+ return make_ref<HostBinarySemaphore>(initial_value);
+}
+
+StatusOr<ref_ptr<TimelineSemaphore>> InterpreterDevice::CreateTimelineSemaphore(
+ uint64_t initial_value) {
+ IREE_TRACE_SCOPE0("InterpreterDevice::CreateTimelineSemaphore");
+
+ // TODO(b/140141417): implement timeline semaphores.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Timeline semaphores not yet implemented";
+}
+
+StatusOr<ref_ptr<Fence>> InterpreterDevice::CreateFence(
+ uint64_t initial_value) {
+ IREE_TRACE_SCOPE0("InterpreterDevice::CreateFence");
+ return make_ref<HostFence>(initial_value);
+}
+
+Status InterpreterDevice::WaitAllFences(absl::Span<const FenceValue> fences,
+ absl::Time deadline) {
+ IREE_TRACE_SCOPE0("InterpreterDevice::WaitAllFences");
+ return HostFence::WaitForFences(fences, /*wait_all=*/true, deadline);
+}
+
+StatusOr<int> InterpreterDevice::WaitAnyFence(
+ absl::Span<const FenceValue> fences, absl::Time deadline) {
+ IREE_TRACE_SCOPE0("InterpreterDevice::WaitAnyFence");
+ return HostFence::WaitForFences(fences, /*wait_all=*/false, deadline);
+}
+
+Status InterpreterDevice::WaitIdle(absl::Time deadline) {
+ for (auto& command_queue : command_queues_) {
+ RETURN_IF_ERROR(command_queue->WaitIdle(deadline));
+ }
+ return OkStatus();
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/interpreter/interpreter_device.h b/hal/interpreter/interpreter_device.h
new file mode 100644
index 0000000..c5bdbf2
--- /dev/null
+++ b/hal/interpreter/interpreter_device.h
@@ -0,0 +1,79 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_
+#define IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_
+
+#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
+#include "base/memory.h"
+#include "hal/device.h"
+#include "hal/host/host_local_allocator.h"
+#include "hal/interpreter/bytecode_kernels.h"
+#include "rt/instance.h"
+
+namespace iree {
+namespace hal {
+
+class InterpreterDevice final : public Device {
+ public:
+ explicit InterpreterDevice(DeviceInfo device_info);
+ ~InterpreterDevice() override;
+
+ kernels::RuntimeState* kernel_runtime_state() {
+ return &kernel_runtime_state_;
+ }
+
+ Allocator* allocator() const override { return &allocator_; }
+
+ absl::Span<CommandQueue*> dispatch_queues() const override {
+ return RawPtrSpan(absl::MakeSpan(command_queues_));
+ }
+
+ absl::Span<CommandQueue*> transfer_queues() const override {
+ return RawPtrSpan(absl::MakeSpan(command_queues_));
+ }
+
+ std::shared_ptr<ExecutableCache> CreateExecutableCache() override;
+
+ StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer(
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories) override;
+
+ StatusOr<ref_ptr<Event>> CreateEvent() override;
+
+ StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore(
+ bool initial_value) override;
+ StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore(
+ uint64_t initial_value) override;
+
+ StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override;
+ Status WaitAllFences(absl::Span<const FenceValue> fences,
+ absl::Time deadline) override;
+ StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences,
+ absl::Time deadline) override;
+
+ Status WaitIdle(absl::Time deadline) override;
+
+ private:
+ ref_ptr<rt::Instance> instance_;
+ kernels::RuntimeState kernel_runtime_state_;
+ mutable HostLocalAllocator allocator_;
+ mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 1> command_queues_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_
diff --git a/hal/interpreter/interpreter_driver.cc b/hal/interpreter/interpreter_driver.cc
new file mode 100644
index 0000000..f0364af
--- /dev/null
+++ b/hal/interpreter/interpreter_driver.cc
@@ -0,0 +1,62 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/interpreter/interpreter_driver.h"
+
+#include <memory>
+
+#include "hal/device_info.h"
+#include "hal/interpreter/interpreter_device.h"
+
+namespace iree {
+namespace hal {
+
+namespace {
+
+DeviceInfo GetDefaultDeviceInfo() {
+ DeviceFeatureBitfield supported_features = DeviceFeature::kNone;
+ // TODO(benvanik): implement debugging/profiling features.
+ // supported_features |= DeviceFeature::kDebugging;
+ // supported_features |= DeviceFeature::kCoverage;
+ // supported_features |= DeviceFeature::kProfiling;
+ DeviceInfo device_info("interpreter", supported_features);
+ // TODO(benvanik): device info.
+ return device_info;
+}
+
+} // namespace
+
+InterpreterDriver::InterpreterDriver() : Driver("interpreter") {}
+
+InterpreterDriver::~InterpreterDriver() = default;
+
+StatusOr<std::vector<DeviceInfo>>
+InterpreterDriver::EnumerateAvailableDevices() {
+ std::vector<DeviceInfo> device_infos;
+ device_infos.push_back(GetDefaultDeviceInfo());
+ return device_infos;
+}
+
+StatusOr<std::shared_ptr<Device>> InterpreterDriver::CreateDefaultDevice() {
+ return CreateDevice(GetDefaultDeviceInfo());
+}
+
+StatusOr<std::shared_ptr<Device>> InterpreterDriver::CreateDevice(
+ const DeviceInfo& device_info) {
+ auto device = std::make_shared<InterpreterDevice>(device_info);
+ return device;
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/interpreter/interpreter_driver.h b/hal/interpreter/interpreter_driver.h
new file mode 100644
index 0000000..5389c5a
--- /dev/null
+++ b/hal/interpreter/interpreter_driver.h
@@ -0,0 +1,39 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_
+#define IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_
+
+#include "hal/driver.h"
+
+namespace iree {
+namespace hal {
+
+class InterpreterDriver final : public Driver {
+ public:
+ InterpreterDriver();
+ ~InterpreterDriver() override;
+
+ StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override;
+
+ StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override;
+
+ StatusOr<std::shared_ptr<Device>> CreateDevice(
+ const DeviceInfo& device_info) override;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_
diff --git a/hal/interpreter/interpreter_driver_module.cc b/hal/interpreter/interpreter_driver_module.cc
new file mode 100644
index 0000000..fc40ae6
--- /dev/null
+++ b/hal/interpreter/interpreter_driver_module.cc
@@ -0,0 +1,39 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+
+#include "base/init.h"
+#include "base/status.h"
+#include "hal/driver_registry.h"
+#include "hal/interpreter/interpreter_driver.h"
+
+namespace iree {
+namespace hal {
+namespace {
+
+StatusOr<std::shared_ptr<Driver>> CreateInterpreterDriver() {
+ return std::make_shared<InterpreterDriver>();
+}
+
+} // namespace
+} // namespace hal
+} // namespace iree
+
+IREE_REGISTER_MODULE_INITIALIZER(iree_hal_interpreter_driver, {
+ QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
+ "interpreter", ::iree::hal::CreateInterpreterDriver));
+});
+IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal,
+ iree_hal_interpreter_driver);
diff --git a/hal/interpreter/interpreter_module.cc b/hal/interpreter/interpreter_module.cc
new file mode 100644
index 0000000..178d1f6
--- /dev/null
+++ b/hal/interpreter/interpreter_module.cc
@@ -0,0 +1,80 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/interpreter/interpreter_module.h"
+
+#include "base/flatbuffer_util.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/interpreter/bytecode_dispatch.h"
+#include "vm/bytecode_tables_interpreter.h"
+
+namespace iree {
+namespace hal {
+
+// static
+StatusOr<ref_ptr<rt::Module>> InterpreterModule::FromDef(
+ hal::Allocator* allocator, const ModuleDef& module_def) {
+ ASSIGN_OR_RETURN(auto module_file,
+ vm::ModuleFile::Create(&module_def, []() {}));
+ if (module_file->root() == nullptr) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "No root ModuleDef present";
+ }
+
+ auto module =
+ assign_ref(new InterpreterModule(allocator, std::move(module_file)));
+
+ // TODO(benvanik): validate internals here? or make explicit?
+
+ return {std::move(module)};
+}
+
+InterpreterModule::InterpreterModule(
+ hal::Allocator* allocator, std::unique_ptr<vm::ModuleFile> module_file)
+ : vm::BytecodeModule(std::move(module_file),
+ vm::interpreter_opcode_table()),
+ allocator_(allocator) {}
+
+Status InterpreterModule::Execute(
+ rt::Stack* stack, const rt::Function function,
+ absl::InlinedVector<hal::BufferView, 8> arguments,
+ absl::InlinedVector<hal::BufferView, 8>* results) const {
+ IREE_TRACE_SCOPE0("InterperterModule::Execute");
+
+ // Push stack frame for the function we are calling.
+ ASSIGN_OR_RETURN(auto* callee_stack_frame, stack->PushFrame(function));
+
+ // TODO(benvanik): rework register storage interface.
+ ASSIGN_OR_RETURN(const auto* function_def,
+ GetFunctionDef(function.linkage(), function.ordinal()));
+ auto* registers = callee_stack_frame->mutable_registers();
+ registers->buffer_views.resize(function_def->bytecode()->local_count());
+
+ // Marshal input arguments.
+ for (int i = 0; i < arguments.size(); ++i) {
+ registers->buffer_views[i] = std::move(arguments[i]);
+ }
+
+ // Run main dispatch loop until it exits (or errors).
+ RETURN_IF_ERROR(Dispatch(allocator_, &kernel_runtime_state_, stack,
+ callee_stack_frame, absl::MakeSpan(*results)));
+
+ // Pop the callee frame to balance out the stack.
+ RETURN_IF_ERROR(stack->PopFrame());
+
+ return OkStatus();
+}
+
+} // namespace hal
+} // namespace iree
diff --git a/hal/interpreter/interpreter_module.h b/hal/interpreter/interpreter_module.h
new file mode 100644
index 0000000..1664589
--- /dev/null
+++ b/hal/interpreter/interpreter_module.h
@@ -0,0 +1,55 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_
+#define IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_
+
+#include <memory>
+
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "hal/allocator.h"
+#include "hal/buffer_view.h"
+#include "hal/interpreter/bytecode_kernels.h"
+#include "rt/function.h"
+#include "rt/module.h"
+#include "rt/stack.h"
+#include "vm/bytecode_module.h"
+#include "vm/bytecode_tables_interpreter.h"
+
+namespace iree {
+namespace hal {
+
+class InterpreterModule final : public vm::BytecodeModule {
+ public:
+ static StatusOr<ref_ptr<rt::Module>> FromDef(hal::Allocator* allocator,
+ const ModuleDef& module_def);
+
+ Status Execute(
+ rt::Stack* stack, const rt::Function function,
+ absl::InlinedVector<hal::BufferView, 8> arguments,
+ absl::InlinedVector<hal::BufferView, 8>* results) const override;
+
+ private:
+ InterpreterModule(hal::Allocator* allocator,
+ std::unique_ptr<vm::ModuleFile> module_file);
+
+ hal::Allocator* allocator_;
+ mutable kernels::RuntimeState kernel_runtime_state_;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_
diff --git a/hal/resource.h b/hal/resource.h
new file mode 100644
index 0000000..8396f64
--- /dev/null
+++ b/hal/resource.h
@@ -0,0 +1,33 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_RESOURCE_H_
+#define IREE_HAL_RESOURCE_H_
+
+#include "base/ref_ptr.h"
+
+namespace iree {
+namespace hal {
+
+// Abstract resource type whose lifetime is managed by a ResourceSet.
+// Used mostly just to get a virtual dtor, though we could add nicer logging.
+class Resource : public RefObject<Resource> {
+ public:
+ virtual ~Resource() = default;
+};
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_RESOURCE_H_
diff --git a/hal/semaphore.h b/hal/semaphore.h
new file mode 100644
index 0000000..5ce0c25
--- /dev/null
+++ b/hal/semaphore.h
@@ -0,0 +1,61 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_SEMAPHORE_H_
+#define IREE_HAL_SEMAPHORE_H_
+
+#include "absl/types/variant.h"
+#include "hal/resource.h"
+
+namespace iree {
+namespace hal {
+
+// A synchronization primitive used to indicate submission dependencies.
+// Semaphores are either of type binary (signaled or unsignaled) or timeline
+// (uint64 payload with >= semantics).
+class Semaphore : public Resource {
+ public:
+};
+
+// Binary semaphores have strict ordering requirements and must be carefully
+// balanced. Each binary semaphore must only be waited on after a signal
+// operation has been issued and each wait requires exactly one signal. They
+// are commonly used only when interacting with external handles that may
+// cross device or process boundaries.
+class BinarySemaphore : public Semaphore {
+ public:
+};
+
+// Timeline semaphores act as a fence along a per-semaphore timeline where
+// signaling is done by setting the payload to a monotonically increasing
+// 64-bit integer and waiting is done by blocking until the payload is set
+// greater-than or equal-to the specified value. Timeline semaphores may be
+// waited on or signaled in any order and can be significantly more
+// efficient due to system-level coalescing.
+class TimelineSemaphore : public Semaphore {
+ public:
+ // TODO(benvanik): add value query support.
+ // TODO(benvanik): add host-side signal/wait.
+};
+
+// A reference to a strongly-typed semaphore and associated information.
+// For TimelineSemaphores the provided payload is used to specify either the
+// payload to wait for or new payload value.
+using SemaphoreValue =
+ absl::variant<BinarySemaphore*, std::pair<TimelineSemaphore*, uint64_t>>;
+
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_SEMAPHORE_H_
diff --git a/iree/hal/stack_trace.h b/hal/stack_trace.h
similarity index 100%
rename from iree/hal/stack_trace.h
rename to hal/stack_trace.h
diff --git a/hal/testing/BUILD b/hal/testing/BUILD
new file mode 100644
index 0000000..4f99b78
--- /dev/null
+++ b/hal/testing/BUILD
@@ -0,0 +1,36 @@
+# Test utilities for HAL-specific code.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "mock_allocator",
+ testonly = True,
+ hdrs = ["mock_allocator.h"],
+ deps = [
+ "///hal:allocator",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "mock_command_buffer",
+ testonly = True,
+ hdrs = ["mock_command_buffer.h"],
+ deps = [
+ "///hal:command_buffer",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "mock_command_queue",
+ testonly = True,
+ hdrs = ["mock_command_queue.h"],
+ deps = [
+ "///hal:command_queue",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/iree/hal/testing/CMakeLists.txt b/hal/testing/CMakeLists.txt
similarity index 100%
rename from iree/hal/testing/CMakeLists.txt
rename to hal/testing/CMakeLists.txt
diff --git a/hal/testing/mock_allocator.h b/hal/testing/mock_allocator.h
new file mode 100644
index 0000000..fa50bac
--- /dev/null
+++ b/hal/testing/mock_allocator.h
@@ -0,0 +1,55 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_TESTING_MOCK_ALLOCATOR_H_
+#define IREE_HAL_TESTING_MOCK_ALLOCATOR_H_
+
+#include "gmock/gmock.h"
+#include "hal/allocator.h"
+
+namespace iree {
+namespace hal {
+namespace testing {
+
+class MockAllocator : public ::testing::StrictMock<Allocator> {
+ public:
+ MockAllocator() : ::testing::StrictMock<Allocator>() {}
+
+ MOCK_CONST_METHOD4(CanUseBufferLike,
+ bool(Allocator* source_allocator,
+ MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ BufferUsageBitfield intended_usage));
+
+ MOCK_CONST_METHOD3(CanAllocate, bool(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size));
+
+ MOCK_METHOD3(Allocate,
+ StatusOr<ref_ptr<Buffer>>(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size));
+
+ MOCK_METHOD5(WrapMutable,
+ StatusOr<ref_ptr<Buffer>>(MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield buffer_usage,
+ void* data, size_t data_length));
+};
+
+} // namespace testing
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_TESTING_MOCK_ALLOCATOR_H_
diff --git a/hal/testing/mock_command_buffer.h b/hal/testing/mock_command_buffer.h
new file mode 100644
index 0000000..dace511
--- /dev/null
+++ b/hal/testing/mock_command_buffer.h
@@ -0,0 +1,80 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_
+#define IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_
+
+#include "gmock/gmock.h"
+#include "hal/command_buffer.h"
+
+namespace iree {
+namespace hal {
+namespace testing {
+
+class MockCommandBuffer : public ::testing::StrictMock<CommandBuffer> {
+ public:
+ MockCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories)
+ : ::testing::StrictMock<CommandBuffer>(allocator, mode,
+ command_categories) {}
+
+ bool is_recording() const override { return false; }
+
+ MOCK_METHOD0(Begin, Status());
+ MOCK_METHOD0(End, Status());
+
+ MOCK_METHOD4(ExecutionBarrier,
+ Status(ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers));
+
+ MOCK_METHOD2(SignalEvent,
+ Status(Event* event, ExecutionStageBitfield source_stage_mask));
+
+ MOCK_METHOD2(ResetEvent,
+ Status(Event* event, ExecutionStageBitfield source_stage_mask));
+
+ MOCK_METHOD5(WaitEvents,
+ Status(absl::Span<Event*> events,
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers));
+
+ MOCK_METHOD5(FillBuffer,
+ Status(Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length, const void* pattern,
+ size_t pattern_length));
+
+ MOCK_METHOD1(DiscardBuffer, Status(Buffer* buffer));
+
+ MOCK_METHOD5(UpdateBuffer,
+ Status(const void* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length));
+
+ MOCK_METHOD5(CopyBuffer,
+ Status(Buffer* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length));
+
+ MOCK_METHOD1(Dispatch, Status(const DispatchRequest& dispatch_request));
+};
+
+} // namespace testing
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_
diff --git a/hal/testing/mock_command_queue.h b/hal/testing/mock_command_queue.h
new file mode 100644
index 0000000..9195fb1
--- /dev/null
+++ b/hal/testing/mock_command_queue.h
@@ -0,0 +1,42 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_
+#define IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_
+
+#include "gmock/gmock.h"
+#include "hal/command_queue.h"
+
+namespace iree {
+namespace hal {
+namespace testing {
+
+class MockCommandQueue : public ::testing::StrictMock<CommandQueue> {
+ public:
+ MockCommandQueue(std::string name,
+ CommandCategoryBitfield supported_categories)
+ : ::testing::StrictMock<CommandQueue>(std::move(name),
+ supported_categories) {}
+
+ MOCK_METHOD2(Submit, Status(absl::Span<const SubmissionBatch> batches,
+ FenceValue fence));
+
+ MOCK_METHOD1(WaitIdle, Status(absl::Time deadline));
+};
+
+} // namespace testing
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_
diff --git a/hal/vulkan/BUILD b/hal/vulkan/BUILD
new file mode 100644
index 0000000..97880b7
--- /dev/null
+++ b/hal/vulkan/BUILD
@@ -0,0 +1,380 @@
+# HAL implementation using Vulkan and (likely) SPIR-V executables.
+
+load("//:build_defs.google.bzl", "PLATFORM_VULKAN_LOADER_COPTS", "PLATFORM_VULKAN_TEST_DEPS")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# --define=IREE_VK=native to use the native Vulkan drivers (and real hardware).
+config_setting(
+ name = "native_vk",
+ values = {
+ "define": "IREE_VK=native",
+ },
+)
+
+# --define=IREE_VK=swiftshader to use SwiftShader.
+config_setting(
+ name = "swiftshader_vk",
+ values = {
+ "define": "IREE_VK=swiftshader",
+ },
+)
+
+cc_library(
+ name = "debug_reporter",
+ srcs = ["debug_reporter.cc"],
+ hdrs = ["debug_reporter.h"],
+ deps = [
+ ":dynamic_symbols",
+ ":status_util",
+ "///base:status",
+ "///base:tracing",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "descriptor_pool_cache",
+ srcs = ["descriptor_pool_cache.cc"],
+ hdrs = ["descriptor_pool_cache.h"],
+ deps = [
+ ":dynamic_symbols",
+ ":handle_util",
+ ":status_util",
+ "///base:ref_ptr",
+ "///base:status",
+ "///base:tracing",
+ "@com_google_absl//absl/container:inlined_vector",
+ ],
+)
+
+cc_library(
+ name = "descriptor_set_arena",
+ srcs = ["descriptor_set_arena.cc"],
+ hdrs = ["descriptor_set_arena.h"],
+ deps = [
+ ":descriptor_pool_cache",
+ ":pipeline_executable",
+ ":status_util",
+ ":vma_allocator",
+ "///base:arena",
+ "///base:math",
+ "///base:status",
+ "///base:tracing",
+ "///hal:command_buffer",
+ ],
+)
+
+cc_library(
+ name = "direct_command_buffer",
+ srcs = ["direct_command_buffer.cc"],
+ hdrs = ["direct_command_buffer.h"],
+ deps = [
+ ":descriptor_pool_cache",
+ ":descriptor_set_arena",
+ ":dynamic_symbols",
+ ":handle_util",
+ ":native_event",
+ ":pipeline_executable",
+ ":status_util",
+ ":vma_allocator",
+ "///base:math",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal:command_buffer",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/synchronization",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "direct_command_queue",
+ srcs = ["direct_command_queue.cc"],
+ hdrs = ["direct_command_queue.h"],
+ deps = [
+ ":direct_command_buffer",
+ ":dynamic_symbols",
+ ":handle_util",
+ ":legacy_fence",
+ ":native_binary_semaphore",
+ ":status_util",
+ "///base:arena",
+ "///base:memory",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal:command_queue",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "dynamic_symbols",
+ srcs = ["dynamic_symbols.cc"],
+ hdrs = [
+ "dynamic_symbol_tables.h",
+ "dynamic_symbols.h",
+ ],
+ copts = PLATFORM_VULKAN_LOADER_COPTS,
+ linkopts = [
+ "-ldl",
+ ],
+ deps = [
+ "///base:file_path",
+ "///base:ref_ptr",
+ "///base:source_location",
+ "///base:status",
+ "///base:target_platform",
+ "///base:tracing",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/memory",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_test(
+ name = "dynamic_symbols_test",
+ srcs = ["dynamic_symbols_test.cc"],
+ deps = [
+ ":status_util",
+ ":dynamic_symbols",
+ "///base:status_matchers",
+ ] + PLATFORM_VULKAN_TEST_DEPS,
+)
+
+cc_library(
+ name = "extensibility_util",
+ srcs = ["extensibility_util.cc"],
+ hdrs = ["extensibility_util.h"],
+ deps = [
+ ":dynamic_symbols",
+ ":status_util",
+ "///base:memory",
+ "///base:status",
+ "///base:tracing",
+ "@com_google_absl//absl/types:span",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "handle_util",
+ hdrs = ["handle_util.h"],
+ deps = [
+ ":dynamic_symbols",
+ ":extensibility_util",
+ "///base:ref_ptr",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/utility",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "legacy_fence",
+ srcs = ["legacy_fence.cc"],
+ hdrs = ["legacy_fence.h"],
+ deps = [
+ ":handle_util",
+ ":status_util",
+ "///base:intrusive_list",
+ "///base:ref_ptr",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal:fence",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "native_binary_semaphore",
+ srcs = ["native_binary_semaphore.cc"],
+ hdrs = ["native_binary_semaphore.h"],
+ deps = [
+ ":handle_util",
+ "///hal:semaphore",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "native_event",
+ srcs = ["native_event.cc"],
+ hdrs = ["native_event.h"],
+ deps = [
+ ":handle_util",
+ "///hal:event",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "pipeline_cache",
+ srcs = ["pipeline_cache.cc"],
+ hdrs = ["pipeline_cache.h"],
+ deps = [
+ ":handle_util",
+ ":pipeline_executable",
+ ":status_util",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal:executable",
+ "///hal:executable_cache",
+ "///hal:executable_format",
+ "///schemas:spirv_executable_def_cc_fbs",
+ "@com_github_google_flatbuffers//:flatbuffers",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/synchronization",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "pipeline_executable",
+ srcs = ["pipeline_executable.cc"],
+ hdrs = ["pipeline_executable.h"],
+ deps = [
+ ":handle_util",
+ ":status_util",
+ "///base:memory",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal:executable",
+ "///hal:executable_cache",
+ "///hal:executable_spec",
+ "///schemas:spirv_executable_def_cc_fbs",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "status_util",
+ srcs = ["status_util.cc"],
+ hdrs = ["status_util.h"],
+ deps = [
+ "///base:status",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "vma_allocator",
+ srcs = [
+ "internal_vk_mem_alloc.cc",
+ "internal_vk_mem_alloc.h",
+ "vma_allocator.cc",
+ "vma_buffer.cc",
+ ],
+ hdrs = [
+ "vma_allocator.h",
+ "vma_buffer.h",
+ ],
+ copts = [
+ # Only needed in the implementation cc and not by external users.
+ "-DVMA_STATIC_VULKAN_FUNCTIONS=0",
+ ],
+ deps = [
+ ":dynamic_symbols",
+ ":handle_util",
+ ":status_util",
+ "///base:logging",
+ "///base:source_location",
+ "///base:status",
+ "///base:tracing",
+ "///hal:allocator",
+ "///hal:buffer",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/synchronization",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ "@vulkan_memory_allocator//:impl_header_only",
+ ],
+)
+
+cc_library(
+ name = "vulkan_device",
+ srcs = ["vulkan_device.cc"],
+ hdrs = ["vulkan_device.h"],
+ deps = [
+ ":descriptor_pool_cache",
+ ":direct_command_buffer",
+ ":direct_command_queue",
+ ":dynamic_symbols",
+ ":extensibility_util",
+ ":handle_util",
+ ":legacy_fence",
+ ":native_binary_semaphore",
+ ":native_event",
+ ":pipeline_cache",
+ ":status_util",
+ ":vma_allocator",
+ "///base:memory",
+ "///base:status",
+ "///base:tracing",
+ "///hal:allocator",
+ "///hal:command_buffer_validation",
+ "///hal:command_queue",
+ "///hal:device",
+ "///hal:fence",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "vulkan_driver",
+ srcs = ["vulkan_driver.cc"],
+ hdrs = ["vulkan_driver.h"],
+ deps = [
+ ":debug_reporter",
+ ":dynamic_symbols",
+ ":extensibility_util",
+ ":status_util",
+ ":vulkan_device",
+ "///base:memory",
+ "///base:status",
+ "///base:tracing",
+ "///hal:device_info",
+ "///hal:driver",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
+ name = "vulkan_driver_module",
+ srcs = ["vulkan_driver_module.cc"],
+ deps = [
+ ":dynamic_symbols",
+ ":vulkan_driver",
+ "///base:init",
+ "///base:status",
+ "///base:tracing",
+ "///hal:driver_registry",
+ "@com_google_absl//absl/flags:flag",
+ ],
+ alwayslink = 1,
+)
diff --git a/iree/hal/vulkan/CMakeLists.txt b/hal/vulkan/CMakeLists.txt
similarity index 100%
rename from iree/hal/vulkan/CMakeLists.txt
rename to hal/vulkan/CMakeLists.txt
diff --git a/hal/vulkan/debug_reporter.cc b/hal/vulkan/debug_reporter.cc
new file mode 100644
index 0000000..4558d27
--- /dev/null
+++ b/hal/vulkan/debug_reporter.cc
@@ -0,0 +1,160 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/debug_reporter.h"
+
+#include "base/tracing.h"
+#include "hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+namespace {
+
+// NOTE: |user_data| may be nullptr if we are being called during instance
+// creation. Otherwise it is a pointer to the DebugReporter instance.
+
+// NOTE: this callback must be thread safe and must be careful not to reach too
+// far outside of the call - it is called in-context from arbitrary threads with
+// some amount of Vulkan state on the stack. Assume that creating or deleting
+// Vulkan objects, issuing most Vulkan commands, etc are off-limits.
+
+VKAPI_ATTR VkBool32 VKAPI_CALL DebugUtilsMessageCallback(
+ VkDebugUtilsMessageSeverityFlagBitsEXT message_severity,
+ VkDebugUtilsMessageTypeFlagsEXT message_type,
+ const VkDebugUtilsMessengerCallbackDataEXT* callback_data,
+ void* user_data) {
+ // TODO(benvanik): better logging once we have switched logging APIs.
+ LOG(ERROR) << callback_data->pMessage;
+
+ return VK_FALSE; // VK_TRUE is reserved for future use.
+}
+
+VKAPI_ATTR VkBool32 VKAPI_CALL DebugReportCallback(
+ VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT object_type,
+ uint64_t object, size_t location, int32_t message_code,
+ const char* layer_prefix, const char* message, void* user_data) {
+ // TODO(benvanik): better logging once we have switched logging APIs.
+ LOG(ERROR) << message;
+
+ return VK_FALSE; // VK_TRUE is reserved for future use.
+}
+
+} // namespace
+
+// static
+void DebugReporter::PopulateStaticCreateInfo(
+ VkDebugUtilsMessengerCreateInfoEXT* create_info) {
+ create_info->sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
+ create_info->pNext = nullptr;
+ create_info->flags = 0;
+
+ // TODO(benvanik): only enable the severities that logging has enabled.
+ create_info->messageSeverity =
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT |
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT |
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT |
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
+
+ // TODO(benvanik): allow filtering by category as a flag.
+ create_info->messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT |
+ VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT |
+ VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
+
+ create_info->pfnUserCallback = DebugUtilsMessageCallback;
+ create_info->pUserData = nullptr;
+}
+
+// static
+void DebugReporter::PopulateStaticCreateInfo(
+ VkDebugReportCallbackCreateInfoEXT* create_info) {
+ create_info->sType = VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT;
+ create_info->pNext = nullptr;
+ create_info->flags = 0;
+
+ // TODO(benvanik): only enable the severities that logging has enabled.
+ create_info->flags |=
+ VK_DEBUG_REPORT_INFORMATION_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT |
+ VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT |
+ VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_DEBUG_BIT_EXT;
+
+ create_info->pfnCallback = DebugReportCallback;
+ create_info->pUserData = nullptr;
+}
+
+// static
+StatusOr<std::unique_ptr<DebugReporter>>
+DebugReporter::CreateDebugUtilsMessenger(
+ VkInstance instance, const ref_ptr<DynamicSymbols>& syms,
+ const VkAllocationCallbacks* allocation_callbacks) {
+ IREE_TRACE_SCOPE0("DebugReporter::CreateDebugUtilsMessenger");
+
+ auto debug_reporter =
+ absl::WrapUnique(new DebugReporter(instance, syms, allocation_callbacks));
+
+ VkDebugUtilsMessengerCreateInfoEXT create_info;
+ PopulateStaticCreateInfo(&create_info);
+ create_info.pUserData = debug_reporter.get();
+
+ VK_RETURN_IF_ERROR(syms->vkCreateDebugUtilsMessengerEXT(
+ instance, &create_info, allocation_callbacks,
+ &debug_reporter->messenger_));
+
+ return debug_reporter;
+}
+
+// static
+StatusOr<std::unique_ptr<DebugReporter>>
+DebugReporter::CreateDebugReportCallback(
+ VkInstance instance, const ref_ptr<DynamicSymbols>& syms,
+ const VkAllocationCallbacks* allocation_callbacks) {
+ IREE_TRACE_SCOPE0("DebugReporter::CreateDebugReportCallback");
+
+ auto debug_reporter =
+ absl::WrapUnique(new DebugReporter(instance, syms, allocation_callbacks));
+
+ VkDebugReportCallbackCreateInfoEXT create_info;
+ PopulateStaticCreateInfo(&create_info);
+ create_info.pUserData = debug_reporter.get();
+
+ VK_RETURN_IF_ERROR(syms->vkCreateDebugReportCallbackEXT(
+ instance, &create_info, allocation_callbacks,
+ &debug_reporter->callback_));
+
+ return debug_reporter;
+}
+
+DebugReporter::DebugReporter(VkInstance instance,
+ const ref_ptr<DynamicSymbols>& syms,
+ const VkAllocationCallbacks* allocation_callbacks)
+ : instance_(instance),
+ syms_(add_ref(syms)),
+ allocation_callbacks_(allocation_callbacks) {}
+
+DebugReporter::~DebugReporter() {
+ IREE_TRACE_SCOPE0("DebugReporter::dtor");
+ if (messenger_ != VK_NULL_HANDLE) {
+ syms_->vkDestroyDebugUtilsMessengerEXT(instance_, messenger_,
+ allocation_callbacks_);
+ }
+ if (callback_ != VK_NULL_HANDLE) {
+ syms_->vkDestroyDebugReportCallbackEXT(instance_, callback_,
+ allocation_callbacks_);
+ }
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/debug_reporter.h b/hal/vulkan/debug_reporter.h
new file mode 100644
index 0000000..2eff99d
--- /dev/null
+++ b/hal/vulkan/debug_reporter.h
@@ -0,0 +1,87 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_DEBUG_REPORTER_H_
+#define IREE_HAL_VULKAN_DEBUG_REPORTER_H_
+
+#include <vulkan/vulkan.h>
+
+#include "base/status.h"
+#include "hal/vulkan/dynamic_symbols.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// A debug reporter that works with the VK_EXT_debug_utils extension.
+// One reporter should be created per VkInstance to receive callbacks from the
+// API and route them to our logging systems. In general VK_EXT_debug_utils
+// should be preferred if available as it provides a much cleaner interface and
+// more plug-points than VK_EXT_debug_report.
+//
+// Since creating a reporter requires a VkInstance it's not possible to report
+// on messages during instance creation. To work around this it's possible to
+// pass a *CreateInfo struct to vkCreateInstance as part of the
+// VkInstanceCreateInfo::pNext chain. The callback will only be used this way
+// during the creation call after which users can create the real
+// instance-specific reporter.
+class DebugReporter final {
+ public:
+ // Populates |create_info| with an instance-agnostic callback.
+ // This can be used during instance creation by chaining the |create_info| to
+ // VkInstanceCreateInfo::pNext.
+ //
+ // Only use if VK_EXT_debug_utils is present.
+ static void PopulateStaticCreateInfo(
+ VkDebugUtilsMessengerCreateInfoEXT* create_info);
+
+ // Populates |create_info| with an instance-agnostic callback.
+ // This can be used during instance creation by chaining the |create_info| to
+ // VkInstanceCreateInfo::pNext.
+ //
+ // Only use if VK_EXT_debug_report is present.
+ static void PopulateStaticCreateInfo(
+ VkDebugReportCallbackCreateInfoEXT* create_info);
+
+ // Creates a debug messenger for the given Vulkan |instance| with
+ // VK_EXT_debug_utils enabled.
+ static StatusOr<std::unique_ptr<DebugReporter>> CreateDebugUtilsMessenger(
+ VkInstance instance, const ref_ptr<DynamicSymbols>& syms,
+ const VkAllocationCallbacks* allocation_callbacks);
+
+ // Creates a debug report callback for the given Vulkan |instance| with
+ // VK_EXT_debug_report enabled.
+ static StatusOr<std::unique_ptr<DebugReporter>> CreateDebugReportCallback(
+ VkInstance instance, const ref_ptr<DynamicSymbols>& syms,
+ const VkAllocationCallbacks* allocation_callbacks);
+
+ ~DebugReporter();
+
+ private:
+ DebugReporter(VkInstance instance, const ref_ptr<DynamicSymbols>& syms,
+ const VkAllocationCallbacks* allocation_callbacks);
+
+ VkInstance instance_ = VK_NULL_HANDLE;
+ ref_ptr<DynamicSymbols> syms_;
+ const VkAllocationCallbacks* allocation_callbacks_ = nullptr;
+
+ VkDebugUtilsMessengerEXT messenger_ = VK_NULL_HANDLE;
+ VkDebugReportCallbackEXT callback_ = VK_NULL_HANDLE;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_DEBUG_REPORTER_H_
diff --git a/hal/vulkan/descriptor_pool_cache.cc b/hal/vulkan/descriptor_pool_cache.cc
new file mode 100644
index 0000000..453fe78
--- /dev/null
+++ b/hal/vulkan/descriptor_pool_cache.cc
@@ -0,0 +1,102 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/descriptor_pool_cache.h"
+
+#include <array>
+
+#include "base/tracing.h"
+#include "hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+namespace {
+
+// TODO(benvanik): be more conservative with descriptor set count or allow
+// chaining in the command buffer when pools run out.
+static constexpr int kMaxDescriptorSets = 4096;
+
+} // namespace
+
+DescriptorSetGroup::~DescriptorSetGroup() {
+ CHECK(descriptor_pools_.empty())
+ << "DescriptorSetGroup must be reset explicitly";
+}
+
+Status DescriptorSetGroup::Reset() {
+ IREE_TRACE_SCOPE0("DescriptorSetGroup::Reset");
+
+ RETURN_IF_ERROR(descriptor_pool_cache_->ReleaseDescriptorPools(
+ absl::MakeSpan(descriptor_pools_)));
+ descriptor_pools_.clear();
+
+ return OkStatus();
+}
+
+DescriptorPoolCache::DescriptorPoolCache(ref_ptr<VkDeviceHandle> logical_device)
+ : logical_device_(std::move(logical_device)) {}
+
+StatusOr<DescriptorPool> DescriptorPoolCache::AcquireDescriptorPool(
+ VkDescriptorType descriptor_type, int max_descriptor_count) {
+ IREE_TRACE_SCOPE0("DescriptorPoolCache::AcquireDescriptorPool");
+
+ // TODO(benvanik): lookup in cache.
+
+ VkDescriptorPoolCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = 0;
+ create_info.maxSets = kMaxDescriptorSets;
+ std::array<VkDescriptorPoolSize, 1> pool_sizes;
+ pool_sizes[0].type = descriptor_type;
+ pool_sizes[0].descriptorCount = max_descriptor_count;
+ create_info.poolSizeCount = pool_sizes.size();
+ create_info.pPoolSizes = pool_sizes.data();
+
+ DescriptorPool descriptor_pool;
+ descriptor_pool.descriptor_type = descriptor_type;
+ descriptor_pool.max_descriptor_count = max_descriptor_count;
+ descriptor_pool.handle = VK_NULL_HANDLE;
+
+ VK_RETURN_IF_ERROR(syms().vkCreateDescriptorPool(
+ *logical_device_, &create_info, logical_device_->allocator(),
+ &descriptor_pool.handle));
+
+ return descriptor_pool;
+}
+
+Status DescriptorPoolCache::ReleaseDescriptorPools(
+ absl::Span<DescriptorPool> descriptor_pools) {
+ IREE_TRACE_SCOPE0("DescriptorPoolCache::ReleaseDescriptorPools");
+
+ for (const auto& descriptor_pool : descriptor_pools) {
+ // Always reset immediately. We could do this on allocation instead however
+ // this leads to better errors when using the validation layers as we'll
+ // throw if there are in-flight command buffers using the sets in the pool.
+ VK_RETURN_IF_ERROR(syms().vkResetDescriptorPool(*logical_device_,
+ descriptor_pool.handle, 0));
+
+ // TODO(benvanik): release to cache.
+ syms().vkDestroyDescriptorPool(*logical_device_, descriptor_pool.handle,
+ logical_device_->allocator());
+ }
+
+ return OkStatus();
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/descriptor_pool_cache.h b/hal/vulkan/descriptor_pool_cache.h
new file mode 100644
index 0000000..8e54112
--- /dev/null
+++ b/hal/vulkan/descriptor_pool_cache.h
@@ -0,0 +1,103 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_
+#define IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_
+
+#include "absl/container/inlined_vector.h"
+#include "base/ref_ptr.h"
+#include "hal/vulkan/dynamic_symbols.h"
+#include "hal/vulkan/handle_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+class DescriptorPoolCache;
+
+// A descriptor pool with a single descriptor type of some number.
+// We only support a single descriptor type for now as we only generate SPIR-V
+// that uses a single type.
+struct DescriptorPool {
+ // Type of the descriptor in the set.
+ VkDescriptorType descriptor_type = VK_DESCRIPTOR_TYPE_MAX_ENUM;
+ // Maximum number of descriptors of the given type per allocation.
+ int max_descriptor_count = 0;
+ // Pool handle.
+ VkDescriptorPool handle = VK_NULL_HANDLE;
+};
+
+// A group of descriptor sets allocated and released together.
+// The group must be explicitly reset with Reset() prior to disposing.
+class DescriptorSetGroup final {
+ public:
+ DescriptorSetGroup() = default;
+ DescriptorSetGroup(ref_ptr<DescriptorPoolCache> descriptor_pool_cache,
+ absl::InlinedVector<DescriptorPool, 8> descriptor_pools)
+ : descriptor_pool_cache_(std::move(descriptor_pool_cache)),
+ descriptor_pools_(std::move(descriptor_pools)) {}
+ DescriptorSetGroup(const DescriptorSetGroup&) = delete;
+ DescriptorSetGroup& operator=(const DescriptorSetGroup&) = delete;
+ DescriptorSetGroup(DescriptorSetGroup&& other) noexcept
+ : descriptor_pool_cache_(std::move(other.descriptor_pool_cache_)),
+ descriptor_pools_(std::move(other.descriptor_pools_)) {}
+ DescriptorSetGroup& operator=(DescriptorSetGroup&& other) {
+ std::swap(descriptor_pool_cache_, other.descriptor_pool_cache_);
+ std::swap(descriptor_pools_, other.descriptor_pools_);
+ return *this;
+ }
+ ~DescriptorSetGroup();
+
+ Status Reset();
+
+ private:
+ ref_ptr<DescriptorPoolCache> descriptor_pool_cache_;
+ absl::InlinedVector<DescriptorPool, 8> descriptor_pools_;
+};
+
+// A "cache" (or really, pool) of descriptor pools. These pools are allocated
+// as needed to satisfy different descriptor size requirements and are given
+// to command buffers during recording to write descriptor updates and bind
+// resources. After the descriptors in the pool are no longer used (all
+// command buffers using descriptor sets allocated from the pool have retired)
+// the pool is returned here to be reused in the future.
+class DescriptorPoolCache final : public RefObject<DescriptorPoolCache> {
+ public:
+ explicit DescriptorPoolCache(ref_ptr<VkDeviceHandle> logical_device);
+
+ const ref_ptr<VkDeviceHandle>& logical_device() const {
+ return logical_device_;
+ }
+ const DynamicSymbols& syms() const { return *logical_device_->syms(); }
+
+ // Acquires a new descriptor pool for use by the caller.
+ // The pool will have been reset and have all descriptor sets available.
+ // When all sets allocated from the pool are no longer in use it must be
+ // returned to the cache with ReleaseDescriptorPool.
+ StatusOr<DescriptorPool> AcquireDescriptorPool(
+ VkDescriptorType descriptor_type, int max_descriptor_count);
+
+ // Releases descriptor pools back to the cache. The pools will be reset
+ // immediately and must no longer be in use by any in-flight command.
+ Status ReleaseDescriptorPools(absl::Span<DescriptorPool> descriptor_pools);
+
+ private:
+ ref_ptr<VkDeviceHandle> logical_device_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_
diff --git a/hal/vulkan/descriptor_set_arena.cc b/hal/vulkan/descriptor_set_arena.cc
new file mode 100644
index 0000000..b31b583
--- /dev/null
+++ b/hal/vulkan/descriptor_set_arena.cc
@@ -0,0 +1,204 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/descriptor_set_arena.h"
+
+#include "base/math.h"
+#include "base/tracing.h"
+#include "hal/vulkan/status_util.h"
+#include "hal/vulkan/vma_buffer.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+namespace {
+
+StatusOr<VmaBuffer*> CastBuffer(Buffer* buffer) {
+ // TODO(benvanik): assert that the buffer is from the right allocator and
+ // that it is compatible with our target queue family.
+ return static_cast<VmaBuffer*>(buffer->allocated_buffer());
+}
+
+StatusOr<absl::Span<VkWriteDescriptorSet>> PopulateDescriptorSetWriteInfos(
+ const PipelineDescriptorSets& pipeline_descriptor_sets,
+ absl::Span<const BufferBinding> bindings, VkDescriptorSet dst_set,
+ Arena* arena) {
+ int required_descriptor_count =
+ pipeline_descriptor_sets.buffer_binding_set_map.size();
+
+ arena->Reset();
+ auto buffer_infos =
+ arena->AllocateSpan<VkDescriptorBufferInfo>(required_descriptor_count);
+ auto write_infos = arena->AllocateSpan<VkWriteDescriptorSet>(bindings.size());
+
+ for (int i = 0; i < bindings.size(); ++i) {
+ const auto& binding = bindings[i];
+
+ auto& buffer_info = buffer_infos[i];
+ ASSIGN_OR_RETURN(auto buffer, CastBuffer(binding.buffer));
+ buffer_info.buffer = buffer->handle();
+ // TODO(benvanik): properly subrange (add to BufferBinding).
+ buffer_info.offset = binding.buffer->byte_offset();
+ buffer_info.range = binding.buffer->byte_length();
+
+ auto& write_info = write_infos[i];
+ write_info.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
+ write_info.pNext = nullptr;
+ write_info.dstSet = dst_set;
+ write_info.dstBinding = pipeline_descriptor_sets.buffer_binding_set_map[i];
+ write_info.dstArrayElement = 0;
+ write_info.descriptorCount = 1;
+ write_info.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+ write_info.pImageInfo = nullptr;
+ write_info.pBufferInfo = &buffer_info;
+ write_info.pTexelBufferView = nullptr;
+ }
+
+ return write_infos;
+}
+
+} // namespace
+
+DescriptorSetArena::DescriptorSetArena(
+ ref_ptr<DescriptorPoolCache> descriptor_pool_cache)
+ : logical_device_(add_ref(descriptor_pool_cache->logical_device())),
+ descriptor_pool_cache_(std::move(descriptor_pool_cache)) {}
+
+DescriptorSetArena::~DescriptorSetArena() {
+ if (!used_descriptor_pools_.empty()) {
+ descriptor_pool_cache_
+ ->ReleaseDescriptorPools(absl::MakeSpan(used_descriptor_pools_))
+ .IgnoreError();
+ used_descriptor_pools_.clear();
+ }
+}
+
+Status DescriptorSetArena::BindDescriptorSet(
+ VkCommandBuffer command_buffer, PipelineExecutable* executable,
+ absl::Span<const BufferBinding> bindings) {
+ // Always prefer using push descriptors when available as we can avoid the
+ // additional API overhead of updating/resetting pools.
+ if (logical_device_->enabled_extensions().push_descriptors) {
+ return PushDescriptorSet(command_buffer, executable, bindings);
+ }
+
+ IREE_TRACE_SCOPE0("DescriptorSetArena::BindDescriptorSet");
+
+ // Pick a bucket based on the number of descriptors required.
+ // NOTE: right now we are 1:1 with bindings.
+ int required_descriptor_count = bindings.size() * 1;
+ int max_descriptor_count =
+ std::max(8, RoundUpToNearestPow2(required_descriptor_count));
+ int bucket = TrailingZeros(max_descriptor_count >> 3);
+ if (bucket >= descriptor_pool_buckets_.size()) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Too many descriptors required: " << required_descriptor_count
+ << " (max=" << (1 << (descriptor_pool_buckets_.size() + 3)) << ")";
+ }
+ if (descriptor_pool_buckets_[bucket].handle == VK_NULL_HANDLE) {
+ // Acquire a pool for this max_descriptor_count bucket.
+ ASSIGN_OR_RETURN(
+ descriptor_pool_buckets_[bucket],
+ descriptor_pool_cache_->AcquireDescriptorPool(
+ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, max_descriptor_count));
+ used_descriptor_pools_.push_back(descriptor_pool_buckets_[bucket]);
+ }
+ auto& descriptor_pool = descriptor_pool_buckets_[bucket];
+
+ const auto& pipeline_descriptor_sets = executable->descriptor_sets();
+
+ VkDescriptorSetAllocateInfo allocate_info;
+ allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
+ allocate_info.pNext = nullptr;
+ allocate_info.descriptorPool = descriptor_pool.handle;
+ allocate_info.descriptorSetCount = 1;
+ allocate_info.pSetLayouts =
+ &pipeline_descriptor_sets.buffer_binding_set_layout;
+ VkDescriptorSet descriptor_set = VK_NULL_HANDLE;
+ VkResult result = syms().vkAllocateDescriptorSets(
+ *logical_device_, &allocate_info, &descriptor_set);
+ if (result == VK_ERROR_OUT_OF_POOL_MEMORY) {
+ // Allocation failed because the pool is either out of descriptors or too
+ // fragmented. We'll just allocate another pool.
+ ASSIGN_OR_RETURN(
+ descriptor_pool_buckets_[bucket],
+ descriptor_pool_cache_->AcquireDescriptorPool(
+ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, max_descriptor_count));
+ used_descriptor_pools_.push_back(descriptor_pool_buckets_[bucket]);
+ }
+
+ // Get a list of VkWriteDescriptorSet structs with all bound buffers.
+ ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos(
+ pipeline_descriptor_sets, bindings,
+ descriptor_set, &scratch_arena_));
+
+ // This is the reason why push descriptor sets are good.
+ // We can't batch these effectively as we don't know prior to recording what
+ // descriptor sets we will need and what buffers they will point to (without
+ // doing just as much work as actually recording the buffer to try to find
+ // out).
+ syms().vkUpdateDescriptorSets(*logical_device_, write_infos.size(),
+ write_infos.data(), 0, nullptr);
+
+ // Bind the descriptor set.
+ syms().vkCmdBindDescriptorSets(command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE,
+ executable->pipeline_layout(),
+ pipeline_descriptor_sets.buffer_binding_set, 1,
+ &descriptor_set, 0, nullptr);
+
+ return OkStatus();
+}
+
+Status DescriptorSetArena::PushDescriptorSet(
+ VkCommandBuffer command_buffer, PipelineExecutable* executable,
+ absl::Span<const BufferBinding> bindings) {
+ IREE_TRACE_SCOPE0("DescriptorSetArena::PushDescriptorSet");
+
+ const auto& pipeline_descriptor_sets = executable->descriptor_sets();
+
+ // Get a list of VkWriteDescriptorSet structs with all bound buffers.
+ ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos(
+ pipeline_descriptor_sets, bindings,
+ VK_NULL_HANDLE, &scratch_arena_));
+
+ // Fast path using push descriptors. These are pooled internally by the
+ // command buffer and prevent the need for our own pooling mechanisms.
+ syms().vkCmdPushDescriptorSetKHR(command_buffer,
+ VK_PIPELINE_BIND_POINT_COMPUTE,
+ executable->pipeline_layout(),
+ pipeline_descriptor_sets.buffer_binding_set,
+ write_infos.size(), write_infos.data());
+
+ return OkStatus();
+}
+
+StatusOr<DescriptorSetGroup> DescriptorSetArena::Flush() {
+ IREE_TRACE_SCOPE0("DescriptorSetArena::Flush");
+
+ if (used_descriptor_pools_.empty()) {
+ // No resources to free.
+ return DescriptorSetGroup{};
+ }
+
+ for (auto& bucket : descriptor_pool_buckets_) {
+ bucket = {};
+ }
+ return DescriptorSetGroup(add_ref(descriptor_pool_cache_),
+ std::move(used_descriptor_pools_));
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/descriptor_set_arena.h b/hal/vulkan/descriptor_set_arena.h
new file mode 100644
index 0000000..5e07268
--- /dev/null
+++ b/hal/vulkan/descriptor_set_arena.h
@@ -0,0 +1,76 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_
+#define IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_
+
+#include <array>
+#include <vector>
+
+#include "base/arena.h"
+#include "base/status.h"
+#include "hal/command_buffer.h"
+#include "hal/vulkan/descriptor_pool_cache.h"
+#include "hal/vulkan/pipeline_executable.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// A reusable arena for allocating descriptor sets and batching updates.
+class DescriptorSetArena final {
+ public:
+ explicit DescriptorSetArena(
+ ref_ptr<DescriptorPoolCache> descriptor_pool_cache);
+ ~DescriptorSetArena();
+
+ // Allocates and binds a descriptor set from the arena.
+ // The command buffer will have the descriptor set containing |bindings| bound
+ // to it.
+ Status BindDescriptorSet(VkCommandBuffer command_buffer,
+ PipelineExecutable* executable,
+ absl::Span<const BufferBinding> bindings);
+
+ // Flushes all pending writes to descriptor sets allocated from the arena and
+ // returns a group that - when dropped - will release the descriptor sets
+ // back to the pools they were allocated from.
+ StatusOr<DescriptorSetGroup> Flush();
+
+ private:
+ const DynamicSymbols& syms() const { return *logical_device_->syms(); }
+
+ // Pushes the descriptor set to the command buffer, if supported.
+ Status PushDescriptorSet(VkCommandBuffer command_buffer,
+ PipelineExecutable* executable,
+ absl::Span<const BufferBinding> bindings);
+
+ ref_ptr<VkDeviceHandle> logical_device_;
+ ref_ptr<DescriptorPoolCache> descriptor_pool_cache_;
+
+ // Arena used for temporary binding information used during allocation.
+ Arena scratch_arena_;
+
+ // A list of pools acquired on demand as different descriptor counts are
+ // needed. Allocation granularity is max_descriptor_count=[8, 16, 32, 64].
+ std::array<DescriptorPool, 4> descriptor_pool_buckets_;
+
+ // All pools that have been used during allocation.
+ absl::InlinedVector<DescriptorPool, 8> used_descriptor_pools_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_
diff --git a/hal/vulkan/direct_command_buffer.cc b/hal/vulkan/direct_command_buffer.cc
new file mode 100644
index 0000000..13398f7
--- /dev/null
+++ b/hal/vulkan/direct_command_buffer.cc
@@ -0,0 +1,403 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/direct_command_buffer.h"
+
+#include "absl/base/attributes.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/synchronization/mutex.h"
+#include "base/math.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+namespace {
+
+VkPipelineStageFlags ConvertPipelineStageFlags(
+ ExecutionStageBitfield stage_mask) {
+ VkPipelineStageFlags flags = 0;
+ flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandIssue)
+ ? VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT
+ : 0;
+ flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandProcess)
+ ? VK_PIPELINE_STAGE_DRAW_INDIRECT_BIT
+ : 0;
+ flags |= AnyBitSet(stage_mask & ExecutionStage::kDispatch)
+ ? VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT
+ : 0;
+ flags |= AnyBitSet(stage_mask & ExecutionStage::kTransfer)
+ ? VK_PIPELINE_STAGE_TRANSFER_BIT
+ : 0;
+ flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandRetire)
+ ? VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT
+ : 0;
+ flags |= AnyBitSet(stage_mask & ExecutionStage::kHost)
+ ? VK_PIPELINE_STAGE_HOST_BIT
+ : 0;
+ return flags;
+}
+
+VkAccessFlags ConvertAccessMask(AccessScopeBitfield access_mask) {
+ VkAccessFlags flags = 0;
+ flags |= AnyBitSet(access_mask & AccessScope::kIndirectCommandRead)
+ ? VK_ACCESS_INDIRECT_COMMAND_READ_BIT
+ : 0;
+ flags |= AnyBitSet(access_mask & AccessScope::kConstantRead)
+ ? VK_ACCESS_UNIFORM_READ_BIT
+ : 0;
+ flags |= AnyBitSet(access_mask & AccessScope::kDispatchRead)
+ ? VK_ACCESS_SHADER_READ_BIT
+ : 0;
+ flags |= AnyBitSet(access_mask & AccessScope::kDispatchWrite)
+ ? VK_ACCESS_SHADER_WRITE_BIT
+ : 0;
+ flags |= AnyBitSet(access_mask & AccessScope::kTransferRead)
+ ? VK_ACCESS_TRANSFER_READ_BIT
+ : 0;
+ flags |= AnyBitSet(access_mask & AccessScope::kTransferWrite)
+ ? VK_ACCESS_TRANSFER_WRITE_BIT
+ : 0;
+ flags |= AnyBitSet(access_mask & AccessScope::kHostRead)
+ ? VK_ACCESS_HOST_READ_BIT
+ : 0;
+ flags |= AnyBitSet(access_mask & AccessScope::kHostWrite)
+ ? VK_ACCESS_HOST_WRITE_BIT
+ : 0;
+ flags |= AnyBitSet(access_mask & AccessScope::kMemoryRead)
+ ? VK_ACCESS_MEMORY_READ_BIT
+ : 0;
+ flags |= AnyBitSet(access_mask & AccessScope::kMemoryWrite)
+ ? VK_ACCESS_MEMORY_WRITE_BIT
+ : 0;
+ return flags;
+}
+
+// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value.
+uint32_t SplatPattern(const void* pattern, size_t pattern_length) {
+ switch (pattern_length) {
+ case 1: {
+ uint32_t pattern_value = *static_cast<const uint8_t*>(pattern);
+ return (pattern_value << 24) | (pattern_value << 16) |
+ (pattern_value << 8) | pattern_value;
+ }
+ case 2: {
+ uint32_t pattern_value = *static_cast<const uint16_t*>(pattern);
+ return (pattern_value << 16) | pattern_value;
+ }
+ case 4: {
+ uint32_t pattern_value = *static_cast<const uint32_t*>(pattern);
+ return pattern_value;
+ }
+ default:
+ return 0; // Already verified that this should not be possible.
+ }
+}
+
+} // namespace
+
+DirectCommandBuffer::DirectCommandBuffer(
+ Allocator* allocator, CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories,
+ ref_ptr<DescriptorPoolCache> descriptor_pool_cache,
+ ref_ptr<VkCommandPoolHandle> command_pool, VkCommandBuffer command_buffer)
+ : CommandBuffer(allocator, mode, command_categories),
+ command_pool_(std::move(command_pool)),
+ command_buffer_(command_buffer),
+ descriptor_set_arena_(std::move(descriptor_pool_cache)) {}
+
+DirectCommandBuffer::~DirectCommandBuffer() {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::dtor");
+ descriptor_set_group_.Reset().IgnoreError();
+ absl::MutexLock lock(command_pool_->mutex());
+ syms()->vkFreeCommandBuffers(*command_pool_->logical_device(), *command_pool_,
+ 1, &command_buffer_);
+}
+
+StatusOr<NativeEvent*> DirectCommandBuffer::CastEvent(Event* event) const {
+ // TODO(benvanik): assert the event is valid.
+ return static_cast<NativeEvent*>(event);
+}
+
+StatusOr<VmaBuffer*> DirectCommandBuffer::CastBuffer(Buffer* buffer) const {
+ // TODO(benvanik): assert that the buffer is from the right allocator and
+ // that it is compatible with our target queue family.
+ return static_cast<VmaBuffer*>(buffer->allocated_buffer());
+}
+
+Status DirectCommandBuffer::Begin() {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::Begin");
+
+ is_recording_ = true;
+
+ // NOTE: we require that command buffers not be recorded while they are
+ // in-flight so this is safe.
+ RETURN_IF_ERROR(descriptor_set_group_.Reset());
+
+ VkCommandBufferBeginInfo begin_info;
+ begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+ begin_info.pNext = nullptr;
+ begin_info.flags = AllBitsSet(mode(), CommandBufferMode::kOneShot)
+ ? VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT
+ : 0;
+ begin_info.pInheritanceInfo = nullptr;
+ VK_RETURN_IF_ERROR(
+ syms()->vkBeginCommandBuffer(command_buffer_, &begin_info));
+
+ return OkStatus();
+}
+
+Status DirectCommandBuffer::End() {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::End");
+
+ VK_RETURN_IF_ERROR(syms()->vkEndCommandBuffer(command_buffer_));
+
+ // Flush all pending descriptor set writes (if any).
+ ASSIGN_OR_RETURN(descriptor_set_group_, descriptor_set_arena_.Flush());
+
+ is_recording_ = false;
+
+ return OkStatus();
+}
+
+Status DirectCommandBuffer::ExecutionBarrier(
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::ExecutionBarrier");
+
+ absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos(
+ memory_barriers.size());
+ for (int i = 0; i < memory_barriers.size(); ++i) {
+ const auto& memory_barrier = memory_barriers[i];
+ auto& info = memory_barrier_infos[i];
+ info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
+ info.pNext = nullptr;
+ info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope);
+ info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope);
+ }
+
+ absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos(
+ buffer_barriers.size());
+ for (int i = 0; i < buffer_barriers.size(); ++i) {
+ const auto& buffer_barrier = buffer_barriers[i];
+ auto& info = buffer_barrier_infos[i];
+ info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER;
+ info.pNext = nullptr;
+ info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope);
+ info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope);
+ info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
+ info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
+ ASSIGN_OR_RETURN(auto* device_buffer, CastBuffer(buffer_barrier.buffer));
+ info.buffer = device_buffer->handle();
+ info.offset = buffer_barrier.offset;
+ info.size = buffer_barrier.length;
+ }
+
+ syms()->vkCmdPipelineBarrier(
+ command_buffer_, ConvertPipelineStageFlags(source_stage_mask),
+ ConvertPipelineStageFlags(target_stage_mask), /*dependencyFlags=*/0,
+ memory_barrier_infos.size(), memory_barrier_infos.data(),
+ buffer_barrier_infos.size(), buffer_barrier_infos.data(), 0, nullptr);
+
+ return OkStatus();
+}
+
+Status DirectCommandBuffer::SignalEvent(
+ Event* event, ExecutionStageBitfield source_stage_mask) {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::SignalEvent");
+ ASSIGN_OR_RETURN(auto* device_event, CastEvent(event));
+ syms()->vkCmdSetEvent(command_buffer_, device_event->handle(),
+ ConvertPipelineStageFlags(source_stage_mask));
+ return OkStatus();
+}
+
+Status DirectCommandBuffer::ResetEvent(
+ Event* event, ExecutionStageBitfield source_stage_mask) {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::ResetEvent");
+ ASSIGN_OR_RETURN(auto* device_event, CastEvent(event));
+ syms()->vkCmdResetEvent(command_buffer_, device_event->handle(),
+ ConvertPipelineStageFlags(source_stage_mask));
+ return OkStatus();
+}
+
+Status DirectCommandBuffer::WaitEvents(
+ absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::WaitEvents");
+
+ absl::InlinedVector<VkEvent, 4> event_handles(events.size());
+ for (int i = 0; i < events.size(); ++i) {
+ ASSIGN_OR_RETURN(auto* device_event, CastEvent(events[i]));
+ event_handles[i] = device_event->handle();
+ }
+
+ absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos(
+ memory_barriers.size());
+ for (int i = 0; i < memory_barriers.size(); ++i) {
+ const auto& memory_barrier = memory_barriers[i];
+ auto& info = memory_barrier_infos[i];
+ info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
+ info.pNext = nullptr;
+ info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope);
+ info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope);
+ }
+
+ absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos(
+ buffer_barriers.size());
+ for (int i = 0; i < buffer_barriers.size(); ++i) {
+ const auto& buffer_barrier = buffer_barriers[i];
+ auto& info = buffer_barrier_infos[i];
+ info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER;
+ info.pNext = nullptr;
+ info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope);
+ info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope);
+ info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
+ info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
+ ASSIGN_OR_RETURN(auto* device_buffer, CastBuffer(buffer_barrier.buffer));
+ info.buffer = device_buffer->handle();
+ info.offset = buffer_barrier.offset;
+ info.size = buffer_barrier.length;
+ }
+
+ syms()->vkCmdWaitEvents(
+ command_buffer_, event_handles.size(), event_handles.data(),
+ ConvertPipelineStageFlags(source_stage_mask),
+ ConvertPipelineStageFlags(target_stage_mask), memory_barrier_infos.size(),
+ memory_barrier_infos.data(), buffer_barrier_infos.size(),
+ buffer_barrier_infos.data(), 0, nullptr);
+ return OkStatus();
+}
+
+Status DirectCommandBuffer::FillBuffer(Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length,
+ const void* pattern,
+ size_t pattern_length) {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::FillBuffer");
+ ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer));
+
+ // Note that fill only accepts 4-byte aligned values so we need to splat out
+ // our variable-length pattern.
+ target_offset += target_buffer->byte_offset();
+ uint32_t dword_pattern = SplatPattern(pattern, pattern_length);
+ syms()->vkCmdFillBuffer(command_buffer_, target_device_buffer->handle(),
+ target_offset, length, dword_pattern);
+
+ return OkStatus();
+}
+
+Status DirectCommandBuffer::DiscardBuffer(Buffer* buffer) {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::DiscardBuffer");
+ // NOTE: we could use this to prevent queue family transitions.
+ return OkStatus();
+}
+
+Status DirectCommandBuffer::UpdateBuffer(const void* source_buffer,
+ device_size_t source_offset,
+ Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length) {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::UpdateBuffer");
+ ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer));
+
+ // Vulkan only allows updates of <= 65536 because you really, really, really
+ // shouldn't do large updates like this (as it wastes command buffer space and
+ // may be slower than just using write-through mapped memory). The
+ // recommendation in the spec for larger updates is to split the single update
+ // into multiple updates over the entire desired range.
+ const auto* source_buffer_ptr = static_cast<const uint8_t*>(source_buffer);
+ target_offset += target_buffer->byte_offset();
+ while (length > 0) {
+ device_size_t chunk_length =
+ std::min(static_cast<device_size_t>(65536u), length);
+ syms()->vkCmdUpdateBuffer(command_buffer_, target_device_buffer->handle(),
+ target_offset, chunk_length, source_buffer_ptr);
+ source_buffer_ptr += chunk_length;
+ target_offset += chunk_length;
+ length -= chunk_length;
+ }
+
+ return OkStatus();
+}
+
+Status DirectCommandBuffer::CopyBuffer(Buffer* source_buffer,
+ device_size_t source_offset,
+ Buffer* target_buffer,
+ device_size_t target_offset,
+ device_size_t length) {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::CopyBuffer");
+ ASSIGN_OR_RETURN(auto* source_device_buffer, CastBuffer(source_buffer));
+ ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer));
+
+ VkBufferCopy region;
+ region.srcOffset = source_buffer->byte_offset() + source_offset;
+ region.dstOffset = target_buffer->byte_offset() + target_offset;
+ region.size = length;
+ syms()->vkCmdCopyBuffer(command_buffer_, source_device_buffer->handle(),
+ target_device_buffer->handle(), 1, ®ion);
+
+ return OkStatus();
+}
+
+Status DirectCommandBuffer::Dispatch(const DispatchRequest& dispatch_request) {
+ IREE_TRACE_SCOPE0("DirectCommandBuffer::Dispatch");
+
+ // Get the compiled and linked pipeline for the specified entry point and
+ // bind it to the command buffer.
+ auto* executable =
+ static_cast<PipelineExecutable*>(dispatch_request.executable);
+ ASSIGN_OR_RETURN(VkPipeline pipeline, executable->GetPipelineForEntryPoint(
+ dispatch_request.entry_point));
+ syms()->vkCmdBindPipeline(command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
+ pipeline);
+
+ // Either allocate, update, and bind a descriptor set or use push descriptor
+ // sets to use the command buffer pool when supported.
+ RETURN_IF_ERROR(descriptor_set_arena_.BindDescriptorSet(
+ command_buffer_, executable, dispatch_request.bindings));
+
+ // TODO(benvanik): divide workload by caps and issue multiple dispatches.
+ // TODO(benvanik): track local workgroup/subgroup size and divide into groups.
+ if (dispatch_request.workload_buffer) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Dynamic dispatches not yet implemented";
+ }
+ uint32_t group_count_x = dispatch_request.workload[0];
+ uint32_t group_count_y = dispatch_request.workload[1];
+ uint32_t group_count_z = dispatch_request.workload[2];
+
+ // TODO(GH-67): pre-divide workload by tile size.
+ if (executable->is_matmul()) {
+ group_count_x = (group_count_x + 16 - 1) / 16;
+ group_count_y = (group_count_y + 16 - 1) / 16;
+ group_count_z = 1;
+ }
+
+ syms()->vkCmdDispatch(command_buffer_, group_count_x, group_count_y,
+ group_count_z);
+
+ return OkStatus();
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/direct_command_buffer.h b/hal/vulkan/direct_command_buffer.h
new file mode 100644
index 0000000..9433dd3
--- /dev/null
+++ b/hal/vulkan/direct_command_buffer.h
@@ -0,0 +1,103 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
+#define IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
+
+#include <vulkan/vulkan.h>
+
+#include "hal/command_buffer.h"
+#include "hal/vulkan/descriptor_pool_cache.h"
+#include "hal/vulkan/descriptor_set_arena.h"
+#include "hal/vulkan/dynamic_symbols.h"
+#include "hal/vulkan/handle_util.h"
+#include "hal/vulkan/native_event.h"
+#include "hal/vulkan/pipeline_executable.h"
+#include "hal/vulkan/vma_buffer.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// Command buffer implementation that directly maps to VkCommandBuffer.
+// This records the commands on the calling thread without additional threading
+// indirection.
+class DirectCommandBuffer final : public CommandBuffer {
+ public:
+ DirectCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories,
+ ref_ptr<DescriptorPoolCache> descriptor_pool_cache,
+ ref_ptr<VkCommandPoolHandle> command_pool,
+ VkCommandBuffer command_buffer);
+ ~DirectCommandBuffer() override;
+
+ VkCommandBuffer handle() const { return command_buffer_; }
+
+ bool is_recording() const override { return is_recording_; }
+
+ Status Begin() override;
+ Status End() override;
+
+ Status ExecutionBarrier(
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) override;
+ Status SignalEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) override;
+ Status ResetEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) override;
+ Status WaitEvents(absl::Span<Event*> events,
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) override;
+
+ Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length, const void* pattern,
+ size_t pattern_length) override;
+ Status DiscardBuffer(Buffer* buffer) override;
+ Status UpdateBuffer(const void* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) override;
+ Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) override;
+
+ Status Dispatch(const DispatchRequest& dispatch_request) override;
+
+ private:
+ const ref_ptr<DynamicSymbols>& syms() const { return command_pool_->syms(); }
+
+ StatusOr<NativeEvent*> CastEvent(Event* event) const;
+ StatusOr<VmaBuffer*> CastBuffer(Buffer* buffer) const;
+
+ bool is_recording_ = false;
+ ref_ptr<VkCommandPoolHandle> command_pool_;
+ VkCommandBuffer command_buffer_;
+
+ // TODO(b/140026716): may grow large - should try to reclaim or reuse.
+ DescriptorSetArena descriptor_set_arena_;
+
+ // The current descriptor set group in use by the command buffer, if any.
+ // This must remain valid until all in-flight submissions of the command
+ // buffer complete.
+ DescriptorSetGroup descriptor_set_group_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
diff --git a/hal/vulkan/direct_command_queue.cc b/hal/vulkan/direct_command_queue.cc
new file mode 100644
index 0000000..c2c325e
--- /dev/null
+++ b/hal/vulkan/direct_command_queue.cc
@@ -0,0 +1,201 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/direct_command_queue.h"
+
+#include <cstdint>
+
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "base/memory.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/vulkan/direct_command_buffer.h"
+#include "hal/vulkan/legacy_fence.h"
+#include "hal/vulkan/native_binary_semaphore.h"
+#include "hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+DirectCommandQueue::DirectCommandQueue(
+ std::string name, CommandCategoryBitfield supported_categories,
+ const ref_ptr<VkDeviceHandle>& logical_device, VkQueue queue)
+ : CommandQueue(std::move(name), supported_categories),
+ logical_device_(add_ref(logical_device)),
+ queue_(queue) {}
+
+DirectCommandQueue::~DirectCommandQueue() {
+ IREE_TRACE_SCOPE0("DirectCommandQueue::dtor");
+ absl::MutexLock lock(&queue_mutex_);
+ syms()->vkQueueWaitIdle(queue_);
+}
+
+Status DirectCommandQueue::TranslateBatchInfo(const SubmissionBatch& batch,
+ VkSubmitInfo* submit_info,
+ Arena* arena) {
+ // TODO(benvanik): see if we can go to finer-grained stages.
+ // For example, if this was just queue ownership transfers then we can use
+ // the pseudo-stage of VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT.
+ VkPipelineStageFlags dst_stage_mask =
+ VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT;
+
+ auto wait_semaphore_handles =
+ arena->AllocateSpan<VkSemaphore>(batch.wait_semaphores.size());
+ auto wait_dst_stage_masks =
+ arena->AllocateSpan<VkPipelineStageFlags>(batch.wait_semaphores.size());
+ for (int i = 0; i < batch.wait_semaphores.size(); ++i) {
+ const auto& semaphore_value = batch.wait_semaphores[i];
+ if (semaphore_value.index() == 0) {
+ const auto& binary_semaphore =
+ static_cast<NativeBinarySemaphore*>(absl::get<0>(semaphore_value));
+ wait_semaphore_handles[i] = binary_semaphore->handle();
+ } else {
+ // TODO(b/140141417): implement timeline semaphores.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Timeline semaphores not yet implemented";
+ }
+ wait_dst_stage_masks[i] = dst_stage_mask;
+ }
+
+ auto signal_semaphore_handles =
+ arena->AllocateSpan<VkSemaphore>(batch.signal_semaphores.size());
+ for (int i = 0; i < batch.signal_semaphores.size(); ++i) {
+ const auto& semaphore_value = batch.signal_semaphores[i];
+ if (semaphore_value.index() == 0) {
+ const auto& binary_semaphore =
+ static_cast<NativeBinarySemaphore*>(absl::get<0>(semaphore_value));
+ signal_semaphore_handles[i] = binary_semaphore->handle();
+ } else {
+ // TODO(b/140141417): implement timeline semaphores.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Timeline semaphores not yet implemented";
+ }
+ }
+
+ auto command_buffer_handles =
+ arena->AllocateSpan<VkCommandBuffer>(batch.command_buffers.size());
+ for (int i = 0; i < batch.command_buffers.size(); ++i) {
+ const auto& command_buffer = batch.command_buffers[i];
+ auto* direct_command_buffer =
+ static_cast<DirectCommandBuffer*>(command_buffer->impl());
+ command_buffer_handles[i] = direct_command_buffer->handle();
+ }
+
+ submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
+ submit_info->pNext = nullptr;
+ submit_info->waitSemaphoreCount = wait_semaphore_handles.size();
+ submit_info->pWaitSemaphores = wait_semaphore_handles.data();
+ submit_info->pWaitDstStageMask = wait_dst_stage_masks.data();
+ submit_info->commandBufferCount = command_buffer_handles.size();
+ submit_info->pCommandBuffers = command_buffer_handles.data();
+ submit_info->signalSemaphoreCount = signal_semaphore_handles.size();
+ submit_info->pSignalSemaphores = signal_semaphore_handles.data();
+
+ return OkStatus();
+}
+
+Status DirectCommandQueue::Submit(absl::Span<const SubmissionBatch> batches,
+ FenceValue fence) {
+ IREE_TRACE_SCOPE0("DirectCommandQueue::Submit");
+
+ // Map the submission batches to VkSubmitInfos.
+ // Note that we must keep all arrays referenced alive until submission
+ // completes and since there are a bunch of them we use an arena.
+ Arena arena(4 * 1024);
+ auto submit_infos = arena.AllocateSpan<VkSubmitInfo>(batches.size());
+ for (int i = 0; i < batches.size(); ++i) {
+ RETURN_IF_ERROR(TranslateBatchInfo(batches[i], &submit_infos[i], &arena));
+ }
+
+ // TODO(b/140141417): implement timeline semaphore fences and switch here.
+ auto legacy_fence = reinterpret_cast<LegacyFence*>(fence.first);
+ ASSIGN_OR_RETURN(VkFence fence_handle,
+ legacy_fence->AcquireSignalFence(fence.second));
+
+ {
+ absl::MutexLock lock(&queue_mutex_);
+ VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(
+ queue_, submit_infos.size(), submit_infos.data(), fence_handle));
+ }
+
+ return OkStatus();
+}
+
+Status DirectCommandQueue::WaitIdle(absl::Time deadline) {
+ if (deadline == absl::InfiniteFuture()) {
+ // Fast path for using vkQueueWaitIdle, which is usually cheaper (as it
+ // requires fewer calls into the driver).
+ IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#vkQueueWaitIdle");
+ absl::MutexLock lock(&queue_mutex_);
+ VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_));
+ return OkStatus();
+ }
+
+ IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#Fence");
+
+ // Create a new fence just for this wait. This keeps us thread-safe as the
+ // behavior of wait+reset is racey.
+ VkFenceCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = 0;
+ VkFence fence = VK_NULL_HANDLE;
+ VK_RETURN_IF_ERROR(syms()->vkCreateFence(
+ *logical_device_, &create_info, logical_device_->allocator(), &fence));
+ auto fence_cleanup = MakeCleanup([this, fence]() {
+ syms()->vkDestroyFence(*logical_device_, fence,
+ logical_device_->allocator());
+ });
+
+ uint64_t timeout;
+ if (deadline == absl::InfinitePast()) {
+ // Do not wait.
+ timeout = 0;
+ } else if (deadline == absl::InfiniteFuture()) {
+ // Wait forever.
+ timeout = UINT64_MAX;
+ } else {
+ // Convert to relative time in nanoseconds.
+ // The implementation may not wait with this granularity (like, by 10000x).
+ absl::Time now = absl::Now();
+ if (deadline < now) {
+ return DeadlineExceededErrorBuilder(IREE_LOC) << "Deadline in the past";
+ }
+ timeout = static_cast<uint64_t>(absl::ToInt64Nanoseconds(deadline - now));
+ }
+
+ {
+ absl::MutexLock lock(&queue_mutex_);
+ VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(queue_, 0, nullptr, fence));
+ }
+
+ VkResult result =
+ syms()->vkWaitForFences(*logical_device_, 1, &fence, VK_TRUE, timeout);
+ switch (result) {
+ case VK_SUCCESS:
+ return OkStatus();
+ case VK_TIMEOUT:
+ return DeadlineExceededErrorBuilder(IREE_LOC)
+ << "Deadline exceeded waiting for idle";
+ default:
+ return VkResultToStatus(result);
+ }
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/direct_command_queue.h b/hal/vulkan/direct_command_queue.h
new file mode 100644
index 0000000..85b6f60
--- /dev/null
+++ b/hal/vulkan/direct_command_queue.h
@@ -0,0 +1,69 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
+#define IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
+
+#include <vulkan/vulkan.h>
+
+#include <cstdint>
+#include <string>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "base/arena.h"
+#include "base/status.h"
+#include "hal/command_queue.h"
+#include "hal/vulkan/dynamic_symbols.h"
+#include "hal/vulkan/handle_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// Command queue implementation directly maps to VkQueue.
+class DirectCommandQueue final : public CommandQueue {
+ public:
+ DirectCommandQueue(std::string name,
+ CommandCategoryBitfield supported_categories,
+ const ref_ptr<VkDeviceHandle>& logical_device,
+ VkQueue queue);
+ ~DirectCommandQueue() override;
+
+ const ref_ptr<DynamicSymbols>& syms() const {
+ return logical_device_->syms();
+ }
+
+ Status Submit(absl::Span<const SubmissionBatch> batches,
+ FenceValue fence) override;
+
+ Status WaitIdle(absl::Time deadline) override;
+
+ private:
+ Status TranslateBatchInfo(const SubmissionBatch& batch,
+ VkSubmitInfo* submit_info, Arena* arena);
+
+ ref_ptr<VkDeviceHandle> logical_device_;
+
+ // VkQueue needs to be externally synchronized.
+ mutable absl::Mutex queue_mutex_;
+ VkQueue queue_ ABSL_GUARDED_BY(queue_mutex_);
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
diff --git a/iree/hal/vulkan/dynamic_symbol_tables.h b/hal/vulkan/dynamic_symbol_tables.h
similarity index 100%
rename from iree/hal/vulkan/dynamic_symbol_tables.h
rename to hal/vulkan/dynamic_symbol_tables.h
diff --git a/hal/vulkan/dynamic_symbols.cc b/hal/vulkan/dynamic_symbols.cc
new file mode 100644
index 0000000..b4823f6
--- /dev/null
+++ b/hal/vulkan/dynamic_symbols.cc
@@ -0,0 +1,238 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/dynamic_symbols.h"
+
+#include <cstddef>
+#include <cstdlib>
+
+#include "absl/base/attributes.h"
+#include "absl/base/macros.h"
+#include "absl/memory/memory.h"
+#include "base/file_path.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/target_platform.h"
+#include "base/tracing.h"
+#include "hal/vulkan/dynamic_symbol_tables.h"
+
+#if defined(IREE_PLATFORM_WINDOWS)
+#include <windows.h>
+#else
+#include <dlfcn.h>
+#endif // IREE_PLATFORM_WINDOWS
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// Read-only table of function pointer information designed to be in .rdata.
+// To reduce binary size this structure is packed (knowing that we won't have
+// gigabytes of function pointers :).
+struct FunctionPtrInfo {
+ // Name of the function (like 'vkSomeFunction').
+ const char* function_name;
+ // 1 if the function pointer can be resolved via vkGetDeviceProcAddr.
+ uint32_t is_device : 1;
+ // 1 if the function is required and the loader should bail if not found.
+ uint32_t is_required : 1;
+ // TODO(benvanik): remove from table by manually walking sizeof(uintptr_t).
+ // An offset in bytes from the base of &syms to where the PFN_vkSomeFunction
+ // member is located.
+ uint32_t member_offset : 30;
+} ABSL_ATTRIBUTE_PACKED;
+
+namespace {
+
+#define REQUIRED_PFN_FUNCTION_PTR(function_name, is_device) \
+ {#function_name, is_device, 1, offsetof(DynamicSymbols, function_name)},
+#define OPTIONAL_PFN_FUNCTION_PTR(function_name, is_device) \
+ {#function_name, is_device, 0, offsetof(DynamicSymbols, function_name)},
+#define EXCLUDED_PFN_FUNCTION_PTR(function_name, is_device)
+#define INS_PFN_FUNCTION_PTR(requirement, function_name) \
+ requirement##_PFN_FUNCTION_PTR(function_name, 0)
+#define DEV_PFN_FUNCTION_PTR(requirement, function_name) \
+ requirement##_PFN_FUNCTION_PTR(function_name, 1)
+
+// Defines the table of mandatory FunctionPtrInfos resolved prior to instance
+// creation. These are safe to call with no instance parameter and should be
+// exported by all loaders/ICDs.
+static constexpr const FunctionPtrInfo kInstancelessFunctionPtrInfos[] = {
+ REQUIRED_PFN_FUNCTION_PTR(vkCreateInstance, false) //
+ REQUIRED_PFN_FUNCTION_PTR(vkEnumerateInstanceLayerProperties, false) //
+ REQUIRED_PFN_FUNCTION_PTR(vkEnumerateInstanceExtensionProperties, false) //
+};
+
+// Defines the table of FunctionPtrInfos for dynamic loading that must wait
+// until an instance has been created to be resolved.
+static constexpr const FunctionPtrInfo kDynamicFunctionPtrInfos[] = {
+ IREE_VULKAN_DYNAMIC_SYMBOL_TABLES(INS_PFN_FUNCTION_PTR,
+ DEV_PFN_FUNCTION_PTR)};
+
+} // namespace
+
+// static
+StatusOr<ref_ptr<DynamicSymbols>> DynamicSymbols::Create(
+ const GetProcAddrFn& get_proc_addr) {
+ IREE_TRACE_SCOPE0("DynamicSymbols::Create");
+
+ auto syms = make_ref<DynamicSymbols>();
+
+ // Resolve the method the shared object uses to resolve other functions.
+ // Some libraries will export all symbols while others will only export this
+ // single function.
+ syms->vkGetInstanceProcAddr = reinterpret_cast<PFN_vkGetInstanceProcAddr>(
+ get_proc_addr("vkGetInstanceProcAddr"));
+ if (!syms->vkGetInstanceProcAddr) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Required method vkGetInstanceProcAddr not "
+ "found in provided Vulkan library (did you pick the wrong file?)";
+ }
+
+ // Resolve the mandatory functions that we need to create instances.
+ // If the provided |get_proc_addr| cannot resolve these then it's not a loader
+ // or ICD we want to use, anyway.
+ for (int i = 0; i < ABSL_ARRAYSIZE(kInstancelessFunctionPtrInfos); ++i) {
+ const auto& function_ptr = kInstancelessFunctionPtrInfos[i];
+ auto* member_ptr = reinterpret_cast<PFN_vkVoidFunction*>(
+ reinterpret_cast<uint8_t*>(syms.get()) + function_ptr.member_offset);
+ *member_ptr =
+ syms->vkGetInstanceProcAddr(VK_NULL_HANDLE, function_ptr.function_name);
+ if (*member_ptr == nullptr) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Mandatory Vulkan function " << function_ptr.function_name
+ << " not available; invalid loader/ICD?";
+ }
+ }
+
+ return syms;
+}
+
+// static
+StatusOr<ref_ptr<DynamicSymbols>> DynamicSymbols::CreateFromSystemLoader() {
+ IREE_TRACE_SCOPE0("DynamicSymbols::CreateFromSystemLoader");
+
+#if defined(IREE_VK_ICD_FILENAMES)
+#define IREE_STRINGIFY_(x) #x
+#define IREE_STRING_(x) IREE_STRINGIFY_(x)
+ std::string vk_icd_filenames = IREE_STRING_(IREE_VK_ICD_FILENAMES);
+#undef IREE_STRINGIFY_
+#undef IREE_STRING_
+#if defined(IREE_PLATFORM_WINDOWS)
+ // TODO(b/138220713): Set VK_ICD_FILENAMES on Windows
+#else
+ ::setenv("VK_ICD_FILENAMES", vk_icd_filenames.c_str(), 0);
+#endif // IREE_PLATFORM_WINDOWS
+#else
+ // Leave VK_ICD_FILENAMES unchanged and rely on the system Vulkan loader to
+ // discover ICDs.
+#endif // IREE_VK_ICD_FILENAMES
+
+// NOTE: we could factor this out into base, but this is the only place we use
+// it right now so it's fine.
+#if defined(IREE_PLATFORM_WINDOWS)
+ HMODULE library = ::LoadLibraryA("vulkan-1.dll");
+ if (!library) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Unable to open vulkan-1.dll; driver not installed/on PATH";
+ }
+ ASSIGN_OR_RETURN(auto syms, Create([library](const char* function_name) {
+ return reinterpret_cast<PFN_vkVoidFunction>(
+ ::GetProcAddress(library, function_name));
+ }));
+ syms->close_fn_ = [library]() {
+ // TODO(benvanik): disable if we want to get profiling results. Sometimes
+ // closing the library can prevent proper symbolization on crashes or
+ // in sampling profilers.
+ ::FreeLibrary(library);
+ };
+ return syms;
+#else
+ void* library = ::dlopen("libvulkan.so.1", RTLD_LAZY | RTLD_LOCAL);
+ if (!library) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Unable to open libvulkan.so; driver not installed/on "
+ "LD_LIBRARY_PATH";
+ }
+ ASSIGN_OR_RETURN(auto syms, Create([library](const char* function_name) {
+ return reinterpret_cast<PFN_vkVoidFunction>(
+ ::dlsym(library, function_name));
+ }));
+ syms->close_fn_ = [library]() {
+ // TODO(benvanik): disable if we want to get profiling results. Sometimes
+ // closing the library can prevent proper symbolization on crashes or
+ // in sampling profilers.
+ ::dlclose(library);
+ };
+ return syms;
+#endif // IREE_PLATFORM_WINDOWS
+}
+
+Status DynamicSymbols::LoadFromInstance(VkInstance instance) {
+ IREE_TRACE_SCOPE0("DynamicSymbols::LoadFromInstance");
+ return LoadFromDevice(instance, VK_NULL_HANDLE);
+}
+
+Status DynamicSymbols::LoadFromDevice(VkInstance instance, VkDevice device) {
+ IREE_TRACE_SCOPE0("DynamicSymbols::LoadFromDevice");
+
+ if (!instance) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Instance must have been created and a default instance proc "
+ "lookup function is required";
+ }
+
+ // Setup the lookup methods first. The rest of the syms uses these to
+ // resolve function pointers.
+ this->vkGetDeviceProcAddr = reinterpret_cast<PFN_vkGetDeviceProcAddr>(
+ this->vkGetInstanceProcAddr(instance, "vkGetDeviceProcAddr"));
+ if (!this->vkGetDeviceProcAddr) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Required Vulkan function vkGetDeviceProcAddr not available; "
+ "invalid driver handle?";
+ }
+
+ // Load the rest of the functions.
+ for (int i = 0; i < ABSL_ARRAYSIZE(kDynamicFunctionPtrInfos); ++i) {
+ const auto& function_ptr = kDynamicFunctionPtrInfos[i];
+ auto* member_ptr = reinterpret_cast<PFN_vkVoidFunction*>(
+ reinterpret_cast<uint8_t*>(this) + function_ptr.member_offset);
+ if (function_ptr.is_device && device) {
+ *member_ptr =
+ this->vkGetDeviceProcAddr(device, function_ptr.function_name);
+ } else {
+ *member_ptr =
+ this->vkGetInstanceProcAddr(instance, function_ptr.function_name);
+ }
+ if (*member_ptr == nullptr && function_ptr.is_required) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Required Vulkan function " << function_ptr.function_name
+ << " not available";
+ }
+ }
+
+ return OkStatus();
+}
+
+DynamicSymbols::DynamicSymbols() = default;
+
+DynamicSymbols::~DynamicSymbols() {
+ if (close_fn_) {
+ close_fn_();
+ }
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/dynamic_symbols.h b/hal/vulkan/dynamic_symbols.h
new file mode 100644
index 0000000..adc95c8
--- /dev/null
+++ b/hal/vulkan/dynamic_symbols.h
@@ -0,0 +1,129 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
+#define IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
+
+#include <vulkan/vulkan.h>
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+
+#include "base/ref_ptr.h"
+#include "base/status.h"
+#include "hal/vulkan/dynamic_symbol_tables.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+struct FunctionPtrInfo;
+
+// Dynamic Vulkan function loader for use with vulkan.hpp.
+// This loader is a subset of the DispatchLoaderDynamic implementation that only
+// loads functions we are interested in (a compute-specific subset) and avoids
+// extensions we will never use.
+//
+// This exposes all Vulkan methods as function pointer members. Optional
+// methods will be nullptr if not present. Excluded methods will be omitted.
+//
+// DynamicSymbols instances are designed to be passed to vulkan.hpp methods as
+// the last argument, though they may also be called directly.
+// **Always make sure to pass the loader to vulkan.hpp methods!**
+//
+// Loading is performed by walking a table of required and optional functions
+// (defined in dynamic_symbol_tables.h) and populating the member function
+// pointers exposed on this struct when available. For example, if the
+// vkSomeFunction method is marked in the table as OPTIONAL the loader will
+// attempt to lookup the function and if successful set the
+// DynamicSymbols::vkSomeFunction pointer to the resolved address. If the
+// function is not found then it will be set to nullptr so users can check for
+// function availability.
+//
+// Documentation:
+// https://github.com/KhronosGroup/Vulkan-Hpp#extensions--per-device-function-pointers
+//
+// Usage:
+// ASSIGN_OR_RETURN(auto syms, DynamicSymbols::CreateFromSystemLoader());
+// VkInstance instance = VK_NULL_HANDLE;
+// syms->vkCreateInstance(..., &instance);
+// RETURN_IF_ERROR(syms->LoadFromInstance(instance));
+struct DynamicSymbols : public RefObject<DynamicSymbols> {
+ using GetProcAddrFn =
+ std::function<PFN_vkVoidFunction(const char* function_name)>;
+
+ DynamicSymbols();
+ ~DynamicSymbols();
+
+ // Creates the dynamic symbol table using the given |get_proc_addr| to resolve
+ // the vkCreateInstance function.
+ //
+ // After the instance is created the caller must use LoadFromInstance (or
+ // LoadFromDevice) to load the remaining symbols.
+ static StatusOr<ref_ptr<DynamicSymbols>> Create(
+ const GetProcAddrFn& get_proc_addr);
+
+ // Loads all required and optional Vulkan functions from the Vulkan loader.
+ // This will look for a Vulkan loader on the system (like libvulkan.so) and
+ // dlsym the functions from that.
+ //
+ // The loaded function pointers will point to thunks in the ICD. This may
+ // enable additional debug checking and more readable stack traces (as
+ // errors come from within the ICD, where we have symbols).
+ static StatusOr<ref_ptr<DynamicSymbols>> CreateFromSystemLoader();
+
+ // Loads all required and optional Vulkan functions from the given instance.
+ //
+ // The loaded function pointers will point to thunks in the ICD. This may
+ // enable additional debug checking and more readable stack traces (as
+ // errors come from within the ICD, where we have symbols).
+ Status LoadFromInstance(VkInstance instance);
+
+ // Loads all required and optional Vulkan functions from the given device,
+ // falling back to the instance when required.
+ //
+ // This attempts to directly query the methods from the device, bypassing any
+ // ICD or shim layers. These methods will generally have less overhead at
+ // runtime as they need not jump through the various trampolines.
+ Status LoadFromDevice(VkInstance instance, VkDevice device);
+
+ // Define members for each function pointer.
+ // See dynamic_symbol_tables.h for the full list of methods.
+ //
+ // Each required and optional function in the loader tables will expand to
+ // the following member, such as for example 'vkSomeFunction':
+ // PFN_vkSomeFunction vkSomeFunction;
+#define REQUIRED_PFN(function_name) PFN_##function_name function_name
+#define OPTIONAL_PFN(function_name) PFN_##function_name function_name
+#define EXCLUDED_PFN(function_name)
+#define PFN_MEMBER(requirement, function_name) requirement##_PFN(function_name);
+ REQUIRED_PFN(vkGetInstanceProcAddr);
+ REQUIRED_PFN(vkGetDeviceProcAddr);
+ IREE_VULKAN_DYNAMIC_SYMBOL_TABLES(PFN_MEMBER, PFN_MEMBER);
+#undef REQUIRED_PFN
+#undef OPTIONAL_PFN
+#undef EXCLUDED_PFN
+#undef PFN_MEMBER
+
+ private:
+ // Optional callback on loader destruction.
+ std::function<void()> close_fn_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
diff --git a/hal/vulkan/dynamic_symbols_test.cc b/hal/vulkan/dynamic_symbols_test.cc
new file mode 100644
index 0000000..3bf0638
--- /dev/null
+++ b/hal/vulkan/dynamic_symbols_test.cc
@@ -0,0 +1,73 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/dynamic_symbols.h"
+
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+namespace {
+
+VkApplicationInfo GetApplicationInfo() {
+ VkApplicationInfo app_info;
+ app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
+ app_info.pNext = nullptr;
+ app_info.pApplicationName = "IREE-ML-TEST";
+ app_info.applicationVersion = 0;
+ app_info.pEngineName = "IREE";
+ app_info.engineVersion = 0;
+ app_info.apiVersion = VK_API_VERSION_1_0;
+ return app_info;
+}
+
+VkInstanceCreateInfo GetInstanceCreateInfo(VkApplicationInfo* app_info) {
+ VkInstanceCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = 0;
+ create_info.pApplicationInfo = app_info;
+ create_info.enabledLayerCount = 0;
+ create_info.ppEnabledLayerNames = nullptr;
+ create_info.enabledExtensionCount = 0;
+ create_info.ppEnabledExtensionNames = nullptr;
+ return create_info;
+}
+
+TEST(DynamicSymbolsTest, CreateFromSystemLoader) {
+ auto status_or_syms = DynamicSymbols::CreateFromSystemLoader();
+ ASSERT_OK(status_or_syms);
+ ref_ptr<DynamicSymbols> syms = std::move(status_or_syms.ValueOrDie());
+
+ // Create and destroy a VkInstance using the symbols. This is mainly testing
+ // that the symbols were loaded successfully and are actually able to be used.
+ VkApplicationInfo app_info = GetApplicationInfo();
+ VkInstanceCreateInfo create_info = GetInstanceCreateInfo(&app_info);
+ VkInstance instance = VK_NULL_HANDLE;
+ VK_CHECK_OK(
+ syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance));
+
+ ASSERT_OK(syms->LoadFromInstance(instance));
+
+ syms->vkDestroyInstance(instance, /*pAllocator=*/nullptr);
+}
+
+} // namespace
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/extensibility_util.cc b/hal/vulkan/extensibility_util.cc
new file mode 100644
index 0000000..09d9f54
--- /dev/null
+++ b/hal/vulkan/extensibility_util.cc
@@ -0,0 +1,221 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/extensibility_util.h"
+
+#include "base/memory.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+namespace {
+
+StatusOr<std::vector<const char*>> MatchAvailableLayers(
+ absl::Span<const char* const> required_layers,
+ absl::Span<const char* const> optional_layers,
+ absl::Span<const VkLayerProperties> properties) {
+ IREE_TRACE_SCOPE0("MatchAvailableLayers");
+
+ std::vector<const char*> enabled_layers;
+ enabled_layers.reserve(required_layers.size() + optional_layers.size());
+
+ for (const char* layer_name : required_layers) {
+ bool found = false;
+ for (const auto& layer_properties : properties) {
+ if (std::strcmp(layer_name, layer_properties.layerName) == 0) {
+ VLOG(1) << "Enabling required layer: " << layer_name;
+ found = true;
+ enabled_layers.push_back(layer_name);
+ break;
+ }
+ }
+ if (!found) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Required layer " << layer_name << " not available";
+ }
+ }
+
+ for (const char* layer_name : optional_layers) {
+ bool found = false;
+ for (const auto& layer_properties : properties) {
+ if (std::strcmp(layer_name, layer_properties.layerName) == 0) {
+ VLOG(1) << "Enabling optional layer: " << layer_name;
+ found = true;
+ enabled_layers.push_back(layer_name);
+ break;
+ }
+ }
+ if (!found) {
+ VLOG(1) << "Optional layer " << layer_name << " not available";
+ }
+ }
+
+ return enabled_layers;
+}
+
+StatusOr<std::vector<const char*>> MatchAvailableExtensions(
+ absl::Span<const char* const> required_extensions,
+ absl::Span<const char* const> optional_extensions,
+ absl::Span<const VkExtensionProperties> properties) {
+ IREE_TRACE_SCOPE0("MatchAvailableExtensions");
+
+ std::vector<const char*> enabled_extensions;
+ enabled_extensions.reserve(required_extensions.size() +
+ optional_extensions.size());
+
+ for (const char* extension_name : required_extensions) {
+ bool found = false;
+ for (const auto& extension_properties : properties) {
+ if (std::strcmp(extension_name, extension_properties.extensionName) ==
+ 0) {
+ VLOG(1) << "Enabling required extension: " << extension_name;
+ found = true;
+ enabled_extensions.push_back(extension_name);
+ break;
+ }
+ }
+ if (!found) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Required extension " << extension_name << " not available";
+ }
+ }
+
+ for (const char* extension_name : optional_extensions) {
+ bool found = false;
+ for (const auto& extension_properties : properties) {
+ if (std::strcmp(extension_name, extension_properties.extensionName) ==
+ 0) {
+ VLOG(1) << "Enabling optional extension: " << extension_name;
+ found = true;
+ enabled_extensions.push_back(extension_name);
+ break;
+ }
+ }
+ if (!found) {
+ VLOG(1) << "Optional extension " << extension_name << " not available";
+ }
+ }
+
+ return enabled_extensions;
+}
+
+} // namespace
+
+StatusOr<std::vector<const char*>> MatchAvailableInstanceLayers(
+ const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) {
+ uint32_t layer_property_count = 0;
+ VK_RETURN_IF_ERROR(
+ syms.vkEnumerateInstanceLayerProperties(&layer_property_count, nullptr));
+ std::vector<VkLayerProperties> layer_properties(layer_property_count);
+ VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceLayerProperties(
+ &layer_property_count, layer_properties.data()));
+ ASSIGN_OR_RETURN(auto enabled_layers,
+ MatchAvailableLayers(extensibility_spec.required_layers,
+ extensibility_spec.optional_layers,
+ layer_properties),
+ _ << "Unable to find all required instance layers");
+ return enabled_layers;
+}
+
+StatusOr<std::vector<const char*>> MatchAvailableInstanceExtensions(
+ const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) {
+ uint32_t extension_property_count = 0;
+ // Warning: leak checks remain disabled if an error is returned.
+ IREE_DISABLE_LEAK_CHECKS();
+ VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties(
+ nullptr, &extension_property_count, nullptr));
+ std::vector<VkExtensionProperties> extension_properties(
+ extension_property_count);
+ VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties(
+ nullptr, &extension_property_count, extension_properties.data()));
+ ASSIGN_OR_RETURN(
+ auto enabled_extensions,
+ MatchAvailableExtensions(extensibility_spec.required_extensions,
+ extensibility_spec.optional_extensions,
+ extension_properties),
+ _ << "Unable to find all required instance extensions");
+ IREE_ENABLE_LEAK_CHECKS();
+ return enabled_extensions;
+}
+
+StatusOr<std::vector<const char*>> MatchAvailableDeviceLayers(
+ VkPhysicalDevice physical_device,
+ const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) {
+ uint32_t layer_property_count = 0;
+ VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceLayerProperties(
+ physical_device, &layer_property_count, nullptr));
+ std::vector<VkLayerProperties> layer_properties(layer_property_count);
+ VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceLayerProperties(
+ physical_device, &layer_property_count, layer_properties.data()));
+ ASSIGN_OR_RETURN(auto enabled_layers,
+ MatchAvailableLayers(extensibility_spec.required_layers,
+ extensibility_spec.optional_layers,
+ layer_properties),
+ _ << "Unable to find all required device layers");
+ return enabled_layers;
+}
+
+StatusOr<std::vector<const char*>> MatchAvailableDeviceExtensions(
+ VkPhysicalDevice physical_device,
+ const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) {
+ uint32_t extension_property_count = 0;
+ VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties(
+ physical_device, nullptr, &extension_property_count, nullptr));
+ std::vector<VkExtensionProperties> extension_properties(
+ extension_property_count);
+ VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties(
+ physical_device, nullptr, &extension_property_count,
+ extension_properties.data()));
+ ASSIGN_OR_RETURN(
+ auto enabled_extensions,
+ MatchAvailableExtensions(extensibility_spec.required_extensions,
+ extensibility_spec.optional_extensions,
+ extension_properties),
+ _ << "Unable to find all required device extensions");
+ return enabled_extensions;
+}
+
+InstanceExtensions PopulateEnabledInstanceExtensions(
+ absl::Span<const char* const> extension_names) {
+ InstanceExtensions extensions = {0};
+ for (const char* extension_name : extension_names) {
+ if (std::strcmp(extension_name, VK_EXT_DEBUG_REPORT_EXTENSION_NAME) == 0) {
+ extensions.debug_report = true;
+ } else if (std::strcmp(extension_name, VK_EXT_DEBUG_UTILS_EXTENSION_NAME) ==
+ 0) {
+ extensions.debug_utils = true;
+ }
+ }
+ return extensions;
+}
+
+DeviceExtensions PopulateEnabledDeviceExtensions(
+ absl::Span<const char* const> extension_names) {
+ DeviceExtensions extensions = {0};
+ for (const char* extension_name : extension_names) {
+ if (std::strcmp(extension_name, VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME) ==
+ 0) {
+ extensions.push_descriptors = true;
+ }
+ }
+ return extensions;
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/extensibility_util.h b/hal/vulkan/extensibility_util.h
new file mode 100644
index 0000000..3e4d725
--- /dev/null
+++ b/hal/vulkan/extensibility_util.h
@@ -0,0 +1,100 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Utilities for working with layers and extensions.
+
+#ifndef IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
+#define IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
+
+#include <vulkan/vulkan.h>
+
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "hal/vulkan/dynamic_symbols.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// Describes required and optional extensibility points.
+struct ExtensibilitySpec {
+ // A list of required and optional layers.
+ std::vector<const char*> required_layers;
+ std::vector<const char*> optional_layers;
+
+ // A list of required and optional extensions.
+ // Prefer using the _EXTENSION_NAME macros to make tracking easier (such as
+ // 'VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME').
+ std::vector<const char*> required_extensions;
+ std::vector<const char*> optional_extensions;
+};
+
+// Returns a list of layer names available for instances.
+// Fails if any required_layers are unavailable.
+StatusOr<std::vector<const char*>> MatchAvailableInstanceLayers(
+ const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms);
+
+// Returns a list of extension names available for instances.
+// Fails if any required_extensions are unavailable.
+StatusOr<std::vector<const char*>> MatchAvailableInstanceExtensions(
+ const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms);
+
+// Returns a list of layer names available for the given |physical_device|.
+// Fails if any required_layers are unavailable.
+StatusOr<std::vector<const char*>> MatchAvailableDeviceLayers(
+ VkPhysicalDevice physical_device,
+ const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms);
+
+// Returns a list of extension names available for the given |physical_device|.
+// Fails if any required_extensions are unavailable.
+StatusOr<std::vector<const char*>> MatchAvailableDeviceExtensions(
+ VkPhysicalDevice physical_device,
+ const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms);
+
+// Bits for enabled instance extensions.
+// We must use this to query support instead of just detecting symbol names as
+// ICDs will resolve the functions sometimes even if they don't support the
+// extension (or we didn't ask for it to be enabled).
+struct InstanceExtensions {
+ // VK_EXT_debug_report is enabled and a callback is regsitered.
+ // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_report
+ bool debug_report : 1;
+
+ // VK_EXT_debug_utils is enabled and a debug messenger is registered.
+ // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_utils
+ bool debug_utils : 1;
+};
+
+// Returns a bitfield with all of the provided extension names.
+InstanceExtensions PopulateEnabledInstanceExtensions(
+ absl::Span<const char* const> extension_names);
+
+// Bits for enabled device extensions.
+// We must use this to query support instead of just detecting symbol names as
+// ICDs will resolve the functions sometimes even if they don't support the
+// extension (or we didn't ask for it to be enabled).
+struct DeviceExtensions {
+ // VK_KHR_push_descriptor is enabled and vkCmdPushDescriptorSetKHR is valid.
+ bool push_descriptors : 1;
+};
+
+// Returns a bitfield with all of the provided extension names.
+DeviceExtensions PopulateEnabledDeviceExtensions(
+ absl::Span<const char* const> extension_names);
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
diff --git a/hal/vulkan/handle_util.h b/hal/vulkan/handle_util.h
new file mode 100644
index 0000000..c12e4d1
--- /dev/null
+++ b/hal/vulkan/handle_util.h
@@ -0,0 +1,136 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Helpers for wrapping Vulkan handles that don't require us to wrap every type.
+// This keeps our compilation time reasonable (as the vulkancpp library is
+// insane) while giving us nice safety around cleanup and ensuring we use
+// dynamic symbols and consistent allocators.
+//
+// Do not add functionality beyond handle management to these types. Keep our
+// Vulkan usage mostly functional and C-like to ensure minimal code size and
+// readability.
+
+#ifndef IREE_HAL_VULKAN_HANDLE_UTIL_H_
+#define IREE_HAL_VULKAN_HANDLE_UTIL_H_
+
+#include <vulkan/vulkan.h>
+
+#include "absl/synchronization/mutex.h"
+#include "absl/utility/utility.h"
+#include "base/ref_ptr.h"
+#include "hal/vulkan/dynamic_symbols.h"
+#include "hal/vulkan/extensibility_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+class VkDeviceHandle : public RefObject<VkDeviceHandle> {
+ public:
+ VkDeviceHandle(const ref_ptr<DynamicSymbols>& syms,
+ DeviceExtensions enabled_extensions,
+ const VkAllocationCallbacks* allocator = nullptr)
+ : syms_(add_ref(syms)),
+ enabled_extensions_(enabled_extensions),
+ allocator_(allocator) {}
+ ~VkDeviceHandle() { reset(); }
+
+ VkDeviceHandle(const VkDeviceHandle&) = delete;
+ VkDeviceHandle& operator=(const VkDeviceHandle&) = delete;
+ VkDeviceHandle(VkDeviceHandle&& other) noexcept
+ : value_(absl::exchange(other.value_,
+ static_cast<VkDevice>(VK_NULL_HANDLE))),
+ syms_(std::move(other.syms_)),
+ enabled_extensions_(other.enabled_extensions_),
+ allocator_(other.allocator_) {}
+
+ void reset() {
+ if (value_ == VK_NULL_HANDLE) return;
+ syms_->vkDestroyDevice(value_, allocator_);
+ value_ = VK_NULL_HANDLE;
+ }
+
+ VkDevice value() const noexcept { return value_; }
+ VkDevice* mutable_value() noexcept { return &value_; }
+ operator VkDevice() const noexcept { return value_; }
+
+ const ref_ptr<DynamicSymbols>& syms() const noexcept { return syms_; }
+ const VkAllocationCallbacks* allocator() const noexcept { return allocator_; }
+
+ const DeviceExtensions& enabled_extensions() const {
+ return enabled_extensions_;
+ }
+
+ private:
+ VkDevice value_ = VK_NULL_HANDLE;
+ ref_ptr<DynamicSymbols> syms_;
+ DeviceExtensions enabled_extensions_;
+ const VkAllocationCallbacks* allocator_ = nullptr;
+};
+
+class VkCommandPoolHandle : public RefObject<VkCommandPoolHandle> {
+ public:
+ explicit VkCommandPoolHandle(const ref_ptr<VkDeviceHandle>& logical_device)
+ : logical_device_(add_ref(logical_device)) {}
+ ~VkCommandPoolHandle() { reset(); }
+
+ VkCommandPoolHandle(const VkCommandPoolHandle&) = delete;
+ VkCommandPoolHandle& operator=(const VkCommandPoolHandle&) = delete;
+ VkCommandPoolHandle(VkCommandPoolHandle&& other) noexcept
+ : logical_device_(std::move(other.logical_device_)),
+ value_(absl::exchange(other.value_,
+ static_cast<VkCommandPool>(VK_NULL_HANDLE))) {}
+ VkCommandPoolHandle& operator=(VkCommandPoolHandle&& other) {
+ std::swap(logical_device_, other.logical_device_);
+ std::swap(value_, other.value_);
+ return *this;
+ }
+
+ void reset() {
+ if (value_ == VK_NULL_HANDLE) return;
+ syms()->vkDestroyCommandPool(*logical_device_, value_, allocator());
+ value_ = VK_NULL_HANDLE;
+ }
+
+ VkCommandPool value() const noexcept { return value_; }
+ VkCommandPool* mutable_value() noexcept { return &value_; }
+ operator VkCommandPool() const noexcept { return value_; }
+
+ const ref_ptr<VkDeviceHandle>& logical_device() const noexcept {
+ return logical_device_;
+ }
+ const ref_ptr<DynamicSymbols>& syms() const noexcept {
+ return logical_device_->syms();
+ }
+ const VkAllocationCallbacks* allocator() const noexcept {
+ return logical_device_->allocator();
+ }
+
+ absl::Mutex* mutex() const { return &mutex_; }
+
+ private:
+ ref_ptr<VkDeviceHandle> logical_device_;
+ VkCommandPool value_ = VK_NULL_HANDLE;
+
+ // Vulkan command pools are not thread safe and require external
+ // synchronization. Since we allow arbitrary threads to allocate and
+ // deallocate the HAL command buffers we need to externally synchronize.
+ mutable absl::Mutex mutex_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_HANDLE_UTIL_H_
diff --git a/hal/vulkan/internal_vk_mem_alloc.cc b/hal/vulkan/internal_vk_mem_alloc.cc
new file mode 100644
index 0000000..f3d317f
--- /dev/null
+++ b/hal/vulkan/internal_vk_mem_alloc.cc
@@ -0,0 +1,68 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file configures VMA to use common Google/Abseil types in an effort to
+// better integrate with applications compiled using other Google code. By using
+// the same types that dependers are likely using we can often reduce binary
+// size and ease debugging (such as by using absl::Mutex to get better tsan
+// warnings).
+
+// Only compile if an external implementation has not been otherwise linked.
+#if !defined(VULKAN_MEMORY_ALLOCATOR_EXTERNAL_IMPL)
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/synchronization/mutex.h"
+#include "base/logging.h"
+
+// Use std::vector instead of the VMA version.
+#define VMA_USE_STL_VECTOR 1
+
+// TODO(benvanik): figure out why std::list cannot be used.
+// #define VMA_USE_STL_LIST 1
+
+// Use absl::flat_hash_map instead of std::unordered_map.
+#define VmaPair std::pair
+#define VMA_MAP_TYPE(KeyT, ValueT) \
+ absl::flat_hash_map<KeyT, ValueT, std::hash<KeyT>, std::equal_to<KeyT>, \
+ VmaStlAllocator<std::pair<KeyT, ValueT> > >
+
+// Use CHECK for assertions.
+#define VMA_ASSERT CHECK
+#define VMA_HEAVY_ASSERT DCHECK
+
+// Use LOG for logging.
+#ifndef NDEBUG
+#define VMA_DEBUG_LOG(...) ABSL_RAW_LOG(INFO, __VA_ARGS__)
+#else
+#define VMA_DEBUG_LOG(...)
+#endif // !NDEBUG
+
+// Use absl::Mutex for VMA_MUTEX.
+#define VMA_MUTEX absl::Mutex
+class AbslVmaRWMutex {
+ public:
+ void LockRead() ABSL_SHARED_LOCK_FUNCTION() { mutex_.ReaderLock(); }
+ void UnlockRead() ABSL_UNLOCK_FUNCTION() { mutex_.ReaderUnlock(); }
+ void LockWrite() ABSL_EXCLUSIVE_LOCK_FUNCTION() { mutex_.WriterLock(); }
+ void UnlockWrite() ABSL_UNLOCK_FUNCTION() { mutex_.WriterUnlock(); }
+
+ private:
+ absl::Mutex mutex_;
+};
+#define VMA_RW_MUTEX AbslVmaRWMutex
+
+#define VMA_IMPLEMENTATION
+#include "vk_mem_alloc.h"
+
+#endif
diff --git a/iree/hal/vulkan/internal_vk_mem_alloc.h b/hal/vulkan/internal_vk_mem_alloc.h
similarity index 100%
rename from iree/hal/vulkan/internal_vk_mem_alloc.h
rename to hal/vulkan/internal_vk_mem_alloc.h
diff --git a/hal/vulkan/legacy_fence.cc b/hal/vulkan/legacy_fence.cc
new file mode 100644
index 0000000..7639817
--- /dev/null
+++ b/hal/vulkan/legacy_fence.cc
@@ -0,0 +1,396 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/legacy_fence.h"
+
+#include <cstdint>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "base/intrusive_list.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+namespace {
+
+// Inserts the given |fence_signal| into |list| in ascending order.
+void InsertOutstandingFenceSignal(OutstandingFenceSignal* fence_signal,
+ IntrusiveList<OutstandingFenceSignal>* list) {
+ for (auto existing_signal : *list) {
+ if (existing_signal->value > fence_signal->value) {
+ list->insert(existing_signal, fence_signal);
+ return;
+ }
+ }
+ list->push_back(fence_signal);
+}
+
+} // namespace
+
+// static
+StatusOr<ref_ptr<LegacyFencePool>> LegacyFencePool::Create(
+ ref_ptr<VkDeviceHandle> logical_device) {
+ IREE_TRACE_SCOPE0("LegacyFencePool::Create");
+ ref_ptr<LegacyFencePool> fence_pool(
+ new LegacyFencePool(std::move(logical_device)));
+ RETURN_IF_ERROR(fence_pool->PreallocateFences());
+ return fence_pool;
+}
+
+LegacyFencePool::LegacyFencePool(ref_ptr<VkDeviceHandle> logical_device)
+ : logical_device_(std::move(logical_device)) {}
+
+LegacyFencePool::~LegacyFencePool() {
+ IREE_TRACE_SCOPE0("LegacyFencePool::dtor");
+
+ absl::MutexLock lock(&mutex_);
+ for (auto& fence_signal : storage_) {
+ syms()->vkDestroyFence(*logical_device_, fence_signal.fence,
+ logical_device_->allocator());
+ }
+ unused_fences_.clear();
+ unresolved_fences_.clear();
+}
+
+Status LegacyFencePool::PreallocateFences() {
+ IREE_TRACE_SCOPE0("LegacyFencePool::PreallocateFences");
+
+ VkFenceCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = 0;
+
+ absl::MutexLock lock(&mutex_);
+ for (int i = 0; i < kMaxInFlightFenceCount; ++i) {
+ auto* fence_signal = &storage_[i];
+ VK_RETURN_IF_ERROR(syms()->vkCreateFence(*logical_device_, &create_info,
+ logical_device_->allocator(),
+ &fence_signal->fence));
+ unused_fences_.push_back(fence_signal);
+ }
+
+ return OkStatus();
+}
+
+StatusOr<OutstandingFenceSignal*> LegacyFencePool::Acquire() {
+ IREE_TRACE_SCOPE0("LegacyFencePool::Acquire");
+
+ absl::MutexLock lock(&mutex_);
+ if (unused_fences_.empty()) {
+ return ResourceExhaustedErrorBuilder(IREE_LOC)
+ << "Fence pool out of unused fences";
+ }
+
+ auto* fence_signal = unused_fences_.front();
+ unused_fences_.pop_front();
+ return fence_signal;
+}
+
+void LegacyFencePool::ReleaseResolved(
+ IntrusiveList<OutstandingFenceSignal>* fence_signals) {
+ IREE_TRACE_SCOPE0("LegacyFencePool::ReleaseResolved");
+
+ // Get a list of fences we need to reset. Note that not all fences may have
+ // been signaled and we can avoid resetting them.
+ absl::InlinedVector<VkFence, 8> handles;
+ handles.reserve(fence_signals->size());
+ for (auto* fence_signal : *fence_signals) {
+ if (fence_signal->is_pending) {
+ handles.push_back(fence_signal->fence);
+ }
+ }
+ if (!handles.empty()) {
+ syms()->vkResetFences(*logical_device_, handles.size(), handles.data());
+ }
+
+ absl::MutexLock lock(&mutex_);
+ unused_fences_.merge_from(fence_signals);
+}
+
+void LegacyFencePool::ReleaseUnresolved(
+ IntrusiveList<OutstandingFenceSignal>* fence_signals) {
+ IREE_TRACE_SCOPE0("LegacyFencePool::ReleaseUnresolved");
+
+ absl::MutexLock lock(&mutex_);
+ while (!fence_signals->empty()) {
+ auto* fence_signal = fence_signals->front();
+ fence_signals->pop_front();
+ if (fence_signal->is_pending) {
+ // Fence was submitted and may still have a pending signal on it. We can't
+ // reuse it until it has resolved.
+ // TODO(benvanik): fix these fences by reallocating? We aren't leaking
+ // here (technically) but we will exhaust the pool pretty quickly.
+ unresolved_fences_.push_back(fence_signal);
+ } else {
+ // Fence was never actually submitted so we can reuse it no problem.
+ unused_fences_.push_back(fence_signal);
+ }
+ }
+}
+
+// static
+Status LegacyFence::WaitForFences(VkDeviceHandle* logical_device,
+ absl::Span<const FenceValue> fences,
+ bool wait_all, absl::Time deadline) {
+ IREE_TRACE_SCOPE0("LegacyFence::WaitForFences");
+
+ // NOTE: we could pool this state too (probably right on the LegacyFencePool)
+ // or be smarter about using stack-allocated storage. The best idea is to use
+ // real timeline semaphores, though, so not much effort has been spent on
+ // optimizing this.
+ absl::InlinedVector<VkFence, 4> handles;
+ handles.reserve(fences.size());
+
+ // Loop over the fences and wait for any/all to signal. In wait_all mode we
+ // perform the bookkeeping to remove fences that have already been signaled so
+ // that we only wait on ones we need to (and possibly avoid making the vk call
+ // entirely!).
+ while (true) {
+ // Grab handles and acquire fences for all fences not yet at the requested
+ // timeline value.
+ for (const auto& fence_value : fences) {
+ auto* fence = reinterpret_cast<LegacyFence*>(fence_value.first);
+ // NOTE: this will return the sticky fence error if the fence has failed.
+ ASSIGN_OR_RETURN(VkFence handle,
+ fence->AcquireWaitFence(fence_value.second));
+ if (handle != VK_NULL_HANDLE) {
+ // Fence is unresolved and we need to really wait for it.
+ handles.push_back(handle);
+ }
+ }
+ if (handles.empty()) {
+ // All fences resolved.
+ return OkStatus();
+ }
+
+ uint64_t timeout_nanos;
+ if (deadline == absl::InfiniteFuture()) {
+ timeout_nanos = UINT64_MAX;
+ } else if (deadline == absl::InfinitePast()) {
+ timeout_nanos = 0;
+ } else {
+ auto relative_nanos = absl::ToInt64Nanoseconds(deadline - absl::Now());
+ timeout_nanos = relative_nanos < 0 ? 0 : relative_nanos;
+ }
+
+ // Wait on the fences we still need.
+ // Note that waking does not actually indicate all fences were hit! We need
+ // to do another pass above on the next iteration to make sure that we don't
+ // need to wait again on another fence.
+ VK_RETURN_IF_ERROR(logical_device->syms()->vkWaitForFences(
+ *logical_device, handles.size(), handles.data(), wait_all,
+ timeout_nanos));
+ handles.clear();
+ }
+
+ return OkStatus();
+}
+
+LegacyFence::LegacyFence(ref_ptr<LegacyFencePool> fence_pool,
+ uint64_t initial_value)
+ : fence_pool_(std::move(fence_pool)), value_(initial_value) {}
+
+LegacyFence::~LegacyFence() {
+ IREE_TRACE_SCOPE0("LegacyFence::dtor");
+ CHECK_OK(TryResolveOutstandingFences(UINT64_MAX));
+ absl::MutexLock lock(&mutex_);
+ CHECK(outstanding_signals_.empty())
+ << "Destroying a fence without first waiting on outstanding signals";
+}
+
+Status LegacyFence::status() const {
+ if (value_.load() != UINT64_MAX) {
+ return OkStatus();
+ }
+ absl::MutexLock lock(&mutex_);
+ return status_;
+}
+
+StatusOr<uint64_t> LegacyFence::QueryValue() {
+ RETURN_IF_ERROR(TryResolveOutstandingFences(UINT64_MAX));
+ return value_.load();
+}
+
+StatusOr<VkFence> LegacyFence::AcquireSignalFence(uint64_t value) {
+ absl::MutexLock lock(&mutex_);
+
+ // It's an error to signal out of order (as that requires a lot more
+ // tracking and magic to get right).
+ if (value_.load() >= value) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Attempting to signal a timeline fence out of order; value="
+ << value_ << ", new_value=" << value;
+ }
+
+ // Scan to see if there's waiters for this value (or values before it).
+ // We may be able to reuse a previously allocated fence in the case that a
+ // user is waiting prior to actually submitting the signal operation.
+ OutstandingFenceSignal* signal_state = nullptr;
+ for (auto* fence_signal : outstanding_signals_) {
+ if (fence_signal->value == value) {
+ // Fence is going to be signaled at exactly the required value.
+ if (fence_signal->is_pending) {
+ // Already have signaled to this value - that's a paddlin'.
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Duplicate signal of timeline fence for value=" << value;
+ }
+ signal_state = fence_signal;
+ break;
+ }
+ }
+ if (!signal_state) {
+ // Allocate a signal state entry and a VkFence to submit with.
+ // TODO(benvanik): check for RESOURCE_EXHAUSTED and force a flush.
+ ASSIGN_OR_RETURN(signal_state, fence_pool_->Acquire());
+ signal_state->value = value;
+ InsertOutstandingFenceSignal(signal_state, &outstanding_signals_);
+ }
+
+ signal_state->is_pending = true;
+ return signal_state->fence;
+}
+
+StatusOr<VkFence> LegacyFence::AcquireWaitFence(uint64_t value) {
+ // If we've already resolved then we want to avoid doing any kind of wait.
+ // Since the value is monotonically increasing we can do a lock-free peek
+ // here to see if we need to bother taking a full lock.
+ if (value_.load() >= value) {
+ return VK_NULL_HANDLE;
+ }
+
+ absl::MutexLock lock(&mutex_);
+
+ // Try to resolve any outstanding fence signals.
+ RETURN_IF_ERROR(TryResolveOutstandingFencesLocked(value));
+ if (value_.load() >= value) {
+ return VK_NULL_HANDLE;
+ }
+
+ // Try to find an existing fence we can reuse based on the required value.
+ OutstandingFenceSignal* signal_state = nullptr;
+ for (auto* fence_signal : outstanding_signals_) {
+ if (fence_signal->value >= value) {
+ // Fence is going to be signaled at or above the required value.
+ signal_state = fence_signal;
+ break; // |outstanding_signals_| is in sorted order.
+ }
+ }
+ if (!signal_state) {
+ // Allocate a signal state entry and a VkFence that we will need to signal
+ // in the future. We can't yet insert it into the queue but it will go in
+ // when the user tries to signal a value >= the required value.
+ // TODO(benvanik): check for RESOURCE_EXHAUSTED and force a flush.
+ ASSIGN_OR_RETURN(signal_state, fence_pool_->Acquire());
+ signal_state->value = value;
+ InsertOutstandingFenceSignal(signal_state, &outstanding_signals_);
+ }
+
+ return signal_state->fence;
+}
+
+Status LegacyFence::TryResolveOutstandingFences(uint64_t upper_value) {
+ absl::MutexLock lock(&mutex_);
+ return TryResolveOutstandingFencesLocked(upper_value);
+}
+
+Status LegacyFence::TryResolveOutstandingFencesLocked(uint64_t upper_value) {
+ // Fast-path for when we have no outstanding fences.
+ // NOTE: we hold the lock during the entire resolve process so that any waiter
+ // will only be woken once we have resolved to the furthest possible value.
+ if (outstanding_signals_.empty() || value_ > upper_value) {
+ return OkStatus();
+ }
+
+ IREE_TRACE_SCOPE0("LegacyFence::TryResolveOutstandingFences");
+
+ IntrusiveList<OutstandingFenceSignal> resolved_fences;
+ IntrusiveList<OutstandingFenceSignal> unresolved_fences;
+ VkDevice device = *fence_pool_->logical_device();
+ const auto& syms = fence_pool_->syms();
+ bool keep_resolving = true;
+ while (keep_resolving && !outstanding_signals_.empty()) {
+ auto* fence_signal = outstanding_signals_.front();
+ if (fence_signal->value > upper_value) {
+ // Signal is for a value beyond our upper limit - early exit so that we
+ // don't spend time dealing with signals we don't yet care about. This can
+ // prevent live lock where one thread is signaling fences as fast/faster
+ // than another thread can consume them.
+ keep_resolving = false;
+ break;
+ }
+ VkResult fence_status = syms->vkGetFenceStatus(device, fence_signal->fence);
+ switch (fence_status) {
+ case VK_SUCCESS: {
+ // Fence has signaled meaning that we have reached this point in the
+ // timeline and can advance the value.
+ value_.store(fence_signal->value);
+ outstanding_signals_.erase(fence_signal);
+ resolved_fences.push_back(fence_signal);
+
+ // Run backwards and resolve any non-pending fences as they will never
+ // be used.
+ for (auto* it = fence_signal; it != nullptr;) {
+ auto* prev_fence_signal = it;
+ it = outstanding_signals_.previous(it);
+ if (!prev_fence_signal->is_pending) {
+ outstanding_signals_.erase(prev_fence_signal);
+ unresolved_fences.push_back(prev_fence_signal);
+ }
+ }
+ break;
+ }
+ case VK_NOT_READY:
+ if (fence_signal->is_pending) {
+ // Fence has not yet been signaled. We stop here and wait for future
+ // attempts at resolution.
+ keep_resolving = false;
+ }
+ // Fence is not even pending yet - we may have skipped it. Keep
+ // resolving to see if there's a higher value we can use.
+ break;
+ default:
+ // Fence indicates an error (device lost, out of memory, etc).
+ // Propagate this back to our status (and thus any waiters).
+ // Since we only take the first error we find we skip all remaining
+ // fences.
+ status_ = VkResultToStatus(fence_status);
+ value_.store(UINT64_MAX);
+ outstanding_signals_.erase(fence_signal);
+ resolved_fences.push_back(fence_signal);
+ break;
+ }
+ }
+
+ // Release resolved fences back to the pool. Note that we can only do this
+ // to fences we know have actually completed: unresolved fences after an error
+ // may still be in-flight and we don't want to reuse them.
+ fence_pool_->ReleaseResolved(&resolved_fences);
+ fence_pool_->ReleaseUnresolved(&unresolved_fences);
+ if (!status_.ok()) {
+ fence_pool_->ReleaseUnresolved(&outstanding_signals_);
+ }
+
+ return status_;
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/legacy_fence.h b/hal/vulkan/legacy_fence.h
new file mode 100644
index 0000000..1ab4768
--- /dev/null
+++ b/hal/vulkan/legacy_fence.h
@@ -0,0 +1,200 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// TODO(b/140141417): share the pool (and possibly most of the fence impl) with
+// the timeline semaphores fallback.
+
+#ifndef IREE_HAL_VULKAN_LEGACY_FENCE_H_
+#define IREE_HAL_VULKAN_LEGACY_FENCE_H_
+
+#include <vulkan/vulkan.h>
+
+#include <array>
+#include <atomic>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#include "base/intrusive_list.h"
+#include "base/ref_ptr.h"
+#include "base/status.h"
+#include "hal/fence.h"
+#include "hal/vulkan/handle_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// An outstanding legacy fence signal for a particular timeline value.
+// Each signal to a new value gets a new VkFence and these are stored in a
+// LegacyFence to quickly scan and process signaled fences.
+//
+// Must be externally synchronized via the LegacyFence mutex.
+struct OutstandingFenceSignal : public IntrusiveLinkBase<void> {
+ // Allocated fence that is passed to vkQueueSubmit/vkWaitForFences.
+ // Represents a point in the timeline of value.
+ VkFence fence = VK_NULL_HANDLE;
+
+ // Value that the fence payload should be when the fence is signaled.
+ // Note that since fences may resolve out of order we still need to check that
+ // we are only ever advancing the timeline and not just setting this value.
+ uint64_t value = UINT64_MAX;
+
+ // True when the fence has been submitted and is pending on the device.
+ bool is_pending = false;
+};
+
+// A pool of VkFences that can be used by LegacyFence to simulate individual
+// payload value signaling. Note that we prefer a pool instead of a ringbuffer
+// as we want to allow out-of-order completion.
+class LegacyFencePool final : public RefObject<LegacyFencePool> {
+ public:
+ static constexpr int kMaxInFlightFenceCount = 64;
+
+ // Allocates a new fence pool and all fences.
+ static StatusOr<ref_ptr<LegacyFencePool>> Create(
+ ref_ptr<VkDeviceHandle> logical_device);
+
+ ~LegacyFencePool();
+
+ const ref_ptr<VkDeviceHandle>& logical_device() const {
+ return logical_device_;
+ }
+ const ref_ptr<DynamicSymbols>& syms() const {
+ return logical_device_->syms();
+ }
+
+ // Acquires a fence from the pool for use by the caller.
+ // The fence is guaranteed to not be in-flight and will have been reset to an
+ // unsignaled state.
+ //
+ // Returns RESOURCE_EXHAUSTED if the pool has no more available fences.
+ // Callers are expected to handle this by waiting on previous fences or for
+ // complete device idle. Yes, that's as bad as it sounds, and if we start
+ // seeing that we should bump up the max count.
+ StatusOr<OutstandingFenceSignal*> Acquire();
+
+ // Releases one or more fences back to the pool.
+ // The fences must either be signaled or not be in-flight.
+ void ReleaseResolved(IntrusiveList<OutstandingFenceSignal>* fence_signals);
+
+ // Releases one or more unresolved fences back to the pool.
+ // These may be in any state and will be assumed as untouchable.
+ void ReleaseUnresolved(IntrusiveList<OutstandingFenceSignal>* fence_signals);
+
+ private:
+ explicit LegacyFencePool(ref_ptr<VkDeviceHandle> logical_device);
+
+ Status PreallocateFences() ABSL_LOCKS_EXCLUDED(mutex_);
+
+ ref_ptr<VkDeviceHandle> logical_device_;
+
+ absl::Mutex mutex_;
+ std::array<OutstandingFenceSignal, kMaxInFlightFenceCount> storage_
+ ABSL_GUARDED_BY(mutex_);
+ IntrusiveList<OutstandingFenceSignal> unused_fences_ ABSL_GUARDED_BY(mutex_);
+ IntrusiveList<OutstandingFenceSignal> unresolved_fences_
+ ABSL_GUARDED_BY(mutex_);
+};
+
+// A fence implemented using a pool of native VkFences.
+// This is supported unconditionally on all versions of Vulkan. When timeline
+// semaphores are available we prefer using those instead and this is only
+// present as a fallback. We keep this implementation separate so that it can be
+// compiled out when the target is known to have the extension.
+//
+// Simulation of timeline semaphore-based fences is done via a pool of native
+// VkFences that each represent a single signaled value. This means that worst
+// case we are using one fence per submit however that's no different than if
+// we did anything else. Though we can't cancel previously-queued fences when
+// increasing values are signaled we can be clever when querying and releasing
+// by always walking in reverse relying on the monotonically increasing values.
+//
+// Valid usage patterns we need to handle:
+// 1. fence signaled and waited on (common case)
+// 2. fence waited on before beginning signaling
+// 3. fence signaled and never waited on
+//
+// Case 1 is fairly straightforward: we acquire a VkFence, pass that to the
+// queue submit, and then vkWaitForFences/query it for completion.
+//
+// Case 2 requires that we reserve a fence during the wait so that we can pass
+// it to vkWaitForFences and track it such that we can reuse it during a future
+// signal operation. Since we don't know during signaling if the specific value
+// we waited on will ever have its own dedicated signal operation we need to be
+// conservative and try to coalesce for correctness. This means that if a wait
+// for a value of 1 is performed and we get a signal for a value of 2 we need to
+// combine the two. If a signal for a value of 1 is later performed it then
+// becomes a no-op. This could lead to some additional latency however that's a
+// risk (or benefit!) of using timelines. Rule of thumb: don't do out of order
+// signaling.
+//
+// Case 3 is like case 2 where we need to reserve a fence to wait on, however
+// since we don't know if it will ever be signaled we need to take care to
+// properly release the VkFence back to the pool for reuse: we don't want to
+// return it while there are still waiters for its original event. For this
+// reason we track the waiters on a given fence during their wait operation and
+// if a fence is released with waiters active we put them in a special
+// unresolved until the waiters continue on.
+class LegacyFence final : public Fence {
+ public:
+ // Waits for one or more (or all) fences to reach or exceed the given values.
+ static Status WaitForFences(VkDeviceHandle* logical_device,
+ absl::Span<const FenceValue> fences,
+ bool wait_all, absl::Time deadline);
+
+ LegacyFence(ref_ptr<LegacyFencePool> fence_pool, uint64_t initial_value);
+ ~LegacyFence() override;
+
+ Status status() const override;
+
+ StatusOr<uint64_t> QueryValue() override;
+
+ // Acquires a new fence for signaling a specific value.
+ StatusOr<VkFence> AcquireSignalFence(uint64_t value);
+
+ private:
+ // Acquires a new fence for waiting on a specific value.
+ // Returns VK_NULL_HANDLE if the fence already resolved and the sticky error
+ // if the fence is in an error state.
+ StatusOr<VkFence> AcquireWaitFence(uint64_t value);
+
+ // Runs down the outstanding fences list and resolves to the latest signaled
+ // value. Will early exit if the value moves beyond |upper_value|.
+ Status TryResolveOutstandingFences(uint64_t upper_value)
+ ABSL_LOCKS_EXCLUDED(mutex_);
+ Status TryResolveOutstandingFencesLocked(uint64_t upper_value)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ ref_ptr<LegacyFencePool> fence_pool_;
+
+ // The current highest value of the fence as verified during a wait or query.
+ // Kept outside of |mutex_| so that queries do not require a lock.
+ std::atomic<uint64_t> value_;
+
+ mutable absl::Mutex mutex_;
+
+ // Sticky status failure value set on first failure.
+ Status status_ ABSL_GUARDED_BY(mutex_);
+
+ // Outstanding VkFences representing signal values.
+ // Expected to be sorted in ascending order by value.
+ IntrusiveList<OutstandingFenceSignal> outstanding_signals_
+ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_LEGACY_FENCE_H_
diff --git a/hal/vulkan/native_binary_semaphore.cc b/hal/vulkan/native_binary_semaphore.cc
new file mode 100644
index 0000000..bb96897
--- /dev/null
+++ b/hal/vulkan/native_binary_semaphore.cc
@@ -0,0 +1,32 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/native_binary_semaphore.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+NativeBinarySemaphore::NativeBinarySemaphore(
+ ref_ptr<VkDeviceHandle> logical_device, VkSemaphore handle)
+ : logical_device_(std::move(logical_device)), handle_(handle) {}
+
+NativeBinarySemaphore::~NativeBinarySemaphore() {
+ logical_device_->syms()->vkDestroySemaphore(*logical_device_, handle_,
+ logical_device_->allocator());
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/native_binary_semaphore.h b/hal/vulkan/native_binary_semaphore.h
new file mode 100644
index 0000000..19ac293
--- /dev/null
+++ b/hal/vulkan/native_binary_semaphore.h
@@ -0,0 +1,46 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_
+#define IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_
+
+#include <vulkan/vulkan.h>
+
+#include "hal/semaphore.h"
+#include "hal/vulkan/handle_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// A binary semaphore implemented using the native VkSemaphore type.
+// This is supported unconditionally on all versions of Vulkan.
+class NativeBinarySemaphore final : public BinarySemaphore {
+ public:
+ NativeBinarySemaphore(ref_ptr<VkDeviceHandle> logical_device,
+ VkSemaphore handle);
+ ~NativeBinarySemaphore() override;
+
+ VkSemaphore handle() const { return handle_; }
+
+ private:
+ ref_ptr<VkDeviceHandle> logical_device_;
+ VkSemaphore handle_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_
diff --git a/hal/vulkan/native_event.cc b/hal/vulkan/native_event.cc
new file mode 100644
index 0000000..4cdb8c6
--- /dev/null
+++ b/hal/vulkan/native_event.cc
@@ -0,0 +1,31 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/native_event.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+NativeEvent::NativeEvent(ref_ptr<VkDeviceHandle> logical_device, VkEvent handle)
+ : logical_device_(std::move(logical_device)), handle_(handle) {}
+
+NativeEvent::~NativeEvent() {
+ logical_device_->syms()->vkDestroyEvent(*logical_device_, handle_,
+ logical_device_->allocator());
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/native_event.h b/hal/vulkan/native_event.h
new file mode 100644
index 0000000..2cc4c4c
--- /dev/null
+++ b/hal/vulkan/native_event.h
@@ -0,0 +1,44 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_NATIVE_EVENT_H_
+#define IREE_HAL_VULKAN_NATIVE_EVENT_H_
+
+#include <vulkan/vulkan.h>
+
+#include "hal/event.h"
+#include "hal/vulkan/handle_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// An event implemented with the native VkEvent type.
+class NativeEvent final : public Event {
+ public:
+ NativeEvent(ref_ptr<VkDeviceHandle> logical_device, VkEvent handle);
+ ~NativeEvent() override;
+
+ VkEvent handle() const { return handle_; }
+
+ private:
+ ref_ptr<VkDeviceHandle> logical_device_;
+ VkEvent handle_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_NATIVE_EVENT_H_
diff --git a/hal/vulkan/pipeline_cache.cc b/hal/vulkan/pipeline_cache.cc
new file mode 100644
index 0000000..9ed48c3
--- /dev/null
+++ b/hal/vulkan/pipeline_cache.cc
@@ -0,0 +1,235 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/pipeline_cache.h"
+
+#include "absl/synchronization/mutex.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "flatbuffers/flatbuffers.h"
+#include "hal/executable_format.h"
+#include "hal/vulkan/status_util.h"
+#include "schemas/spirv_executable_def_generated.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+PipelineCache::PipelineCache(const ref_ptr<VkDeviceHandle>& logical_device)
+ : logical_device_(add_ref(logical_device)) {}
+
+PipelineCache::~PipelineCache() {
+ IREE_TRACE_SCOPE0("PipelineCache::dtor");
+ ClearLayoutCaches();
+}
+
+bool PipelineCache::CanPrepareFormat(ExecutableFormat format) const {
+ return format == kExecutableFormatSpirV;
+}
+
+StatusOr<ref_ptr<Executable>> PipelineCache::PrepareExecutable(
+ ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) {
+ IREE_TRACE_SCOPE0("PipelineCache::PrepareExecutable");
+ if (!CanPrepareFormat(spec.format)) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unsupported 4CC format: 0x" << std::hex << spec.format;
+ }
+ if (spec.executable_data.size() <= 4 ||
+ !SpirVExecutableDefBufferHasIdentifier(spec.executable_data.data())) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Supplied executable data does not contain a SpirVExecutableDef";
+ }
+
+ // Get the SPIR-V executable def flatbuffer.
+ const auto& spirv_executable_def =
+ *::flatbuffers::GetRoot<SpirVExecutableDef>(spec.executable_data.data());
+
+ // Create (or reuse) a pipeline layout.
+ if (!spirv_executable_def.pipeline_layout()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Missing pipeline layout def";
+ }
+ ASSIGN_OR_RETURN(
+ auto pipeline_layout_entry,
+ LookupOrInsertPipelineLayout(*spirv_executable_def.pipeline_layout()));
+
+ // Create the executable (which may itself own many pipelines).
+ ASSIGN_OR_RETURN(auto executable, PipelineExecutable::Create(
+ logical_device_,
+ /*pipeline_cache=*/VK_NULL_HANDLE,
+ pipeline_layout_entry->pipeline_layout,
+ pipeline_layout_entry->descriptor_sets,
+ mode, spirv_executable_def));
+ return executable;
+}
+
+StatusOr<const PipelineCache::CachedPipelineLayout*>
+PipelineCache::LookupOrInsertPipelineLayout(
+ const VkPipelineLayoutDef& pipeline_layout_def) {
+ IREE_TRACE_SCOPE0("PipelineCache::LookupOrInsertPipelineLayout");
+ absl::MutexLock lock(&mutex_);
+
+ // Build a list of the required descriptor set layouts and push constants.
+ // If we were being fast about this we would just hash the def and directly
+ // look up the pipeline layout.
+ PipelineDescriptorSets descriptor_sets;
+ descriptor_sets.buffer_binding_set = pipeline_layout_def.buffer_binding_set();
+ descriptor_sets.buffer_binding_set_layout = VK_NULL_HANDLE;
+ absl::InlinedVector<VkDescriptorSetLayout, 4> descriptor_set_layouts;
+ if (pipeline_layout_def.descriptor_set_layouts()) {
+ const auto& layout_defs = *pipeline_layout_def.descriptor_set_layouts();
+ descriptor_set_layouts.resize(layout_defs.size());
+ for (int i = 0; i < descriptor_set_layouts.size(); ++i) {
+ if (!layout_defs[i]) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "Missing layout def";
+ }
+ ASSIGN_OR_RETURN(descriptor_set_layouts[i],
+ LookupOrInsertDescriptorSetLayout(*layout_defs[i]));
+ if (i == pipeline_layout_def.buffer_binding_set()) {
+ descriptor_sets.buffer_binding_set_layout = descriptor_set_layouts[i];
+ descriptor_sets.buffer_binding_set_map.resize(
+ layout_defs[i]->bindings()->size());
+ for (int j = 0; j < layout_defs[i]->bindings()->size(); ++j) {
+ descriptor_sets.buffer_binding_set_map[j] =
+ layout_defs[i]->bindings()->Get(j)->binding();
+ }
+ }
+ }
+ }
+
+ absl::InlinedVector<VkPushConstantRange, 1> push_constant_ranges;
+ if (pipeline_layout_def.push_constant_ranges()) {
+ const auto& range_defs = *pipeline_layout_def.push_constant_ranges();
+ push_constant_ranges.resize(range_defs.size());
+ for (int i = 0; i < push_constant_ranges.size(); ++i) {
+ if (!range_defs[i]) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Missing push constant range def";
+ }
+ push_constant_ranges[i].stageFlags = range_defs[i]->stage_flags();
+ push_constant_ranges[i].offset = range_defs[i]->offset();
+ push_constant_ranges[i].size = range_defs[i]->size();
+ }
+ }
+
+ // Scan for an existing pipeline layout that matches the descriptor sets.
+ for (auto& entry : pipeline_layout_cache_) {
+ if (entry.descriptor_set_layouts.size() != descriptor_set_layouts.size() ||
+ entry.push_constant_ranges.size() != push_constant_ranges.size()) {
+ continue;
+ }
+ if (std::memcmp(
+ descriptor_set_layouts.data(), entry.descriptor_set_layouts.data(),
+ descriptor_set_layouts.size() * sizeof(VkDescriptorSetLayout)) ==
+ 0 &&
+ std::memcmp(
+ push_constant_ranges.data(), entry.push_constant_ranges.data(),
+ push_constant_ranges.size() * sizeof(VkPushConstantRange)) == 0) {
+ return &entry;
+ }
+ }
+
+ VkPipelineLayoutCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = 0;
+ create_info.setLayoutCount = descriptor_set_layouts.size();
+ create_info.pSetLayouts = descriptor_set_layouts.data();
+ create_info.pushConstantRangeCount = push_constant_ranges.size();
+ create_info.pPushConstantRanges = push_constant_ranges.data();
+
+ // Create and insert into the cache.
+ VkPipelineLayout pipeline_layout = VK_NULL_HANDLE;
+ VK_RETURN_IF_ERROR(syms()->vkCreatePipelineLayout(
+ *logical_device_, &create_info, logical_device_->allocator(),
+ &pipeline_layout));
+ pipeline_layout_cache_.push_back({std::move(descriptor_set_layouts),
+ std::move(push_constant_ranges),
+ pipeline_layout, descriptor_sets});
+ return &pipeline_layout_cache_.back();
+}
+
+StatusOr<VkDescriptorSetLayout>
+PipelineCache::LookupOrInsertDescriptorSetLayout(
+ const VkDescriptorSetLayoutDef& descriptor_set_layout_def) {
+ // Build a list of bindings in the set.
+ // If we were being fast we would hash the bindings and directly lookup
+ // without doing this allocation.
+ absl::InlinedVector<VkDescriptorSetLayoutBinding, 4> bindings;
+ if (descriptor_set_layout_def.bindings()) {
+ const auto& binding_defs = *descriptor_set_layout_def.bindings();
+ bindings.resize(binding_defs.size());
+ for (int i = 0; i < binding_defs.size(); ++i) {
+ bindings[i].binding = binding_defs[i]->binding();
+ bindings[i].descriptorType =
+ static_cast<VkDescriptorType>(binding_defs[i]->descriptor_type());
+ bindings[i].descriptorCount = binding_defs[i]->descriptor_count();
+ bindings[i].stageFlags = binding_defs[i]->stage_flags();
+ bindings[i].pImmutableSamplers = nullptr;
+ }
+ }
+
+ // Scan for an existing descriptor set layout that matches the bindings.
+ for (auto& entry : descriptor_set_layout_cache_) {
+ if (entry.bindings.size() != bindings.size()) continue;
+ if (std::memcmp(bindings.data(), entry.bindings.data(),
+ bindings.size() * sizeof(VkDescriptorSetLayoutBinding)) ==
+ 0) {
+ return entry.descriptor_set_layout;
+ }
+ }
+
+ VkDescriptorSetLayoutCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = 0;
+ if (logical_device_->enabled_extensions().push_descriptors) {
+ // Note that we can *only* use push descriptor sets if we set this create
+ // flag. That's fine, though, as the command buffer recording logic always
+ // prefers the extension if available.
+ create_info.flags |=
+ VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
+ }
+ create_info.bindingCount = bindings.size();
+ create_info.pBindings = bindings.data();
+
+ // Create and insert into the cache.
+ VkDescriptorSetLayout descriptor_set_layout = VK_NULL_HANDLE;
+ VK_RETURN_IF_ERROR(syms()->vkCreateDescriptorSetLayout(
+ *logical_device_, &create_info, logical_device_->allocator(),
+ &descriptor_set_layout));
+ descriptor_set_layout_cache_.push_back(
+ {std::move(bindings), descriptor_set_layout});
+ return descriptor_set_layout;
+}
+
+void PipelineCache::ClearLayoutCaches() {
+ absl::MutexLock lock(&mutex_);
+ for (auto& entry : pipeline_layout_cache_) {
+ syms()->vkDestroyPipelineLayout(*logical_device_, entry.pipeline_layout,
+ logical_device_->allocator());
+ }
+ pipeline_layout_cache_.clear();
+ for (auto& entry : descriptor_set_layout_cache_) {
+ syms()->vkDestroyDescriptorSetLayout(*logical_device_,
+ entry.descriptor_set_layout,
+ logical_device_->allocator());
+ }
+ descriptor_set_layout_cache_.clear();
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/pipeline_cache.h b/hal/vulkan/pipeline_cache.h
new file mode 100644
index 0000000..5941ab8
--- /dev/null
+++ b/hal/vulkan/pipeline_cache.h
@@ -0,0 +1,85 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_PIPELINE_CACHE_H_
+#define IREE_HAL_VULKAN_PIPELINE_CACHE_H_
+
+#include <vulkan/vulkan.h>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/synchronization/mutex.h"
+#include "hal/executable.h"
+#include "hal/executable_cache.h"
+#include "hal/vulkan/handle_util.h"
+#include "hal/vulkan/pipeline_executable.h"
+#include "schemas/spirv_executable_def_generated.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+class PipelineCache final : public ExecutableCache {
+ public:
+ explicit PipelineCache(const ref_ptr<VkDeviceHandle>& logical_device);
+ ~PipelineCache() override;
+
+ const ref_ptr<DynamicSymbols>& syms() const {
+ return logical_device_->syms();
+ }
+
+ bool CanPrepareFormat(ExecutableFormat format) const override;
+
+ StatusOr<ref_ptr<Executable>> PrepareExecutable(
+ ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override;
+
+ private:
+ struct CachedDescriptorSetLayout {
+ absl::InlinedVector<VkDescriptorSetLayoutBinding, 4> bindings;
+ VkDescriptorSetLayout descriptor_set_layout;
+ };
+ struct CachedPipelineLayout {
+ absl::InlinedVector<VkDescriptorSetLayout, 4> descriptor_set_layouts;
+ absl::InlinedVector<VkPushConstantRange, 1> push_constant_ranges;
+ VkPipelineLayout pipeline_layout;
+ PipelineDescriptorSets descriptor_sets;
+ };
+
+ StatusOr<const CachedPipelineLayout*> LookupOrInsertPipelineLayout(
+ const VkPipelineLayoutDef& pipeline_layout_def)
+ ABSL_LOCKS_EXCLUDED(mutex_);
+ StatusOr<VkDescriptorSetLayout> LookupOrInsertDescriptorSetLayout(
+ const VkDescriptorSetLayoutDef& descriptor_set_layout_def)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ void ClearLayoutCaches() ABSL_LOCKS_EXCLUDED(mutex_);
+
+ ref_ptr<VkDeviceHandle> logical_device_;
+
+ // A "cache" of descriptor set and pipeline layouts for various values.
+ // We never evict and just do a simple linear scan on lookup. This is fine for
+ // now as we only support a single descriptor type and really we only need to
+ // check for binding count. As we go toward more general usage of descriptors
+ // (images/etc) we will likely want to change this to a real cache.
+ absl::Mutex mutex_;
+ std::vector<CachedDescriptorSetLayout> descriptor_set_layout_cache_
+ ABSL_GUARDED_BY(mutex_);
+ std::vector<CachedPipelineLayout> pipeline_layout_cache_
+ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_PIPELINE_CACHE_H_
diff --git a/hal/vulkan/pipeline_executable.cc b/hal/vulkan/pipeline_executable.cc
new file mode 100644
index 0000000..f4f7a61
--- /dev/null
+++ b/hal/vulkan/pipeline_executable.cc
@@ -0,0 +1,191 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/pipeline_executable.h"
+
+#include "absl/container/inlined_vector.h"
+#include "base/memory.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+namespace {
+
+// Generates the baked specialization constant data based on the flatbuffer.
+// We only support uint32_t right now so this is easy.
+// Note that the returned vectors are referenced by pointers in |out_info| and
+// must remain valid until the info is no longer in use.
+std::pair<std::vector<VkSpecializationMapEntry>, std::vector<uint8_t>>
+PopulateSpecializationInfo(const VkSpecializationInfoDef* info_def) {
+ int entry_count =
+ info_def && info_def->map_entries() ? info_def->map_entries()->size() : 0;
+ if (!entry_count) {
+ return {};
+ }
+
+ std::vector<VkSpecializationMapEntry> entries;
+ entries.reserve(entry_count);
+ std::vector<uint8_t> data;
+ data.resize(entry_count * sizeof(uint32_t));
+
+ uint32_t offset = 0;
+ for (const auto* entry_def : *info_def->map_entries()) {
+ if (!entry_def) continue;
+ entries.push_back({});
+ auto& entry = entries.back();
+ entry.constantID = entry_def->constant_id();
+ entry.offset = offset;
+ entry.size = sizeof(uint32_t);
+ uint32_t value = entry_def->uint32_value();
+ std::memcpy(data.data() + offset, &value, sizeof(value));
+ offset += entry.size;
+ }
+
+ return {std::move(entries), std::move(data)};
+}
+
+} // namespace
+
+// static
+StatusOr<ref_ptr<PipelineExecutable>> PipelineExecutable::Create(
+ const ref_ptr<VkDeviceHandle>& logical_device,
+ VkPipelineCache pipeline_cache, VkPipelineLayout pipeline_layout,
+ PipelineDescriptorSets descriptor_sets, ExecutableCachingModeBitfield mode,
+ const SpirVExecutableDef& spirv_executable_def) {
+ IREE_TRACE_SCOPE0("PipelineExecutable::Create");
+ const auto& syms = logical_device->syms();
+ if (!spirv_executable_def.entry_points() ||
+ spirv_executable_def.entry_points()->size() == 0) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined";
+ }
+ if (!spirv_executable_def.code()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "No SPIR-V code present";
+ }
+ const auto& code = *spirv_executable_def.code();
+
+ // Create the shader module.
+ VkShaderModuleCreateInfo shader_module_create_info;
+ shader_module_create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
+ shader_module_create_info.pNext = nullptr;
+ shader_module_create_info.flags = 0;
+ shader_module_create_info.codeSize = code.size() * sizeof(uint32_t);
+ shader_module_create_info.pCode = code.data();
+ VkShaderModule shader_module = VK_NULL_HANDLE;
+ VK_RETURN_IF_ERROR(
+ syms->vkCreateShaderModule(*logical_device, &shader_module_create_info,
+ logical_device->allocator(), &shader_module));
+
+ // We only need to keep this around during pipeline creation so ensure we
+ // always clean it up when we exit this function.
+ auto shader_module_cleanup = MakeCleanup([&logical_device, shader_module]() {
+ logical_device->syms()->vkDestroyShaderModule(
+ *logical_device, shader_module, logical_device->allocator());
+ });
+
+ // Specialization info is currently constant against all entry points.
+ std::vector<VkSpecializationMapEntry> spec_entries;
+ std::vector<uint8_t> spec_data;
+ std::tie(spec_entries, spec_data) =
+ PopulateSpecializationInfo(spirv_executable_def.specialization_info());
+ VkSpecializationInfo specialization_info;
+ specialization_info.mapEntryCount = spec_entries.size();
+ specialization_info.pMapEntries = spec_entries.data();
+ specialization_info.dataSize = spec_data.size();
+ specialization_info.pData = spec_data.data();
+
+ // Create pipelines for each entry point.
+ const auto& entry_points = *spirv_executable_def.entry_points();
+ absl::InlinedVector<VkComputePipelineCreateInfo, 1> pipeline_create_infos;
+ pipeline_create_infos.resize(entry_points.size());
+ for (int entry_ordinal = 0; entry_ordinal < entry_points.size();
+ ++entry_ordinal) {
+ auto& pipeline_create_info = pipeline_create_infos[entry_ordinal];
+ pipeline_create_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
+ pipeline_create_info.pNext = nullptr;
+ pipeline_create_info.flags = 0;
+ if (!AllBitsSet(mode, ExecutableCachingMode::kAllowOptimization)) {
+ pipeline_create_info.flags |= VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT;
+ }
+ if (entry_ordinal == 0) {
+ pipeline_create_info.flags |= VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT;
+ } else {
+ pipeline_create_info.flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT;
+ }
+ pipeline_create_info.layout = pipeline_layout;
+ pipeline_create_info.basePipelineHandle = VK_NULL_HANDLE;
+ pipeline_create_info.basePipelineIndex = 0;
+ auto& stage_create_info = pipeline_create_info.stage;
+ stage_create_info.sType =
+ VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
+ stage_create_info.pNext = nullptr;
+ stage_create_info.flags = 0;
+ stage_create_info.stage = VK_SHADER_STAGE_COMPUTE_BIT;
+ stage_create_info.module = shader_module;
+ stage_create_info.pName = entry_points[entry_ordinal]->c_str();
+ stage_create_info.pSpecializationInfo = &specialization_info;
+ }
+ absl::InlinedVector<VkPipeline, 1> pipelines;
+ pipelines.resize(entry_points.size());
+
+ // Some ICDs appear to leak in here, out of our control.
+ // Warning: leak checks remain disabled if an error is returned.
+ IREE_DISABLE_LEAK_CHECKS();
+ VK_RETURN_IF_ERROR(syms->vkCreateComputePipelines(
+ *logical_device, pipeline_cache, pipeline_create_infos.size(),
+ pipeline_create_infos.data(), logical_device->allocator(),
+ pipelines.data()));
+ IREE_ENABLE_LEAK_CHECKS();
+
+ auto executable =
+ make_ref<PipelineExecutable>(CtorKey{}, logical_device, pipeline_layout,
+ descriptor_sets, std::move(pipelines));
+ executable->tag_ =
+ spirv_executable_def.tag() ? spirv_executable_def.tag()->str() : "";
+ return executable;
+}
+
+PipelineExecutable::PipelineExecutable(
+ CtorKey ctor_key, const ref_ptr<VkDeviceHandle>& logical_device,
+ VkPipelineLayout pipeline_layout, PipelineDescriptorSets descriptor_sets,
+ absl::InlinedVector<VkPipeline, 1> pipelines)
+ : logical_device_(add_ref(logical_device)),
+ pipeline_layout_(pipeline_layout),
+ descriptor_sets_(descriptor_sets),
+ pipelines_(std::move(pipelines)) {}
+
+PipelineExecutable::~PipelineExecutable() {
+ IREE_TRACE_SCOPE0("PipelineExecutable::dtor");
+ for (auto pipeline : pipelines_) {
+ syms()->vkDestroyPipeline(*logical_device_, pipeline,
+ logical_device_->allocator());
+ }
+ pipelines_.clear();
+}
+
+StatusOr<VkPipeline> PipelineExecutable::GetPipelineForEntryPoint(
+ int entry_ordinal) const {
+ if (entry_ordinal < 0 || entry_ordinal >= pipelines_.size()) {
+ return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal";
+ }
+ return pipelines_[entry_ordinal];
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/pipeline_executable.h b/hal/vulkan/pipeline_executable.h
new file mode 100644
index 0000000..51fe650
--- /dev/null
+++ b/hal/vulkan/pipeline_executable.h
@@ -0,0 +1,91 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
+#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
+
+#include <vulkan/vulkan.h>
+
+#include <vector>
+
+#include "absl/container/inlined_vector.h"
+#include "base/status.h"
+#include "hal/executable.h"
+#include "hal/executable_cache.h"
+#include "hal/executable_spec.h"
+#include "hal/vulkan/handle_util.h"
+#include "schemas/spirv_executable_def_generated.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+struct PipelineDescriptorSets {
+ uint32_t buffer_binding_set;
+ VkDescriptorSetLayout buffer_binding_set_layout;
+ absl::InlinedVector<uint32_t, 8> buffer_binding_set_map;
+};
+
+class PipelineExecutable final : public Executable {
+ public:
+ static StatusOr<ref_ptr<PipelineExecutable>> Create(
+ const ref_ptr<VkDeviceHandle>& logical_device,
+ VkPipelineCache pipeline_cache, VkPipelineLayout pipeline_layout,
+ PipelineDescriptorSets descriptor_sets,
+ ExecutableCachingModeBitfield mode,
+ const SpirVExecutableDef& spirv_executable_def);
+
+ // Private constructor.
+ struct CtorKey {
+ private:
+ friend class PipelineExecutable;
+ CtorKey() = default;
+ };
+ PipelineExecutable(CtorKey ctor_key,
+ const ref_ptr<VkDeviceHandle>& logical_device,
+ VkPipelineLayout pipeline_layout,
+ PipelineDescriptorSets descriptor_sets,
+ absl::InlinedVector<VkPipeline, 1> pipelines);
+ ~PipelineExecutable() override;
+
+ const ref_ptr<DynamicSymbols>& syms() const {
+ return logical_device_->syms();
+ }
+
+ bool supports_debugging() const override { return false; }
+
+ VkPipelineLayout pipeline_layout() const { return pipeline_layout_; }
+ const PipelineDescriptorSets& descriptor_sets() const {
+ return descriptor_sets_;
+ }
+
+ bool is_matmul() const { return tag_ == "__matmul__"; }
+
+ StatusOr<VkPipeline> GetPipelineForEntryPoint(int entry_ordinal) const;
+
+ private:
+ ref_ptr<VkDeviceHandle> logical_device_;
+ VkPipelineLayout pipeline_layout_;
+ PipelineDescriptorSets descriptor_sets_;
+ std::string tag_;
+
+ // One pipeline per entry point.
+ absl::InlinedVector<VkPipeline, 1> pipelines_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
diff --git a/hal/vulkan/status_util.cc b/hal/vulkan/status_util.cc
new file mode 100644
index 0000000..b9ed027
--- /dev/null
+++ b/hal/vulkan/status_util.cc
@@ -0,0 +1,231 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/status_util.h"
+
+#include "base/status.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+Status VkResultToStatus(VkResult result) {
+ switch (result) {
+ // Success codes.
+ case VK_SUCCESS:
+ // Command successfully completed.
+ return OkStatus();
+ case VK_NOT_READY:
+ // A fence or query has not yet completed.
+ return OkStatus();
+ case VK_TIMEOUT:
+ // A wait operation has not completed in the specified time.
+ return OkStatus();
+ case VK_EVENT_SET:
+ // An event is signaled.
+ return OkStatus();
+ case VK_EVENT_RESET:
+ // An event is unsignaled.
+ return OkStatus();
+ case VK_INCOMPLETE:
+ // A return array was too small for the result.
+ return OkStatus();
+ case VK_SUBOPTIMAL_KHR:
+ // A swapchain no longer matches the surface properties exactly, but can
+ // still be used to present to the surface successfully.
+ return OkStatus();
+
+ // Error codes.
+ case VK_ERROR_OUT_OF_HOST_MEMORY:
+ // A host memory allocation has failed.
+ return ResourceExhaustedError("VK_ERROR_OUT_OF_HOST_MEMORY");
+ case VK_ERROR_OUT_OF_DEVICE_MEMORY:
+ // A device memory allocation has failed.
+ return ResourceExhaustedError("VK_ERROR_OUT_OF_DEVICE_MEMORY");
+ case VK_ERROR_INITIALIZATION_FAILED:
+ // Initialization of an object could not be completed for
+ // implementation-specific reasons.
+ return InternalError("VK_ERROR_INITIALIZATION_FAILED");
+ case VK_ERROR_DEVICE_LOST:
+ // The logical or physical device has been lost.
+ //
+ // A logical device may become lost for a number of
+ // implementation-specific reasons, indicating that pending and future
+ // command execution may fail and cause resources and backing memory to
+ // become undefined.
+ //
+ // Typical reasons for device loss will include things like execution
+ // timing out (to prevent denial of service), power management events,
+ // platform resource management, or implementation errors.
+ //
+ // When this happens, certain commands will return
+ // VK_ERROR_DEVICE_LOST (see Error Codes for a list of such
+ // commands). After any such event, the logical device is considered lost.
+ // It is not possible to reset the logical device to a non-lost state,
+ // however the lost state is specific to a logical device (VkDevice), and
+ // the corresponding physical device (VkPhysicalDevice) may be otherwise
+ // unaffected.
+ //
+ // In some cases, the physical device may also be lost, and attempting to
+ // create a new logical device will fail, returning VK_ERROR_DEVICE_LOST.
+ // This is usually indicative of a problem with the underlying
+ // implementation, or its connection to the host. If the physical device
+ // has not been lost, and a new logical device is successfully created
+ // from that physical device, it must be in the non-lost state.
+ //
+ // Whilst logical device loss may be recoverable, in the case of physical
+ // device loss, it is unlikely that an application will be able to recover
+ // unless additional, unaffected physical devices exist on the system. The
+ // error is largely informational and intended only to inform the user
+ // that a platform issue has occurred, and should be investigated further.
+ // For example, underlying hardware may have developed a fault or become
+ // physically disconnected from the rest of the system. In many cases,
+ // physical device loss may cause other more serious issues such as the
+ // operating system crashing; in which case it may not be reported via the
+ // Vulkan API.
+ //
+ // Undefined behavior caused by an application error may cause a device to
+ // become lost. However, such undefined behavior may also cause
+ // unrecoverable damage to the process, and it is then not guaranteed that
+ // the API objects, including the VkPhysicalDevice or the VkInstance are
+ // still valid or that the error is recoverable.
+ //
+ // When a device is lost, its child objects are not implicitly destroyed
+ // and their handles are still valid. Those objects must still be
+ // destroyed before their parents or the device can be destroyed (see the
+ // Object Lifetime section). The host address space corresponding to
+ // device memory mapped using vkMapMemory is still valid, and host memory
+ // accesses to these mapped regions are still valid, but the contents are
+ // undefined. It is still legal to call any API command on the device and
+ // child objects.
+ //
+ // Once a device is lost, command execution may fail, and commands that
+ // return a VkResult may return VK_ERROR_DEVICE_LOST.
+ // Commands that do not allow run-time errors must still operate correctly
+ // for valid usage and, if applicable, return valid data.
+ //
+ // Commands that wait indefinitely for device execution (namely
+ // vkDeviceWaitIdle, vkQueueWaitIdle, vkWaitForFences with a maximum
+ // timeout, and vkGetQueryPoolResults with the VK_QUERY_RESULT_WAIT_BIT
+ // bit set in flags) must return in finite time even in the case
+ // of a lost device, and return either VK_SUCCESS or
+ // VK_ERROR_DEVICE_LOST. For any command that may return
+ // VK_ERROR_DEVICE_LOST, for the purpose of determining whether a
+ // command buffer is in the pending state, or whether resources are
+ // considered in-use by the device, a return value of
+ // VK_ERROR_DEVICE_LOST is equivalent to VK_SUCCESS.
+ return InternalError("VK_ERROR_DEVICE_LOST");
+ case VK_ERROR_MEMORY_MAP_FAILED:
+ // Mapping of a memory object has failed.
+ return InternalError("VK_ERROR_MEMORY_MAP_FAILED");
+ case VK_ERROR_LAYER_NOT_PRESENT:
+ // A requested layer is not present or could not be loaded.
+ return UnimplementedError("VK_ERROR_LAYER_NOT_PRESENT");
+ case VK_ERROR_EXTENSION_NOT_PRESENT:
+ // A requested extension is not supported.
+ return UnimplementedError("VK_ERROR_EXTENSION_NOT_PRESENT");
+ case VK_ERROR_FEATURE_NOT_PRESENT:
+ // A requested feature is not supported.
+ return UnimplementedError("VK_ERROR_FEATURE_NOT_PRESENT");
+ case VK_ERROR_INCOMPATIBLE_DRIVER:
+ // The requested version of Vulkan is not supported by the driver or is
+ // otherwise incompatible for implementation-specific reasons.
+ return FailedPreconditionError("VK_ERROR_INCOMPATIBLE_DRIVER");
+ case VK_ERROR_TOO_MANY_OBJECTS:
+ // Too many objects of the type have already been created.
+ return ResourceExhaustedError("VK_ERROR_TOO_MANY_OBJECTS");
+ case VK_ERROR_FORMAT_NOT_SUPPORTED:
+ // A requested format is not supported on this device.
+ return UnimplementedError("VK_ERROR_FORMAT_NOT_SUPPORTED");
+ case VK_ERROR_FRAGMENTED_POOL:
+ // A pool allocation has failed due to fragmentation of the pool’s memory.
+ // This must only be returned if no attempt to allocate host or device
+ // memory was made to accommodate the new allocation.
+ return ResourceExhaustedError("VK_ERROR_FRAGMENTED_POOL");
+ case VK_ERROR_OUT_OF_POOL_MEMORY:
+ // A pool memory allocation has failed. This must only be returned if no
+ // attempt to allocate host or device memory was made to accommodate the
+ // new allocation. If the failure was definitely due to fragmentation of
+ // the pool, VK_ERROR_FRAGMENTED_POOL should be returned instead.
+ return ResourceExhaustedError("VK_ERROR_OUT_OF_POOL_MEMORY");
+ case VK_ERROR_INVALID_EXTERNAL_HANDLE:
+ // An external handle is not a valid handle of the specified type.
+ return InvalidArgumentError("VK_ERROR_INVALID_EXTERNAL_HANDLE");
+ case VK_ERROR_SURFACE_LOST_KHR:
+ // A surface is no longer available.
+ return UnavailableError("VK_ERROR_SURFACE_LOST_KHR");
+ case VK_ERROR_NATIVE_WINDOW_IN_USE_KHR:
+ // The requested window is already in use by Vulkan or another API in a
+ // manner which prevents it from being used again.
+ return InvalidArgumentError("VK_ERROR_NATIVE_WINDOW_IN_USE_KHR");
+ case VK_ERROR_OUT_OF_DATE_KHR:
+ // A surface has changed in such a way that it is no longer compatible
+ // with the swapchain, and further presentation requests using the
+ // swapchain will fail. Applications must query the new surface properties
+ // and recreate their swapchain if they wish to continue presenting to the
+ // surface.
+ return FailedPreconditionError("VK_ERROR_OUT_OF_DATE_KHR");
+ case VK_ERROR_INCOMPATIBLE_DISPLAY_KHR:
+ // The display used by a swapchain does not use the same presentable image
+ // layout, or is incompatible in a way that prevents sharing an image.
+ return InvalidArgumentError("VK_ERROR_INCOMPATIBLE_DISPLAY_KHR");
+ case VK_ERROR_VALIDATION_FAILED_EXT:
+ // Validation layer testing failed. It is not expected that an
+ // application would see this this error code during normal use of the
+ // validation layers.
+ return InvalidArgumentError("VK_ERROR_VALIDATION_FAILED_EXT");
+ case VK_ERROR_INVALID_SHADER_NV:
+ // One or more shaders failed to compile or link. More details are
+ // reported back to the application when the validation layer is enabled
+ // using the extension VK_EXT_debug_report.
+ return InvalidArgumentError("VK_ERROR_INVALID_SHADER_NV");
+ case VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT:
+ // When creating an image with
+ // VkImageDrmFormatModifierExplicitCreateInfoEXT, it is the application’s
+ // responsibility to satisfy all Valid Usage requirements. However, the
+ // implementation must validate that the provided pPlaneLayouts, when
+ // combined with the provided drmFormatModifier and other creation
+ // parameters in VkImageCreateInfo and its pNext chain, produce a valid
+ // image. (This validation is necessarily implementation-dependent and
+ // outside the scope of Vulkan, and therefore not described by Valid Usage
+ // requirements). If this validation fails, then vkCreateImage returns
+ // VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT.
+ return InvalidArgumentError(
+ "VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT");
+ case VK_ERROR_FRAGMENTATION_EXT:
+ // A descriptor pool creation has failed due to fragmentation.
+ return ResourceExhaustedError("VK_ERROR_FRAGMENTATION_EXT");
+ case VK_ERROR_NOT_PERMITTED_EXT:
+ // When creating a queue, the caller does not have sufficient privileges
+ // to request to acquire a priority above the default priority
+ // (VK_QUEUE_GLOBAL_PRIORITY_MEDIUM_EXT).
+ return PermissionDeniedError("VK_ERROR_NOT_PERMITTED_EXT");
+ case VK_ERROR_INVALID_DEVICE_ADDRESS_EXT:
+ // A buffer creation failed because the requested address is not
+ // available.
+ return OutOfRangeError("VK_ERROR_INVALID_DEVICE_ADDRESS_EXT");
+ case VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT:
+ // An operation on a swapchain created with
+ // VK_FULL_SCREEN_EXCLUSIVE_APPLICATION_CONTROLLED_EXT failed as it did
+ // not have exlusive full-screen access. This may occur due to
+ // implementation-dependent reasons, outside of the application’s control.
+ return UnavailableError("VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT");
+ default:
+ return UnknownError(std::to_string(result));
+ }
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/status_util.h b/hal/vulkan/status_util.h
new file mode 100644
index 0000000..8ff07f5
--- /dev/null
+++ b/hal/vulkan/status_util.h
@@ -0,0 +1,87 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_STATUS_UTIL_H_
+#define IREE_HAL_VULKAN_STATUS_UTIL_H_
+
+#include <vulkan/vulkan.h>
+
+#include "base/status.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// RETURN_IF_ERROR but implicitly converts the VkResult return value to
+// a Status.
+//
+// Usage:
+// VK_RETURN_IF_ERROR(vkDoThing(...));
+#define VK_RETURN_IF_ERROR(expr) \
+ RETURN_IF_ERROR(::iree::hal::vulkan::VkResultToStatus(expr))
+
+// CHECK_OK but implicitly converts the VkResults return value to a
+// Status and checks that it is OkStatus.
+//
+// Usage:
+// VK_CHECK_OK(vkDoThing(...));
+#define VK_CHECK_OK(expr) CHECK_OK(::iree::hal::vulkan::VkResultToStatus(expr))
+
+// Converts a VkResult to a Status object.
+//
+// Vulkan considers the following as "success codes" and users should ensure
+// they first check the result prior to converting:
+//
+// - VK_SUCCESS -> OkStatus()
+// - VK_NOT_READY -> OkStatus()
+// - VK_TIMEOUT -> OkStatus()
+// - VK_EVENT_SET -> OkStatus()
+// - VK_EVENT_RESET -> OkStatus()
+// - VK_INCOMPLETE -> OkStatus()
+// - VK_SUBOPTIMAL_KHR -> OkStatus()
+//
+// The rest are considered as "error codes":
+//
+// - VK_ERROR_OUT_OF_HOST_MEMORY -> ResourceExhaustedError("VK...")
+// - VK_ERROR_OUT_OF_DEVICE_MEMORY -> ResourceExhaustedError("VK...")
+// - VK_ERROR_INITIALIZATION_FAILED -> InternalError("VK...")
+// - VK_ERROR_DEVICE_LOST -> InternalError("VK...")
+// - VK_ERROR_MEMORY_MAP_FAILED -> InternalError("VK...")
+// - VK_ERROR_LAYER_NOT_PRESENT -> NotFoundError("VK...")
+// - VK_ERROR_EXTENSION_NOT_PRESENT -> NotFoundError("VK...")
+// - VK_ERROR_FEATURE_NOT_PRESENT -> NotFoundError("VK...")
+// - VK_ERROR_INCOMPATIBLE_DRIVER -> FailedPreconditionError("VK...")
+// - VK_ERROR_TOO_MANY_OBJECTS -> ResourceExhaustedError("VK...")
+// - VK_ERROR_FORMAT_NOT_SUPPORTED -> UnimplementedError("VK...")
+// - VK_ERROR_FRAGMENTED_POOL -> ResourceExhaustedError("VK...")
+// - VK_ERROR_OUT_OF_POOL_MEMORY -> ResourceExhaustedError("VK...")
+// - VK_ERROR_INVALID_EXTERNAL_HANDLE -> InvalidArgumentError("VK...")
+// - VK_ERROR_SURFACE_LOST_KHR -> InternalError("VK...")
+// - VK_ERROR_NATIVE_WINDOW_IN_USE_KHR -> InternalError("VK...")
+// - VK_ERROR_OUT_OF_DATE_KHR -> InternalError("VK...")
+// - VK_ERROR_INCOMPATIBLE_DISPLAY_KHR -> InternalError("VK...")
+// - VK_ERROR_VALIDATION_FAILED_EXT -> InternalError("VK...")
+// - VK_ERROR_INVALID_SHADER_NV -> InternalError("VK...")
+// - VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT -> InternalError
+// - VK_ERROR_FRAGMENTATION_EXT -> ResourceExhaustedError("VK...")
+// - VK_ERROR_NOT_PERMITTED_EXT -> PermissionDeniedError("VK...")
+// - VK_ERROR_INVALID_DEVICE_ADDRESS_EXT -> OutOfRangeError("VK...")
+// - VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT -> InternalError("VK...")
+Status VkResultToStatus(VkResult result);
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_STATUS_UTIL_H_
diff --git a/hal/vulkan/vma_allocator.cc b/hal/vulkan/vma_allocator.cc
new file mode 100644
index 0000000..a40dc00
--- /dev/null
+++ b/hal/vulkan/vma_allocator.cc
@@ -0,0 +1,248 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/vma_allocator.h"
+
+#include "absl/flags/flag.h"
+#include "absl/memory/memory.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/buffer.h"
+#include "hal/vulkan/status_util.h"
+#include "hal/vulkan/vma_buffer.h"
+
+#if VMA_RECORDING_ENABLED
+ABSL_FLAG(std::string, vma_recording_file, "",
+ "File path to write a CSV containing the VMA recording.");
+ABSL_FLAG(bool, vma_recording_flush_after_call, false,
+ "Flush the VMA recording file after every call (useful if "
+ "crashing/not exiting cleanly).");
+#endif // VMA_RECORDING_ENABLED
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// static
+StatusOr<std::unique_ptr<VmaAllocator>> VmaAllocator::Create(
+ VkPhysicalDevice physical_device,
+ const ref_ptr<VkDeviceHandle>& logical_device) {
+ IREE_TRACE_SCOPE0("VmaAllocator::Create");
+
+ const auto& syms = logical_device->syms();
+ VmaVulkanFunctions vulkan_fns;
+ vulkan_fns.vkGetPhysicalDeviceProperties =
+ syms->vkGetPhysicalDeviceProperties;
+ vulkan_fns.vkGetPhysicalDeviceMemoryProperties =
+ syms->vkGetPhysicalDeviceMemoryProperties;
+ vulkan_fns.vkAllocateMemory = syms->vkAllocateMemory;
+ vulkan_fns.vkFreeMemory = syms->vkFreeMemory;
+ vulkan_fns.vkMapMemory = syms->vkMapMemory;
+ vulkan_fns.vkUnmapMemory = syms->vkUnmapMemory;
+ vulkan_fns.vkFlushMappedMemoryRanges = syms->vkFlushMappedMemoryRanges;
+ vulkan_fns.vkInvalidateMappedMemoryRanges =
+ syms->vkInvalidateMappedMemoryRanges;
+ vulkan_fns.vkBindBufferMemory = syms->vkBindBufferMemory;
+ vulkan_fns.vkBindImageMemory = syms->vkBindImageMemory;
+ vulkan_fns.vkGetBufferMemoryRequirements =
+ syms->vkGetBufferMemoryRequirements;
+ vulkan_fns.vkGetImageMemoryRequirements = syms->vkGetImageMemoryRequirements;
+ vulkan_fns.vkCreateBuffer = syms->vkCreateBuffer;
+ vulkan_fns.vkDestroyBuffer = syms->vkDestroyBuffer;
+ vulkan_fns.vkCreateImage = syms->vkCreateImage;
+ vulkan_fns.vkDestroyImage = syms->vkDestroyImage;
+ vulkan_fns.vkCmdCopyBuffer = syms->vkCmdCopyBuffer;
+
+ VmaRecordSettings record_settings;
+#if VMA_RECORDING_ENABLED
+ record_settings.flags = absl::GetFlag(FLAGS_vma_recording_flush_after_call)
+ ? VMA_RECORD_FLUSH_AFTER_CALL_BIT
+ : 0;
+ record_settings.pFilePath = absl::GetFlag(FLAGS_vma_recording_file).c_str();
+#else
+ record_settings.flags = 0;
+ record_settings.pFilePath = nullptr;
+#endif // VMA_RECORDING_ENABLED
+
+ VmaAllocatorCreateInfo create_info;
+ create_info.flags = 0;
+ create_info.physicalDevice = physical_device;
+ create_info.device = *logical_device;
+ create_info.preferredLargeHeapBlockSize = 64 * 1024 * 1024;
+ create_info.pAllocationCallbacks = logical_device->allocator();
+ create_info.pDeviceMemoryCallbacks = nullptr;
+ create_info.frameInUseCount = 0;
+ create_info.pHeapSizeLimit = nullptr;
+ create_info.pVulkanFunctions = &vulkan_fns;
+ create_info.pRecordSettings = &record_settings;
+ ::VmaAllocator vma = VK_NULL_HANDLE;
+ VK_RETURN_IF_ERROR(vmaCreateAllocator(&create_info, &vma));
+
+ auto allocator =
+ absl::WrapUnique(new VmaAllocator(physical_device, logical_device, vma));
+ // TODO(benvanik): query memory properties/types.
+ return allocator;
+}
+
+VmaAllocator::VmaAllocator(VkPhysicalDevice physical_device,
+ const ref_ptr<VkDeviceHandle>& logical_device,
+ ::VmaAllocator vma)
+ : physical_device_(physical_device),
+ logical_device_(add_ref(logical_device)),
+ vma_(vma) {}
+
+VmaAllocator::~VmaAllocator() {
+ IREE_TRACE_SCOPE0("VmaAllocator::dtor");
+ vmaDestroyAllocator(vma_);
+}
+
+bool VmaAllocator::CanUseBufferLike(Allocator* source_allocator,
+ MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ BufferUsageBitfield intended_usage) const {
+ // TODO(benvanik): ensure there is a memory type that can satisfy the request.
+ return source_allocator == this;
+}
+
+bool VmaAllocator::CanAllocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) const {
+ // TODO(benvnik): ensure there is a memory type that can satisfy the request.
+ return true;
+}
+
+Status VmaAllocator::MakeCompatible(MemoryTypeBitfield* memory_type,
+ BufferUsageBitfield* buffer_usage) const {
+ // TODO(benvanik): mutate to match supported memory types.
+ return OkStatus();
+}
+
+StatusOr<ref_ptr<VmaBuffer>> VmaAllocator::AllocateInternal(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ MemoryAccessBitfield allowed_access, size_t allocation_size,
+ VmaAllocationCreateFlags flags) {
+ IREE_TRACE_SCOPE0("VmaAllocator::AllocateInternal");
+
+ VkBufferCreateInfo buffer_create_info;
+ buffer_create_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+ buffer_create_info.pNext = nullptr;
+ buffer_create_info.flags = 0;
+ buffer_create_info.size = allocation_size;
+ buffer_create_info.usage = 0;
+ if (AllBitsSet(buffer_usage, BufferUsage::kTransfer)) {
+ buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
+ buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_DST_BIT;
+ }
+ if (AllBitsSet(buffer_usage, BufferUsage::kDispatch)) {
+ buffer_create_info.usage |= VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
+ buffer_create_info.usage |= VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
+ buffer_create_info.usage |= VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT;
+ }
+ buffer_create_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+ buffer_create_info.queueFamilyIndexCount = 0;
+ buffer_create_info.pQueueFamilyIndices = nullptr;
+
+ VmaAllocationCreateInfo allocation_create_info;
+ allocation_create_info.flags = flags;
+ allocation_create_info.usage = VMA_MEMORY_USAGE_UNKNOWN;
+ allocation_create_info.requiredFlags = 0;
+ allocation_create_info.preferredFlags = 0;
+ allocation_create_info.memoryTypeBits = 0; // Automatic selection.
+ allocation_create_info.pool = VK_NULL_HANDLE;
+ allocation_create_info.pUserData = nullptr;
+ if (AllBitsSet(memory_type, MemoryType::kDeviceLocal)) {
+ if (AllBitsSet(memory_type, MemoryType::kHostVisible)) {
+ // Device-local, host-visible.
+ allocation_create_info.usage = VMA_MEMORY_USAGE_CPU_TO_GPU;
+ allocation_create_info.preferredFlags |=
+ VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT;
+ } else {
+ // Device-local only.
+ allocation_create_info.usage = VMA_MEMORY_USAGE_GPU_ONLY;
+ allocation_create_info.requiredFlags |=
+ VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT;
+ }
+ } else {
+ if (AllBitsSet(memory_type, MemoryType::kDeviceVisible)) {
+ // Host-local, device-visible.
+ allocation_create_info.usage = VMA_MEMORY_USAGE_GPU_TO_CPU;
+ } else {
+ // Host-local only.
+ allocation_create_info.usage = VMA_MEMORY_USAGE_CPU_ONLY;
+ }
+ }
+ if (AllBitsSet(memory_type, MemoryType::kHostCached)) {
+ allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
+ }
+ if (AllBitsSet(memory_type, MemoryType::kHostCoherent)) {
+ allocation_create_info.requiredFlags |=
+ VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
+ }
+ if (AllBitsSet(memory_type, MemoryType::kTransient)) {
+ allocation_create_info.preferredFlags |=
+ VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT;
+ }
+ if (AllBitsSet(buffer_usage, BufferUsage::kMapping)) {
+ allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT;
+ }
+
+ VkBuffer buffer = VK_NULL_HANDLE;
+ VmaAllocation allocation = VK_NULL_HANDLE;
+ VmaAllocationInfo allocation_info;
+ VK_RETURN_IF_ERROR(vmaCreateBuffer(vma_, &buffer_create_info,
+ &allocation_create_info, &buffer,
+ &allocation, &allocation_info));
+
+ return make_ref<VmaBuffer>(this, memory_type, allowed_access, buffer_usage,
+ allocation_size, 0, allocation_size, buffer,
+ allocation, allocation_info);
+}
+
+StatusOr<ref_ptr<Buffer>> VmaAllocator::Allocate(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ size_t allocation_size) {
+ IREE_TRACE_SCOPE0("VmaAllocator::Allocate");
+ return AllocateInternal(memory_type, buffer_usage, MemoryAccess::kAll,
+ allocation_size, /*flags=*/0);
+}
+
+StatusOr<ref_ptr<Buffer>> VmaAllocator::AllocateConstant(
+ BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) {
+ IREE_TRACE_SCOPE0("VmaAllocator::AllocateConstant");
+ // TODO(benvanik): import memory to avoid the copy.
+ ASSIGN_OR_RETURN(
+ auto buffer,
+ AllocateInternal(MemoryType::kDeviceLocal | MemoryType::kHostVisible,
+ buffer_usage,
+ MemoryAccess::kRead | MemoryAccess::kDiscardWrite,
+ source_buffer->byte_length(),
+ /*flags=*/0));
+ RETURN_IF_ERROR(buffer->CopyData(0, source_buffer.get(), 0, kWholeBuffer));
+ buffer->set_allowed_access(MemoryAccess::kRead);
+ return buffer;
+}
+
+StatusOr<ref_ptr<Buffer>> VmaAllocator::WrapMutable(
+ MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield buffer_usage, void* data, size_t data_length) {
+ IREE_TRACE_SCOPE0("VmaAllocator::WrapMutable");
+ // TODO(benvanik): import memory.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Wrapping host memory is not yet implemented";
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/vma_allocator.h b/hal/vulkan/vma_allocator.h
new file mode 100644
index 0000000..ab908b3
--- /dev/null
+++ b/hal/vulkan/vma_allocator.h
@@ -0,0 +1,110 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
+#define IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
+
+#include <vulkan/vulkan.h>
+
+#include <memory>
+
+#include "base/status.h"
+#include "hal/allocator.h"
+#include "hal/vulkan/dynamic_symbols.h"
+#include "hal/vulkan/handle_util.h"
+#include "hal/vulkan/internal_vk_mem_alloc.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+class VmaBuffer;
+
+// A HAL allocator using the Vulkan Memory Allocator (VMA) to manage memory.
+// VMA (//third_party/vulkan_memory_allocator) provides dlmalloc-like behavior
+// with suballocations made with various policies (best fit, first fit, etc).
+// This reduces the number of allocations we need from the Vulkan implementation
+// (which can sometimes be limited to as little as 4096 total allowed) and
+// manages higher level allocation semantics like slab allocation and
+// defragmentation.
+//
+// VMA is internally synchronized and the functionality exposed on the HAL
+// interface is thread-safe.
+//
+// More information:
+// https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator
+// https://gpuopen-librariesandsdks.github.io/VulkanMemoryAllocator/html/
+class VmaAllocator final : public Allocator {
+ public:
+ static StatusOr<std::unique_ptr<VmaAllocator>> Create(
+ VkPhysicalDevice physical_device,
+ const ref_ptr<VkDeviceHandle>& logical_device);
+
+ ~VmaAllocator() override;
+
+ const ref_ptr<DynamicSymbols>& syms() const {
+ return logical_device_->syms();
+ }
+
+ ::VmaAllocator vma() const { return vma_; }
+
+ bool CanUseBufferLike(Allocator* source_allocator,
+ MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ BufferUsageBitfield intended_usage) const override;
+
+ bool CanAllocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) const override;
+
+ Status MakeCompatible(MemoryTypeBitfield* memory_type,
+ BufferUsageBitfield* buffer_usage) const override;
+
+ StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type,
+ BufferUsageBitfield buffer_usage,
+ size_t allocation_size) override;
+
+ StatusOr<ref_ptr<Buffer>> AllocateConstant(
+ BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) override;
+
+ StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield buffer_usage,
+ void* data,
+ size_t data_length) override;
+
+ private:
+ VmaAllocator(VkPhysicalDevice physical_device,
+ const ref_ptr<VkDeviceHandle>& logical_device,
+ ::VmaAllocator vma);
+
+ StatusOr<ref_ptr<VmaBuffer>> AllocateInternal(
+ MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
+ MemoryAccessBitfield allowed_access, size_t allocation_size,
+ VmaAllocationCreateFlags flags);
+
+ VkPhysicalDevice physical_device_;
+ ref_ptr<VkDeviceHandle> logical_device_;
+
+ // Internally synchronized. We could externally synchronize if we thought it
+ // was worth it, however I'm not sure we'd be able to do much better with the
+ // current Allocator API.
+ ::VmaAllocator vma_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
diff --git a/hal/vulkan/vma_buffer.cc b/hal/vulkan/vma_buffer.cc
new file mode 100644
index 0000000..b7d3b6b
--- /dev/null
+++ b/hal/vulkan/vma_buffer.cc
@@ -0,0 +1,163 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/vma_buffer.h"
+
+#include "base/source_location.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/vulkan/status_util.h"
+#include "hal/vulkan/vma_allocator.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+VmaBuffer::VmaBuffer(VmaAllocator* allocator, MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access,
+ BufferUsageBitfield usage, device_size_t allocation_size,
+ device_size_t byte_offset, device_size_t byte_length,
+ VkBuffer buffer, VmaAllocation allocation,
+ VmaAllocationInfo allocation_info)
+ : Buffer(allocator, memory_type, allowed_access, usage, allocation_size,
+ byte_offset, byte_length),
+ vma_(allocator->vma()),
+ buffer_(buffer),
+ allocation_(allocation),
+ allocation_info_(allocation_info) {
+ // TODO(benvanik): set debug name instead and use the
+ // VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT flag.
+ vmaSetAllocationUserData(vma_, allocation_, this);
+}
+
+VmaBuffer::~VmaBuffer() {
+ IREE_TRACE_SCOPE0("VmaBuffer::dtor");
+ vmaDestroyBuffer(vma_, buffer_, allocation_);
+}
+
+Status VmaBuffer::FillImpl(device_size_t byte_offset, device_size_t byte_length,
+ const void* pattern, device_size_t pattern_length) {
+ ASSIGN_OR_RETURN(auto mapping, MapMemory<uint8_t>(MemoryAccess::kDiscardWrite,
+ byte_offset, byte_length));
+ void* data_ptr = static_cast<void*>(mapping.mutable_data());
+ switch (pattern_length) {
+ case 1: {
+ uint8_t* data = static_cast<uint8_t*>(data_ptr);
+ uint8_t value_bits = *static_cast<const uint8_t*>(pattern);
+ std::fill_n(data + byte_offset, byte_length, value_bits);
+ break;
+ }
+ case 2: {
+ uint16_t* data = static_cast<uint16_t*>(data_ptr);
+ uint16_t value_bits = *static_cast<const uint16_t*>(pattern);
+ std::fill_n(data + byte_offset / sizeof(uint16_t),
+ byte_length / sizeof(uint16_t), value_bits);
+ break;
+ }
+ case 4: {
+ uint32_t* data = static_cast<uint32_t*>(data_ptr);
+ uint32_t value_bits = *static_cast<const uint32_t*>(pattern);
+ std::fill_n(data + byte_offset / sizeof(uint32_t),
+ byte_length / sizeof(uint32_t), value_bits);
+ break;
+ }
+ default:
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Unsupported scalar data size: " << pattern_length;
+ }
+ return OkStatus();
+}
+
+Status VmaBuffer::ReadDataImpl(device_size_t source_offset, void* data,
+ device_size_t data_length) {
+ ASSIGN_OR_RETURN(
+ auto mapping,
+ MapMemory<uint8_t>(MemoryAccess::kRead, source_offset, data_length));
+ std::memcpy(data, mapping.data(), mapping.byte_length());
+ return OkStatus();
+}
+
+Status VmaBuffer::WriteDataImpl(device_size_t target_offset, const void* data,
+ device_size_t data_length) {
+ ASSIGN_OR_RETURN(auto mapping,
+ MapMemory<uint8_t>(MemoryAccess::kDiscardWrite,
+ target_offset, data_length));
+ std::memcpy(mapping.mutable_data(), data, mapping.byte_length());
+ return OkStatus();
+}
+
+Status VmaBuffer::CopyDataImpl(device_size_t target_offset,
+ Buffer* source_buffer,
+ device_size_t source_offset,
+ device_size_t data_length) {
+ // This is pretty terrible. Let's not do this.
+ // TODO(benvanik): a way for allocators to indicate transfer compat.
+ ASSIGN_OR_RETURN(auto source_mapping,
+ source_buffer->MapMemory<uint8_t>(
+ MemoryAccess::kRead, source_offset, data_length));
+ CHECK_EQ(data_length, source_mapping.size());
+ ASSIGN_OR_RETURN(auto target_mapping,
+ MapMemory<uint8_t>(MemoryAccess::kDiscardWrite,
+ target_offset, data_length));
+ CHECK_EQ(data_length, target_mapping.size());
+ std::memcpy(target_mapping.mutable_data() + target_offset,
+ source_mapping.data(), data_length);
+ return OkStatus();
+}
+
+Status VmaBuffer::MapMemoryImpl(MappingMode mapping_mode,
+ MemoryAccessBitfield memory_access,
+ device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void** out_data) {
+ uint8_t* data_ptr = nullptr;
+ VK_RETURN_IF_ERROR(
+ vmaMapMemory(vma_, allocation_, reinterpret_cast<void**>(&data_ptr)));
+ *out_data = data_ptr + local_byte_offset;
+
+ // If we mapped for discard scribble over the bytes. This is not a mandated
+ // behavior but it will make debugging issues easier. Alternatively for
+ // heap buffers we could reallocate them such that ASAN yells, but that
+ // would only work if the entire buffer was discarded.
+#ifndef NDEBUG
+ if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) {
+ std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length);
+ }
+#endif // !NDEBUG
+
+ return OkStatus();
+}
+
+Status VmaBuffer::UnmapMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length, void* data) {
+ vmaUnmapMemory(vma_, allocation_);
+ return OkStatus();
+}
+
+Status VmaBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) {
+ vmaInvalidateAllocation(vma_, allocation_, local_byte_offset,
+ local_byte_length);
+ return OkStatus();
+}
+
+Status VmaBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) {
+ vmaFlushAllocation(vma_, allocation_, local_byte_offset, local_byte_length);
+ return OkStatus();
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/vma_buffer.h b/hal/vulkan/vma_buffer.h
new file mode 100644
index 0000000..5642c87
--- /dev/null
+++ b/hal/vulkan/vma_buffer.h
@@ -0,0 +1,79 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_VMA_BUFFER_H_
+#define IREE_HAL_VULKAN_VMA_BUFFER_H_
+
+#include <vulkan/vulkan.h>
+
+#include "hal/buffer.h"
+#include "vk_mem_alloc.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+class VmaAllocator;
+
+// A buffer implementation representing an allocation made from within a pool of
+// a Vulkan Memory Allocator instance. See VmaAllocator for more information.
+class VmaBuffer final : public Buffer {
+ public:
+ VmaBuffer(VmaAllocator* allocator, MemoryTypeBitfield memory_type,
+ MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
+ device_size_t allocation_size, device_size_t byte_offset,
+ device_size_t byte_length, VkBuffer buffer,
+ VmaAllocation allocation, VmaAllocationInfo allocation_info);
+ ~VmaBuffer() override;
+
+ VkBuffer handle() const { return buffer_; }
+ VmaAllocation allocation() const { return allocation_; }
+ const VmaAllocationInfo& allocation_info() const { return allocation_info_; }
+
+ // Exposed so that VmaAllocator can reset access after initial mapping.
+ using Buffer::set_allowed_access;
+
+ private:
+ Status FillImpl(device_size_t byte_offset, device_size_t byte_length,
+ const void* pattern, device_size_t pattern_length) override;
+ Status ReadDataImpl(device_size_t source_offset, void* data,
+ device_size_t data_length) override;
+ Status WriteDataImpl(device_size_t target_offset, const void* data,
+ device_size_t data_length) override;
+ Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer,
+ device_size_t source_offset,
+ device_size_t data_length) override;
+ Status MapMemoryImpl(MappingMode mapping_mode,
+ MemoryAccessBitfield memory_access,
+ device_size_t local_byte_offset,
+ device_size_t local_byte_length,
+ void** out_data) override;
+ Status UnmapMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length, void* data) override;
+ Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) override;
+ Status FlushMappedMemoryImpl(device_size_t local_byte_offset,
+ device_size_t local_byte_length) override;
+
+ ::VmaAllocator vma_;
+ VkBuffer buffer_;
+ VmaAllocation allocation_;
+ VmaAllocationInfo allocation_info_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_VMA_BUFFER_H_
diff --git a/hal/vulkan/vulkan_device.cc b/hal/vulkan/vulkan_device.cc
new file mode 100644
index 0000000..8fe5e60
--- /dev/null
+++ b/hal/vulkan/vulkan_device.cc
@@ -0,0 +1,500 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/vulkan_device.h"
+
+#include <functional>
+#include <utility>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/synchronization/mutex.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/command_buffer_validation.h"
+#include "hal/command_queue.h"
+#include "hal/fence.h"
+#include "hal/vulkan/direct_command_buffer.h"
+#include "hal/vulkan/direct_command_queue.h"
+#include "hal/vulkan/dynamic_symbols.h"
+#include "hal/vulkan/extensibility_util.h"
+#include "hal/vulkan/legacy_fence.h"
+#include "hal/vulkan/native_binary_semaphore.h"
+#include "hal/vulkan/native_event.h"
+#include "hal/vulkan/pipeline_cache.h"
+#include "hal/vulkan/status_util.h"
+#include "hal/vulkan/vma_allocator.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+namespace {
+
+constexpr uint32_t kInvalidQueueFamilyIndex = -1;
+
+struct QueueFamilyInfo {
+ uint32_t dispatch_index = kInvalidQueueFamilyIndex;
+ uint32_t dispatch_queue_count = 0;
+ uint32_t transfer_index = kInvalidQueueFamilyIndex;
+ uint32_t transfer_queue_count = 0;
+};
+
+// Finds the first queue in the listing (which is usually the driver-preferred)
+// that has all of the |required_queue_flags| and none of the
+// |excluded_queue_flags|.
+// Returns kInvalidQueueFamilyIndex if no matching queue is found.
+uint32_t FindFirstQueueFamilyWithFlags(
+ absl::Span<const VkQueueFamilyProperties> queue_family_properties,
+ uint32_t required_queue_flags, uint32_t excluded_queue_flags) {
+ for (int queue_family_index = 0;
+ queue_family_index < queue_family_properties.size();
+ ++queue_family_index) {
+ const auto& properties = queue_family_properties[queue_family_index];
+ if ((properties.queueFlags & required_queue_flags) ==
+ required_queue_flags &&
+ (properties.queueFlags & excluded_queue_flags) == 0) {
+ return queue_family_index;
+ }
+ }
+ return kInvalidQueueFamilyIndex;
+}
+
+// Selects queue family indices for compute and transfer queues.
+// Note that both queue families may be the same if there is only one family
+// available.
+StatusOr<QueueFamilyInfo> SelectQueueFamilies(
+ VkPhysicalDevice physical_device, const ref_ptr<DynamicSymbols>& syms) {
+ // Enumerate queue families available on the device.
+ uint32_t queue_family_count = 0;
+ syms->vkGetPhysicalDeviceQueueFamilyProperties(physical_device,
+ &queue_family_count, nullptr);
+ absl::InlinedVector<VkQueueFamilyProperties, 4> queue_family_properties(
+ queue_family_count);
+ syms->vkGetPhysicalDeviceQueueFamilyProperties(
+ physical_device, &queue_family_count, queue_family_properties.data());
+
+ QueueFamilyInfo queue_family_info;
+
+ // Try to find a dedicated compute queue (no graphics caps).
+ // Some may support both transfer and compute. If that fails then fallback to
+ // any queue that supports compute.
+ queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags(
+ queue_family_properties, VK_QUEUE_COMPUTE_BIT, VK_QUEUE_GRAPHICS_BIT);
+ if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) {
+ queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags(
+ queue_family_properties, VK_QUEUE_COMPUTE_BIT, 0);
+ }
+ if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) {
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Unable to find any queue family support compute operations";
+ }
+ queue_family_info.dispatch_queue_count =
+ queue_family_properties[queue_family_info.dispatch_index].queueCount;
+
+ // Try to find a dedicated transfer queue (no compute or graphics caps).
+ // Not all devices have one, and some have only a queue family for everything
+ // and possibly a queue family just for compute/etc. If that fails then
+ // fallback to any queue that supports transfer. Finally, if /that/ fails then
+ // we just won't create a transfer queue and instead use the compute queue for
+ // all operations.
+ queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags(
+ queue_family_properties, VK_QUEUE_TRANSFER_BIT,
+ VK_QUEUE_COMPUTE_BIT | VK_QUEUE_GRAPHICS_BIT);
+ if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) {
+ queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags(
+ queue_family_properties, VK_QUEUE_TRANSFER_BIT, VK_QUEUE_GRAPHICS_BIT);
+ }
+ if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) {
+ queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags(
+ queue_family_properties, VK_QUEUE_TRANSFER_BIT, 0);
+ }
+ if (queue_family_info.transfer_index != kInvalidQueueFamilyIndex) {
+ queue_family_info.transfer_queue_count =
+ queue_family_properties[queue_family_info.transfer_index].queueCount;
+ }
+
+ // Ensure that we don't share the dispatch queues with transfer queues if that
+ // would put us over the queue count.
+ if (queue_family_info.dispatch_index == queue_family_info.transfer_index) {
+ queue_family_info.transfer_queue_count = std::min(
+ queue_family_properties[queue_family_info.dispatch_index].queueCount -
+ queue_family_info.dispatch_queue_count,
+ queue_family_info.transfer_queue_count);
+ }
+
+ return queue_family_info;
+}
+
+// Creates a transient command pool for the given queue family.
+// Command buffers allocated from the pool must only be issued on queues
+// belonging to the specified family.
+StatusOr<ref_ptr<VkCommandPoolHandle>> CreateTransientCommandPool(
+ const ref_ptr<VkDeviceHandle>& logical_device,
+ uint32_t queue_family_index) {
+ VkCommandPoolCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT;
+ create_info.queueFamilyIndex = queue_family_index;
+
+ auto command_pool = make_ref<VkCommandPoolHandle>(logical_device);
+ VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateCommandPool(
+ *logical_device, &create_info, logical_device->allocator(),
+ command_pool->mutable_value()));
+ return command_pool;
+}
+
+} // namespace
+
+// static
+StatusOr<std::shared_ptr<VulkanDevice>> VulkanDevice::Create(
+ const DeviceInfo& device_info, VkPhysicalDevice physical_device,
+ const ExtensibilitySpec& extensibility_spec,
+ const ref_ptr<DynamicSymbols>& syms) {
+ IREE_TRACE_SCOPE0("VulkanDevice::Create");
+
+ // Find the layers and extensions we need (or want) that are also available
+ // on the device. This will fail when required ones are not present.
+ ASSIGN_OR_RETURN(
+ auto enabled_layer_names,
+ MatchAvailableDeviceLayers(physical_device, extensibility_spec, *syms));
+ ASSIGN_OR_RETURN(auto enabled_extension_names,
+ MatchAvailableDeviceExtensions(physical_device,
+ extensibility_spec, *syms));
+ auto enabled_device_extensions =
+ PopulateEnabledDeviceExtensions(enabled_extension_names);
+
+ // Find queue families we will expose as HAL queues.
+ ASSIGN_OR_RETURN(auto queue_family_info,
+ SelectQueueFamilies(physical_device, syms));
+
+ // Limit the number of queues we create (for now).
+ // We may want to allow this to grow, but each queue adds overhead and we need
+ // to measure to make sure we can effectively use them all.
+ queue_family_info.dispatch_queue_count =
+ std::min(2u, queue_family_info.dispatch_queue_count);
+ queue_family_info.transfer_queue_count =
+ std::min(1u, queue_family_info.transfer_queue_count);
+ bool has_dedicated_transfer_queues =
+ queue_family_info.transfer_queue_count > 0;
+
+ // Setup the queue info we'll be using.
+ // Each queue here (created from within a family) will map to a HAL queue.
+ //
+ // Note that we need to handle the case where we have transfer queues that are
+ // of the same queue family as the dispatch queues: Vulkan requires that all
+ // queues created from the same family are done in the same
+ // VkDeviceQueueCreateInfo struct.
+ DVLOG(1) << "Creating " << queue_family_info.dispatch_queue_count
+ << " dispatch queue(s) in queue family "
+ << queue_family_info.dispatch_index;
+ absl::InlinedVector<VkDeviceQueueCreateInfo, 2> queue_create_info;
+ absl::InlinedVector<float, 4> dispatch_queue_priorities;
+ absl::InlinedVector<float, 4> transfer_queue_priorities;
+ queue_create_info.push_back({});
+ auto& dispatch_queue_info = queue_create_info.back();
+ dispatch_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
+ dispatch_queue_info.pNext = nullptr;
+ dispatch_queue_info.flags = 0;
+ dispatch_queue_info.queueFamilyIndex = queue_family_info.dispatch_index;
+ dispatch_queue_info.queueCount = queue_family_info.dispatch_queue_count;
+ if (has_dedicated_transfer_queues) {
+ if (queue_family_info.dispatch_index == queue_family_info.transfer_index) {
+ DVLOG(1) << "Creating " << queue_family_info.transfer_queue_count
+ << " dedicated transfer queue(s) in shared queue family "
+ << queue_family_info.transfer_index;
+ dispatch_queue_info.queueCount += queue_family_info.transfer_queue_count;
+ } else {
+ DVLOG(1) << "Creating " << queue_family_info.transfer_queue_count
+ << " dedicated transfer queue(s) in independent queue family "
+ << queue_family_info.transfer_index;
+ queue_create_info.push_back({});
+ auto& transfer_queue_info = queue_create_info.back();
+ transfer_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
+ transfer_queue_info.pNext = nullptr;
+ transfer_queue_info.queueFamilyIndex = queue_family_info.transfer_index;
+ transfer_queue_info.queueCount = queue_family_info.transfer_queue_count;
+ transfer_queue_info.flags = 0;
+ transfer_queue_priorities.resize(transfer_queue_info.queueCount);
+ transfer_queue_info.pQueuePriorities = transfer_queue_priorities.data();
+ }
+ }
+ dispatch_queue_priorities.resize(dispatch_queue_info.queueCount);
+ dispatch_queue_info.pQueuePriorities = dispatch_queue_priorities.data();
+
+ // TODO(benvanik): specify features with VkPhysicalDeviceFeatures.
+
+ // Create device and its queues.
+ VkDeviceCreateInfo device_create_info = {};
+ device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
+ device_create_info.pNext = nullptr;
+ device_create_info.enabledLayerCount = enabled_layer_names.size();
+ device_create_info.ppEnabledLayerNames = enabled_layer_names.data();
+ device_create_info.enabledExtensionCount = enabled_extension_names.size();
+ device_create_info.ppEnabledExtensionNames = enabled_extension_names.data();
+ device_create_info.queueCreateInfoCount = queue_create_info.size();
+ device_create_info.pQueueCreateInfos = queue_create_info.data();
+ device_create_info.pEnabledFeatures = nullptr;
+ auto logical_device = make_ref<VkDeviceHandle>(
+ syms, enabled_device_extensions, /*allocator=*/nullptr);
+ VK_RETURN_IF_ERROR(syms->vkCreateDevice(physical_device, &device_create_info,
+ logical_device->allocator(),
+ logical_device->mutable_value()));
+
+ // Create the device memory allocator.
+ // TODO(benvanik): allow other types to be plugged in.
+ ASSIGN_OR_RETURN(auto allocator,
+ VmaAllocator::Create(physical_device, logical_device));
+
+ // Create command pools for each queue family. If we don't have a transfer
+ // queue then we'll ignore that one and just use the dispatch pool.
+ // If we wanted to expose the pools through the HAL to allow the VM to more
+ // effectively manage them (pool per fiber, etc) we could, however I doubt the
+ // overhead of locking the pool will be even a blip.
+ ASSIGN_OR_RETURN(auto dispatch_command_pool,
+ CreateTransientCommandPool(
+ logical_device, queue_family_info.dispatch_index));
+ ref_ptr<VkCommandPoolHandle> transfer_command_pool;
+ if (has_dedicated_transfer_queues) {
+ ASSIGN_OR_RETURN(transfer_command_pool,
+ CreateTransientCommandPool(
+ logical_device, queue_family_info.transfer_index));
+ }
+
+ // Get the queues and create the HAL wrappers.
+ absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues;
+ for (uint32_t i = 0; i < queue_family_info.dispatch_queue_count; ++i) {
+ VkQueue queue = VK_NULL_HANDLE;
+ syms->vkGetDeviceQueue(*logical_device, queue_family_info.dispatch_index, i,
+ &queue);
+ std::string queue_name = absl::StrCat(device_info.name(), ":d", i);
+ command_queues.push_back(absl::make_unique<DirectCommandQueue>(
+ std::move(queue_name),
+ CommandCategory::kDispatch | CommandCategory::kTransfer, logical_device,
+ queue));
+ }
+ if (has_dedicated_transfer_queues) {
+ uint32_t base_queue_index = 0;
+ if (queue_family_info.dispatch_index == queue_family_info.transfer_index) {
+ // Sharing a family, so transfer queues follow compute queues.
+ base_queue_index = queue_family_info.dispatch_index;
+ }
+ for (uint32_t i = 0; i < queue_family_info.transfer_queue_count; ++i) {
+ VkQueue queue = VK_NULL_HANDLE;
+ syms->vkGetDeviceQueue(*logical_device, queue_family_info.transfer_index,
+ base_queue_index + i, &queue);
+ std::string queue_name = absl::StrCat(device_info.name(), ":t", i);
+ command_queues.push_back(absl::make_unique<DirectCommandQueue>(
+ std::move(queue_name), CommandCategory::kTransfer, logical_device,
+ queue));
+ }
+ }
+
+ // TODO(b/140141417): implement timeline semaphore fences and switch here.
+ ASSIGN_OR_RETURN(auto legacy_fence_pool,
+ LegacyFencePool::Create(add_ref(logical_device)));
+
+ return std::make_shared<VulkanDevice>(
+ CtorKey{}, device_info, physical_device, std::move(logical_device),
+ std::move(allocator), std::move(command_queues),
+ std::move(dispatch_command_pool), std::move(transfer_command_pool),
+ std::move(legacy_fence_pool));
+}
+
+VulkanDevice::VulkanDevice(
+ CtorKey ctor_key, const DeviceInfo& device_info,
+ VkPhysicalDevice physical_device, ref_ptr<VkDeviceHandle> logical_device,
+ std::unique_ptr<Allocator> allocator,
+ absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues,
+ ref_ptr<VkCommandPoolHandle> dispatch_command_pool,
+ ref_ptr<VkCommandPoolHandle> transfer_command_pool,
+ ref_ptr<LegacyFencePool> legacy_fence_pool)
+ : Device(device_info),
+ physical_device_(physical_device),
+ logical_device_(std::move(logical_device)),
+ allocator_(std::move(allocator)),
+ command_queues_(std::move(command_queues)),
+ descriptor_pool_cache_(
+ make_ref<DescriptorPoolCache>(add_ref(logical_device_))),
+ dispatch_command_pool_(std::move(dispatch_command_pool)),
+ transfer_command_pool_(std::move(transfer_command_pool)),
+ legacy_fence_pool_(std::move(legacy_fence_pool)) {
+ // Populate the queue lists based on queue capabilities.
+ for (auto& command_queue : command_queues_) {
+ if (command_queue->can_dispatch()) {
+ dispatch_queues_.push_back(command_queue.get());
+ if (transfer_command_pool_ == VK_NULL_HANDLE) {
+ transfer_queues_.push_back(command_queue.get());
+ }
+ } else {
+ transfer_queues_.push_back(command_queue.get());
+ }
+ }
+}
+
+VulkanDevice::~VulkanDevice() {
+ IREE_TRACE_SCOPE0("VulkanDevice::dtor");
+
+ // Drop all command queues. These may wait until idle.
+ command_queues_.clear();
+ dispatch_queues_.clear();
+ transfer_queues_.clear();
+
+ // Drop command pools now that we know there are no more outstanding command
+ // buffers.
+ dispatch_command_pool_.reset();
+ transfer_command_pool_.reset();
+
+ // Now that no commands are outstanding we can release all descriptor sets.
+ descriptor_pool_cache_.reset();
+
+ // Finally, destroy the device.
+ logical_device_.reset();
+}
+
+std::shared_ptr<ExecutableCache> VulkanDevice::CreateExecutableCache() {
+ IREE_TRACE_SCOPE0("VulkanDevice::CreateExecutableCache");
+ return std::make_shared<PipelineCache>(logical_device_);
+}
+
+StatusOr<ref_ptr<CommandBuffer>> VulkanDevice::CreateCommandBuffer(
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories) {
+ IREE_TRACE_SCOPE0("VulkanDevice::CreateCommandBuffer");
+
+ // Select the command pool to used based on the types of commands used.
+ // Note that we may not have a dedicated transfer command pool if there are no
+ // dedicated transfer queues.
+ ref_ptr<VkCommandPoolHandle> command_pool;
+ if (transfer_command_pool_ &&
+ !AllBitsSet(command_categories, CommandCategory::kDispatch)) {
+ command_pool = add_ref(transfer_command_pool_);
+ } else {
+ command_pool = add_ref(dispatch_command_pool_);
+ }
+
+ VkCommandBufferAllocateInfo allocate_info;
+ allocate_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
+ allocate_info.pNext = nullptr;
+ allocate_info.commandPool = *command_pool;
+ allocate_info.commandBufferCount = 1;
+ allocate_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
+
+ VkCommandBuffer command_buffer = VK_NULL_HANDLE;
+ {
+ absl::MutexLock lock(command_pool->mutex());
+ VK_RETURN_IF_ERROR(syms()->vkAllocateCommandBuffers(
+ *logical_device_, &allocate_info, &command_buffer));
+ }
+
+ // TODO(b/140026716): conditionally enable validation.
+ auto impl = make_ref<DirectCommandBuffer>(
+ allocator(), mode, command_categories, add_ref(descriptor_pool_cache_),
+ add_ref(command_pool), command_buffer);
+ return WrapCommandBufferWithValidation(std::move(impl));
+}
+
+StatusOr<ref_ptr<Event>> VulkanDevice::CreateEvent() {
+ IREE_TRACE_SCOPE0("VulkanDevice::CreateEvent");
+
+ // TODO(b/138729892): pool events.
+ VkEventCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_EVENT_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = 0;
+ VkEvent event_handle = VK_NULL_HANDLE;
+ VK_RETURN_IF_ERROR(syms()->vkCreateEvent(*logical_device_, &create_info,
+ logical_device_->allocator(),
+ &event_handle));
+
+ return make_ref<NativeEvent>(add_ref(logical_device_), event_handle);
+}
+
+StatusOr<ref_ptr<BinarySemaphore>> VulkanDevice::CreateBinarySemaphore(
+ bool initial_value) {
+ IREE_TRACE_SCOPE0("VulkanDevice::CreateBinarySemaphore");
+
+ VkSemaphoreCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = initial_value ? VK_FENCE_CREATE_SIGNALED_BIT : 0;
+ VkSemaphore semaphore_handle = VK_NULL_HANDLE;
+ VK_RETURN_IF_ERROR(syms()->vkCreateSemaphore(*logical_device_, &create_info,
+ logical_device_->allocator(),
+ &semaphore_handle));
+
+ return make_ref<NativeBinarySemaphore>(add_ref(logical_device_),
+ semaphore_handle);
+}
+
+StatusOr<ref_ptr<TimelineSemaphore>> VulkanDevice::CreateTimelineSemaphore(
+ uint64_t initial_value) {
+ IREE_TRACE_SCOPE0("VulkanDevice::CreateTimelineSemaphore");
+
+ // TODO(b/140141417): implement timeline semaphores.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Timeline semaphores not yet implemented";
+}
+
+StatusOr<ref_ptr<Fence>> VulkanDevice::CreateFence(uint64_t initial_value) {
+ IREE_TRACE_SCOPE0("VulkanDevice::CreateFence");
+
+ // TODO(b/140141417): implement timeline semaphore fences and switch here.
+ // NOTE: we'll want some magic factory so that we can cleanly compile out the
+ // legacy implementation and pool.
+
+ return make_ref<LegacyFence>(add_ref(legacy_fence_pool_), initial_value);
+}
+
+Status VulkanDevice::WaitAllFences(absl::Span<const FenceValue> fences,
+ absl::Time deadline) {
+ IREE_TRACE_SCOPE0("VulkanDevice::WaitAllFences");
+
+ // TODO(b/140141417): implement timeline semaphore fences and switch here.
+
+ return LegacyFence::WaitForFences(logical_device_.get(), fences,
+ /*wait_all=*/true, deadline);
+}
+
+StatusOr<int> VulkanDevice::WaitAnyFence(absl::Span<const FenceValue> fences,
+ absl::Time deadline) {
+ IREE_TRACE_SCOPE0("VulkanDevice::WaitAnyFence");
+
+ // TODO(b/140141417): implement timeline semaphore fences and switch here.
+
+ return LegacyFence::WaitForFences(logical_device_.get(), fences,
+ /*wait_all=*/false, deadline);
+}
+
+Status VulkanDevice::WaitIdle(absl::Time deadline) {
+ if (deadline == absl::InfiniteFuture()) {
+ // Fast path for using vkDeviceWaitIdle, which is usually cheaper (as it
+ // requires fewer calls into the driver).
+ IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#vkDeviceWaitIdle");
+ VK_RETURN_IF_ERROR(syms()->vkDeviceWaitIdle(*logical_device_));
+ return OkStatus();
+ }
+
+ IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#Fences");
+ for (auto& command_queue : command_queues_) {
+ RETURN_IF_ERROR(command_queue->WaitIdle(deadline));
+ }
+ return OkStatus();
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/vulkan_device.h b/hal/vulkan/vulkan_device.h
new file mode 100644
index 0000000..b208bb2
--- /dev/null
+++ b/hal/vulkan/vulkan_device.h
@@ -0,0 +1,120 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_VULKAN_DEVICE_H_
+#define IREE_HAL_VULKAN_VULKAN_DEVICE_H_
+
+#include <vulkan/vulkan.h>
+
+#include <functional>
+#include <memory>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
+#include "base/memory.h"
+#include "hal/allocator.h"
+#include "hal/device.h"
+#include "hal/vulkan/descriptor_pool_cache.h"
+#include "hal/vulkan/dynamic_symbols.h"
+#include "hal/vulkan/extensibility_util.h"
+#include "hal/vulkan/handle_util.h"
+#include "hal/vulkan/legacy_fence.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+class VulkanDevice final : public Device {
+ public:
+ static StatusOr<std::shared_ptr<VulkanDevice>> Create(
+ const DeviceInfo& device_info, VkPhysicalDevice physical_device,
+ const ExtensibilitySpec& extensibility_spec,
+ const ref_ptr<DynamicSymbols>& syms);
+
+ // Private constructor.
+ struct CtorKey {
+ private:
+ friend class VulkanDevice;
+ CtorKey() = default;
+ };
+ VulkanDevice(
+ CtorKey ctor_key, const DeviceInfo& device_info,
+ VkPhysicalDevice physical_device, ref_ptr<VkDeviceHandle> logical_device,
+ std::unique_ptr<Allocator> allocator,
+ absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues,
+ ref_ptr<VkCommandPoolHandle> dispatch_command_pool,
+ ref_ptr<VkCommandPoolHandle> transfer_command_pool,
+ ref_ptr<LegacyFencePool> legacy_fence_pool);
+ ~VulkanDevice() override;
+
+ const ref_ptr<DynamicSymbols>& syms() const {
+ return logical_device_->syms();
+ }
+
+ Allocator* allocator() const override { return allocator_.get(); }
+
+ absl::Span<CommandQueue*> dispatch_queues() const override {
+ return absl::MakeSpan(dispatch_queues_);
+ }
+
+ absl::Span<CommandQueue*> transfer_queues() const override {
+ return absl::MakeSpan(transfer_queues_);
+ }
+
+ std::shared_ptr<ExecutableCache> CreateExecutableCache() override;
+
+ StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer(
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories) override;
+
+ StatusOr<ref_ptr<Event>> CreateEvent() override;
+
+ StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore(
+ bool initial_value) override;
+ StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore(
+ uint64_t initial_value) override;
+
+ StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override;
+ Status WaitAllFences(absl::Span<const FenceValue> fences,
+ absl::Time deadline) override;
+ StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences,
+ absl::Time deadline) override;
+
+ Status WaitIdle(absl::Time deadline) override;
+
+ private:
+ VkPhysicalDevice physical_device_;
+ ref_ptr<VkDeviceHandle> logical_device_;
+
+ std::unique_ptr<Allocator> allocator_;
+
+ mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues_;
+ mutable absl::InlinedVector<CommandQueue*, 4> dispatch_queues_;
+ mutable absl::InlinedVector<CommandQueue*, 4> transfer_queues_;
+
+ ref_ptr<DescriptorPoolCache> descriptor_pool_cache_;
+
+ ref_ptr<VkCommandPoolHandle> dispatch_command_pool_;
+ ref_ptr<VkCommandPoolHandle> transfer_command_pool_;
+
+ // TODO(b/140141417): implement timeline semaphore fences and conditionally
+ // compile the legacy fence pool out.
+ ref_ptr<LegacyFencePool> legacy_fence_pool_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_VULKAN_DEVICE_H_
diff --git a/hal/vulkan/vulkan_driver.cc b/hal/vulkan/vulkan_driver.cc
new file mode 100644
index 0000000..f4d5e34
--- /dev/null
+++ b/hal/vulkan/vulkan_driver.cc
@@ -0,0 +1,229 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "hal/vulkan/vulkan_driver.h"
+
+#include <memory>
+
+#include "absl/container/inlined_vector.h"
+#include "base/memory.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/device_info.h"
+#include "hal/vulkan/extensibility_util.h"
+#include "hal/vulkan/status_util.h"
+#include "hal/vulkan/vulkan_device.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+namespace {
+
+// Returns a VkApplicationInfo struct populated with the default app info.
+// We may allow hosting applications to override this via weak-linkage if it's
+// useful, otherwise this is enough to create the application.
+VkApplicationInfo GetDefaultApplicationInfo() {
+ VkApplicationInfo info;
+ info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
+ info.pNext = nullptr;
+ info.pApplicationName = "IREE-ML";
+ info.applicationVersion = 0;
+ info.pEngineName = "IREE";
+ info.engineVersion = 0;
+ info.apiVersion = VK_API_VERSION_1_0;
+ return info;
+}
+
+// Populates device information from the given Vulkan physical device handle.
+StatusOr<DeviceInfo> PopulateDeviceInfo(VkPhysicalDevice physical_device,
+ const ref_ptr<DynamicSymbols>& syms) {
+ VkPhysicalDeviceFeatures physical_device_features;
+ syms->vkGetPhysicalDeviceFeatures(physical_device, &physical_device_features);
+ // TODO(benvanik): check and optionally require these features:
+ // - physical_device_features.robustBufferAccess
+ // - physical_device_features.shaderInt16
+ // - physical_device_features.shaderInt64
+ // - physical_device_features.shaderFloat64
+
+ VkPhysicalDeviceProperties physical_device_properties;
+ syms->vkGetPhysicalDeviceProperties(physical_device,
+ &physical_device_properties);
+ // TODO(benvanik): check and optionally require reasonable limits.
+
+ // TODO(benvanik): more clever/sanitized device naming.
+ std::string name = std::string(physical_device_properties.deviceName);
+
+ DeviceFeatureBitfield supported_features = DeviceFeature::kNone;
+ // TODO(benvanik): implement debugging/profiling features.
+ // TODO(benvanik): use props to determine if we have timing info.
+ // supported_features |= DeviceFeature::kDebugging;
+ // supported_features |= DeviceFeature::kCoverage;
+ // supported_features |= DeviceFeature::kProfiling;
+ return DeviceInfo(std::move(name), supported_features, physical_device);
+}
+
+} // namespace
+
+// static
+StatusOr<std::shared_ptr<VulkanDriver>> VulkanDriver::Create(
+ Options options, ref_ptr<DynamicSymbols> syms) {
+ IREE_TRACE_SCOPE0("VulkanDriver::Create");
+
+ // Find the layers and extensions we need (or want) that are also available
+ // on the instance. This will fail when required ones are not present.
+ ASSIGN_OR_RETURN(
+ auto enabled_layer_names,
+ MatchAvailableInstanceLayers(options.instance_extensibility, *syms));
+ ASSIGN_OR_RETURN(
+ auto enabled_extension_names,
+ MatchAvailableInstanceExtensions(options.instance_extensibility, *syms));
+ auto instance_extensions =
+ PopulateEnabledInstanceExtensions(enabled_extension_names);
+
+ // Create the instance this driver will use for all requests.
+ VkApplicationInfo app_info = GetDefaultApplicationInfo();
+ app_info.apiVersion = options.api_version;
+ VkInstanceCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = 0;
+ create_info.pApplicationInfo = &app_info;
+ create_info.enabledLayerCount = enabled_layer_names.size();
+ create_info.ppEnabledLayerNames = enabled_layer_names.data();
+ create_info.enabledExtensionCount = enabled_extension_names.size();
+ create_info.ppEnabledExtensionNames = enabled_extension_names.data();
+
+ // If we have the debug_utils extension then we can chain a one-shot messenger
+ // callback that we can use to log out the instance creation errors. Once we
+ // have the real instance we can then register a real messenger.
+ union {
+ VkDebugUtilsMessengerCreateInfoEXT debug_utils_create_info;
+ VkDebugReportCallbackCreateInfoEXT debug_report_create_info;
+ };
+ if (instance_extensions.debug_utils) {
+ create_info.pNext = &debug_utils_create_info;
+ DebugReporter::PopulateStaticCreateInfo(&debug_utils_create_info);
+ } else if (instance_extensions.debug_report) {
+ create_info.pNext = &debug_report_create_info;
+ DebugReporter::PopulateStaticCreateInfo(&debug_report_create_info);
+ }
+
+ // Some ICDs appear to leak in here, out of our control.
+ // Warning: leak checks remain disabled if an error is returned.
+ IREE_DISABLE_LEAK_CHECKS();
+ VkInstance instance = VK_NULL_HANDLE;
+ VK_RETURN_IF_ERROR(
+ syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance))
+ << "Unable to create Vulkan instance";
+ IREE_ENABLE_LEAK_CHECKS();
+
+ // TODO(benvanik): enable validation layers if needed.
+
+ // Now that the instance has been created we can fetch all of the instance
+ // symbols.
+ RETURN_IF_ERROR(syms->LoadFromInstance(instance));
+
+ // The real debug messenger (not just the static one used above) can now be
+ // created as we've loaded all the required symbols.
+ // TODO(benvanik): strip in release builds.
+ std::unique_ptr<DebugReporter> debug_reporter;
+ if (instance_extensions.debug_utils) {
+ ASSIGN_OR_RETURN(debug_reporter, DebugReporter::CreateDebugUtilsMessenger(
+ instance, syms,
+ /*allocation_callbacks=*/nullptr));
+ } else if (instance_extensions.debug_report) {
+ ASSIGN_OR_RETURN(debug_reporter,
+ DebugReporter::CreateDebugReportCallback(
+ instance, syms, /*allocation_callbacks=*/nullptr));
+ }
+
+ return std::make_shared<VulkanDriver>(
+ CtorKey{}, std::move(syms), instance, std::move(debug_reporter),
+ std::move(options.device_extensibility));
+}
+
+VulkanDriver::VulkanDriver(CtorKey ctor_key, ref_ptr<DynamicSymbols> syms,
+ VkInstance instance,
+ std::unique_ptr<DebugReporter> debug_reporter,
+ ExtensibilitySpec device_extensibility_spec)
+ : Driver("vulkan"),
+ syms_(std::move(syms)),
+ instance_(instance),
+ debug_reporter_(std::move(debug_reporter)),
+ device_extensibility_spec_(std::move(device_extensibility_spec)) {}
+
+VulkanDriver::~VulkanDriver() {
+ IREE_TRACE_SCOPE0("VulkanDriver::dtor");
+ debug_reporter_.reset();
+ syms()->vkDestroyInstance(instance_, /*pAllocator=*/nullptr);
+}
+
+StatusOr<std::vector<DeviceInfo>> VulkanDriver::EnumerateAvailableDevices() {
+ IREE_TRACE_SCOPE0("VulkanDriver::EnumerateAvailableDevices");
+
+ // Query all available devices (at this moment, note that this may change!).
+ uint32_t physical_device_count = 0;
+ VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices(
+ instance_, &physical_device_count, nullptr));
+ absl::InlinedVector<VkPhysicalDevice, 2> physical_devices(
+ physical_device_count);
+ VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices(
+ instance_, &physical_device_count, physical_devices.data()));
+
+ // Convert to our HAL structure.
+ std::vector<DeviceInfo> device_infos;
+ device_infos.reserve(physical_device_count);
+ for (auto physical_device : physical_devices) {
+ // TODO(benvanik): if we fail should we just ignore the device in the list?
+ ASSIGN_OR_RETURN(auto device_info,
+ PopulateDeviceInfo(physical_device, syms()));
+ device_infos.push_back(std::move(device_info));
+ }
+ return device_infos;
+}
+
+StatusOr<std::shared_ptr<Device>> VulkanDriver::CreateDefaultDevice() {
+ IREE_TRACE_SCOPE0("VulkanDriver::CreateDefaultDevice");
+
+ // Query available devices.
+ ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices());
+ if (available_devices.empty()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "No devices are available";
+ }
+
+ // Just create the first one we find.
+ return CreateDevice(available_devices.front());
+}
+
+StatusOr<std::shared_ptr<Device>> VulkanDriver::CreateDevice(
+ const DeviceInfo& device_info) {
+ IREE_TRACE_SCOPE0("VulkanDriver::CreateDevice");
+
+ auto physical_device =
+ static_cast<VkPhysicalDevice>(device_info.driver_handle());
+
+ // Attempt to create the device.
+ // This may fail if the device was enumerated but is in exclusive use,
+ // disabled by the system, or permission is denied.
+ ASSIGN_OR_RETURN(auto device,
+ VulkanDevice::Create(device_info, physical_device,
+ device_extensibility_spec_, syms()));
+
+ return device;
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/hal/vulkan/vulkan_driver.h b/hal/vulkan/vulkan_driver.h
new file mode 100644
index 0000000..34ec759
--- /dev/null
+++ b/hal/vulkan/vulkan_driver.h
@@ -0,0 +1,84 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_VULKAN_DRIVER_H_
+#define IREE_HAL_VULKAN_VULKAN_DRIVER_H_
+
+#include <vulkan/vulkan.h>
+
+#include <memory>
+#include <vector>
+
+#include "hal/driver.h"
+#include "hal/vulkan/debug_reporter.h"
+#include "hal/vulkan/dynamic_symbols.h"
+#include "hal/vulkan/extensibility_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+class VulkanDriver final : public Driver {
+ public:
+ struct Options {
+ // Vulkan version that will be requested.
+ // Driver creation will fail if the required version is not available.
+ uint32_t api_version = VK_API_VERSION_1_0;
+
+ // Extensibility descriptions for instances and devices.
+ // Device descriptions will be used for all devices created by the driver.
+ ExtensibilitySpec instance_extensibility;
+ ExtensibilitySpec device_extensibility;
+ };
+
+ static StatusOr<std::shared_ptr<VulkanDriver>> Create(
+ Options options, ref_ptr<DynamicSymbols> syms);
+
+ // TODO(benvanik): method to wrap an existing instance/device (interop).
+
+ // Private constructor.
+ struct CtorKey {
+ private:
+ friend class VulkanDriver;
+ CtorKey() = default;
+ };
+ VulkanDriver(CtorKey ctor_key, ref_ptr<DynamicSymbols> syms,
+ VkInstance instance,
+ std::unique_ptr<DebugReporter> debug_reporter,
+ ExtensibilitySpec device_extensibility_spec);
+ ~VulkanDriver() override;
+
+ const ref_ptr<DynamicSymbols>& syms() const { return syms_; }
+
+ VkInstance instance() const { return instance_; }
+
+ StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override;
+
+ StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override;
+
+ StatusOr<std::shared_ptr<Device>> CreateDevice(
+ const DeviceInfo& device_info) override;
+
+ private:
+ ref_ptr<DynamicSymbols> syms_;
+ VkInstance instance_;
+ std::unique_ptr<DebugReporter> debug_reporter_;
+ ExtensibilitySpec device_extensibility_spec_;
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_VULKAN_DRIVER_H_
diff --git a/hal/vulkan/vulkan_driver_module.cc b/hal/vulkan/vulkan_driver_module.cc
new file mode 100644
index 0000000..51da626
--- /dev/null
+++ b/hal/vulkan/vulkan_driver_module.cc
@@ -0,0 +1,102 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+
+#include "absl/flags/flag.h"
+#include "base/init.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/driver_registry.h"
+#include "hal/vulkan/dynamic_symbols.h"
+#include "hal/vulkan/vulkan_driver.h"
+
+ABSL_FLAG(bool, vulkan_validation_layers, true,
+ "Enables standard Vulkan validation layers.");
+ABSL_FLAG(bool, vulkan_debug_utils, true,
+ "Enables VK_EXT_debug_utils, records markers, and logs errors.");
+ABSL_FLAG(bool, vulkan_debug_report, false,
+ "Enables VK_EXT_debug_report and logs errors.");
+ABSL_FLAG(bool, vulkan_push_descriptors, true,
+ "Enables use of vkCmdPushDescriptorSetKHR, if available.");
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+namespace {
+
+StatusOr<std::shared_ptr<Driver>> CreateVulkanDriver() {
+ IREE_TRACE_SCOPE0("CreateVulkanDriver");
+
+ // Load the Vulkan library. This will fail if the library cannot be found or
+ // does not have the expected functions.
+ ASSIGN_OR_RETURN(auto syms, DynamicSymbols::CreateFromSystemLoader());
+
+ // Setup driver options from flags. We do this here as we want to enable other
+ // consumers that may not be using modules/command line flags to be able to
+ // set their options however they want.
+ VulkanDriver::Options options;
+
+ // TODO: validation layers have bugs when using VK_EXT_debug_report, so if the
+ // user requested that we force them off with a warning. Prefer using
+ // VK_EXT_debug_utils when available.
+ if (absl::GetFlag(FLAGS_vulkan_debug_report) &&
+ absl::GetFlag(FLAGS_vulkan_validation_layers)) {
+ LOG(WARNING) << "VK_EXT_debug_report has issues with modern validation "
+ "layers; disabling validation";
+ absl::SetFlag(&FLAGS_vulkan_validation_layers, false);
+ }
+
+ // REQUIRED: these are required extensions that must be present for IREE to
+ // work (such as those relied upon by SPIR-V kernels, etc).
+ options.device_extensibility.required_extensions.push_back(
+ VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME);
+
+ if (absl::GetFlag(FLAGS_vulkan_validation_layers)) {
+ options.instance_extensibility.optional_layers.push_back(
+ "VK_LAYER_LUNARG_standard_validation");
+ }
+
+ if (absl::GetFlag(FLAGS_vulkan_debug_report)) {
+ options.instance_extensibility.optional_extensions.push_back(
+ VK_EXT_DEBUG_REPORT_EXTENSION_NAME);
+ }
+ if (absl::GetFlag(FLAGS_vulkan_debug_utils)) {
+ options.instance_extensibility.optional_extensions.push_back(
+ VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
+ }
+
+ if (absl::GetFlag(FLAGS_vulkan_push_descriptors)) {
+ options.instance_extensibility.optional_extensions.push_back(
+ VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME);
+ options.device_extensibility.optional_extensions.push_back(
+ VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME);
+ }
+
+ // Create the driver and VkInstance.
+ ASSIGN_OR_RETURN(auto driver, VulkanDriver::Create(options, std::move(syms)));
+
+ return driver;
+}
+
+} // namespace
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+IREE_REGISTER_MODULE_INITIALIZER(iree_hal_vulkan_driver, {
+ QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
+ "vulkan", ::iree::hal::vulkan::CreateVulkanDriver));
+});
+IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_vulkan_driver);
diff --git a/iree/BUILD b/iree/BUILD
deleted file mode 100644
index d21f74c..0000000
--- a/iree/BUILD
+++ /dev/null
@@ -1,31 +0,0 @@
-# Main IREE build file.
-# Note that project-wide, bazel repo aliases are used:
-# "@com_google_absl//absl/python"
-# "@com_google_absl//absl"
-# "@com_google_benchmark//:benchmark"
-# "@local_config_mlir//"
-# "@llvm//"
-# "@com_github_google_flatbuffers//:flatbuffers"
-# "@org_tensorflow//tensorflow"
-#
-# Various scripts and helpers operate on these prefixes textually, so
-# avoid doing any systematic construction that would break the matching.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-# Enables the debug service and other profiling features.
-# $ bazel build --define=IREE_DEBUG=1 :some_target
-config_setting(
- name = "debug",
- define_values = {"IREE_DEBUG": "1"},
-)
-
-# Marker library which can be extended to provide flags for things that
-# need to know the platform target.
-cc_library(
- name = "target_config",
- defines = ["IREE_UNSPECIFIED_TARGET=1"],
-)
diff --git a/iree/base/BUILD b/iree/base/BUILD
deleted file mode 100644
index 3b8249e..0000000
--- a/iree/base/BUILD
+++ /dev/null
@@ -1,362 +0,0 @@
-# Common types and utilities used in the IREE codebase.
-
-load("//iree:build_defs.bzl", "platform_trampoline_deps")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "api",
- srcs = ["api.cc"],
- hdrs = ["api.h"],
- visibility = ["//visibility:public"],
- deps = [
- ":api_hdrs",
- ":api_util",
- ":file_mapping",
- ":tracing",
- ],
-)
-
-cc_library(
- name = "api_hdrs",
- hdrs = ["api.h"],
-)
-
-cc_library(
- name = "api_util",
- hdrs = ["api_util.h"],
- deps = [
- ":api_hdrs",
- ":logging",
- ":shape",
- ":status",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/time",
- ],
-)
-
-cc_library(
- name = "arena",
- srcs = ["arena.cc"],
- hdrs = ["arena.h"],
- deps = [
- ":logging",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_test(
- name = "arena_test",
- srcs = ["arena_test.cc"],
- deps = [
- ":arena",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "bitfield",
- hdrs = ["bitfield.h"],
- deps = [
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_test(
- name = "bitfield_test",
- srcs = ["bitfield_test.cc"],
- deps = [
- ":bitfield",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "file_io",
- hdrs = ["file_io.h"],
- deps = [
- ":status",
- ":target_platform",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- ] + platform_trampoline_deps("file_io"),
-)
-
-cc_library(
- name = "file_io_hdrs",
- hdrs = ["file_io.h"],
- deps = [":status"],
-)
-
-cc_library(
- name = "file_mapping",
- hdrs = ["file_mapping.h"],
- deps = [
- ":ref_ptr",
- ":status",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- ] + platform_trampoline_deps("file_mapping"),
-)
-
-cc_library(
- name = "file_mapping_hdrs",
- hdrs = ["file_mapping.h"],
- deps = [
- ":ref_ptr",
- ":status",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "file_path",
- srcs = ["file_path.cc"],
- hdrs = ["file_path.h"],
- deps = [
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
- name = "flatbuffer_util",
- srcs = ["flatbuffer_util.cc"],
- hdrs = ["flatbuffer_util.h"],
- deps = [
- ":file_mapping",
- ":memory",
- ":source_location",
- ":status",
- ":tracing",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "init",
- hdrs = ["init.h"],
- deps = platform_trampoline_deps("init"),
-)
-
-cc_library(
- name = "intrusive_list",
- hdrs = [
- "intrusive_list.h",
- "intrusive_list_ref_ptr.inc",
- "intrusive_list_unique_ptr.inc",
- ],
- deps = [
- ":logging",
- ":ref_ptr",
- ],
-)
-
-cc_test(
- name = "intrusive_list_test",
- srcs = [
- "intrusive_list_ref_ptr_test.cc",
- "intrusive_list_test.cc",
- "intrusive_list_unique_ptr_test.cc",
- ],
- deps = [
- ":intrusive_list",
- "@com_google_absl//absl/memory",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "logging",
- hdrs = ["logging.h"],
- deps = platform_trampoline_deps("logging"),
-)
-
-cc_library(
- name = "math",
- hdrs = ["math.h"],
- deps = [
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "memory",
- hdrs = ["memory.h"],
- deps = [
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "ref_ptr",
- hdrs = ["ref_ptr.h"],
- deps = [
- ":logging",
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_test(
- name = "ref_ptr_test",
- size = "small",
- srcs = ["ref_ptr_test.cc"],
- deps = [
- ":ref_ptr",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "shape",
- srcs = ["shape.cc"],
- hdrs = ["shape.h"],
- deps = [
- ":logging",
- ":source_location",
- ":status",
- "@com_google_absl//absl/meta:type_traits",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_test(
- name = "shape_test",
- srcs = ["shape_test.cc"],
- deps = [
- ":shape",
- ":status",
- ":status_matchers",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "source_location",
- hdrs = ["source_location.h"],
- deps = platform_trampoline_deps("source_location"),
-)
-
-cc_library(
- name = "status",
- hdrs = ["status.h"],
- deps = [
- ":source_location",
- ] + platform_trampoline_deps("status"),
-)
-
-cc_library(
- name = "status_matchers",
- testonly = 1,
- hdrs = ["status_matchers.h"],
- deps = platform_trampoline_deps("status_matchers"),
-)
-
-cc_library(
- name = "target_platform",
- hdrs = ["target_platform.h"],
-)
-
-cc_library(
- name = "time",
- hdrs = ["time.h"],
- deps = [
- "@com_google_absl//absl/time",
- ],
-)
-
-cc_library(
- name = "tracing",
- hdrs = ["tracing.h"],
- deps = [
- "//iree:target_config",
- "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
- ] + select({
- "@com_google_tracing_framework_cpp//:wtf_enable": [":tracing_enabled"],
- "//conditions:default": [":tracing_disabled"],
- }),
-)
-
-cc_library(
- name = "tracing_disabled",
- srcs = [
- "tracing.h",
- "tracing_disabled.cc",
- ],
- visibility = ["//visibility:private"],
- deps = [
- ":init",
- ":logging",
- "@com_google_absl//absl/flags:flag",
- "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "tracing_enabled",
- srcs = [
- "tracing.cc",
- "tracing.h",
- ],
- visibility = ["//visibility:private"],
- deps = [
- ":file_io",
- ":file_path",
- ":init",
- ":logging",
- ":status",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
- ],
- alwayslink = 1,
-)
-
-# Dependent code has been removed and wait_handle is currently incompatible
-# with Windows, so excluding entirely.
-# See google/iree/65
-# cc_library(
-# name = "wait_handle",
-# srcs = ["wait_handle.cc"],
-# hdrs = ["wait_handle.h"],
-# deps = [
-# ":logging",
-# ":ref_ptr",
-# ":source_location",
-# ":status",
-# ":time",
-# "@com_google_absl//absl/base:core_headers",
-# "@com_google_absl//absl/container:fixed_array",
-# "@com_google_absl//absl/strings",
-# "@com_google_absl//absl/time",
-# "@com_google_absl//absl/types:span",
-# ],
-# )
-
-# cc_test(
-# name = "wait_handle_test",
-# srcs = ["wait_handle_test.cc"],
-# deps = [
-# ":status",
-# ":status_matchers",
-# ":wait_handle",
-# "@com_google_absl//absl/time",
-# "@com_google_googletest//:gtest_main",
-# ],
-# )
diff --git a/iree/base/api.cc b/iree/base/api.cc
deleted file mode 100644
index cff917f..0000000
--- a/iree/base/api.cc
+++ /dev/null
@@ -1,124 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/api.h"
-
-#include <cstdlib>
-#include <string>
-
-#include "iree/base/api_util.h"
-#include "iree/base/file_mapping.h"
-#include "iree/base/tracing.h"
-
-namespace iree {
-
-//===----------------------------------------------------------------------===//
-// iree Core API
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_api_version_check(iree_api_version_t expected_version,
- iree_api_version_t* out_actual_version) {
- iree_api_version_t actual_version = IREE_API_VERSION_0;
- *out_actual_version = actual_version;
- return expected_version == actual_version ? IREE_STATUS_OK
- : IREE_STATUS_OUT_OF_RANGE;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_allocator_alloc(void* self, iree_host_size_t byte_length, void** out_ptr) {
- IREE_TRACE_SCOPE0("iree_allocator_alloc");
-
- if (!out_ptr) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_ptr = nullptr;
-
- if (byte_length <= 0) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- *out_ptr = std::malloc(byte_length);
- if (!*out_ptr) {
- return IREE_STATUS_RESOURCE_EXHAUSTED;
- }
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_allocator_free(void* self,
- void* ptr) {
- IREE_TRACE_SCOPE0("iree_allocator_free");
- if (ptr) {
- std::free(ptr);
- }
- return IREE_STATUS_OK;
-}
-
-//===----------------------------------------------------------------------===//
-// iree::FileMapping
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_file_mapping_open_read(iree_string_view_t path, iree_allocator_t allocator,
- iree_file_mapping_t** out_file_mapping) {
- IREE_TRACE_SCOPE0("iree_file_mapping_open_read");
-
- if (!out_file_mapping) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_file_mapping = nullptr;
-
- IREE_API_ASSIGN_OR_RETURN(
- auto file_mapping,
- FileMapping::OpenRead(std::string(path.data, path.size)));
-
- *out_file_mapping =
- reinterpret_cast<iree_file_mapping_t*>(file_mapping.release());
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_file_mapping_retain(iree_file_mapping_t* file_mapping) {
- IREE_TRACE_SCOPE0("iree_file_mapping_retain");
- auto* handle = reinterpret_cast<FileMapping*>(file_mapping);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->AddReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_file_mapping_release(iree_file_mapping_t* file_mapping) {
- IREE_TRACE_SCOPE0("iree_file_mapping_release");
- auto* handle = reinterpret_cast<FileMapping*>(file_mapping);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->ReleaseReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_byte_span_t IREE_API_CALL
-iree_file_mapping_data(iree_file_mapping_t* file_mapping) {
- IREE_TRACE_SCOPE0("iree_file_mapping_data");
- auto* handle = reinterpret_cast<FileMapping*>(file_mapping);
- CHECK(handle) << "NULL file_mapping handle";
- auto data = handle->data();
- return {const_cast<uint8_t*>(data.data()), data.size()};
-}
-
-} // namespace iree
diff --git a/iree/base/api_util.h b/iree/base/api_util.h
deleted file mode 100644
index 5eaa5e5..0000000
--- a/iree/base/api_util.h
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_API_UTIL_H_
-#define IREE_BASE_API_UTIL_H_
-
-#include "absl/base/macros.h"
-#include "absl/time/time.h"
-#include "iree/base/api.h"
-#include "iree/base/logging.h"
-#include "iree/base/shape.h"
-#include "iree/base/status.h"
-
-namespace iree {
-
-inline iree_status_t ToApiStatus(Status status) {
- DLOG(ERROR) << status;
- return static_cast<iree_status_t>(status.code());
-}
-
-inline Status FromApiStatus(iree_status_t status_code, SourceLocation loc) {
- return StatusBuilder(static_cast<StatusCode>(status_code), loc);
-}
-
-// Internal helper for concatenating macro values.
-#define IREE_API_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
-#define IREE_API_STATUS_MACROS_IMPL_CONCAT_(x, y) \
- IREE_API_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y)
-
-// clang-format off
-#define IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT
-// clang-format on
-
-namespace status_macro_internal {
-class StatusAdaptorForApiMacros {
- public:
- StatusAdaptorForApiMacros(const Status& status) : status_(status) {}
- StatusAdaptorForApiMacros(Status&& status) : status_(std::move(status)) {}
- StatusAdaptorForApiMacros(const StatusAdaptorForApiMacros&) = delete;
- StatusAdaptorForApiMacros& operator=(const StatusAdaptorForApiMacros&) =
- delete;
- explicit operator bool() const { return ABSL_PREDICT_TRUE(status_.ok()); }
- Status&& Consume() { return std::move(status_); }
-
- private:
- Status status_;
-};
-} // namespace status_macro_internal
-
-#define IREE_API_RETURN_IF_ERROR(expr) \
- IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
- if (::iree::status_macro_internal::StatusAdaptorForApiMacros \
- status_adaptor = {expr}) { \
- } else /* NOLINT */ \
- return ::iree::ToApiStatus(status_adaptor.Consume())
-
-#define IREE_API_RETURN_IF_API_ERROR(expr) \
- IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
- if (iree_status_t status = (expr)) { \
- return status; \
- }
-
-#define IREE_API_ASSIGN_OR_RETURN(...) \
- IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_( \
- (__VA_ARGS__, IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \
- IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \
- (__VA_ARGS__)
-
-#define IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, \
- ...) \
- NAME
-#define IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_(args) \
- IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args
-
-#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \
- IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, std::move(_))
-#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, \
- error_expression) \
- IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \
- IREE_API_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, \
- rexpr, error_expression)
-#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \
- error_expression) \
- auto statusor = (rexpr); \
- if (ABSL_PREDICT_FALSE(!statusor.ok())) { \
- return ::iree::ToApiStatus(std::move(statusor).status()); \
- } \
- lhs = std::move(statusor).ValueOrDie()
-
-// Converts an iree_time_t to its equivalent absl::Time.
-inline absl::Time ToAbslTime(iree_time_t time) {
- if (time == IREE_TIME_INFINITE_PAST) {
- return absl::InfinitePast();
- } else if (time == IREE_TIME_INFINITE_FUTURE) {
- return absl::InfiniteFuture();
- } else {
- return absl::FromUnixNanos(time);
- }
-}
-
-// Converts a Shape to an iree_shape_t.
-inline iree_status_t ToApiShape(const Shape& shape, iree_shape_t* out_shape) {
- out_shape->rank = shape.size();
- if (shape.size() > ABSL_ARRAYSIZE(out_shape->dims)) {
- return IREE_STATUS_OUT_OF_RANGE;
- }
- for (int i = 0; i < out_shape->rank; ++i) {
- out_shape->dims[i] = shape[i];
- }
- return IREE_STATUS_OK;
-}
-
-} // namespace iree
-
-#endif // IREE_BASE_API_UTIL_H_
diff --git a/iree/base/arena.cc b/iree/base/arena.cc
deleted file mode 100644
index 9a6d7c2..0000000
--- a/iree/base/arena.cc
+++ /dev/null
@@ -1,125 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/arena.h"
-
-#include <memory>
-
-#include "absl/base/attributes.h"
-#include "iree/base/logging.h"
-
-namespace iree {
-
-namespace {
-
-// Rounds up to the next alignment value, if it is not already aligned.
-template <typename T>
-ABSL_ATTRIBUTE_ALWAYS_INLINE constexpr T RoundToAlignment(
- T value, T alignment) noexcept {
- return ((value + alignment - 1) / alignment) * alignment;
-}
-
-} // namespace
-
-Arena::Arena(size_t block_size) : block_size_(block_size) {}
-
-Arena::~Arena() { Clear(); }
-
-void Arena::Clear() {
- // Deallocate all memory.
- auto block_header = block_list_head_;
- while (block_header) {
- auto next_block = block_header->next_block;
- std::free(block_header);
- block_header = next_block;
- }
- block_list_head_ = nullptr;
- block_header = unused_block_list_head_;
- while (block_header) {
- auto next_block = block_header->next_block;
- std::free(block_header);
- block_header = next_block;
- }
- unused_block_list_head_ = nullptr;
-
- bytes_allocated_ = 0;
- block_bytes_allocated_ = 0;
-}
-
-void Arena::Reset() {
- // Move all blocks to the unused list and reset allocation count only.
- auto block_header = block_list_head_;
- while (block_header) {
- auto next_block = block_header->next_block;
- block_header->bytes_allocated = 0;
- block_header->next_block = unused_block_list_head_;
- unused_block_list_head_ = block_header;
- block_header = next_block;
- }
- block_list_head_ = nullptr;
-
- bytes_allocated_ = 0;
-}
-
-uint8_t* Arena::AllocateBytes(size_t length) {
- if (!length) {
- // Guarantee zero-length allocations return nullptr.
- return nullptr;
- }
-
- // Pad length allocated so we are machine word aligned.
- // This ensures the next allocation starts at the right boundary.
- size_t aligned_length = RoundToAlignment(length, sizeof(uintptr_t));
-
- if (aligned_length > block_size_) {
- // This allocation is larger than an entire block. That's bad.
- // We could allocate this with malloc (and then keep track of those to free
- // things), but for now let's just die.
- CHECK(false);
- return nullptr;
- }
-
- if (!block_list_head_ ||
- block_list_head_->bytes_allocated + aligned_length > block_size_) {
- // Check to see if we have an existing unused block we can use.
- if (unused_block_list_head_) {
- // Move block from unused list to main list.
- auto block_header = unused_block_list_head_;
- unused_block_list_head_ = block_header->next_block;
- block_header->next_block = block_list_head_;
- block_header->bytes_allocated = 0;
- block_list_head_ = block_header;
- } else {
- // Allocate a new block.
- auto block_ptr = reinterpret_cast<uint8_t*>(
- std::malloc(sizeof(BlockHeader) + block_size_));
- auto block_header = reinterpret_cast<BlockHeader*>(block_ptr);
- block_header->next_block = block_list_head_;
- block_header->bytes_allocated = 0;
- block_list_head_ = block_header;
- block_bytes_allocated_ += sizeof(BlockHeader) + block_size_;
- }
- }
-
- BlockHeader* target_block = block_list_head_;
- auto data_ptr = reinterpret_cast<uint8_t*>(target_block) +
- sizeof(BlockHeader) + target_block->bytes_allocated;
- target_block->bytes_allocated += aligned_length;
-
- bytes_allocated_ += length;
-
- return data_ptr;
-}
-
-} // namespace iree
diff --git a/iree/base/arena_test.cc b/iree/base/arena_test.cc
deleted file mode 100644
index fb0e801..0000000
--- a/iree/base/arena_test.cc
+++ /dev/null
@@ -1,148 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/arena.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace iree {
-namespace {
-
-// Tests basic block allocations.
-TEST(ArenaTest, BasicAllocation) {
- Arena arena(64);
- EXPECT_EQ(64, arena.block_size());
- EXPECT_EQ(0, arena.bytes_allocated());
- EXPECT_EQ(0, arena.block_bytes_allocated());
-
- // Zero byte allocations should return nullptr and not allocate bytes.
- auto zero_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(0));
- EXPECT_EQ(0, zero_ptr);
- EXPECT_EQ(0, arena.bytes_allocated());
- EXPECT_EQ(0, arena.block_bytes_allocated());
-
- arena.Clear();
-
- // Allocations must be machine word aligned.
- auto one_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(1));
- EXPECT_NE(0, one_ptr);
- EXPECT_EQ(0, one_ptr % sizeof(uintptr_t));
- one_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(1));
- EXPECT_NE(0, one_ptr);
- EXPECT_EQ(0, one_ptr % sizeof(uintptr_t));
- EXPECT_EQ(2, arena.bytes_allocated());
- EXPECT_LT(2, arena.block_bytes_allocated());
-
- arena.Clear();
- EXPECT_EQ(0, arena.bytes_allocated());
- EXPECT_EQ(0, arena.block_bytes_allocated());
-}
-
-// Tests typed allocations.
-TEST(ArenaTest, TypedAllocations) {
- Arena arena(64);
-
- EXPECT_NE(nullptr, arena.Allocate<int>());
- EXPECT_EQ(4, arena.bytes_allocated());
- EXPECT_EQ(64 + Arena::kBlockOverhead, arena.block_bytes_allocated());
- arena.Clear();
- EXPECT_EQ(0, arena.bytes_allocated());
- EXPECT_EQ(0, arena.block_bytes_allocated());
-
- struct MyType {
- MyType() {}
- explicit MyType(int initial_value) : value(initial_value) {}
-
- int value = 5;
- };
- auto my_type_ptr = arena.Allocate<MyType>();
- EXPECT_NE(nullptr, my_type_ptr);
- EXPECT_EQ(sizeof(MyType), arena.bytes_allocated());
- EXPECT_EQ(5, my_type_ptr->value); // Default ctor must be called.
- arena.Clear();
- EXPECT_EQ(0, arena.bytes_allocated());
- EXPECT_EQ(0, arena.block_bytes_allocated());
-
- my_type_ptr = arena.Allocate<MyType>(10);
- EXPECT_NE(nullptr, my_type_ptr);
- EXPECT_EQ(sizeof(MyType), arena.bytes_allocated());
- EXPECT_EQ(10, my_type_ptr->value); // Ctor should have been called.
- arena.Clear();
- EXPECT_EQ(0, arena.bytes_allocated());
- EXPECT_EQ(0, arena.block_bytes_allocated());
-}
-
-// Tests multiple blocks.
-TEST(ArenaTest, MultipleBlocks) {
- Arena arena(16);
- EXPECT_EQ(0, arena.bytes_allocated());
- EXPECT_EQ(0, arena.block_bytes_allocated());
-
- // Allocate one entire block.
- EXPECT_NE(nullptr, arena.AllocateBytes(16));
- EXPECT_EQ(16, arena.bytes_allocated());
- EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated());
-
- // Allocate into the next block.
- EXPECT_NE(nullptr, arena.AllocateBytes(16));
- EXPECT_EQ(32, arena.bytes_allocated());
- EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
-
- // Clear.
- arena.Clear();
- EXPECT_EQ(0, arena.bytes_allocated());
- EXPECT_EQ(0, arena.block_bytes_allocated());
-
- // Allocate again.
- EXPECT_NE(nullptr, arena.AllocateBytes(16));
- EXPECT_EQ(16, arena.bytes_allocated());
- EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated());
- EXPECT_NE(nullptr, arena.AllocateBytes(16));
- EXPECT_EQ(32, arena.bytes_allocated());
- EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
-}
-
-// Tests fast reset.
-TEST(ArenaTest, FastReset) {
- Arena arena(16);
- EXPECT_EQ(0, arena.bytes_allocated());
- EXPECT_EQ(0, arena.block_bytes_allocated());
-
- // Allocate one entire block.
- EXPECT_NE(nullptr, arena.AllocateBytes(16));
- EXPECT_EQ(16, arena.bytes_allocated());
- EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated());
-
- // Allocate into the next block.
- EXPECT_NE(nullptr, arena.AllocateBytes(16));
- EXPECT_EQ(32, arena.bytes_allocated());
- EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
-
- // Reset (without deallocating).
- arena.Reset();
- EXPECT_EQ(0, arena.bytes_allocated());
- EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
-
- // Allocate again.
- EXPECT_NE(nullptr, arena.AllocateBytes(16));
- EXPECT_EQ(16, arena.bytes_allocated());
- EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
- EXPECT_NE(nullptr, arena.AllocateBytes(16));
- EXPECT_EQ(32, arena.bytes_allocated());
- EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated());
-}
-
-} // namespace
-} // namespace iree
diff --git a/iree/base/bitfield_test.cc b/iree/base/bitfield_test.cc
deleted file mode 100644
index 0e545a6..0000000
--- a/iree/base/bitfield_test.cc
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/bitfield.h"
-
-#include <cstdint>
-#include <vector>
-
-#include "gtest/gtest.h"
-
-namespace iree {
-
-// NOTE: define here so that we don't get internal linkage warnings.
-enum class MyValue : uint32_t {
- kNone = 0,
- kA = 1 << 0,
- kB = 1 << 1,
- kAll = kA | kB,
-};
-IREE_BITFIELD(MyValue);
-
-namespace {
-
-// Tests general usage.
-TEST(BitfieldTest, FormatBitfieldValue) {
- std::vector<std::pair<MyValue, const char *>> mappings = {
- {MyValue::kA, "kA"},
- {MyValue::kB, "kB"},
- };
- EXPECT_EQ("",
- FormatBitfieldValue(MyValue::kNone, absl::MakeConstSpan(mappings)));
- EXPECT_EQ("kA",
- FormatBitfieldValue(MyValue::kA, absl::MakeConstSpan(mappings)));
- EXPECT_EQ("kA|kB", FormatBitfieldValue(MyValue::kA | MyValue::kB,
- absl::MakeConstSpan(mappings)));
-}
-
-// Tests that empty mapping tables are fine.
-TEST(BitfieldTest, FormatBitfieldValueEmpty) {
- EXPECT_EQ("", FormatBitfieldValue(MyValue::kNone, {}));
-}
-
-// Tests that values not found in the mappings are still displayed.
-TEST(BitfieldTest, FormatBitfieldValueUnhandledValues) {
- EXPECT_EQ("kA|2h", FormatBitfieldValue(MyValue::kA | MyValue::kB,
- {
- {MyValue::kA, "kA"},
- }));
-}
-
-// Tests priority order in the mapping table.
-TEST(BitfieldTest, FormatBitfieldValuePriority) {
- // No priority, will do separate.
- EXPECT_EQ("kA|kB", FormatBitfieldValue(MyValue::kA | MyValue::kB,
- {
- {MyValue::kA, "kA"},
- {MyValue::kB, "kB"},
- {MyValue::kAll, "kAll"},
- }));
-
- // Priority on the combined flag, use that instead.
- EXPECT_EQ("kAll", FormatBitfieldValue(MyValue::kA | MyValue::kB,
- {
- {MyValue::kAll, "kAll"},
- {MyValue::kA, "kA"},
- {MyValue::kB, "kB"},
- }));
-}
-
-} // namespace
-} // namespace iree
diff --git a/iree/base/file_io.h b/iree/base/file_io.h
deleted file mode 100644
index 0a2003d..0000000
--- a/iree/base/file_io.h
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_FILE_IO_H_
-#define IREE_BASE_FILE_IO_H_
-
-#include <string>
-
-#include "iree/base/status.h"
-
-namespace iree {
-namespace file_io {
-
-// Checks if a file exists at the provided path.
-//
-// Returns an OK status if the file definitely exists.
-// Errors can include PermissionDeniedError, NotFoundError, etc.
-Status FileExists(const std::string& path);
-
-// Synchronously reads a file's contents into a string.
-StatusOr<std::string> GetFileContents(const std::string& path);
-
-// Deletes the file at the provided path.
-Status DeleteFile(const std::string& path);
-
-// Moves a file from 'source_path' to 'destination_path'.
-//
-// This may simply rename the file, but may fall back to a full copy and delete
-// of the original if renaming is not possible (for example when moving between
-// physical storage locations).
-Status MoveFile(const std::string& source_path,
- const std::string& destination_path);
-
-} // namespace file_io
-} // namespace iree
-
-#endif // IREE_BASE_FILE_IO_H_
diff --git a/iree/base/file_mapping.h b/iree/base/file_mapping.h
deleted file mode 100644
index e871370..0000000
--- a/iree/base/file_mapping.h
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_FILE_MAPPING_H_
-#define IREE_BASE_FILE_MAPPING_H_
-
-#include <cstdint>
-#include <string>
-
-#include "absl/types/span.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-
-namespace iree {
-
-// A memory-mapped file handle.
-class FileMapping : public RefObject<FileMapping> {
- public:
- // Opens a file and maps it into the calling process memory.
- // The file will be opened for shared read access.
- static StatusOr<ref_ptr<FileMapping>> OpenRead(std::string path);
-
- virtual ~FileMapping() = default;
-
- // Read-only contents of the file.
- inline absl::Span<const uint8_t> data() const noexcept { return data_; }
-
- protected:
- explicit FileMapping(absl::Span<const uint8_t> data) : data_(data) {}
-
- absl::Span<const uint8_t> data_;
-
- private:
- FileMapping(const FileMapping&) = delete;
- FileMapping& operator=(const FileMapping&) = delete;
-};
-
-} // namespace iree
-
-#endif // IREE_BASE_FILE_MAPPING_H_
diff --git a/iree/base/file_path.cc b/iree/base/file_path.cc
deleted file mode 100644
index 13230ce..0000000
--- a/iree/base/file_path.cc
+++ /dev/null
@@ -1,83 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/file_path.h"
-
-#include "absl/strings/str_cat.h"
-
-namespace iree {
-namespace file_path {
-
-namespace {
-
-std::pair<absl::string_view, absl::string_view> SplitPath(
- absl::string_view path) {
- size_t pos = path.find_last_of('/');
- // Handle the case with no '/' in 'path'.
- if (pos == absl::string_view::npos) {
- return std::make_pair(path.substr(0, 0), path);
- }
- // Handle the case with a single leading '/' in 'path'.
- if (pos == 0) {
- return std::make_pair(path.substr(0, 1), absl::ClippedSubstr(path, 1));
- }
- return std::make_pair(path.substr(0, pos),
- absl::ClippedSubstr(path, pos + 1));
-}
-
-// Return the parts of the basename of path, split on the final ".".
-// If there is no "." in the basename or "." is the final character in the
-// basename, the second value will be empty.
-std::pair<absl::string_view, absl::string_view> SplitBasename(
- absl::string_view path) {
- path = Basename(path);
- size_t pos = path.find_last_of('.');
- if (pos == absl::string_view::npos)
- return std::make_pair(path, absl::ClippedSubstr(path, path.size(), 0));
- return std::make_pair(path.substr(0, pos),
- absl::ClippedSubstr(path, pos + 1));
-}
-
-} // namespace
-
-std::string JoinPaths(absl::string_view path1, absl::string_view path2) {
- if (path1.empty()) return std::string(path2);
- if (path2.empty()) return std::string(path1);
- if (path1.back() == '/') {
- if (path2.front() == '/')
- return absl::StrCat(path1, absl::ClippedSubstr(path2, 1));
- } else {
- if (path2.front() != '/') return absl::StrCat(path1, "/", path2);
- }
- return absl::StrCat(path1, path2);
-}
-
-absl::string_view DirectoryName(absl::string_view path) {
- return SplitPath(path).first;
-}
-
-absl::string_view Basename(absl::string_view path) {
- return SplitPath(path).second;
-}
-
-absl::string_view Stem(absl::string_view path) {
- return SplitBasename(path).first;
-}
-
-absl::string_view Extension(absl::string_view path) {
- return SplitBasename(path).second;
-}
-
-} // namespace file_path
-} // namespace iree
diff --git a/iree/base/flatbuffer_util.cc b/iree/base/flatbuffer_util.cc
deleted file mode 100644
index 7ce44f4..0000000
--- a/iree/base/flatbuffer_util.cc
+++ /dev/null
@@ -1,145 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/flatbuffer_util.h"
-
-#include <cerrno>
-#include <cstring>
-
-#include "absl/memory/memory.h"
-#include "iree/base/file_mapping.h"
-#include "iree/base/memory.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-
-namespace iree {
-
-FlatBufferFileBase::~FlatBufferFileBase() {
- if (deleter_) {
- deleter_();
- deleter_ = []() {};
- }
-}
-
-Status FlatBufferFileBase::Create(const void* root_ptr,
- std::function<void()> deleter) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::Create");
-
- root_ptr_ = root_ptr;
- deleter_ = std::move(deleter);
-
- return OkStatus();
-}
-
-Status FlatBufferFileBase::CreateWithBackingBuffer(
- const void* root_ptr, ::flatbuffers::DetachedBuffer backing_buffer) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::Create");
-
- root_ptr_ = root_ptr;
-
- // Pass along the buffer provided so we keep it alive until the
- // FlatBufferFileBase is destructed.
- auto backing_buffer_baton = IreeMoveToLambda(backing_buffer);
- deleter_ = [backing_buffer_baton]() { (void)backing_buffer_baton.value; };
-
- return OkStatus();
-}
-
-Status FlatBufferFileBase::Wrap(const void* root_ptr) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::Wrap");
- return Create(root_ptr, []() {});
-}
-
-Status FlatBufferFileBase::FromBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- IREE_TRACE_SCOPE("FlatBufferFileBase::FromBuffer:size", int)
- (static_cast<int>(buffer_data.size()));
-
- // Sanity check buffer for the minimum size as FlatBuffers doesn't.
- if (buffer_data.size() < 16) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Provided serialized flatbuffer buffer is too small to be legit "
- "at size="
- << buffer_data.size();
- }
-
- // Ensure the buffer has the BIPE magic bytes.
- if (identifier.has_value() && !::flatbuffers::BufferHasIdentifier(
- buffer_data.data(), identifier.value())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Provided serialized buffer does not contain the expected type; "
- "magic bytes mismatch (expected "
- << identifier.value() << ")";
- }
-
- // Verify the FlatBuffer contains valid offsets and won't try to read out of
- // bounds of the buffer. We inline a bit of VerifyBufferFromStart so this code
- // can stay generic.
- {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::FromBufferVerification");
- ::flatbuffers::Verifier verifier{buffer_data.data(), buffer_data.size()};
- if (!verifier_fn(identifier.value_or(nullptr), &verifier)) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "FlatBuffer failed to verify as expected type; possibly "
- "corrupt input";
- }
- }
-
- // Resolve the root pointer in the buffer.
- // This is GetMutableRoot such that we don't need to know T.
- root_ptr_ = buffer_data.data() +
- ::flatbuffers::EndianScalar(
- *reinterpret_cast<const ::flatbuffers::uoffset_t*>(
- buffer_data.data()));
- if (!root_ptr_) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Unable to resolve root table";
- }
- deleter_ = std::move(deleter);
-
- return OkStatus();
-}
-
-Status FlatBufferFileBase::WrapBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::WrapBuffer");
- return FromBuffer(
- identifier, buffer_data, []() {}, root_type_size, verifier_fn);
-}
-
-Status FlatBufferFileBase::LoadFile(Identifier identifier, std::string path,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::LoadFile");
-
- ASSIGN_OR_RETURN(auto file_mapping, FileMapping::OpenRead(path));
- auto buffer_data = file_mapping->data();
-
- auto handle_baton = IreeMoveToLambda(file_mapping);
- return FromBuffer(
- identifier, buffer_data,
- [handle_baton]() {
- // Keeping the mmap handle alive.
- (void)handle_baton.value;
- },
- root_type_size, verifier_fn);
-}
-
-} // namespace iree
diff --git a/iree/base/flatbuffer_util.h b/iree/base/flatbuffer_util.h
deleted file mode 100644
index 0e9344f..0000000
--- a/iree/base/flatbuffer_util.h
+++ /dev/null
@@ -1,321 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_FLATBUFFER_UTIL_H_
-#define IREE_BASE_FLATBUFFER_UTIL_H_
-
-#include <cstddef>
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "absl/types/span.h"
-#include "flatbuffers/flatbuffers.h"
-#include "iree/base/memory.h"
-#include "iree/base/status.h"
-
-namespace iree {
-
-// Wraps a FlatBuffer String in an absl::string_view.
-// Returns empty-string ("") for nullptr values.
-inline absl::string_view WrapString(const ::flatbuffers::String* value) {
- return value ? absl::string_view{value->data(), value->size()} : "";
-}
-
-// Base type for FlatBufferFile<T>. See below.
-class FlatBufferFileBase {
- public:
- using Identifier = absl::optional<const char*>;
-
- virtual ~FlatBufferFileBase();
-
- protected:
- template <typename T>
- friend class FlatBufferFile;
-
- using VerifierFn = bool (*)(const char* identifier,
- ::flatbuffers::Verifier* verifier);
-
- FlatBufferFileBase() = default;
-
- const void* root_ptr() const { return root_ptr_; }
-
- // Redirections of template static methods on FlatBufferFile so we can put the
- // implementations in a shared compilation unit.
- // See FlatBufferFile<T> for doc comments.
- Status Create(const void* root_ptr, std::function<void()> deleter);
- Status CreateWithBackingBuffer(const void* root_ptr,
- ::flatbuffers::DetachedBuffer backing_buffer);
- Status Wrap(const void* root);
- Status FromBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter, size_t root_type_size,
- VerifierFn verifier_fn);
- // Initializes from an STL byte based container (string and vector of
- // char/byte should be compatible).
- template <typename Container>
- Status FromContainer(Identifier identifier, Container container,
- size_t root_type_size, VerifierFn verifier_fn);
- Status WrapBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- size_t root_type_size, VerifierFn verifier_fn);
- Status LoadFile(Identifier identifier, std::string path,
- size_t root_type_size, VerifierFn verifier_fn);
-
- private:
- const void* root_ptr_ = nullptr;
- std::function<void()> deleter_;
-};
-
-// Immutable root FlatBuffer type wrapper with support for loading and backing
-// buffer management.
-//
-// Immutable and thread-safe.
-template <typename T>
-class FlatBufferFile final : public FlatBufferFileBase {
- public:
- // Creates a FlatBufferFile from an in-memory root pointer.
- // The provided |deleter| will be called when the FlatBufferFile is destructed
- // and can be used to deallocate/clean up resources.
- //
- // This assumes that the root pointer has already been verified as valid.
- // If verification is required instead use FromBuffer on the original buffer.
- static StatusOr<std::unique_ptr<FlatBufferFile<T>>> Create(
- const T* root, std::function<void()> deleter);
-
- // Creates a FlatBufferFile from an in-memory root pointer and the detached
- // backing buffer storing it.
- //
- // Example:
- // FlatBufferBuilder fbb;
- // MyTypeBuilder mtb(fbb);
- // fbb.Finish(mtb.Finish());
- // auto my_type = FlatBufferFile<MyType>::CreateWithBackingBuffer(
- // fbb.Release());
- // my_type->foo();
- static StatusOr<std::unique_ptr<FlatBufferFile<T>>> CreateWithBackingBuffer(
- ::flatbuffers::DetachedBuffer backing_buffer);
-
- // Wraps a caller-owned in-memory root pointer.
- // The provided |root| must remain valid for the lifetime of the returned
- // FlatBufferFile.
- //
- // This assumes that the root pointer has already been verified as valid.
- // If verification is required instead use FromBuffer on the original buffer.
- static StatusOr<std::unique_ptr<FlatBufferFile<T>>> Wrap(const T* root);
-
- // Creates a FlatBufferFile wrapping an external data buffer with a deleter
- // function that will be called when the FlatBufferFile is destructed.
- static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter);
-
- // Creates a FlatBufferFile from a serialized data buffer.
- // The FlatBufferFile takes ownership of the vector.
- static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromBuffer(
- Identifier identifier, std::vector<uint8_t> buffer_data);
-
- // Loads a FlatBufferFile from an external buffer owned by the caller.
- // The buffer must remain valid until the Pipeline is destroyed.
- static StatusOr<std::unique_ptr<FlatBufferFile<T>>> WrapBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data);
-
- // Loads the FlatBufferFile from a serialized byte-based STL container.
- template <typename Container>
- static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromContainer(
- Identifier identifier, Container buffer_data);
-
- // Loads a FlatBufferFile from a serialized string.
- // The FlatBufferFile takes ownership of the string.
- static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromString(
- Identifier identifier, std::string buffer_data) {
- return FromContainer(identifier, std::move(buffer_data));
- }
-
- // Loads a FlatBufferFile from a serialized byte vector.
- // The FlatBufferFile takes ownership of the vector.
- static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromVector(
- Identifier identifier, std::vector<uint8_t> buffer_data) {
- return FromContainer(identifier, std::move(buffer_data));
- }
-
- // Loads a FlatBufferFile from a serialized file on the file system.
- // This will attempt to mmap the file and is the preferred way of loading as
- // only those pages that contain requested tables will be read.
- static StatusOr<std::unique_ptr<FlatBufferFile<T>>> LoadFile(
- Identifier identifier, std::string path);
-
- // Returns a vector of file references that share the same underlying data
- // buffer. The buffer will be kept alive until the last file is released.
- static StatusOr<std::vector<std::unique_ptr<FlatBufferFile<T>>>>
- CreateShareGroup(std::unique_ptr<FlatBufferFile<T>> file, int count);
-
- ~FlatBufferFile() override = default;
-
- // Typed root pointer of the file.
- const T* root() const { return reinterpret_cast<const T*>(root_ptr()); }
-
- private:
- FlatBufferFile() = default;
-
- // Conforms to VerifierFn.
- static bool VerifierFnT(const char* identifier,
- ::flatbuffers::Verifier* verifier) {
- return verifier->VerifyBuffer<T>(identifier);
- }
-};
-
-template <typename Container>
-Status FlatBufferFileBase::FromContainer(Identifier identifier,
- Container container,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- static_assert(sizeof(*container.data()) == 1,
- "Expected container of byte sized elements");
- auto buffer_data = absl::MakeConstSpan(
- // Double static_cast through void is safer than reinterpret_cast.
- static_cast<const uint8_t*>(static_cast<const void*>(container.data())),
- container.size());
- // Use a baton to keep the container alive until the FlatBufferFileBase is
- // destroyed.
- auto buffer_data_baton = IreeMoveToLambda(container);
- return FromBuffer(
- identifier, buffer_data,
- [buffer_data_baton]() {
- // Keeping the container alive.
- (void)buffer_data_baton.value;
- },
- root_type_size, verifier_fn);
-}
-
-// static
-template <typename T>
-StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Create(
- const T* root, std::function<void()> deleter) {
- std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- RETURN_IF_ERROR(base_file->Create(root, std::move(deleter)));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<std::unique_ptr<FlatBufferFile<T>>>
-FlatBufferFile<T>::CreateWithBackingBuffer(
- ::flatbuffers::DetachedBuffer backing_buffer) {
- std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- auto* root_ptr = ::flatbuffers::GetRoot<T>(backing_buffer.data());
- RETURN_IF_ERROR(
- base_file->CreateWithBackingBuffer(root_ptr, std::move(backing_buffer)));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Wrap(
- const T* root) {
- std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- RETURN_IF_ERROR(base_file->Wrap(root));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter) {
- std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- RETURN_IF_ERROR(base_file->FromBuffer(
- identifier, buffer_data, std::move(deleter), sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer(
- Identifier identifier, std::vector<uint8_t> buffer_data) {
- auto* buffer_data_ptr = new decltype(buffer_data);
- (*buffer_data_ptr) = std::move(buffer_data);
- return FromBuffer(identifier, absl::MakeConstSpan(*buffer_data_ptr),
- [buffer_data_ptr]() { delete buffer_data_ptr; });
-}
-
-// static
-template <typename T>
-StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::WrapBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data) {
- std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- RETURN_IF_ERROR(
- base_file->WrapBuffer(identifier, buffer_data, sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-template <typename Container>
-StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromContainer(
- Identifier identifier, Container buffer_data) {
- std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- RETURN_IF_ERROR(base_file->FromContainer(identifier, std::move(buffer_data),
- sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::LoadFile(
- Identifier identifier, std::string path) {
- std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- RETURN_IF_ERROR(
- base_file->LoadFile(identifier, std::move(path), sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<std::vector<std::unique_ptr<FlatBufferFile<T>>>>
-FlatBufferFile<T>::CreateShareGroup(std::unique_ptr<FlatBufferFile<T>> file,
- int count) {
- // Create a shared_ptr wrapper for the base file that will be.
- std::shared_ptr<FlatBufferFile<T>> shared_file{file.release()};
-
- // Create N files. We wrap and keep the shared_ptr alive in the deleter
- // capture. By wrapping we avoid reverifying the entire buffer.
- std::vector<std::unique_ptr<FlatBufferFile<T>>> list;
- for (int i = 0; i < count; ++i) {
- ASSIGN_OR_RETURN(auto new_file, FlatBufferFile<T>::Create(
- shared_file->root(), [shared_file]() {
- // Each new file keeps a reference to
- // the shared file to keep it alive.
- (void)shared_file;
- }));
- list.push_back(std::move(new_file));
- }
- return std::move(list);
-}
-
-} // namespace iree
-
-#endif // IREE_BASE_FLATBUFFER_UTIL_H_
diff --git a/iree/base/init.h b/iree/base/init.h
deleted file mode 100644
index 4e8f345..0000000
--- a/iree/base/init.h
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INIT_H_
-#define IREE_BASE_INIT_H_
-
-// Initializer macros are defined in separate files:
-// IREE_DECLARE_MODULE_INITIALIZER(name)
-// IREE_REGISTER_MODULE_INITIALIZER(name, body)
-// IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2)
-// IREE_REQUIRE_MODULE_INITIALIZED(name)
-// IREE_RUN_MODULE_INITIALIZERS()
-// IREE_REQUIRE_MODULE_LINKED(name)
-//
-// These macros allow for arranging pieces of initialization code to be
-// executed at a well-defined time and in a well-defined order.
-//
-// Initialization happens automatically during InitializeEnvironment(), which
-// should be called early in main(), before other code runs.
-
-#ifdef IREE_CONFIG_GOOGLE_INTERNAL
-#include "iree/base/google/init_google.h"
-#else
-#include "iree/base/internal/init_internal.h"
-#endif // IREE_CONFIG_GOOGLE_INTERNAL
-
-namespace iree {
-
-// Initializes the system environment in a binary.
-//
-// This first parses command line flags, then resolves module initializers
-// by calling IREE_RUN_MODULE_INITIALIZERS().
-//
-// 'argc' and 'argv' are the command line flags to parse.
-//
-// This should typically be called early in main(), before other code runs.
-void InitializeEnvironment(int* argc, char*** argv);
-
-} // namespace iree
-
-#endif // IREE_BASE_INIT_H_
diff --git a/iree/base/internal/BUILD b/iree/base/internal/BUILD
deleted file mode 100644
index 6ffe4ac..0000000
--- a/iree/base/internal/BUILD
+++ /dev/null
@@ -1,118 +0,0 @@
-# Implementations for iree/base/
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "file_handle_win32",
- srcs = ["file_handle_win32.cc"],
- hdrs = ["file_handle_win32.h"],
- deps = [
- "//iree/base:status",
- "//iree/base:target_platform",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
- name = "file_io_internal",
- srcs = [
- "file_io_posix.cc",
- "file_io_win32.cc",
- ],
- deps = [
- ":file_handle_win32",
- "//iree/base:file_io_hdrs",
- "//iree/base:status",
- "//iree/base:target_platform",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
- name = "file_mapping_internal",
- srcs = [
- "file_mapping_posix.cc",
- "file_mapping_win32.cc",
- ],
- deps = [
- ":file_handle_win32",
- "//iree/base:file_mapping_hdrs",
- "//iree/base:target_platform",
- "//iree/base:tracing",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
- name = "init_internal",
- srcs = ["init_internal.cc"],
- hdrs = ["init_internal.h"],
- deps = [
- "//iree/base:target_platform",
- "@com_google_absl//absl/flags:parse",
- ],
-)
-
-cc_library(
- name = "logging_internal",
- srcs = ["logging.cc"],
- hdrs = ["logging.h"],
- deps = [
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/flags:flag",
- ],
-)
-
-cc_library(
- name = "source_location_internal",
- hdrs = ["source_location.h"],
-)
-
-cc_library(
- name = "status_internal",
- srcs = [
- "status.cc",
- "status_builder.cc",
- "status_errno.cc",
- "status_errors.cc",
- "status_win32_errors.cc",
- "statusor.cc",
- ],
- hdrs = [
- "status.h",
- "status_builder.h",
- "status_errno.h",
- "status_errors.h",
- "status_macros.h",
- "status_win32_errors.h",
- "statusor.h",
- ],
- deps = [
- ":logging_internal",
- "//iree/base:source_location",
- "//iree/base:target_platform",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/debugging:stacktrace",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
- name = "status_matchers_internal",
- testonly = 1,
- hdrs = ["status_matchers.h"],
- deps = [
- "//iree/base:status",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- "@com_google_googletest//:gtest",
- ],
-)
diff --git a/iree/base/internal/file_handle_win32.cc b/iree/base/internal/file_handle_win32.cc
deleted file mode 100644
index b4bf05a..0000000
--- a/iree/base/internal/file_handle_win32.cc
+++ /dev/null
@@ -1,55 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/internal/file_handle_win32.h"
-
-#include "absl/memory/memory.h"
-#include "iree/base/target_platform.h"
-
-#if defined(IREE_PLATFORM_WINDOWS)
-
-#include <windows.h>
-
-namespace iree {
-
-// static
-StatusOr<std::unique_ptr<FileHandle>> FileHandle::OpenRead(std::string path,
- DWORD file_flags) {
- HANDLE handle = ::CreateFileA(
- /*lpFileName=*/path.c_str(), /*dwDesiredAccess=*/GENERIC_READ,
- /*dwShareMode=*/FILE_SHARE_READ, /*lpSecurityAttributes=*/nullptr,
- /*dwCreationDisposition=*/OPEN_EXISTING,
- /*dwFlagsAndAttributes=*/FILE_ATTRIBUTE_NORMAL | file_flags,
- /*hTemplateFile=*/nullptr);
- if (handle == INVALID_HANDLE_VALUE) {
- return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
- << "Unable to open file " << path;
- }
-
- BY_HANDLE_FILE_INFORMATION file_info;
- if (::GetFileInformationByHandle(handle, &file_info) == FALSE) {
- return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
- << "Unable to query file info for " << path;
- }
-
- uint64_t file_size = (static_cast<uint64_t>(file_info.nFileSizeHigh) << 32) |
- file_info.nFileSizeLow;
- return absl::make_unique<FileHandle>(handle, file_size);
-}
-
-FileHandle::~FileHandle() { ::CloseHandle(handle_); }
-
-} // namespace iree
-
-#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/base/internal/file_handle_win32.h b/iree/base/internal/file_handle_win32.h
deleted file mode 100644
index 793e5c6..0000000
--- a/iree/base/internal/file_handle_win32.h
+++ /dev/null
@@ -1,57 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_
-#define IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_
-
-#include <memory>
-#include <string>
-
-#include "absl/memory/memory.h"
-#include "absl/strings/string_view.h"
-#include "iree/base/status.h"
-#include "iree/base/target_platform.h"
-
-#if defined(IREE_PLATFORM_WINDOWS)
-
-#include <windows.h>
-
-namespace iree {
-
-class FileHandle {
- public:
- static StatusOr<std::unique_ptr<FileHandle>> OpenRead(std::string path,
- DWORD file_flags);
-
- FileHandle(HANDLE handle, size_t size) : handle_(handle), size_(size) {}
- ~FileHandle();
-
- absl::string_view path() const { return path_; }
- HANDLE handle() const { return handle_; }
- size_t size() const { return size_; }
-
- private:
- FileHandle(const FileHandle&) = delete;
- FileHandle& operator=(const FileHandle&) = delete;
-
- std::string path_;
- HANDLE handle_;
- size_t size_;
-};
-
-} // namespace iree
-
-#endif // IREE_PLATFORM_WINDOWS
-
-#endif // IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_
diff --git a/iree/base/internal/file_io_posix.cc b/iree/base/internal/file_io_posix.cc
deleted file mode 100644
index 3674509..0000000
--- a/iree/base/internal/file_io_posix.cc
+++ /dev/null
@@ -1,89 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <cstdio>
-
-#include "iree/base/file_io.h"
-#include "iree/base/status.h"
-#include "iree/base/target_platform.h"
-
-#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_APPLE) || \
- defined(IREE_PLATFORM_LINUX)
-
-#include <sys/stat.h>
-#include <sys/types.h>
-#include <unistd.h>
-
-namespace iree {
-namespace file_io {
-
-Status FileExists(const std::string& path) {
- struct stat stat_buf;
- return stat(path.c_str(), &stat_buf) == 0 ? OkStatus()
- : NotFoundErrorBuilder(IREE_LOC);
-}
-
-StatusOr<std::string> GetFileContents(const std::string& path) {
- std::unique_ptr<FILE, void (*)(FILE*)> file = {std::fopen(path.c_str(), "r"),
- +[](FILE* file) {
- if (file) fclose(file);
- }};
- if (file == nullptr) {
- return ErrnoToCanonicalStatusBuilder(errno, "Failed to open file",
- IREE_LOC);
- }
- if (std::fseek(file.get(), 0, SEEK_END) == -1) {
- return ErrnoToCanonicalStatusBuilder(errno, "Failed to seek file",
- IREE_LOC);
- }
- size_t file_size = std::ftell(file.get());
- if (file_size == -1L) {
- return ErrnoToCanonicalStatusBuilder(errno, "Failed to read file length",
- IREE_LOC);
- }
- if (std::fseek(file.get(), 0, SEEK_SET) == -1) {
- return ErrnoToCanonicalStatusBuilder(errno, "Failed to seek file",
- IREE_LOC);
- }
- std::string contents;
- contents.resize(file_size);
- if (std::fread(const_cast<char*>(contents.data()), file_size, 1,
- file.get()) != file_size) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Unable to read entire file contents";
- }
- return contents;
-}
-
-Status DeleteFile(const std::string& path) {
- if (::remove(path.c_str()) == -1) {
- return ErrnoToCanonicalStatusBuilder(errno, "Failed to delete file",
- IREE_LOC);
- }
- return OkStatus();
-}
-
-Status MoveFile(const std::string& source_path,
- const std::string& destination_path) {
- if (::rename(source_path.c_str(), destination_path.c_str()) == -1) {
- return ErrnoToCanonicalStatusBuilder(errno, "Failed to rename file",
- IREE_LOC);
- }
- return OkStatus();
-}
-
-} // namespace file_io
-} // namespace iree
-
-#endif // IREE_PLATFORM_*
diff --git a/iree/base/internal/file_io_win32.cc b/iree/base/internal/file_io_win32.cc
deleted file mode 100644
index bb06b5f..0000000
--- a/iree/base/internal/file_io_win32.cc
+++ /dev/null
@@ -1,76 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
-#include "iree/base/file_io.h"
-#include "iree/base/internal/file_handle_win32.h"
-#include "iree/base/target_platform.h"
-
-#if defined(IREE_PLATFORM_WINDOWS)
-
-#include <windows.h>
-
-namespace iree {
-namespace file_io {
-
-Status FileExists(const std::string& path) {
- DWORD attrs = ::GetFileAttributesA(path.c_str());
- if (attrs == INVALID_FILE_ATTRIBUTES) {
- return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
- << "Unable to find/access file: " << path;
- }
- return OkStatus();
-}
-
-StatusOr<std::string> GetFileContents(const std::string& path) {
- ASSIGN_OR_RETURN(auto file, FileHandle::OpenRead(std::move(path),
- FILE_FLAG_SEQUENTIAL_SCAN));
- std::string result;
- result.resize(file->size());
- DWORD bytes_read = 0;
- if (::ReadFile(file->handle(), const_cast<char*>(result.data()),
- result.size(), &bytes_read, nullptr) == FALSE) {
- return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
- << "Unable to read file span of " << result.size() << " bytes";
- } else if (bytes_read != file->size()) {
- return ResourceExhaustedErrorBuilder(IREE_LOC)
- << "Unable to read all " << file->size()
- << " bytes from the file (got " << bytes_read << ")";
- }
- return result;
-}
-
-Status DeleteFile(const std::string& path) {
- if (::DeleteFileA(path.c_str()) == FALSE) {
- return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
- << "Unable to delete/access file: " << path;
- }
- return OkStatus();
-}
-
-Status MoveFile(const std::string& source_path,
- const std::string& destination_path) {
- if (::MoveFileA(source_path.c_str(), destination_path.c_str()) == FALSE) {
- return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
- << "Unable to move file " << source_path << " to "
- << destination_path;
- }
- return OkStatus();
-}
-
-} // namespace file_io
-} // namespace iree
-
-#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/base/internal/file_mapping_posix.cc b/iree/base/internal/file_mapping_posix.cc
deleted file mode 100644
index 38006ee..0000000
--- a/iree/base/internal/file_mapping_posix.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/file_mapping.h"
-#include "iree/base/target_platform.h"
-#include "iree/base/tracing.h"
-
-#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_APPLE) || \
- defined(IREE_PLATFORM_LINUX)
-
-#include <fcntl.h>
-#include <sys/mman.h>
-#include <sys/stat.h>
-#include <sys/types.h>
-#include <unistd.h>
-
-#include <cerrno>
-
-namespace iree {
-
-namespace {
-
-class FileDescriptor {
- public:
- static StatusOr<std::unique_ptr<FileDescriptor>> OpenRead(std::string path) {
- struct stat buf;
- if (::lstat(path.c_str(), &buf) == -1) {
- return NotFoundErrorBuilder(IREE_LOC)
- << "Unable to stat file " << path << ": " << ::strerror(errno);
- }
- uint64_t file_size = static_cast<size_t>(buf.st_size);
-
- int fd = ::open(path.c_str(), O_RDONLY);
- if (fd == -1) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Unable to open file " << path << ": " << ::strerror(errno);
- }
-
- return absl::make_unique<FileDescriptor>(std::move(path), fd, file_size);
- }
-
- FileDescriptor(std::string path, int fd, size_t size)
- : path_(std::move(path)), fd_(fd), size_(size) {}
- ~FileDescriptor() { ::close(fd_); }
-
- absl::string_view path() const { return path_; }
- int fd() const { return fd_; }
- size_t size() const { return size_; }
-
- private:
- FileDescriptor(const FileDescriptor&) = delete;
- FileDescriptor& operator=(const FileDescriptor&) = delete;
-
- std::string path_;
- int fd_;
- size_t size_;
-};
-
-class MMapMapping : public FileMapping {
- public:
- MMapMapping(void* data, size_t data_size)
- : FileMapping(
- absl::MakeSpan(reinterpret_cast<uint8_t*>(data), data_size)) {}
-
- ~MMapMapping() override {
- if (::munmap(const_cast<uint8_t*>(data_.data()), data_.size()) != 0) {
- LOG(WARNING) << "Unable to unmap file: " << strerror(errno);
- }
- }
-};
-
-} // namespace
-
-// static
-StatusOr<ref_ptr<FileMapping>> FileMapping::OpenRead(std::string path) {
- IREE_TRACE_SCOPE0("FileMapping::Open");
-
- // Open the file for reading. Note that we only need to keep it open long
- // enough to map it and we can close the descriptor after that.
- ASSIGN_OR_RETURN(auto file, FileDescriptor::OpenRead(std::move(path)));
-
- // Map the file from the file descriptor.
- void* data =
- ::mmap(nullptr, file->size(), PROT_READ, MAP_SHARED, file->fd(), 0);
- if (data == MAP_FAILED) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Mapping failed on file (ensure uncompressed): " << file->path();
- }
-
- return make_ref<MMapMapping>(data, file->size());
-}
-
-} // namespace iree
-
-#endif // IREE_PLATFORM_*
diff --git a/iree/base/internal/file_mapping_win32.cc b/iree/base/internal/file_mapping_win32.cc
deleted file mode 100644
index ebe8fb7..0000000
--- a/iree/base/internal/file_mapping_win32.cc
+++ /dev/null
@@ -1,98 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
-#include "iree/base/file_mapping.h"
-#include "iree/base/internal/file_handle_win32.h"
-#include "iree/base/target_platform.h"
-#include "iree/base/tracing.h"
-
-#if defined(IREE_PLATFORM_WINDOWS)
-
-#include <windows.h>
-
-namespace iree {
-
-namespace {
-
-class Win32FileMapping : public FileMapping {
- public:
- Win32FileMapping(HANDLE mapping_handle, void* data, size_t data_size)
- : FileMapping(
- absl::MakeSpan(reinterpret_cast<uint8_t*>(data), data_size)),
- mapping_handle_(mapping_handle) {}
-
- ~Win32FileMapping() override {
- if (!data_.empty()) {
- if (::UnmapViewOfFile(data_.data()) == FALSE) {
- LOG(WARNING) << "Unable to unmap file: " << GetLastError();
- }
- data_ = {};
- }
- if (mapping_handle_) {
- ::CloseHandle(mapping_handle_);
- mapping_handle_ = nullptr;
- }
- }
-
- private:
- HANDLE mapping_handle_;
-};
-
-} // namespace
-
-// static
-StatusOr<ref_ptr<FileMapping>> FileMapping::OpenRead(std::string path) {
- IREE_TRACE_SCOPE0("FileMapping::Open");
-
- // Open the file for reading. Note that we only need to keep it open long
- // enough to map it and we can close the descriptor after that.
- ASSIGN_OR_RETURN(auto file, FileHandle::OpenRead(std::move(path),
- FILE_FLAG_RANDOM_ACCESS));
-
- HANDLE mapping_handle = ::CreateFileMappingA(
- /*hFile=*/file->handle(), /*lpFileMappingAttributes=*/nullptr,
- /*flProtect=*/PAGE_READONLY, /*dwMaximumSizeHigh=*/0,
- /*dwMaximumSizeLow=*/0, /*lpName=*/nullptr);
- if (!mapping_handle) {
- return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
- << "Failed to create mapping on file (ensure uncompressed): "
- << file->path();
- }
-
- void* data =
- ::MapViewOfFileEx(/*hFileMappingObject=*/mapping_handle,
- /*dwDesiredAccess=*/FILE_MAP_READ,
- /*dwFileOffsetHigh=*/0, /*dwFileOffsetLow=*/0,
- /*dwNumberOfBytesToMap=*/0, /*lpBaseAddress=*/nullptr);
- if (!data) {
- DWORD map_view_error = GetLastError();
- ::CloseHandle(mapping_handle);
- return Win32ErrorToCanonicalStatusBuilder(map_view_error, IREE_LOC)
- << "Failed to map view of file: " << file->path();
- }
-
- auto result = make_ref<Win32FileMapping>(mapping_handle, data, file->size());
-
- // NOTE: file mappings hold references to the file, so we don't need to keep
- // the file around any longer than this function.
- file.reset();
-
- return result;
-}
-
-} // namespace iree
-
-#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/base/internal/init_internal.cc b/iree/base/internal/init_internal.cc
deleted file mode 100644
index c379199..0000000
--- a/iree/base/internal/init_internal.cc
+++ /dev/null
@@ -1,110 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/internal/init_internal.h"
-
-#include <string.h>
-
-#include <set>
-
-#include "absl/flags/parse.h"
-
-namespace iree {
-
-static Initializer::NameMap* static_name_map = nullptr;
-
-struct Initializer::InitializerData {
- Initializer* initializer_obj;
- std::set<std::string> dependency_names;
-
- InitializerData() : initializer_obj(nullptr) {}
- explicit InitializerData(Initializer* i) : initializer_obj(i) {}
-};
-
-Initializer::DependencyRegisterer::DependencyRegisterer(
- const char* name, Initializer* initializer, const Dependency& dependency) {
- NameMap* name_map = InitializerNameMap();
-
- // Insert 'dependency' into the 'dependency_names' set for 'initializer'.
- InitializerData* initializer_data = &(*name_map)[name];
- initializer_data->dependency_names.insert(dependency.name);
-
- // Ensure that 'dependency' exists in the map.
- InitializerData* dependency_data = &(*name_map)[dependency.name];
- dependency_data->initializer_obj = dependency.initializer;
-}
-
-Initializer::Initializer(const char* name, InitializerFunc function)
- : name_(name), function_(function), done_(false) {
- // Register this Initializer instance (wrapped by an InitializerData) within
- // the static name map.
- NameMap* name_map = InitializerNameMap();
- InitializerData* initializer_data = &(*name_map)[name];
- initializer_data->initializer_obj = this;
-}
-
-void Initializer::RunInitializers() {
- // Run each registered Initializer, in lexicographic order of their names.
- // Initializer dependencies will be run first as needed.
- NameMap* name_map = InitializerNameMap();
- for (auto& p : *name_map) {
- RunInitializer(&p.second);
- }
-}
-
-void Initializer::Require() {
- NameMap* name_map = InitializerNameMap();
- InitializerData* initializer_data = &(name_map->find(name_)->second);
- RunInitializer(initializer_data);
-}
-
-Initializer::NameMap* Initializer::InitializerNameMap() {
- if (static_name_map == nullptr) {
- static_name_map = new Initializer::NameMap;
- }
- return static_name_map;
-}
-
-void Initializer::RunInitializer(InitializerData* initializer_data) {
- if (initializer_data->initializer_obj->done_) {
- return;
- }
-
- // Run Initializer dependencies first.
- NameMap* name_map = InitializerNameMap();
- for (const auto& dependency_name : initializer_data->dependency_names) {
- auto dep_init = name_map->find(dependency_name);
- RunInitializer(&dep_init->second);
- }
-
- // Finally run the Initializer itself.
- initializer_data->initializer_obj->function_();
- initializer_data->initializer_obj->done_ = true;
-}
-
-void InitializeEnvironment(int* argc, char*** argv) {
- auto positional_args = absl::ParseCommandLine(*argc, *argv);
- if (positional_args.size() < *argc) {
- // Edit the passed argument refs to only include positional args.
- *argc = positional_args.size();
- for (int i = 0; i < *argc; ++i) {
- (*argv)[i] = positional_args[i];
- }
- (*argv)[*argc + 1] = nullptr;
- }
-
- IREE_RUN_MODULE_INITIALIZERS();
-}
-
-} // namespace iree
diff --git a/iree/base/internal/init_internal.h b/iree/base/internal/init_internal.h
deleted file mode 100644
index 79431d8..0000000
--- a/iree/base/internal/init_internal.h
+++ /dev/null
@@ -1,110 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INTERNAL_INIT_INTERNAL_H_
-#define IREE_BASE_INTERNAL_INIT_INTERNAL_H_
-
-#include <map>
-#include <string>
-
-#include "iree/base/target_platform.h"
-
-namespace iree {
-
-// A static instance of this class is declared for each piece of initialization
-// code using the initializer macros.
-class Initializer {
- public:
- typedef void (*InitializerFunc)();
-
- Initializer(const char* name, InitializerFunc function);
-
- // Runs all registered initializers that have not yet run.
- // The initializers are invoked in lexicographically increasing order by name,
- // except as necessary to satisfy dependencies.
- //
- // This is normally called by InitializeEnvironment(), so application code
- // typically should not call it directly.
- static void RunInitializers();
-
- // Runs this initializer if it has not yet run, including any dependencies.
- void Require();
-
- struct Dependency {
- Dependency(const char* n, Initializer* i) : name(n), initializer(i) {}
- const char* const name;
- Initializer* const initializer;
- };
-
- // A static instance of this class is declared for each piece of
- // initializer ordering definition.
- struct DependencyRegisterer {
- DependencyRegisterer(const char* name, Initializer* initializer,
- const Dependency& dependency);
- };
-
- struct InitializerData;
- typedef std::map<std::string, InitializerData> NameMap;
-
- private:
- static NameMap* InitializerNameMap();
- static void RunInitializer(InitializerData* initializer_data);
-
- const std::string name_;
- InitializerFunc function_;
- bool done_;
-};
-
-// In iree/base/init.h:
-void InitializeEnvironment(int* argc, char*** argv);
-
-} // namespace iree
-
-#define IREE_DECLARE_MODULE_INITIALIZER(name) \
- extern ::iree::Initializer iree_initializer_##name
-
-#define IREE_REGISTER_MODULE_INITIALIZER(name, body) \
- static void iree_init_##name() { body; } \
- ::iree::Initializer iree_initializer_##name(#name, iree_init_##name)
-
-#define IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) \
- namespace { \
- static ::iree::Initializer::DependencyRegisterer \
- iree_initializer_dependency_##name1##_##name2( \
- #name2, &iree_initializer_##name2, \
- ::iree::Initializer::Dependency(#name1, &iree_initializer_##name1)); \
- }
-
-#define IREE_REQUIRE_MODULE_INITIALIZED(name) \
- do { \
- IREE_DECLARE_MODULE_INITIALIZER(name); \
- iree_initializer_##name.Require(); \
- } while (0)
-
-#define IREE_RUN_MODULE_INITIALIZERS() \
- do { \
- ::iree::Initializer::RunInitializers(); \
- } while (0)
-
-#if !defined(IREE_COMPILER_MSVC)
-#define IREE_ATTRIBUTE_USED __attribute__((used))
-#else
-#define IREE_ATTRIBUTE_USED
-#endif // IREE_COMPILER_MSVC
-
-#define IREE_REQUIRE_MODULE_LINKED(name) \
- IREE_ATTRIBUTE_USED static ::iree::Initializer* iree_module_ref_##name = \
- &iree_initializer_##name
-
-#endif // IREE_BASE_INTERNAL_INIT_INTERNAL_H_
diff --git a/iree/base/internal/logging.cc b/iree/base/internal/logging.cc
deleted file mode 100644
index 29a2604..0000000
--- a/iree/base/internal/logging.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/internal/logging.h"
-
-#include <string>
-
-#include "absl/flags/flag.h"
-
-ABSL_FLAG(int, iree_minloglevel, 0,
- "Minimum logging level. 0 = INFO and above.");
-ABSL_FLAG(int, iree_v, 0,
- "Verbosity level maximum. 1 = VLOG(0-1), 2 = VLOG(0-2).");
-ABSL_FLAG(bool, iree_logtostderr, false, "Logs to stderr instead of stdout");
-
-namespace iree {
-namespace internal {
-
-namespace {
-
-// Parse log level (int64_t) from environment variable (char*).
-// Returns true if the value was present and parsed successfully.
-bool LogLevelStrToInt(const char* iree_env_var_val, int64_t* out_level) {
- *out_level = 0;
- if (iree_env_var_val == nullptr) {
- return false;
- }
-
- std::string min_log_level(iree_env_var_val);
- std::istringstream ss(min_log_level);
- int64_t level;
- if (!(ss >> level)) {
- // Invalid vlog level setting, set level to default (0).
- return false;
- }
-
- *out_level = level;
- return true;
-}
-
-int64_t MinLogLevelFromEnv() {
- const char* iree_env_var_val = getenv("IREE_MIN_LOG_LEVEL");
- int64_t level = 0;
- if (LogLevelStrToInt(iree_env_var_val, &level)) {
- return level;
- }
- return absl::GetFlag(FLAGS_iree_minloglevel);
-}
-
-int64_t MinVLogLevelFromEnv() {
- const char* iree_env_var_val = getenv("IREE_MIN_VLOG_LEVEL");
- int64_t level = 0;
- if (LogLevelStrToInt(iree_env_var_val, &level)) {
- return level;
- }
- return absl::GetFlag(FLAGS_iree_v);
-}
-
-} // namespace
-
-LogMessage::LogMessage(const char* file_name, int line, int severity)
- : file_name_(file_name), line_(line), severity_(severity) {}
-
-LogMessage::~LogMessage() {
- // Read the min log level once during the first call to logging.
- static int64_t min_log_level = MinLogLevelFromEnv();
- if (ABSL_PREDICT_TRUE(severity_ >= min_log_level)) {
- EmitLogMessage();
- }
-}
-
-int64_t LogMessage::MinVLogLevel() {
- static int64_t min_vlog_level = MinVLogLevelFromEnv();
- return min_vlog_level;
-}
-
-void LogMessage::EmitLogMessage() {
- // TODO(scotttodd): Include current system time
- fprintf(absl::GetFlag(FLAGS_iree_logtostderr) ? stderr : stdout,
- "%c %s:%d] %s\n", "IWEF"[severity_], file_name_, line_,
- str().c_str());
-}
-
-LogMessageFatal::LogMessageFatal(const char* file, int line)
- : LogMessage(file, line, FATAL) {}
-
-LogMessageFatal::~LogMessageFatal() {
- EmitLogMessage();
-
- // abort() ensures we don't return (as promised via ATTRIBUTE_NORETURN).
- abort();
-}
-
-} // namespace internal
-} // namespace iree
diff --git a/iree/base/internal/status.cc b/iree/base/internal/status.cc
deleted file mode 100644
index f4f4d57..0000000
--- a/iree/base/internal/status.cc
+++ /dev/null
@@ -1,178 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/internal/status.h"
-
-#include <atomic>
-#include <memory>
-
-#include "absl/base/attributes.h"
-#include "absl/debugging/stacktrace.h"
-#include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
-
-ABSL_FLAG(bool, iree_status_save_stack_trace, false,
- "Save and display the full stack trace of the point of error")
- .OnUpdate([]() {
- iree::StatusSavesStackTrace(
- absl::GetFlag(FLAGS_iree_status_save_stack_trace));
- });
-
-namespace iree {
-
-namespace status_internal {
-
-ABSL_CONST_INIT std::atomic<bool> iree_save_stack_trace{false};
-
-} // namespace status_internal
-
-bool DoesStatusSaveStackTrace() {
- return status_internal::iree_save_stack_trace.load(std::memory_order_relaxed);
-}
-void StatusSavesStackTrace(bool on_off) {
- status_internal::iree_save_stack_trace.store(on_off,
- std::memory_order_relaxed);
-}
-
-std::string StatusCodeToString(StatusCode code) {
- switch (code) {
- case StatusCode::kOk:
- return "OK";
- case StatusCode::kCancelled:
- return "CANCELLED";
- case StatusCode::kUnknown:
- return "UNKNOWN";
- case StatusCode::kInvalidArgument:
- return "INVALID_ARGUMENT";
- case StatusCode::kDeadlineExceeded:
- return "DEADLINE_EXCEEDED";
- case StatusCode::kNotFound:
- return "NOT_FOUND";
- case StatusCode::kAlreadyExists:
- return "ALREADY_EXISTS";
- case StatusCode::kPermissionDenied:
- return "PERMISSION_DENIED";
- case StatusCode::kUnauthenticated:
- return "UNAUTHENTICATED";
- case StatusCode::kResourceExhausted:
- return "RESOURCE_EXHAUSTED";
- case StatusCode::kFailedPrecondition:
- return "FAILED_PRECONDITION";
- case StatusCode::kAborted:
- return "ABORTED";
- case StatusCode::kOutOfRange:
- return "OUT_OF_RANGE";
- case StatusCode::kUnimplemented:
- return "UNIMPLEMENTED";
- case StatusCode::kInternal:
- return "INTERNAL";
- case StatusCode::kUnavailable:
- return "UNAVAILABLE";
- case StatusCode::kDataLoss:
- return "DATA_LOSS";
- default:
- return "";
- }
-}
-
-Status::Status() {}
-
-Status::Status(StatusCode code, absl::string_view message) {
- state_ = absl::make_unique<State>();
- state_->code = code;
- state_->message = std::string(message);
-}
-
-Status::Status(const Status& x) {
- if (x.ok()) return;
-
- state_ = absl::make_unique<State>();
- state_->code = x.state_->code;
- state_->message = x.state_->message;
-}
-
-Status& Status::operator=(const Status& x) {
- if (x.ok()) {
- state_ = nullptr;
- } else {
- state_ = absl::make_unique<State>();
- state_->code = x.state_->code;
- state_->message = x.state_->message;
- }
- return *this;
-}
-
-Status::~Status() {}
-
-bool Status::ok() const { return state_ == nullptr; }
-
-StatusCode Status::code() const {
- return ok() ? StatusCode::kOk : state_->code;
-}
-
-absl::string_view Status::message() const {
- return ok() ? absl::string_view() : absl::string_view(state_->message);
-}
-
-std::string Status::ToString() const {
- if (ok()) {
- return "OK";
- }
-
- std::string text;
- absl::StrAppend(&text, StatusCodeToString(state_->code), ": ",
- state_->message);
- // TODO(scotttodd): Payloads (stack traces)
- return text;
-}
-
-void Status::IgnoreError() const {
- // no-op
-}
-
-bool Status::EqualsSlow(const Status& a, const Status& b) {
- if (a.code() != b.code()) return false;
- if (a.message() != b.message()) return false;
- // TODO(scotttodd): Payloads
- return true;
-}
-
-bool operator==(const Status& lhs, const Status& rhs) {
- return lhs.state_ == rhs.state_ || Status::EqualsSlow(lhs, rhs);
-}
-
-bool operator!=(const Status& lhs, const Status& rhs) { return !(lhs == rhs); }
-
-std::ostream& operator<<(std::ostream& os, const Status& x) {
- os << x.ToString();
- return os;
-}
-
-Status OkStatus() { return Status(); }
-
-Status Annotate(const Status& s, absl::string_view msg) {
- if (s.ok() || msg.empty()) return s;
-
- absl::string_view new_msg = msg;
- std::string annotated;
- if (!s.message().empty()) {
- absl::StrAppend(&annotated, s.message(), "; ", msg);
- new_msg = annotated;
- }
- Status result(s.code(), new_msg);
- // TODO(scotttodd): Copy payload(s) into the new Status
- return result;
-}
-
-} // namespace iree
diff --git a/iree/base/internal/status.h b/iree/base/internal/status.h
deleted file mode 100644
index b83e249..0000000
--- a/iree/base/internal/status.h
+++ /dev/null
@@ -1,130 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INTERNAL_STATUS_H_
-#define IREE_BASE_INTERNAL_STATUS_H_
-
-#include <atomic>
-#include <string>
-
-#include "absl/base/attributes.h"
-#include "absl/flags/flag.h"
-#include "absl/strings/string_view.h"
-#include "iree/base/internal/logging.h"
-
-ABSL_DECLARE_FLAG(bool, iree_status_save_stack_trace);
-
-namespace iree {
-
-// True if Status objects will capture stack traces on init for non-ok Statuses.
-bool DoesStatusSaveStackTrace();
-
-// Enables/disables status stack trace saving. This is global for the process.
-// While useful for debugging, stack traces can impact performance severely.
-void StatusSavesStackTrace(bool on_off);
-
-enum class StatusCode : int {
- kOk = 0,
- kCancelled = 1,
- kUnknown = 2,
- kInvalidArgument = 3,
- kDeadlineExceeded = 4,
- kNotFound = 5,
- kAlreadyExists = 6,
- kPermissionDenied = 7,
- kResourceExhausted = 8,
- kFailedPrecondition = 9,
- kAborted = 10,
- kOutOfRange = 11,
- kUnimplemented = 12,
- kInternal = 13,
- kUnavailable = 14,
- kDataLoss = 15,
- kUnauthenticated = 16,
- kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20
-};
-
-std::string StatusCodeToString(StatusCode code);
-
-class ABSL_MUST_USE_RESULT Status;
-
-// A Status value can be either OK or not-OK
-// * OK indicates that the operation succeeded.
-// * A not-OK value indicates that the operation failed and contains details
-// about the error.
-class Status final {
- public:
- // Creates an OK status with no message.
- Status();
-
- // Creates a status with the specified code and error message.
- Status(StatusCode code, absl::string_view message);
-
- Status(const Status&);
- Status& operator=(const Status& x);
-
- ~Status();
-
- // Returns true if the Status is OK.
- ABSL_MUST_USE_RESULT bool ok() const;
-
- // Returns the error code.
- StatusCode code() const;
-
- // Returns the error message. Note: prefer ToString() for debug logging.
- // This message rarely describes the error code. It is not unusual for the
- // error message to be the empty string.
- absl::string_view message() const;
-
- // Return a combination of the error code name and message.
- std::string ToString() const;
-
- // Compatibility with upstream API. Equiv to ToString().
- std::string error_message() const { return ToString(); }
-
- friend bool operator==(const Status&, const Status&);
- friend bool operator!=(const Status&, const Status&);
-
- // Ignores any errors, potentially suppressing complaints from any tools.
- void IgnoreError() const;
-
- private:
- static bool EqualsSlow(const Status& a, const Status& b);
-
- struct State {
- StatusCode code;
- std::string message;
- };
- // OK status has a nullptr state_. Otherwise, 'state_' points to
- // a 'State' structure containing the error code and message(s).
- std::unique_ptr<State> state_;
-};
-
-// Returns an OK status, equivalent to a default constructed instance.
-Status OkStatus();
-
-// Prints a human-readable representation of `x` to `os`.
-std::ostream& operator<<(std::ostream& os, const Status& x);
-
-// Returns a Status that is identical to `s` except that the message()
-// has been augmented by adding `msg` to the end of the original message.
-Status Annotate(const Status& s, absl::string_view msg);
-
-#define CHECK_OK(val) CHECK_EQ(::iree::OkStatus(), (val))
-#define QCHECK_OK(val) QCHECK_EQ(::iree::OkStatus(), (val))
-#define DCHECK_OK(val) DCHECK_EQ(::iree::OkStatus(), (val))
-
-} // namespace iree
-
-#endif // IREE_BASE_INTERNAL_STATUS_H_
diff --git a/iree/base/internal/status_builder.cc b/iree/base/internal/status_builder.cc
deleted file mode 100644
index 7ac0e87..0000000
--- a/iree/base/internal/status_builder.cc
+++ /dev/null
@@ -1,140 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/internal/status_builder.h"
-
-#include <cstdio>
-
-#include "iree/base/internal/status_errors.h"
-
-namespace iree {
-
-StatusBuilder::StatusBuilder(const Status& original_status,
- SourceLocation location)
- : status_(original_status), loc_(location) {}
-
-StatusBuilder::StatusBuilder(Status&& original_status, SourceLocation location)
- : status_(original_status), loc_(location) {}
-
-StatusBuilder::StatusBuilder(const StatusBuilder& sb)
- : status_(sb.status_), loc_(sb.loc_), message_(sb.message_) {}
-
-StatusBuilder::StatusBuilder(StatusCode code, SourceLocation location)
- : status_(code, ""), loc_(location) {}
-
-StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) {
- status_ = sb.status_;
- loc_ = sb.loc_;
- message_ = sb.message_;
- return *this;
-}
-
-StatusBuilder::operator Status() const& {
- return StatusBuilder(*this).CreateStatus();
-}
-StatusBuilder::operator Status() && { return std::move(*this).CreateStatus(); }
-
-bool StatusBuilder::ok() const { return status_.ok(); }
-
-StatusCode StatusBuilder::code() const { return status_.code(); }
-
-SourceLocation StatusBuilder::source_location() const { return loc_; }
-
-Status StatusBuilder::CreateStatus() && {
- Status result = JoinMessageToStatus(status_, message_);
-
- // Reset the status after consuming it.
- status_ = UnknownError("");
- message_ = "";
- return result;
-}
-
-Status StatusBuilder::JoinMessageToStatus(Status s, absl::string_view msg) {
- if (msg.empty()) return s;
- return Annotate(s, msg);
-}
-
-std::ostream& operator<<(std::ostream& os, const StatusBuilder& builder) {
- return os << static_cast<Status>(builder);
-}
-
-std::ostream& operator<<(std::ostream& os, StatusBuilder&& builder) {
- return os << static_cast<Status>(std::move(builder));
-}
-
-StatusBuilder AbortedErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kAborted, location);
-}
-
-StatusBuilder AlreadyExistsErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kAlreadyExists, location);
-}
-
-StatusBuilder CancelledErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kCancelled, location);
-}
-
-StatusBuilder DataLossErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kDataLoss, location);
-}
-
-StatusBuilder DeadlineExceededErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kDeadlineExceeded, location);
-}
-
-StatusBuilder FailedPreconditionErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kFailedPrecondition, location);
-}
-
-StatusBuilder InternalErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kInternal, location);
-}
-
-StatusBuilder InvalidArgumentErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kInvalidArgument, location);
-}
-
-StatusBuilder NotFoundErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kNotFound, location);
-}
-
-StatusBuilder OutOfRangeErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kOutOfRange, location);
-}
-
-StatusBuilder PermissionDeniedErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kPermissionDenied, location);
-}
-
-StatusBuilder UnauthenticatedErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kUnauthenticated, location);
-}
-
-StatusBuilder ResourceExhaustedErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kResourceExhausted, location);
-}
-
-StatusBuilder UnavailableErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kUnavailable, location);
-}
-
-StatusBuilder UnimplementedErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kUnimplemented, location);
-}
-
-StatusBuilder UnknownErrorBuilder(SourceLocation location) {
- return StatusBuilder(StatusCode::kUnknown, location);
-}
-
-} // namespace iree
diff --git a/iree/base/internal/status_builder.h b/iree/base/internal/status_builder.h
deleted file mode 100644
index e05665b..0000000
--- a/iree/base/internal/status_builder.h
+++ /dev/null
@@ -1,137 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INTERNAL_STATUS_BUILDER_H_
-#define IREE_BASE_INTERNAL_STATUS_BUILDER_H_
-
-#include "iree/base/internal/status.h"
-#include "iree/base/source_location.h"
-
-namespace iree {
-
-// Creates a status based on an original_status, but enriched with additional
-// information. The builder implicitly converts to Status and StatusOr<T>
-// allowing for it to be returned directly.
-class ABSL_MUST_USE_RESULT StatusBuilder {
- public:
- // Creates a `StatusBuilder` based on an original status.
- explicit StatusBuilder(const Status& original_status,
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
- explicit StatusBuilder(Status&& original_status,
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-
- // Creates a `StatusBuilder` from a status code.
- // A typical user will not specify `location`, allowing it to default to the
- // current location.
- explicit StatusBuilder(StatusCode code,
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-
- StatusBuilder(const StatusBuilder& sb);
- StatusBuilder& operator=(const StatusBuilder& sb);
- StatusBuilder(StatusBuilder&&) = default;
- StatusBuilder& operator=(StatusBuilder&&) = default;
-
- // Appends to the extra message that will be added to the original status.
- template <typename T>
- StatusBuilder& operator<<(const T& value) &;
- template <typename T>
- StatusBuilder&& operator<<(const T& value) &&;
-
- // No-op functions that may be added later.
- StatusBuilder& LogError() & { return *this; }
- StatusBuilder&& LogError() && { return std::move(LogError()); }
- StatusBuilder& LogWarning() & { return *this; }
- StatusBuilder&& LogWarning() && { return std::move(LogWarning()); }
- StatusBuilder& LogInfo() & { return *this; }
- StatusBuilder&& LogInfo() && { return std::move(LogInfo()); }
-
- // Returns true if the Status created by this builder will be ok().
- bool ok() const;
-
- // Returns the error code for the Status created by this builder.
- StatusCode code() const;
-
- // Returns the source location used to create this builder.
- SourceLocation source_location() const;
-
- // Implicit conversion to Status.
- operator Status() const&;
- operator Status() &&;
-
- private:
- Status CreateStatus() &&;
-
- static Status JoinMessageToStatus(Status s, absl::string_view msg);
-
- // The status that the result will be based on.
- Status status_;
-
- // The location to record if this status is logged.
- SourceLocation loc_;
-
- // The message that will be added to the original status.
- std::string message_;
-};
-
-template <typename T>
-StatusBuilder& StatusBuilder::operator<<(const T& value) & {
- return *this;
-}
-template <typename T>
-StatusBuilder&& StatusBuilder::operator<<(const T& value) && {
- return std::move(operator<<(value));
-}
-
-// Implicitly converts `builder` to `Status` and write it to `os`.
-std::ostream& operator<<(std::ostream& os, const StatusBuilder& builder);
-std::ostream& operator<<(std::ostream& os, StatusBuilder&& builder);
-
-// Each of the functions below creates StatusBuilder with a canonical error.
-// The error code of the StatusBuilder matches the name of the function.
-StatusBuilder AbortedErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder AlreadyExistsErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder CancelledErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder DataLossErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder DeadlineExceededErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder FailedPreconditionErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder InternalErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder InvalidArgumentErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder NotFoundErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder OutOfRangeErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder PermissionDeniedErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder UnauthenticatedErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder ResourceExhaustedErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder UnavailableErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder UnimplementedErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-StatusBuilder UnknownErrorBuilder(
- SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG);
-
-} // namespace iree
-
-#endif // IREE_BASE_INTERNAL_STATUS_BUILDER_H_
diff --git a/iree/base/internal/status_errno.cc b/iree/base/internal/status_errno.cc
deleted file mode 100644
index deaa5c1..0000000
--- a/iree/base/internal/status_errno.cc
+++ /dev/null
@@ -1,175 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/internal/status_errno.h"
-
-#include <cerrno>
-
-#include "absl/strings/str_cat.h"
-
-namespace iree {
-
-StatusCode ErrnoToCanonicalCode(int error_number) {
- switch (error_number) {
- case 0:
- return StatusCode::kOk;
- case EINVAL: // Invalid argument
- case ENAMETOOLONG: // Filename too long
- case E2BIG: // Argument list too long
- case EDESTADDRREQ: // Destination address required
- case EDOM: // Mathematics argument out of domain of function
- case EFAULT: // Bad address
- case EILSEQ: // Illegal byte sequence
- case ENOPROTOOPT: // Protocol not available
- case ENOSTR: // Not a STREAM
- case ENOTSOCK: // Not a socket
- case ENOTTY: // Inappropriate I/O control operation
- case EPROTOTYPE: // Protocol wrong type for socket
- case ESPIPE: // Invalid seek
- return StatusCode::kInvalidArgument;
- case ETIMEDOUT: // Connection timed out
- case ETIME: // Timer expired
- return StatusCode::kDeadlineExceeded;
- case ENODEV: // No such device
- case ENOENT: // No such file or directory
-#ifdef ENOMEDIUM
- case ENOMEDIUM: // No medium found
-#endif
- case ENXIO: // No such device or address
- case ESRCH: // No such process
- return StatusCode::kNotFound;
- case EEXIST: // File exists
- case EADDRNOTAVAIL: // Address not available
- case EALREADY: // Connection already in progress
-#ifdef ENOTUNIQ
- case ENOTUNIQ: // Name not unique on network
-#endif
- return StatusCode::kAlreadyExists;
- case EPERM: // Operation not permitted
- case EACCES: // Permission denied
-#ifdef ENOKEY
- case ENOKEY: // Required key not available
-#endif
- case EROFS: // Read only file system
- return StatusCode::kPermissionDenied;
- case ENOTEMPTY: // Directory not empty
- case EISDIR: // Is a directory
- case ENOTDIR: // Not a directory
- case EADDRINUSE: // Address already in use
- case EBADF: // Invalid file descriptor
-#ifdef EBADFD
- case EBADFD: // File descriptor in bad state
-#endif
- case EBUSY: // Device or resource busy
- case ECHILD: // No child processes
- case EISCONN: // Socket is connected
-#ifdef EISNAM
- case EISNAM: // Is a named type file
-#endif
-#ifdef ENOTBLK
- case ENOTBLK: // Block device required
-#endif
- case ENOTCONN: // The socket is not connected
- case EPIPE: // Broken pipe
-#ifdef ESHUTDOWN
- case ESHUTDOWN: // Cannot send after transport endpoint shutdown
-#endif
- case ETXTBSY: // Text file busy
-#ifdef EUNATCH
- case EUNATCH: // Protocol driver not attached
-#endif
- return StatusCode::kFailedPrecondition;
- case ENOSPC: // No space left on device
-#ifdef EDQUOT
- case EDQUOT: // Disk quota exceeded
-#endif
- case EMFILE: // Too many open files
- case EMLINK: // Too many links
- case ENFILE: // Too many open files in system
- case ENOBUFS: // No buffer space available
- case ENODATA: // No message is available on the STREAM read queue
- case ENOMEM: // Not enough space
- case ENOSR: // No STREAM resources
-#ifdef EUSERS
- case EUSERS: // Too many users
-#endif
- return StatusCode::kResourceExhausted;
-#ifdef ECHRNG
- case ECHRNG: // Channel number out of range
-#endif
- case EFBIG: // File too large
- case EOVERFLOW: // Value too large to be stored in data type
- case ERANGE: // Result too large
- return StatusCode::kOutOfRange;
-#ifdef ENOPKG
- case ENOPKG: // Package not installed
-#endif
- case ENOSYS: // Function not implemented
- case ENOTSUP: // Operation not supported
- case EAFNOSUPPORT: // Address family not supported
-#ifdef EPFNOSUPPORT
- case EPFNOSUPPORT: // Protocol family not supported
-#endif
- case EPROTONOSUPPORT: // Protocol not supported
-#ifdef ESOCKTNOSUPPORT
- case ESOCKTNOSUPPORT: // Socket type not supported
-#endif
- case EXDEV: // Improper link
- return StatusCode::kUnimplemented;
- case EAGAIN: // Resource temporarily unavailable
-#ifdef ECOMM
- case ECOMM: // Communication error on send
-#endif
- case ECONNREFUSED: // Connection refused
- case ECONNABORTED: // Connection aborted
- case ECONNRESET: // Connection reset
- case EINTR: // Interrupted function call
-#ifdef EHOSTDOWN
- case EHOSTDOWN: // Host is down
-#endif
- case EHOSTUNREACH: // Host is unreachable
- case ENETDOWN: // Network is down
- case ENETRESET: // Connection aborted by network
- case ENETUNREACH: // Network unreachable
- case ENOLCK: // No locks available
- case ENOLINK: // Link has been severed
-#ifdef ENONET
- case ENONET: // Machine is not on the network
-#endif
- return StatusCode::kUnavailable;
- case EDEADLK: // Resource deadlock avoided
-#ifdef ESTALE
- case ESTALE: // Stale file handle
-#endif
- return StatusCode::kAborted;
- case ECANCELED: // Operation cancelled
- return StatusCode::kCancelled;
- default:
- return StatusCode::kUnknown;
- }
-}
-
-Status ErrnoToCanonicalStatus(int error_number, absl::string_view message) {
- // TODO(scotttodd): convert error number to a string
- return Status(ErrnoToCanonicalCode(error_number),
- absl::StrCat(message, ": ", error_number));
-}
-
-StatusBuilder ErrnoToCanonicalStatusBuilder(int error_number,
- absl::string_view message,
- SourceLocation location) {
- return StatusBuilder(ErrnoToCanonicalStatus(error_number, message), location);
-}
-
-} // namespace iree
diff --git a/iree/base/internal/status_errno.h b/iree/base/internal/status_errno.h
deleted file mode 100644
index 8b3d74b..0000000
--- a/iree/base/internal/status_errno.h
+++ /dev/null
@@ -1,41 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INTERNAL_STATUS_ERRNO_H_
-#define IREE_BASE_INTERNAL_STATUS_ERRNO_H_
-
-#include "absl/strings/string_view.h"
-#include "iree/base/internal/status.h"
-#include "iree/base/internal/statusor.h"
-#include "iree/base/source_location.h"
-
-namespace iree {
-
-// Returns the code for |error_number|, which should be an |errno| value.
-// See https://en.cppreference.com/w/cpp/error/errno_macros and similar refs.
-StatusCode ErrnoToCanonicalCode(int error_number);
-
-// Returns a Status, using a code of `ErrnoToCode(error_number)`, and a
-// |message| with the result of `StrError(error_number)` appended.
-Status ErrnoToCanonicalStatus(int error_number, absl::string_view message);
-
-// Returns a StatusBuilder using a status of
-// `ErrnoToCanonicalStatus(error_number, message)` and |location|.
-StatusBuilder ErrnoToCanonicalStatusBuilder(int error_number,
- absl::string_view message,
- SourceLocation location);
-
-} // namespace iree
-
-#endif // IREE_BASE_INTERNAL_STATUS_ERRNO_H_
diff --git a/iree/base/internal/status_errors.cc b/iree/base/internal/status_errors.cc
deleted file mode 100644
index 28acbbc..0000000
--- a/iree/base/internal/status_errors.cc
+++ /dev/null
@@ -1,147 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/internal/status_errors.h"
-
-namespace iree {
-
-Status AbortedError(absl::string_view message) {
- return Status(StatusCode::kAborted, message);
-}
-
-Status AlreadyExistsError(absl::string_view message) {
- return Status(StatusCode::kAlreadyExists, message);
-}
-
-Status CancelledError(absl::string_view message) {
- return Status(StatusCode::kCancelled, message);
-}
-
-Status DataLossError(absl::string_view message) {
- return Status(StatusCode::kDataLoss, message);
-}
-
-Status DeadlineExceededError(absl::string_view message) {
- return Status(StatusCode::kDeadlineExceeded, message);
-}
-
-Status FailedPreconditionError(absl::string_view message) {
- return Status(StatusCode::kFailedPrecondition, message);
-}
-
-Status InternalError(absl::string_view message) {
- return Status(StatusCode::kInternal, message);
-}
-
-Status InvalidArgumentError(absl::string_view message) {
- return Status(StatusCode::kInvalidArgument, message);
-}
-
-Status NotFoundError(absl::string_view message) {
- return Status(StatusCode::kNotFound, message);
-}
-
-Status OutOfRangeError(absl::string_view message) {
- return Status(StatusCode::kOutOfRange, message);
-}
-
-Status PermissionDeniedError(absl::string_view message) {
- return Status(StatusCode::kPermissionDenied, message);
-}
-
-Status ResourceExhaustedError(absl::string_view message) {
- return Status(StatusCode::kResourceExhausted, message);
-}
-
-Status UnauthenticatedError(absl::string_view message) {
- return Status(StatusCode::kUnauthenticated, message);
-}
-
-Status UnavailableError(absl::string_view message) {
- return Status(StatusCode::kUnavailable, message);
-}
-
-Status UnimplementedError(absl::string_view message) {
- return Status(StatusCode::kUnimplemented, message);
-}
-
-Status UnknownError(absl::string_view message) {
- return Status(StatusCode::kUnknown, message);
-}
-
-bool IsAborted(const Status& status) {
- return status.code() == StatusCode::kAborted;
-}
-
-bool IsAlreadyExists(const Status& status) {
- return status.code() == StatusCode::kAlreadyExists;
-}
-
-bool IsCancelled(const Status& status) {
- return status.code() == StatusCode::kCancelled;
-}
-
-bool IsDataLoss(const Status& status) {
- return status.code() == StatusCode::kDataLoss;
-}
-
-bool IsDeadlineExceeded(const Status& status) {
- return status.code() == StatusCode::kDeadlineExceeded;
-}
-
-bool IsFailedPrecondition(const Status& status) {
- return status.code() == StatusCode::kFailedPrecondition;
-}
-
-bool IsInternal(const Status& status) {
- return status.code() == StatusCode::kInternal;
-}
-
-bool IsInvalidArgument(const Status& status) {
- return status.code() == StatusCode::kInvalidArgument;
-}
-
-bool IsNotFound(const Status& status) {
- return status.code() == StatusCode::kNotFound;
-}
-
-bool IsOutOfRange(const Status& status) {
- return status.code() == StatusCode::kOutOfRange;
-}
-
-bool IsPermissionDenied(const Status& status) {
- return status.code() == StatusCode::kPermissionDenied;
-}
-
-bool IsResourceExhausted(const Status& status) {
- return status.code() == StatusCode::kResourceExhausted;
-}
-
-bool IsUnauthenticated(const Status& status) {
- return status.code() == StatusCode::kUnauthenticated;
-}
-
-bool IsUnavailable(const Status& status) {
- return status.code() == StatusCode::kUnavailable;
-}
-
-bool IsUnimplemented(const Status& status) {
- return status.code() == StatusCode::kUnimplemented;
-}
-
-bool IsUnknown(const Status& status) {
- return status.code() == StatusCode::kUnknown;
-}
-
-} // namespace iree
diff --git a/iree/base/internal/status_errors.h b/iree/base/internal/status_errors.h
deleted file mode 100644
index ac23d74..0000000
--- a/iree/base/internal/status_errors.h
+++ /dev/null
@@ -1,60 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INTERNAL_STATUS_ERRORS_H_
-#define IREE_BASE_INTERNAL_STATUS_ERRORS_H_
-
-#include "absl/base/attributes.h"
-#include "absl/strings/string_view.h"
-#include "iree/base/internal/status.h"
-
-namespace iree {
-
-Status AbortedError(absl::string_view message);
-Status AlreadyExistsError(absl::string_view message);
-Status CancelledError(absl::string_view message);
-Status DataLossError(absl::string_view message);
-Status DeadlineExceededError(absl::string_view message);
-Status FailedPreconditionError(absl::string_view message);
-Status InternalError(absl::string_view message);
-Status InvalidArgumentError(absl::string_view message);
-Status NotFoundError(absl::string_view message);
-Status OutOfRangeError(absl::string_view message);
-Status PermissionDeniedError(absl::string_view message);
-Status ResourceExhaustedError(absl::string_view message);
-Status UnauthenticatedError(absl::string_view message);
-Status UnavailableError(absl::string_view message);
-Status UnimplementedError(absl::string_view message);
-Status UnknownError(absl::string_view message);
-
-ABSL_MUST_USE_RESULT bool IsAborted(const Status& status);
-ABSL_MUST_USE_RESULT bool IsAlreadyExists(const Status& status);
-ABSL_MUST_USE_RESULT bool IsCancelled(const Status& status);
-ABSL_MUST_USE_RESULT bool IsDataLoss(const Status& status);
-ABSL_MUST_USE_RESULT bool IsDeadlineExceeded(const Status& status);
-ABSL_MUST_USE_RESULT bool IsFailedPrecondition(const Status& status);
-ABSL_MUST_USE_RESULT bool IsInternal(const Status& status);
-ABSL_MUST_USE_RESULT bool IsInvalidArgument(const Status& status);
-ABSL_MUST_USE_RESULT bool IsNotFound(const Status& status);
-ABSL_MUST_USE_RESULT bool IsOutOfRange(const Status& status);
-ABSL_MUST_USE_RESULT bool IsPermissionDenied(const Status& status);
-ABSL_MUST_USE_RESULT bool IsResourceExhausted(const Status& status);
-ABSL_MUST_USE_RESULT bool IsUnauthenticated(const Status& status);
-ABSL_MUST_USE_RESULT bool IsUnavailable(const Status& status);
-ABSL_MUST_USE_RESULT bool IsUnimplemented(const Status& status);
-ABSL_MUST_USE_RESULT bool IsUnknown(const Status& status);
-
-} // namespace iree
-
-#endif // IREE_BASE_INTERNAL_STATUS_ERRORS_H_
diff --git a/iree/base/internal/status_macros.h b/iree/base/internal/status_macros.h
deleted file mode 100644
index def1e6e..0000000
--- a/iree/base/internal/status_macros.h
+++ /dev/null
@@ -1,108 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INTERNAL_STATUS_MACROS_H_
-#define IREE_BASE_INTERNAL_STATUS_MACROS_H_
-
-#include "iree/base/internal/status.h"
-#include "iree/base/internal/status_builder.h"
-#include "iree/base/internal/statusor.h"
-#include "iree/base/source_location.h"
-
-// Evaluates an expression that produces a `iree::Status`. If the status is not
-// ok, returns it from the current function.
-#define RETURN_IF_ERROR(expr) \
- STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
- if (iree::status_macro_internal::StatusAdaptorForMacros \
- status_macro_internal_adaptor = {(expr), IREE_LOC}) { \
- } else /* NOLINT */ \
- return status_macro_internal_adaptor.Consume()
-
-// Executes an expression `rexpr` that returns a `iree::StatusOr<T>`. On OK,
-// moves its value into the variable defined by `lhs`, otherwise returns
-// from the current function.
-#define ASSIGN_OR_RETURN(...) \
- STATUS_MACROS_IMPL_GET_VARIADIC_((__VA_ARGS__, \
- STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \
- STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \
- (__VA_ARGS__)
-
-// =================================================================
-// == Implementation details, do not rely on anything below here. ==
-// =================================================================
-
-// MSVC incorrectly expands variadic macros, splice together a macro call to
-// work around the bug.
-#define STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME
-#define STATUS_MACROS_IMPL_GET_VARIADIC_(args) \
- STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args
-
-#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \
- STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, std::move(_))
-#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \
- STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \
- STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \
- error_expression)
-#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \
- error_expression) \
- auto statusor = (rexpr); \
- if (ABSL_PREDICT_FALSE(!statusor.ok())) { \
- iree::StatusBuilder _(std::move(statusor).status(), IREE_LOC); \
- (void)_; /* error_expression is allowed to not use this variable */ \
- return (error_expression); \
- } \
- lhs = std::move(statusor).ValueOrDie()
-
-// Internal helper for concatenating macro values.
-#define STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
-#define STATUS_MACROS_IMPL_CONCAT_(x, y) STATUS_MACROS_IMPL_CONCAT_INNER_(x, y)
-
-// clang-format off
-#define STATUS_MACROS_IMPL_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT
-// clang-format on
-
-namespace iree {
-namespace status_macro_internal {
-
-// Provides a conversion to bool so that it can be used inside an if statement
-// that declares a variable.
-class StatusAdaptorForMacros {
- public:
- StatusAdaptorForMacros(const Status& status, SourceLocation loc)
- : builder_(status, loc) {}
-
- StatusAdaptorForMacros(Status&& status, SourceLocation loc)
- : builder_(std::move(status), loc) {}
-
- StatusAdaptorForMacros(const StatusBuilder& builder, SourceLocation loc)
- : builder_(builder) {}
-
- StatusAdaptorForMacros(StatusBuilder&& builder, SourceLocation loc)
- : builder_(std::move(builder)) {}
-
- StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete;
- StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete;
-
- explicit operator bool() const { return ABSL_PREDICT_TRUE(builder_.ok()); }
-
- StatusBuilder&& Consume() { return std::move(builder_); }
-
- private:
- StatusBuilder builder_;
-};
-
-} // namespace status_macro_internal
-} // namespace iree
-
-#endif // IREE_BASE_INTERNAL_STATUS_MACROS_H_
diff --git a/iree/base/internal/status_matchers.h b/iree/base/internal/status_matchers.h
deleted file mode 100644
index 83e6a8f..0000000
--- a/iree/base/internal/status_matchers.h
+++ /dev/null
@@ -1,299 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INTERNAL_STATUS_MATCHERS_H_
-#define IREE_BASE_INTERNAL_STATUS_MATCHERS_H_
-
-#include <memory>
-
-#include "absl/strings/str_cat.h"
-#include "absl/types/optional.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status.h"
-
-#undef EXPECT_OK
-#undef ASSERT_OK
-#undef ASSERT_OK_AND_ASSIGN
-
-namespace iree {
-
-namespace internal {
-
-// Implements a gMock matcher that checks that an iree::StaturOr<T> has an OK
-// status and that the contained T value matches another matcher.
-template <typename T>
-class IsOkAndHoldsMatcher
- : public ::testing::MatcherInterface<const StatusOr<T> &> {
- public:
- template <typename MatcherT>
- IsOkAndHoldsMatcher(MatcherT &&value_matcher)
- : value_matcher_(::testing::SafeMatcherCast<const T &>(value_matcher)) {}
-
- // From testing::MatcherInterface.
- void DescribeTo(std::ostream *os) const override {
- *os << "is OK and contains a value that ";
- value_matcher_.DescribeTo(os);
- }
-
- // From testing::MatcherInterface.
- void DescribeNegationTo(std::ostream *os) const override {
- *os << "is not OK or contains a value that ";
- value_matcher_.DescribeNegationTo(os);
- }
-
- // From testing::MatcherInterface.
- bool MatchAndExplain(
- const StatusOr<T> &status_or,
- ::testing::MatchResultListener *listener) const override {
- if (!status_or.ok()) {
- *listener << "which is not OK";
- return false;
- }
-
- ::testing::StringMatchResultListener value_listener;
- bool is_a_match =
- value_matcher_.MatchAndExplain(status_or.ValueOrDie(), &value_listener);
- std::string value_explanation = value_listener.str();
- if (!value_explanation.empty()) {
- *listener << absl::StrCat("which contains a value ", value_explanation);
- }
-
- return is_a_match;
- }
-
- private:
- const ::testing::Matcher<const T &> value_matcher_;
-};
-
-// A polymorphic IsOkAndHolds() matcher.
-//
-// IsOkAndHolds() returns a matcher that can be used to process an IsOkAndHolds
-// expectation. However, the value type T is not provided when IsOkAndHolds() is
-// invoked. The value type is only inferable when the gUnit framework invokes
-// the matcher with a value. Consequently, the IsOkAndHolds() function must
-// return an object that is implicitly convertible to a matcher for StatusOr<T>.
-// gUnit refers to such an object as a polymorphic matcher, since it can be used
-// to match with more than one type of value.
-template <typename ValueMatcherT>
-class IsOkAndHoldsGenerator {
- public:
- explicit IsOkAndHoldsGenerator(ValueMatcherT value_matcher)
- : value_matcher_(std::move(value_matcher)) {}
-
- template <typename T>
- operator ::testing::Matcher<const StatusOr<T> &>() const {
- return ::testing::MakeMatcher(new IsOkAndHoldsMatcher<T>(value_matcher_));
- }
-
- private:
- const ValueMatcherT value_matcher_;
-};
-
-// Implements a gMock matcher for checking error-code expectations on
-// iree::Status objects.
-template <typename Enum>
-class StatusMatcher : public ::testing::MatcherInterface<const Status &> {
- public:
- StatusMatcher(Enum code, absl::optional<absl::string_view> message)
- : code_(code), message_(message) {}
-
- // From testing::MatcherInterface.
- //
- // Describes the expected error code.
- void DescribeTo(std::ostream *os) const override {
- *os << "error code " << StatusCodeToString(code_);
- if (message_.has_value()) {
- *os << "::'" << message_.value() << "'";
- }
- }
-
- // From testing::MatcherInterface.
- //
- // Tests whether |status| has an error code that meets this matcher's
- // expectation. If an error message string is specified in this matcher, it
- // also tests that |status| has an error message that matches that
- // expectation.
- bool MatchAndExplain(
- const Status &status,
- ::testing::MatchResultListener *listener) const override {
- if (status.code() != code_) {
- *listener << "whose error code is " << StatusCodeToString(status.code());
- return false;
- }
- if (message_.has_value() && status.message() != message_.value()) {
- *listener << "whose error message is '" << status.message() << "'";
- return false;
- }
- return true;
- }
-
- private:
- // Expected error code.
- const Enum code_;
-
- // Expected error message (empty if none expected and verified).
- const absl::optional<std::string> message_;
-};
-
-// Implements a gMock matcher that checks whether a status container (e.g.
-// iree::Status or iree::StatusOr<T>) has an OK status.
-template <class T>
-class IsOkMatcherImpl : public ::testing::MatcherInterface<T> {
- public:
- IsOkMatcherImpl() = default;
-
- // From testing::MatcherInterface.
- //
- // Describes the OK expectation.
- void DescribeTo(std::ostream *os) const override { *os << "is OK"; }
-
- // From testing::MatcherInterface.
- //
- // Describes the negative OK expectation.
- void DescribeNegationTo(std::ostream *os) const override {
- *os << "is not OK";
- }
-
- // From testing::MatcherInterface.
- //
- // Tests whether |status_container|'s OK value meets this matcher's
- // expectation.
- bool MatchAndExplain(
- const T &status_container,
- ::testing::MatchResultListener *listener) const override {
- if (!status_container.ok()) {
- *listener << "which is not OK";
- return false;
- }
- return true;
- }
-};
-
-// IsOkMatcherGenerator is an intermediate object returned by iree::IsOk().
-// It implements implicit type-cast operators to supported matcher types:
-// Matcher<const Status &> and Matcher<const StatusOr<T> &>. These typecast
-// operators create gMock matchers that test OK expectations on a status
-// container.
-class IsOkMatcherGenerator {
- public:
- // Type-cast operator for Matcher<const iree::Status &>.
- operator ::testing::Matcher<const Status &>() const {
- return ::testing::MakeMatcher(
- new internal::IsOkMatcherImpl<const Status &>());
- }
-
- // Type-cast operator for Matcher<const iree::StatusOr<T> &>.
- template <class T>
- operator ::testing::Matcher<const StatusOr<T> &>() const {
- return ::testing::MakeMatcher(
- new internal::IsOkMatcherImpl<const StatusOr<T> &>());
- }
-};
-
-} // namespace internal
-
-// Returns a gMock matcher that expects an iree::StatusOr<T> object to have an
-// OK status and for the contained T object to match |value_matcher|.
-//
-// Example:
-//
-// StatusOr<string> raven_speech_result = raven.Speak();
-// EXPECT_THAT(raven_speech_result, IsOkAndHolds(HasSubstr("nevermore")));
-//
-// If foo is an object of type T and foo_result is an object of type
-// StatusOr<T>, you can write:
-//
-// EXPECT_THAT(foo_result, IsOkAndHolds(foo));
-//
-// instead of:
-//
-// EXPECT_THAT(foo_result, IsOkAndHolds(Eq(foo)));
-template <typename ValueMatcherT>
-internal::IsOkAndHoldsGenerator<ValueMatcherT> IsOkAndHolds(
- ValueMatcherT value_matcher) {
- return internal::IsOkAndHoldsGenerator<ValueMatcherT>(value_matcher);
-}
-
-// Returns a gMock matcher that expects an iree::Status object to have the
-// given |code|.
-template <typename Enum>
-::testing::Matcher<const Status &> StatusIs(Enum code) {
- return ::testing::MakeMatcher(
- new internal::StatusMatcher<Enum>(code, absl::nullopt));
-}
-
-// Returns a gMock matcher that expects an iree::Status object to have the
-// given |code| and |message|.
-template <typename Enum>
-::testing::Matcher<const Status &> StatusIs(Enum code,
- absl::string_view message) {
- return ::testing::MakeMatcher(
- new internal::StatusMatcher<Enum>(code, message));
-}
-
-// Returns an internal::IsOkMatcherGenerator, which may be typecast to a
-// Matcher<iree::Status> or Matcher<iree::StatusOr<T>>. These gMock
-// matchers test that a given status container has an OK status.
-inline internal::IsOkMatcherGenerator IsOk() {
- return internal::IsOkMatcherGenerator();
-}
-
-// Macros for testing the results of functions that return iree::Status or
-// iree::StatusOr<T> (for any type T).
-#define EXPECT_OK(rexpr) EXPECT_THAT(rexpr, ::iree::IsOk())
-#define ASSERT_OK(rexpr) ASSERT_THAT(rexpr, ::iree::IsOk())
-
-// Executes an expression that returns an iree::StatusOr<T>, and assigns the
-// contained variable to lhs if the error code is OK.
-// If the Status is non-OK, generates a test failure and returns from the
-// current function, which must have a void return type.
-//
-// Example: Assigning to an existing value
-// ASSERT_OK_AND_ASSIGN(ValueType value, MaybeGetValue(arg));
-//
-// The value assignment example might expand into:
-// StatusOr<ValueType> status_or_value = MaybeGetValue(arg);
-// ASSERT_OK(status_or_value.status());
-// ValueType value = status_or_value.ValueOrDie();
-#define ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
- IREE_ASSERT_OK_AND_ASSIGN_IMPL( \
- IREE_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
- rexpr);
-
-#define IREE_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \
- auto statusor = (rexpr); \
- ASSERT_OK(statusor.status()) << statusor.status(); \
- lhs = std::move(statusor.ValueOrDie())
-#define IREE_STATUS_MACROS_CONCAT_NAME(x, y) \
- IREE_STATUS_MACROS_CONCAT_IMPL(x, y)
-#define IREE_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
-
-// Implements the PrintTo() method for iree::StatusOr<T>. This method is
-// used by gUnit to print iree::StatusOr<T> objects for debugging. The
-// implementation relies on gUnit for printing values of T when a
-// iree::StatusOr<T> object is OK and contains a value.
-template <typename T>
-void PrintTo(const StatusOr<T> &statusor, std::ostream *os) {
- if (!statusor.ok()) {
- *os << statusor.status();
- } else {
- *os << absl::StrCat("OK: ",
- ::testing::PrintToString(statusor.ValueOrDie()));
- }
-}
-
-} // namespace iree
-
-#endif // IREE_BASE_INTERNAL_STATUS_MATCHERS_H_
diff --git a/iree/base/internal/status_win32_errors.cc b/iree/base/internal/status_win32_errors.cc
deleted file mode 100644
index 16814d9..0000000
--- a/iree/base/internal/status_win32_errors.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/internal/status_win32_errors.h"
-
-#include "absl/strings/str_cat.h"
-
-#if defined(IREE_PLATFORM_WINDOWS)
-
-#include <windows.h>
-
-namespace iree {
-
-StatusCode Win32ErrorToCanonicalCode(uint32_t error) {
- switch (error) {
- case ERROR_SUCCESS:
- return StatusCode::kOk;
- case ERROR_FILE_NOT_FOUND:
- case ERROR_PATH_NOT_FOUND:
- return StatusCode::kNotFound;
- case ERROR_TOO_MANY_OPEN_FILES:
- case ERROR_OUTOFMEMORY:
- case ERROR_HANDLE_DISK_FULL:
- case ERROR_HANDLE_EOF:
- return StatusCode::kResourceExhausted;
- case ERROR_ACCESS_DENIED:
- return StatusCode::kPermissionDenied;
- case ERROR_INVALID_HANDLE:
- return StatusCode::kInvalidArgument;
- case ERROR_NOT_READY:
- case ERROR_READ_FAULT:
- return StatusCode::kUnavailable;
- case ERROR_WRITE_FAULT:
- return StatusCode::kDataLoss;
- case ERROR_NOT_SUPPORTED:
- return StatusCode::kUnimplemented;
- default:
- return StatusCode::kUnknown;
- }
-}
-
-StatusBuilder Win32ErrorToCanonicalStatusBuilder(uint32_t error,
- SourceLocation location) {
- // TODO(benvanik): use FormatMessage; or defer until required?
- return StatusBuilder(
- Status(Win32ErrorToCanonicalCode(error), absl::StrCat("<TBD>", error)),
- location);
-}
-
-} // namespace iree
-
-#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/base/internal/status_win32_errors.h b/iree/base/internal/status_win32_errors.h
deleted file mode 100644
index 599852d..0000000
--- a/iree/base/internal/status_win32_errors.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_
-#define IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_
-
-#include "absl/strings/string_view.h"
-#include "iree/base/internal/statusor.h"
-#include "iree/base/source_location.h"
-#include "iree/base/target_platform.h"
-
-#if defined(IREE_PLATFORM_WINDOWS)
-
-namespace iree {
-
-// Returns the code for |error| which should be a Win32 error dword.
-StatusCode Win32ErrorToCanonicalCode(uint32_t error);
-
-// Returns a StatusBuilder with a status describing the |error| and |location|.
-StatusBuilder Win32ErrorToCanonicalStatusBuilder(uint32_t error,
- SourceLocation location);
-
-} // namespace iree
-
-#endif // IREE_PLATFORM_WINDOWS
-
-#endif // IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_
diff --git a/iree/base/internal/statusor.cc b/iree/base/internal/statusor.cc
deleted file mode 100644
index f7eb153..0000000
--- a/iree/base/internal/statusor.cc
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/internal/statusor.h"
-
-#include "iree/base/internal/status_errors.h"
-
-namespace iree {
-
-namespace internal_statusor {
-
-void Helper::HandleInvalidStatusCtorArg(Status* status) {
- const char* kMessage =
- "An OK status is not a valid constructor argument to StatusOr<T>";
- LOG(ERROR) << kMessage;
- *status = InternalError(kMessage);
- abort();
-}
-
-void Helper::Crash(const Status& status) {
- LOG(FATAL) << "Attempting to fetch value instead of handling error "
- << status;
- abort();
-}
-
-} // namespace internal_statusor
-
-} // namespace iree
diff --git a/iree/base/internal/statusor.h b/iree/base/internal/statusor.h
deleted file mode 100644
index 4784d38..0000000
--- a/iree/base/internal/statusor.h
+++ /dev/null
@@ -1,699 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INTERNAL_STATUSOR_H_
-#define IREE_BASE_INTERNAL_STATUSOR_H_
-
-#include "absl/base/attributes.h"
-#include "iree/base/internal/status.h"
-#include "iree/base/internal/status_builder.h"
-
-namespace iree {
-
-template <typename T>
-class ABSL_MUST_USE_RESULT StatusOr;
-
-namespace internal_statusor {
-
-template <typename T, typename U>
-using IsStatusOrConversionAmbiguous =
- absl::disjunction<std::is_constructible<T, StatusOr<U>&>,
- std::is_constructible<T, const StatusOr<U>&>,
- std::is_constructible<T, StatusOr<U>&&>,
- std::is_constructible<T, const StatusOr<U>&&>,
- std::is_convertible<StatusOr<U>&, T>,
- std::is_convertible<const StatusOr<U>&, T>,
- std::is_convertible<StatusOr<U>&&, T>,
- std::is_convertible<const StatusOr<U>&&, T>>;
-
-template <typename T, typename U>
-using IsStatusOrConversionAssigmentAmbiguous =
- absl::disjunction<IsStatusOrConversionAmbiguous<T, U>,
- std::is_assignable<T&, StatusOr<U>&>,
- std::is_assignable<T&, const StatusOr<U>&>,
- std::is_assignable<T&, StatusOr<U>&&>,
- std::is_assignable<T&, const StatusOr<U>&&>>;
-
-template <typename T, typename U>
-struct IsAmbiguousStatusOrForInitialization
- : // Strip const-value refs from type and check again, else false_type.
- public absl::conditional_t<
- std::is_same<absl::remove_cv_t<absl::remove_reference_t<U>>,
- U>::value,
- std::false_type,
- IsAmbiguousStatusOrForInitialization<
- T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
-
-template <typename T, typename U>
-struct IsAmbiguousStatusOrForInitialization<T, StatusOr<U>>
- : public IsStatusOrConversionAmbiguous<T, U> {};
-
-template <typename T, typename U>
-using IsStatusOrDirectInitializationAmbiguous = absl::disjunction<
- std::is_same<StatusOr<T>, absl::remove_cv_t<absl::remove_reference_t<U>>>,
- std::is_same<Status, absl::remove_cv_t<absl::remove_reference_t<U>>>,
- std::is_same<StatusBuilder, absl::remove_cv_t<absl::remove_reference_t<U>>>,
- std::is_same<absl::in_place_t,
- absl::remove_cv_t<absl::remove_reference_t<U>>>,
- IsAmbiguousStatusOrForInitialization<T, U>>;
-
-template <typename T, typename U>
-using IsStatusOrDirectInitializationValid = absl::disjunction<
- // The is_same allows nested status ors to ignore this check iff same type.
- std::is_same<T, absl::remove_cv_t<absl::remove_reference_t<U>>>,
- absl::negation<IsStatusOrDirectInitializationAmbiguous<T, U>>>;
-
-class Helper {
- public:
- ABSL_ATTRIBUTE_NORETURN static void HandleInvalidStatusCtorArg(Status*);
- ABSL_ATTRIBUTE_NORETURN static void Crash(const Status& status);
-};
-
-// Construct an instance of T in `p` through placement new, passing Args... to
-// the constructor.
-// This abstraction is here mostly for the gcc performance fix.
-template <typename T, typename... Args>
-void PlacementNew(void* p, Args&&... args) {
-#if defined(__GNUC__) && !defined(__clang__)
- // Teach gcc that 'p' cannot be null, fixing code size issues.
- if (p == nullptr) __builtin_unreachable();
-#endif
- new (p) T(std::forward<Args>(args)...);
-}
-
-// Helper base class to hold the data and all operations.
-// We move all this to a base class to allow mixing with the appropriate
-// TraitsBase specialization.
-template <typename T>
-class StatusOrData {
- template <typename U>
- friend class StatusOrData;
-
- public:
- StatusOrData() = delete;
-
- StatusOrData(const StatusOrData& other) {
- if (other.ok()) {
- MakeValue(other.data_);
- MakeStatus();
- } else {
- MakeStatus(other.status_);
- }
- }
-
- StatusOrData(StatusOrData&& other) noexcept {
- if (other.ok()) {
- MakeValue(std::move(other.data_));
- MakeStatus();
- } else {
- MakeStatus(std::move(other.status_));
- }
- }
-
- template <typename U>
- explicit StatusOrData(const StatusOrData<U>& other) {
- if (other.ok()) {
- MakeValue(other.data_);
- MakeStatus();
- } else {
- MakeStatus(other.status_);
- }
- }
-
- template <typename U>
- explicit StatusOrData(StatusOrData<U>&& other) {
- if (other.ok()) {
- MakeValue(std::move(other.data_));
- MakeStatus();
- } else {
- MakeStatus(std::move(other.status_));
- }
- }
-
- template <typename... Args>
- explicit StatusOrData(absl::in_place_t, Args&&... args)
- : data_(std::forward<Args>(args)...) {
- MakeStatus();
- }
-
- explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); }
- explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); }
-
- explicit StatusOrData(const Status& status) : status_(status) {
- EnsureNotOk();
- }
- explicit StatusOrData(Status&& status) : status_(status) { EnsureNotOk(); }
-
- explicit StatusOrData(const StatusBuilder& builder) : status_(builder) {
- EnsureNotOk();
- }
- explicit StatusOrData(StatusBuilder&& builder) : status_(std::move(builder)) {
- EnsureNotOk();
- }
-
- StatusOrData& operator=(const StatusOrData& other) {
- if (this == &other) return *this;
- if (other.ok())
- Assign(other.data_);
- else
- Assign(other.status_);
- return *this;
- }
-
- StatusOrData& operator=(StatusOrData&& other) {
- if (this == &other) return *this;
- if (other.ok())
- Assign(std::move(other.data_));
- else
- Assign(std::move(other.status_));
- return *this;
- }
-
- ~StatusOrData() {
- if (ok()) {
- status_.~Status();
- data_.~T();
- } else {
- status_.~Status();
- }
- }
-
- void Assign(const T& value) {
- if (ok()) {
- data_.~T();
- MakeValue(value);
- } else {
- MakeValue(value);
- status_ = OkStatus();
- }
- }
-
- void Assign(T&& value) {
- if (ok()) {
- data_.~T();
- MakeValue(std::move(value));
- } else {
- MakeValue(std::move(value));
- status_ = OkStatus();
- }
- }
-
- void Assign(const Status& status) {
- Clear();
- status_ = status;
- EnsureNotOk();
- }
-
- void Assign(Status&& status) {
- Clear();
- status_ = std::move(status);
- EnsureNotOk();
- }
-
- bool ok() const { return status_.ok(); }
-
- protected:
- // status_ will always be active after the constructor.
- // Union to be able to initialize exactly how we need without waste.
- // Eg. in the copy constructor we use the default constructor of Status in
- // the ok() path to avoid an extra Ref call.
- union {
- Status status_;
- };
-
- // data_ is active iff status_.ok()==true
- struct Dummy {};
- union {
- // When T is const, we need some non-const object we can cast to void* for
- // the placement new. dummy_ is that object.
- Dummy dummy_;
- T data_;
- };
-
- void Clear() {
- if (ok()) data_.~T();
- }
-
- void EnsureOk() const {
- if (!ok()) Helper::Crash(status_);
- }
-
- void EnsureNotOk() {
- if (ok()) Helper::HandleInvalidStatusCtorArg(&status_);
- }
-
- // Construct the value (data_) through placement new with the passed arg.
- template <typename Arg>
- void MakeValue(Arg&& arg) {
- internal_statusor::PlacementNew<T>(&dummy_, std::forward<Arg>(arg));
- }
-
- // Construct the status (status_) through placement new with the passed arg.
- template <typename... Args>
- void MakeStatus(Args&&... args) {
- internal_statusor::PlacementNew<Status>(&status_,
- std::forward<Args>(args)...);
- }
-};
-
-// Helper base class to allow implicitly deleted constructors and assignment
-// operations in StatusOr.
-// TraitsBase will explicitly delete what it can't support and StatusOr will
-// inherit that behavior implicitly.
-template <bool Copy, bool Move>
-struct TraitsBase {
- TraitsBase() = default;
- TraitsBase(const TraitsBase&) = default;
- TraitsBase(TraitsBase&&) = default;
- TraitsBase& operator=(const TraitsBase&) = default;
- TraitsBase& operator=(TraitsBase&&) = default;
-};
-
-template <>
-struct TraitsBase<false, true> {
- TraitsBase() = default;
- TraitsBase(const TraitsBase&) = delete;
- TraitsBase(TraitsBase&&) = default;
- TraitsBase& operator=(const TraitsBase&) = delete;
- TraitsBase& operator=(TraitsBase&&) = default;
-};
-
-template <>
-struct TraitsBase<false, false> {
- TraitsBase() = default;
- TraitsBase(const TraitsBase&) = delete;
- TraitsBase(TraitsBase&&) = delete;
- TraitsBase& operator=(const TraitsBase&) = delete;
- TraitsBase& operator=(TraitsBase&&) = delete;
-};
-
-} // namespace internal_statusor
-
-// StatusOr<T> is the union of a Status object and a T object.
-//
-// A StatusOr object either holds a usable value, or an error Status explaining
-// why such a value is not present.
-template <typename T>
-class StatusOr : private internal_statusor::StatusOrData<T>,
- private internal_statusor::TraitsBase<
- std::is_copy_constructible<T>::value,
- std::is_move_constructible<T>::value> {
- template <typename U>
- friend class StatusOr;
-
- typedef internal_statusor::StatusOrData<T> Base;
-
- public:
- typedef T element_type;
-
- // Constructs a new StatusOr with StatusCode::kUnknown status.
- explicit StatusOr();
-
- // StatusOr<T> is copy constructible/assignable if T is copy constructible.
- StatusOr(const StatusOr&) = default;
- StatusOr& operator=(const StatusOr&) = default;
-
- // StatusOr<T> is move constructible/assignable if T is move constructible.
- StatusOr(StatusOr&&) = default;
- StatusOr& operator=(StatusOr&&) = default;
-
- // Converting constructors from StatusOr<U>, when T is constructible from U.
- // To avoid ambiguity, they are disabled if T is also constructible from
- // StatusOr<U>. Explicit iff the corresponding construction of T from U is
- // explicit.
- template <
- typename U,
- absl::enable_if_t<
- absl::conjunction<
- absl::negation<std::is_same<T, U>>,
- std::is_constructible<T, const U&>,
- std::is_convertible<const U&, T>,
- absl::negation<internal_statusor::IsStatusOrConversionAmbiguous<
- T, U>>>::value,
- int> = 0>
- StatusOr(const StatusOr<U>& other) // NOLINT
- : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
- template <
- typename U,
- absl::enable_if_t<
- absl::conjunction<
- absl::negation<std::is_same<T, U>>,
- std::is_constructible<T, const U&>,
- absl::negation<std::is_convertible<const U&, T>>,
- absl::negation<internal_statusor::IsStatusOrConversionAmbiguous<
- T, U>>>::value,
- int> = 0>
- explicit StatusOr(const StatusOr<U>& other)
- : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
-
- template <
- typename U,
- absl::enable_if_t<
- absl::conjunction<
- absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
- std::is_convertible<U&&, T>,
- absl::negation<internal_statusor::IsStatusOrConversionAmbiguous<
- T, U>>>::value,
- int> = 0>
- StatusOr(StatusOr<U>&& other) // NOLINT
- : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
- template <
- typename U,
- absl::enable_if_t<
- absl::conjunction<
- absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
- absl::negation<std::is_convertible<U&&, T>>,
- absl::negation<internal_statusor::IsStatusOrConversionAmbiguous<
- T, U>>>::value,
- int> = 0>
- explicit StatusOr(StatusOr<U>&& other)
- : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
-
- // Conversion copy/move assignment operator, T must be constructible and
- // assignable from U. Only enable if T cannot be directly assigned from
- // StatusOr<U>.
- template <
- typename U,
- absl::enable_if_t<
- absl::conjunction<
- absl::negation<std::is_same<T, U>>,
- std::is_constructible<T, const U&>,
- std::is_assignable<T, const U&>,
- absl::negation<
- internal_statusor::IsStatusOrConversionAssigmentAmbiguous<
- T, U>>>::value,
- int> = 0>
- StatusOr& operator=(const StatusOr<U>& other) {
- this->Assign(other);
- return *this;
- }
- template <
- typename U,
- absl::enable_if_t<
- absl::conjunction<
- absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
- std::is_assignable<T, U&&>,
- absl::negation<
- internal_statusor::IsStatusOrConversionAssigmentAmbiguous<
- T, U>>>::value,
- int> = 0>
- StatusOr& operator=(StatusOr<U>&& other) {
- this->Assign(std::move(other));
- return *this;
- }
-
- // Constructs a new StatusOr with the given value. After calling this
- // constructor, this->ok() will be true and the contained value may be
- // retrieved with ValueOrDie(), operator*(), or operator->().
- StatusOr(const T& value);
-
- // Constructs a new StatusOr with the given non-ok status. After calling this
- // constructor, this->ok() will be false and calls to ValueOrDie() will
- // CHECK-fail.
- StatusOr(const Status& status);
- StatusOr& operator=(const Status& status);
- StatusOr(const StatusBuilder& builder);
- StatusOr& operator=(const StatusBuilder& builder);
-
- // Similar to the `const T&` overload.
- //
- // REQUIRES: T is move constructible.
- StatusOr(T&& value);
-
- // RValue versions of the operations declared above.
- StatusOr(Status&& status);
- StatusOr& operator=(Status&& status);
- StatusOr(StatusBuilder&& builder);
- StatusOr& operator=(StatusBuilder&& builder);
-
- // Constructs the inner value T in-place using the provided args, using the
- // T(args...) constructor.
- template <typename... Args>
- explicit StatusOr(absl::in_place_t, Args&&... args);
- template <typename U, typename... Args>
- explicit StatusOr(absl::in_place_t, std::initializer_list<U> ilist,
- Args&&... args);
-
- // Constructs the inner value T in-place using the provided args, using the
- // T(U) (direct-initialization) constructor. Only valid if T can be
- // constructed from a U. Can accept move or copy constructors. Explicit it
- // U is not convertible to T. To avoid ambiguity, this is disabled if U is
- // a StatusOr<J>, where J is convertible to T.
- template <
- typename U = T,
- absl::enable_if_t<
- absl::conjunction<
- internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>,
- std::is_constructible<T, U&&>,
- std::is_convertible<U&&, T>>::value,
- int> = 0>
- StatusOr(U&& u) // NOLINT
- : StatusOr(absl::in_place, std::forward<U>(u)) {}
-
- template <
- typename U = T,
- absl::enable_if_t<
- absl::conjunction<
- internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>,
- std::is_constructible<T, U&&>,
- absl::negation<std::is_convertible<U&&, T>>>::value,
- int> = 0>
- explicit StatusOr(U&& u) // NOLINT
- : StatusOr(absl::in_place, std::forward<U>(u)) {}
-
- // Returns this->ok()
- explicit operator bool() const { return ok(); }
-
- // Returns this->status().ok()
- ABSL_MUST_USE_RESULT bool ok() const { return this->status_.ok(); }
-
- // Returns a reference to our status. If this contains a T, then
- // returns OkStatus().
- const Status& status() const&;
- Status status() &&;
-
- // Returns a reference to our current value, or CHECK-fails if !this->ok(). If
- // you have already checked the status using this->ok() or operator bool(),
- // then you probably want to use operator*() or operator->() to access the
- // current value instead of ValueOrDie().
- const T& ValueOrDie() const&;
- T& ValueOrDie() &;
- const T&& ValueOrDie() const&&;
- T&& ValueOrDie() &&;
-
- // Returns a reference to the current value.
- //
- // REQUIRES: this->ok() == true, otherwise the behavior is undefined.
- const T& operator*() const&;
- T& operator*() &;
- const T&& operator*() const&&;
- T&& operator*() &&;
-
- // Returns a pointer to the current value.
- //
- // REQUIRES: this->ok() == true, otherwise the behavior is undefined.
- const T* operator->() const;
- T* operator->();
-
- // Returns a copy of the current value if this->ok() == true. Otherwise
- // returns a default value.
- template <typename U>
- T value_or(U&& default_value) const&;
- template <typename U>
- T value_or(U&& default_value) &&;
-
- // Ignores any errors. This method does nothing except potentially suppress
- // complaints from any tools that are checking that errors are not dropped on
- // the floor.
- void IgnoreError() const;
-
- private:
- using internal_statusor::StatusOrData<T>::Assign;
- template <typename U>
- void Assign(const StatusOr<U>& other);
- template <typename U>
- void Assign(StatusOr<U>&& other);
-};
-
-////////////////////////////////////////////////////////////////////////////////
-// Implementation details for StatusOr<T>
-
-template <typename T>
-StatusOr<T>::StatusOr() : Base(Status(StatusCode::kUnknown, "")) {}
-
-template <typename T>
-StatusOr<T>::StatusOr(const T& value) : Base(value) {}
-
-template <typename T>
-StatusOr<T>::StatusOr(const Status& status) : Base(status) {}
-
-template <typename T>
-StatusOr<T>::StatusOr(const StatusBuilder& builder) : Base(builder) {}
-
-template <typename T>
-StatusOr<T>& StatusOr<T>::operator=(const Status& status) {
- this->Assign(status);
- return *this;
-}
-
-template <typename T>
-StatusOr<T>& StatusOr<T>::operator=(const StatusBuilder& builder) {
- return *this = static_cast<Status>(builder);
-}
-
-template <typename T>
-StatusOr<T>::StatusOr(T&& value) : Base(std::move(value)) {}
-
-template <typename T>
-StatusOr<T>::StatusOr(Status&& status) : Base(std::move(status)) {}
-
-template <typename T>
-StatusOr<T>::StatusOr(StatusBuilder&& builder) : Base(std::move(builder)) {}
-
-template <typename T>
-StatusOr<T>& StatusOr<T>::operator=(Status&& status) {
- this->Assign(std::move(status));
- return *this;
-}
-
-template <typename T>
-StatusOr<T>& StatusOr<T>::operator=(StatusBuilder&& builder) {
- return *this = static_cast<Status>(std::move(builder));
-}
-
-template <typename T>
-template <typename U>
-inline void StatusOr<T>::Assign(const StatusOr<U>& other) {
- if (other.ok()) {
- this->Assign(other.ValueOrDie());
- } else {
- this->Assign(other.status());
- }
-}
-
-template <typename T>
-template <typename U>
-inline void StatusOr<T>::Assign(StatusOr<U>&& other) {
- if (other.ok()) {
- this->Assign(std::move(other).ValueOrDie());
- } else {
- this->Assign(std::move(other).status());
- }
-}
-template <typename T>
-template <typename... Args>
-StatusOr<T>::StatusOr(absl::in_place_t, Args&&... args)
- : Base(absl::in_place, std::forward<Args>(args)...) {}
-
-template <typename T>
-template <typename U, typename... Args>
-StatusOr<T>::StatusOr(absl::in_place_t, std::initializer_list<U> ilist,
- Args&&... args)
- : Base(absl::in_place, ilist, std::forward<Args>(args)...) {}
-
-template <typename T>
-const Status& StatusOr<T>::status() const& {
- return this->status_;
-}
-template <typename T>
-Status StatusOr<T>::status() && {
- return ok() ? OkStatus() : std::move(this->status_);
-}
-
-template <typename T>
-const T& StatusOr<T>::ValueOrDie() const& {
- this->EnsureOk();
- return this->data_;
-}
-
-template <typename T>
-T& StatusOr<T>::ValueOrDie() & {
- this->EnsureOk();
- return this->data_;
-}
-
-template <typename T>
-const T&& StatusOr<T>::ValueOrDie() const&& {
- this->EnsureOk();
- return std::move(this->data_);
-}
-
-template <typename T>
-T&& StatusOr<T>::ValueOrDie() && {
- this->EnsureOk();
- return std::move(this->data_);
-}
-
-template <typename T>
-const T& StatusOr<T>::operator*() const& {
- this->EnsureOk();
- return this->data_;
-}
-
-template <typename T>
-T& StatusOr<T>::operator*() & {
- this->EnsureOk();
- return this->data_;
-}
-
-template <typename T>
-const T&& StatusOr<T>::operator*() const&& {
- this->EnsureOk();
- return std::move(this->data_);
-}
-
-template <typename T>
-T&& StatusOr<T>::operator*() && {
- this->EnsureOk();
- return std::move(this->data_);
-}
-
-template <typename T>
-const T* StatusOr<T>::operator->() const {
- this->EnsureOk();
- return &this->data_;
-}
-
-template <typename T>
-T* StatusOr<T>::operator->() {
- this->EnsureOk();
- return &this->data_;
-}
-
-template <typename T>
-template <typename U>
-T StatusOr<T>::value_or(U&& default_value) const& {
- if (ok()) {
- return this->data_;
- }
- return std::forward<U>(default_value);
-}
-
-template <typename T>
-template <typename U>
-T StatusOr<T>::value_or(U&& default_value) && {
- if (ok()) {
- return std::move(this->data_);
- }
- return std::forward<U>(default_value);
-}
-
-template <typename T>
-void StatusOr<T>::IgnoreError() const {
- // no-op
-}
-
-} // namespace iree
-
-#endif // IREE_BASE_INTERNAL_STATUSOR_H_
diff --git a/iree/base/intrusive_list.h b/iree/base/intrusive_list.h
deleted file mode 100644
index ce31c54..0000000
--- a/iree/base/intrusive_list.h
+++ /dev/null
@@ -1,758 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Doubly linked list using element interior storage.
-// This has the performance of std::list (that means O(1) on insert and remove)
-// but performs no allocations and has better caching behavior.
-//
-// Elements are maintained in lists by way of IntrusiveListLinks, with each link
-// allowing the element to exist in one list simultaneously. In the most simple
-// case subclassing IntrusiveLinkBase will let the type be added to a list with
-// little boilerplate. If an element must be in more than one list
-// simultaneously IntrusiveListLinks can be added as members.
-//
-// Usage (simple):
-// class MySimpleElement : public IntrusiveLinkBase {};
-// IntrusiveList<MySimpleElement> list;
-// list.push_back(new MySimpleElement());
-// for (auto element : list) { ... }
-//
-// Usage (multiple lists):
-// class MultiElement {
-// public:
-// IntrusiveListLink list_link_a;
-// IntrusiveListLink list_link_b;
-// };
-// IntrusiveList<MultiElement, offsetof(MultiElement, list_link_a)> list_a;
-// IntrusiveList<MultiElement, offsetof(MultiElement, list_link_b)> list_b;
-//
-// By default elements in the list are not retained and must be kept alive
-// externally. For automatic memory management there are specializations for
-// std::unique_ptr.
-//
-// Usage (unique_ptr):
-// IntrusiveList<std::unique_ptr<MyElement>> list;
-// list.push_back(absl::make_unique<MyElement>());
-// std::unique_ptr<MyElement> elm = list.take(list.front());
-//
-// This type is thread-unsafe.
-
-#ifndef IREE_BASE_INTRUSIVE_LIST_H_
-#define IREE_BASE_INTRUSIVE_LIST_H_
-
-#include <cstddef>
-#include <cstdint>
-#include <functional>
-#include <iterator>
-#include <limits>
-
-#include "iree/base/logging.h"
-
-namespace iree {
-
-// Define to enable extensive checks after each mutation of the intrusive list.
-// #define IREE_PARANOID_INTRUSIVE_LIST
-
-// Storage for the doubly-linked list.
-// This is embedded within all elements in an intrusive list.
-struct IntrusiveListLink {
- IntrusiveListLink* prev = nullptr;
- IntrusiveListLink* next = nullptr;
-
- IntrusiveListLink() = default;
-
- // Prevent copies.
- IntrusiveListLink(const IntrusiveListLink&) = delete;
- IntrusiveListLink& operator=(const IntrusiveListLink&) = delete;
-};
-
-template <class T>
-struct IntrusiveLinkBase : public T {
- public:
- IntrusiveListLink link;
-};
-
-template <>
-struct IntrusiveLinkBase<void> {
- public:
- IntrusiveListLink link;
-};
-
-// Base type for intrusive lists.
-// This is either used directly when the list is on naked pointers or
-// specialized to std::unique_ptr.
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-class IntrusiveListBase {
- public:
- using self_type = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>;
-
- IntrusiveListBase() = default;
- virtual ~IntrusiveListBase() { clear(); }
-
- // Prevent copies.
- IntrusiveListBase(const IntrusiveListBase&) = delete;
- IntrusiveListBase& operator=(const IntrusiveListBase&) = delete;
-
- // Returns true if the list is empty.
- // Performance: O(1)
- constexpr bool empty() const { return head_ == nullptr; }
-
- // Returns the total number of items in the list.
- // Performance: O(1)
- constexpr size_t size() const { return count_; }
-
- // Returns true if the given item is contained within the list.
- // Performance: O(n)
- bool contains(T* value) const;
-
- // Appends the contents of the given list to this one.
- // The |other_list| is cleared.
- // Performance: O(1)
- void merge_from(self_type* other_list);
-
- // Removes all items from the list.
- // Performance: O(n)
- void clear();
-
- IteratorT begin() const { return IteratorT(head_); }
- IteratorT end() const { return IteratorT(nullptr); }
- ReverseIteratorT rbegin() const { return ReverseIteratorT(tail_); }
- ReverseIteratorT rend() const { return ReverseIteratorT(nullptr); }
-
- // Returns the next item in the list relative to the given item.
- // |value| must exist in the list.
- // Performance: O(1)
- T* next(T* value) const;
-
- // Returns the previous item in the list relative to the given item.
- // |value| must exist in the list.
- // Performance: O(1)
- T* previous(T* value) const;
-
- // Returns the item at the front of the list, if any.
- // Performance: O(1)
- T* front() const;
-
- // Inserts an item at the front of the list.
- // Performance: O(1)
- void push_front(T* value);
-
- // Removes the item at the front of the list.
- // Performance: O(1)
- void pop_front();
-
- // Returns the item at the back of the list, if any.
- // Performance: O(1)
- T* back() const;
-
- // Inserts an item at the back of the list.
- // Performance: O(1)
- void push_back(T* value);
-
- // Removes the item at the back of the list.
- // Performance: O(1)
- void pop_back();
-
- // Inserts an item into the list before the given iterator.
- // Performance: O(1)
- void insert(const IteratorT& it, T* value) { return insert(*it, value); }
- void insert(T* position, T* value);
-
- // Erases the given item from the list.
- // Returns the item following the erased item, if any.
- // Performance: O(1)
- T* erase(T* value);
-
- // Erases the item from the list at the given iterator.
- // Performance: O(1)
- IteratorT erase(const IteratorT& it);
- ReverseIteratorT erase(const ReverseIteratorT& it);
-
- // Replaces the item with a new item at the same position.
- // |new_value| must not be contained in any list.
- // Performance: O(1)
- void replace(T* old_value, T* new_value);
-
- // Sorts the list with the given comparison function.
- // The sort function is the same as used by std::sort.
- //
- // Uses merge sort O(N log N) using the algorithm described here:
- // http://www.chiark.greenend.org.uk/~sgtatham/algorithms/listsort.html
- void sort(bool (*compare_fn)(T* a, T* b));
-
- protected:
- // Called when an item is added to the list.
- virtual void OnAdd(T* value) {}
- // Called when an item is removed from the list.
- virtual void OnRemove(T* value) {}
- // Called when an item is removed and deallocated.
- virtual void OnDeallocate(T* value) {}
-
- // Performs expensive correctness checks on the list structure. It's too slow
- // to use in normal builds (even dbg), so it should only be used when there's
- // a suspected issue with an intrusive list. Define
- // IREE_PARANOID_INTRUSIVE_LIST to enable.
- void CheckCorrectness() const;
-
- IntrusiveListLink* head_ = nullptr;
- IntrusiveListLink* tail_ = nullptr;
- size_t count_ = 0;
-};
-
-// Basic iterator for an IntrusiveList.
-template <typename T, size_t kOffset, bool kForward>
-class IntrusiveListIterator
- : public std::iterator<std::input_iterator_tag, int> {
- public:
- using self_type = IntrusiveListIterator<T, kOffset, kForward>;
-
- explicit IntrusiveListIterator(IntrusiveListLink* current)
- : current_(current) {}
- IntrusiveListIterator& operator++();
- self_type operator++(int);
- self_type& operator--();
- self_type operator--(int);
- bool operator==(const self_type& rhs) const;
- bool operator!=(const self_type& rhs) const;
- T* operator*() const;
-
- protected:
- IntrusiveListLink* current_;
-};
-
-// Specialized IntrusiveListBase used for unreferenced naked pointers.
-// This very thinly wraps the base type and does no special memory management.
-template <typename T, size_t kOffset>
-class IntrusiveListUnrefBase
- : public IntrusiveListBase<T, IntrusiveListIterator<T, kOffset, true>,
- IntrusiveListIterator<T, kOffset, false>,
- kOffset> {
- public:
- using IteratorT = IntrusiveListIterator<T, kOffset, true>;
- using ReverseIteratorT = IntrusiveListIterator<T, kOffset, false>;
- using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>;
-
- using base_list::clear;
-
- // Removes all items from the list and calls the given deleter function for
- // each of them. The built-in OnDeallocate will not be used.
- // Performance: O(n)
- void clear(const std::function<void(T*)>& deleter);
-
- private:
- using base_list::count_;
- using base_list::head_;
- using base_list::tail_;
-};
-
-constexpr size_t kUseDefaultLinkOffset = std::numeric_limits<size_t>::max();
-
-// IntrusiveList for raw pointers with a specified offset.
-// Use this if there are multiple links within a type.
-//
-// Usage:
-// struct MyType {
-// IntrusiveListLink link_a;
-// IntrusiveListLink link_b;
-// };
-// IntrusiveList<MyType, offsetof(MyType, link_a)> list_a;
-// IntrusiveList<MyType, offsetof(MyType, link_b)> list_b;
-template <typename T, size_t kOffset = kUseDefaultLinkOffset>
-class IntrusiveList : public IntrusiveListUnrefBase<T, kOffset> {};
-
-// IntrusiveList for raw pointers.
-// Items added to the list will not be owned by the list and must be freed by
-// the caller.
-//
-// Usage:
-// struct MyType : public IntrusiveListBase<void> {};
-// IntrusiveList<MyType> list;
-// auto* p = new MyType();
-// list.push_back(p); // p is not retained and won't be freed!
-// delete p;
-template <typename T>
-class IntrusiveList<T, kUseDefaultLinkOffset>
- : public IntrusiveListUnrefBase<T, offsetof(T, link)> {};
-
-// -- implementation --
-
-namespace impl {
-
-// Maps an IntrusiveListLink to its containing type T.
-template <typename T, size_t kOffset>
-static inline T* LinkToT(IntrusiveListLink* link) {
- if (link) {
- return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(link) - kOffset);
- } else {
- return nullptr;
- }
-}
-
-// Maps a containing type T to its IntrusiveListLink.
-template <typename T, size_t kOffset>
-static inline IntrusiveListLink* TToLink(T* value) {
- if (value) {
- return reinterpret_cast<IntrusiveListLink*>(
- reinterpret_cast<uintptr_t>(value) + kOffset);
- } else {
- return nullptr;
- }
-}
-
-} // namespace impl
-
-template <typename T, size_t kOffset, bool kForward>
-IntrusiveListIterator<T, kOffset, kForward>&
-IntrusiveListIterator<T, kOffset, kForward>::operator++() {
- if (current_) {
- current_ = kForward ? current_->next : current_->prev;
- }
- return *this;
-}
-
-template <typename T, size_t kOffset, bool kForward>
-IntrusiveListIterator<T, kOffset, kForward>
-IntrusiveListIterator<T, kOffset, kForward>::operator++(int) {
- self_type tmp(current_);
- operator++();
- return tmp;
-}
-
-template <typename T, size_t kOffset, bool kForward>
-IntrusiveListIterator<T, kOffset, kForward>&
-IntrusiveListIterator<T, kOffset, kForward>::operator--() {
- if (current_) {
- current_ = kForward ? current_->prev : current_->next;
- }
- return *this;
-}
-
-template <typename T, size_t kOffset, bool kForward>
-IntrusiveListIterator<T, kOffset, kForward>
-IntrusiveListIterator<T, kOffset, kForward>::operator--(int) {
- self_type tmp(current_);
- operator--();
- return tmp;
-}
-
-template <typename T, size_t kOffset, bool kForward>
-bool IntrusiveListIterator<T, kOffset, kForward>::operator==(
- const self_type& rhs) const {
- return rhs.current_ == current_;
-}
-
-template <typename T, size_t kOffset, bool kForward>
-bool IntrusiveListIterator<T, kOffset, kForward>::operator!=(
- const self_type& rhs) const {
- return !operator==(rhs);
-}
-
-template <typename T, size_t kOffset, bool kForward>
-T* IntrusiveListIterator<T, kOffset, kForward>::operator*() const {
- return impl::LinkToT<T, kOffset>(current_);
-}
-
-template <typename T, size_t kOffset>
-void IntrusiveListUnrefBase<T, kOffset>::clear(
- const std::function<void(T*)>& deleter) {
- auto* link = head_;
- while (link) {
- auto* next = link->next;
- link->prev = link->next = nullptr;
- deleter(impl::LinkToT<T, kOffset>(link));
- link = next;
- }
- head_ = tail_ = nullptr;
- count_ = 0;
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-void IntrusiveListBase<T, IteratorT, ReverseIteratorT,
- kOffset>::CheckCorrectness() const {
-#if defined(IREE_PARANOID_INTRUSIVE_LIST)
- auto* link = head_;
- IntrusiveListLink* previous = nullptr;
- size_t actual_count = 0;
- while (link) {
- ++actual_count;
- if (!link->prev) {
- DCHECK_EQ(link, head_);
- }
- if (!link->next) {
- DCHECK_EQ(link, tail_);
- }
- DCHECK_EQ(link->prev, previous);
- previous = link;
- link = link->next;
- }
- DCHECK_EQ(actual_count, count_);
-#endif // IREE_PARANOID_INTRUSIVE_LIST
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-bool IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::contains(
- T* value) const {
- if (!value) return false;
- // TODO(benvanik): faster way of checking? requires list ptr in link?
- auto* needle = impl::TToLink<T, kOffset>(value);
- auto* link = head_;
- while (link) {
- if (link == needle) {
- return true;
- }
- link = link->next;
- }
- return false;
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::merge_from(
- self_type* other_list) {
- if (tail_) {
- tail_->next = other_list->head_;
- }
- if (other_list->head_) {
- other_list->head_->prev = tail_;
- }
- if (!head_) {
- head_ = other_list->head_;
- }
- tail_ = other_list->tail_;
-
- other_list->head_ = nullptr;
- other_list->tail_ = nullptr;
-
- count_ += other_list->count_;
- other_list->count_ = 0;
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::clear() {
- auto* link = head_;
- while (link) {
- auto* next = link->next;
- link->prev = link->next = nullptr;
- OnDeallocate(impl::LinkToT<T, kOffset>(link));
- link = next;
- }
- head_ = tail_ = nullptr;
- count_ = 0;
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::next(
- T* value) const {
- if (!value) {
- return nullptr;
- }
- auto* link = impl::TToLink<T, kOffset>(value);
- return impl::LinkToT<T, kOffset>(link->next);
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::previous(
- T* value) const {
- if (!value) {
- return nullptr;
- }
- auto* link = impl::TToLink<T, kOffset>(value);
- return impl::LinkToT<T, kOffset>(link->prev);
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::front()
- const {
- return impl::LinkToT<T, kOffset>(head_);
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::push_front(
- T* value) {
- DCHECK(value);
- auto* link = impl::TToLink<T, kOffset>(value);
- DCHECK(!link->next);
- DCHECK(!link->prev);
- link->next = head_;
- link->prev = nullptr;
- head_ = link;
- if (link->next) {
- link->next->prev = link;
- }
- if (!tail_) {
- tail_ = link;
- }
- ++count_;
- OnAdd(value);
- CheckCorrectness();
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::pop_front() {
- DCHECK(head_);
- auto* link = head_;
- if (link) {
- head_ = head_->next;
- link->next = link->prev = nullptr;
- if (head_) {
- head_->prev = nullptr;
- }
- if (link == tail_) {
- tail_ = nullptr;
- }
- --count_;
- OnDeallocate(impl::LinkToT<T, kOffset>(link));
- }
- CheckCorrectness();
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::back()
- const {
- return impl::LinkToT<T, kOffset>(tail_);
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::push_back(
- T* value) {
- DCHECK(value);
- auto* link = impl::TToLink<T, kOffset>(value);
- DCHECK(!link->next);
- DCHECK(!link->prev);
- link->prev = tail_;
- link->next = nullptr;
- tail_ = link;
- if (link->prev) {
- link->prev->next = link;
- }
- if (!head_) {
- head_ = link;
- }
- ++count_;
- OnAdd(value);
- CheckCorrectness();
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::pop_back() {
- DCHECK(tail_);
- auto* link = tail_;
- if (link) {
- tail_ = tail_->prev;
- link->next = link->prev = nullptr;
- if (tail_) {
- tail_->next = nullptr;
- }
- if (link == head_) {
- head_ = nullptr;
- }
- --count_;
- OnDeallocate(impl::LinkToT<T, kOffset>(link));
- }
- CheckCorrectness();
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::insert(
- T* position, T* value) {
- DCHECK(value);
- auto* link = impl::TToLink<T, kOffset>(value);
- auto* position_link = impl::TToLink<T, kOffset>(position);
- DCHECK(!link->next);
- DCHECK(!link->prev);
-
- if (position_link == head_) {
- push_front(value);
- } else if (position_link == nullptr) {
- push_back(value);
- } else {
- link->next = position_link;
- link->prev = position_link->prev;
- position_link->prev->next = link;
- position_link->prev = link;
- ++count_;
- OnAdd(value);
- }
- CheckCorrectness();
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::erase(T* value) {
- if (!value) {
- return nullptr;
- }
- auto* link = impl::TToLink<T, kOffset>(value);
- if (link->prev) {
- DCHECK_NE(link, head_);
- link->prev->next = link->next;
- } else {
- DCHECK_EQ(link, head_);
- head_ = link->next;
- }
- if (link->next) {
- DCHECK_NE(link, tail_);
- link->next->prev = link->prev;
- } else {
- DCHECK_EQ(link, tail_);
- tail_ = link->prev;
- }
- auto* next = link->next;
- link->next = link->prev = nullptr;
- --count_;
- OnDeallocate(value);
- CheckCorrectness();
- return impl::LinkToT<T, kOffset>(next);
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-IteratorT IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::erase(
- const IteratorT& it) {
- return IteratorT(impl::TToLink<T, kOffset>(erase(*it)));
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-ReverseIteratorT IntrusiveListBase<T, IteratorT, ReverseIteratorT,
- kOffset>::erase(const ReverseIteratorT& it) {
- return ReverseIteratorT(impl::TToLink<T, kOffset>(erase(*it)));
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::replace(
- T* old_value, T* new_value) {
- DCHECK(old_value);
- DCHECK(new_value);
- DCHECK_NE(old_value, new_value);
- auto* old_link = impl::TToLink<T, kOffset>(old_value);
- auto* new_link = impl::TToLink<T, kOffset>(new_value);
- new_link->next = old_link->next;
- new_link->prev = old_link->prev;
- if (new_link->prev) {
- new_link->prev->next = new_link;
- } else {
- head_ = new_link;
- }
- if (new_link->next) {
- new_link->next->prev = new_link;
- } else {
- tail_ = new_link;
- }
- old_link->next = old_link->prev = nullptr;
- OnAdd(new_value);
- OnDeallocate(old_value);
- CheckCorrectness();
-}
-
-template <typename T, typename IteratorT, typename ReverseIteratorT,
- size_t kOffset>
-void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::sort(
- bool (*compare_fn)(T* a, T* b)) {
- if (empty()) {
- // Empty list no-op.
- return;
- }
- // Repeatedly run until the list is sorted.
- int in_size = 1;
- while (true) {
- IntrusiveListLink* p = head_;
- IntrusiveListLink* q = nullptr;
- IntrusiveListLink* e = nullptr;
- IntrusiveListLink* tail = nullptr;
- head_ = nullptr;
- tail_ = nullptr;
- // Repeatedly merge sublists.
- int merge_count = 0;
- do {
- ++merge_count;
- q = p;
- // Determine the size of the first part and find the second.
- int p_size = 0;
- for (int i = 0; i < in_size; ++i) {
- ++p_size;
- q = q->next;
- if (!q) {
- break;
- }
- }
- // Merge the two lists (if we have two).
- int q_size = in_size;
- while (p_size > 0 || (q_size > 0 && q)) {
- if (p_size == 0) {
- // p is empty; e must come from q.
- e = q;
- q = q->next;
- --q_size;
- } else if (q_size == 0 || !q) {
- // q is empty; e must come from p.
- e = p;
- p = p->next;
- --p_size;
- } else if (compare_fn(impl::LinkToT<T, kOffset>(p),
- impl::LinkToT<T, kOffset>(q))) {
- // p <= q; e must come from p.
- e = p;
- p = p->next;
- --p_size;
- } else {
- // q < p; e must come from q.
- e = q;
- q = q->next;
- --q_size;
- }
- // Append e to the merged list.
- if (tail) {
- tail->next = e;
- } else {
- head_ = e;
- }
- e->prev = tail;
- tail = e;
- }
- p = q;
- } while (p);
- tail->next = nullptr;
- if (merge_count <= 1) {
- // List is now sorted; stash and return.
- tail_ = tail;
- CheckCorrectness();
- return;
- }
- // Run merge again with larger lists.
- in_size *= 2;
- }
-}
-
-} // namespace iree
-
-// Specializations:
-#include "iree/base/intrusive_list_ref_ptr.inc"
-#include "iree/base/intrusive_list_unique_ptr.inc"
-
-#endif // IREE_BASE_INTRUSIVE_LIST_H_
diff --git a/iree/base/intrusive_list_ref_ptr.inc b/iree/base/intrusive_list_ref_ptr.inc
deleted file mode 100644
index cdc7ad5..0000000
--- a/iree/base/intrusive_list_ref_ptr.inc
+++ /dev/null
@@ -1,174 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// IWYU pragma: private, include "iree/base/intrusive_list.h"
-
-#ifndef IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_
-#define IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_
-
-#include <cstddef>
-#include <iterator>
-
-#include "iree/base/intrusive_list.h"
-#include "iree/base/ref_ptr.h"
-
-namespace iree {
-
-// Iterator for an IntrusiveList specialized to ref_ptr.
-template <typename T, size_t kOffset, bool kForward>
-class IntrusiveListRefPtrIterator
- : public std::iterator<std::input_iterator_tag, int> {
- public:
- using self_type = IntrusiveListRefPtrIterator<T, kOffset, kForward>;
-
- explicit IntrusiveListRefPtrIterator(IntrusiveListLink* current)
- : current_(current) {}
- self_type& operator++() {
- if (current_) {
- current_ = kForward ? current_->next : current_->prev;
- }
- return *this;
- }
- self_type operator++(int) {
- self_type tmp(current_);
- operator++();
- return tmp;
- }
- self_type& operator--() {
- if (current_) {
- current_ = kForward ? current_->prev : current_->next;
- }
- return *this;
- }
- self_type operator--(int) {
- self_type tmp(current_);
- operator--();
- return tmp;
- }
- bool operator==(const self_type& rhs) const {
- return rhs.current_ == current_;
- }
- bool operator!=(const self_type& rhs) const { return !operator==(rhs); }
- ref_ptr<T> operator*() const {
- return add_ref(impl::LinkToT<T, kOffset>(current_));
- }
-
- protected:
- IntrusiveListLink* current_;
-};
-
-// Specialized IntrusiveListBase for ref_ptr types.
-// This makes the list methods accept/return ref_ptrs and iterate with
-// a ref_ptr iterator.
-template <typename T, size_t kOffset>
-class IntrusiveListRefPtrBase
- : private IntrusiveListBase<
- T, IntrusiveListRefPtrIterator<T, kOffset, true>,
- IntrusiveListRefPtrIterator<T, kOffset, false>, kOffset> {
- public:
- using IteratorT = IntrusiveListRefPtrIterator<T, kOffset, true>;
- using ReverseIteratorT = IntrusiveListRefPtrIterator<T, kOffset, false>;
- using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>;
-
- IntrusiveListRefPtrBase() = default;
-
- using base_list::empty;
- using base_list::size;
-
- using base_list::contains;
- bool contains(const ref_ptr<T>& value) const {
- return base_list::contains(value.get());
- }
-
- using base_list::clear;
-
- using base_list::begin;
- using base_list::end;
- using base_list::rbegin;
- using base_list::rend;
-
- inline ref_ptr<T> next(const ref_ptr<T>& value) const {
- return add_ref(base_list::next(value.get()));
- }
- inline ref_ptr<T> next(T* value) const {
- return add_ref(base_list::next(value));
- }
-
- inline ref_ptr<T> previous(const ref_ptr<T>& value) const {
- return add_ref(base_list::previous(value.get()));
- }
- inline ref_ptr<T> previous(T* value) const {
- return add_ref(base_list::previous(value));
- }
-
- // Performance: O(1)
- inline ref_ptr<T> front() const {
- return add_ref(impl::LinkToT<T, kOffset>(head_));
- }
-
- void push_front(const ref_ptr<T>& value) {
- base_list::push_front(value.get());
- }
-
- using base_list::pop_front;
-
- // Performance: O(1)
- inline ref_ptr<T> back() const {
- return add_ref(impl::LinkToT<T, kOffset>(tail_));
- }
-
- void push_back(const ref_ptr<T>& value) { base_list::push_back(value.get()); }
-
- using base_list::pop_back;
-
- void insert(const IteratorT& it, const ref_ptr<T>& value) {
- base_list::insert(it, value.get());
- }
-
- using base_list::erase;
-
- ref_ptr<T> erase(const ref_ptr<T>& value) {
- return add_ref(base_list::erase(value.get()));
- }
-
- void replace(const ref_ptr<T>& old_value, const ref_ptr<T>& new_value) {
- base_list::replace(old_value.get(), new_value.get());
- }
- void replace(T* old_value, const ref_ptr<T>& new_value) {
- base_list::replace(old_value, new_value.get());
- }
-
- using base_list::sort;
-
- private:
- void OnAdd(T* value) override { value->AddReference(); }
- void OnRemove(T* value) override { value->ReleaseReference(); }
- void OnDeallocate(T* value) override { value->ReleaseReference(); }
-
- using base_list::count_;
- using base_list::head_;
- using base_list::tail_;
-};
-
-template <typename U, size_t kOffset>
-class IntrusiveList<ref_ptr<U>, kOffset>
- : public IntrusiveListRefPtrBase<U, kOffset> {};
-
-template <typename U>
-class IntrusiveList<ref_ptr<U>, kUseDefaultLinkOffset>
- : public IntrusiveListRefPtrBase<U, offsetof(U, link)> {};
-
-} // namespace iree
-
-#endif // IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_
diff --git a/iree/base/intrusive_list_ref_ptr_test.cc b/iree/base/intrusive_list_ref_ptr_test.cc
deleted file mode 100644
index b94da7d..0000000
--- a/iree/base/intrusive_list_ref_ptr_test.cc
+++ /dev/null
@@ -1,100 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "gtest/gtest.h"
-#include "iree/base/intrusive_list.h"
-
-namespace iree {
-namespace {
-
-static int alloc_count = 0;
-struct RefCountedType : public RefObject<RefCountedType> {
- IntrusiveListLink link;
- RefCountedType() { ++alloc_count; }
- ~RefCountedType() { --alloc_count; }
- static void Deallocate(RefCountedType* value) { delete value; }
- using RefObject<RefCountedType>::counter_;
-};
-
-TEST(IntrusiveListRefPtrTest, PushAndClear) {
- alloc_count = 0;
- IntrusiveList<ref_ptr<RefCountedType>> list;
- EXPECT_EQ(0, alloc_count);
- list.push_back(make_ref<RefCountedType>());
- EXPECT_EQ(1, alloc_count);
- EXPECT_NE(nullptr, list.front());
- EXPECT_EQ(2, list.front()->counter_);
- list.clear();
- EXPECT_EQ(0, alloc_count);
-}
-
-TEST(IntrusiveListRefPtrTest, PushPop) {
- alloc_count = 0;
- IntrusiveList<ref_ptr<RefCountedType>> list;
- list.push_back(make_ref<RefCountedType>());
- EXPECT_EQ(1, alloc_count);
- list.push_back(make_ref<RefCountedType>());
- EXPECT_EQ(2, alloc_count);
- EXPECT_NE(list.front(), list.back());
- list.pop_back();
- EXPECT_EQ(1, alloc_count);
- list.pop_front();
- EXPECT_EQ(0, alloc_count);
-}
-
-TEST(IntrusiveListRefPtrTest, PushErase) {
- alloc_count = 0;
- IntrusiveList<ref_ptr<RefCountedType>> list;
- list.push_back(make_ref<RefCountedType>());
- EXPECT_EQ(1, alloc_count);
- EXPECT_NE(nullptr, list.front());
- EXPECT_EQ(2, list.front()->counter_);
- auto item = list.front();
- EXPECT_NE(nullptr, item.get());
- EXPECT_EQ(3, list.front()->counter_);
- EXPECT_EQ(1, alloc_count);
- list.erase(item);
- EXPECT_EQ(1, alloc_count);
- item.reset();
- EXPECT_EQ(0, alloc_count);
-}
-
-TEST(IntrusiveListRefPtrTest, PushReplace) {
- alloc_count = 0;
- IntrusiveList<ref_ptr<RefCountedType>> list;
- list.push_back(make_ref<RefCountedType>());
- EXPECT_EQ(1, alloc_count);
- list.replace(list.front(), make_ref<RefCountedType>());
- EXPECT_EQ(1, alloc_count);
- list.clear();
- EXPECT_EQ(0, alloc_count);
-}
-
-TEST(IntrusiveListRefPtrTest, Iteration) {
- alloc_count = 0;
- IntrusiveList<ref_ptr<RefCountedType>> list;
- list.push_back(make_ref<RefCountedType>());
- list.push_back(make_ref<RefCountedType>());
- list.push_back(make_ref<RefCountedType>());
- EXPECT_EQ(3, alloc_count);
- for (auto item : list) {
- const ref_ptr<RefCountedType>& item_ref = item;
- EXPECT_NE(nullptr, item_ref.get());
- }
- list.clear();
- EXPECT_EQ(0, alloc_count);
-}
-
-} // namespace
-} // namespace iree
diff --git a/iree/base/intrusive_list_test.cc b/iree/base/intrusive_list_test.cc
deleted file mode 100644
index 37216ce..0000000
--- a/iree/base/intrusive_list_test.cc
+++ /dev/null
@@ -1,523 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/intrusive_list.h"
-
-#include <algorithm>
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace iree {
-namespace {
-
-using ::testing::ElementsAre;
-
-struct Item {
- size_t some_data_0;
- IntrusiveListLink list_a;
- size_t some_data_1;
- IntrusiveListLink list_b;
- size_t some_data_2;
- int value;
-
- static const size_t kToken = 0xDEADBEEF;
- explicit Item(int value)
- : some_data_0(kToken),
- some_data_1(kToken),
- some_data_2(kToken),
- value(value) {}
- bool is_valid() {
- return some_data_0 == kToken && some_data_1 == kToken &&
- some_data_2 == kToken;
- }
-};
-
-template <typename T, size_t V>
-std::vector<T*> ExtractItems(const IntrusiveList<T, V>& list) {
- std::vector<T*> items;
- for (auto* item : list) {
- items.push_back(item);
- }
- return items;
-}
-
-template <typename T, size_t V>
-std::vector<int> ExtractValues(const IntrusiveList<T, V>& list) {
- std::vector<int> values;
- for (auto* item : list) {
- values.push_back(item->value);
- }
- return values;
-}
-
-template <typename T, size_t V>
-std::vector<int> ExtractValuesMutable(const IntrusiveList<T, V>& list) {
- std::vector<int> values;
- for (auto* item : list) {
- values.push_back(item->value);
- }
- return values;
-}
-
-TEST(IntrusiveListTest, PushPopItems) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
- Item item4(4);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items;
- EXPECT_TRUE(items.empty());
- EXPECT_EQ(items.size(), 0u);
- EXPECT_EQ(items.front(), nullptr);
- EXPECT_EQ(items.back(), nullptr);
- EXPECT_TRUE(items.begin() == items.end());
- items.push_front(&item1);
- EXPECT_FALSE(items.empty());
- EXPECT_EQ(items.size(), 1u);
- EXPECT_EQ(items.front(), &item1);
- EXPECT_EQ(items.back(), &item1);
- EXPECT_FALSE(items.begin() == items.end());
- items.push_front(&item2);
- EXPECT_EQ(items.size(), 2u);
- EXPECT_EQ(items.front(), &item2);
- EXPECT_EQ(items.back(), &item1);
- items.push_front(&item3);
- EXPECT_EQ(items.size(), 3u);
- EXPECT_EQ(items.front(), &item3);
- EXPECT_EQ(items.back(), &item1);
- EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2, 1));
-
- items.push_back(&item4);
- EXPECT_EQ(items.size(), 4u);
- EXPECT_EQ(items.front(), &item3);
- EXPECT_EQ(items.back(), &item4);
- EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2, 1, 4));
-
- items.pop_front();
- EXPECT_EQ(items.size(), 3u);
- EXPECT_EQ(items.front(), &item2);
- EXPECT_EQ(items.back(), &item4);
- EXPECT_THAT(ExtractValues(items), ElementsAre(2, 1, 4));
-
- items.pop_back();
- EXPECT_EQ(items.size(), 2u);
- EXPECT_EQ(items.front(), &item2);
- EXPECT_EQ(items.back(), &item1);
- EXPECT_THAT(ExtractValues(items), ElementsAre(2, 1));
-
- items.pop_back();
- items.pop_front();
- EXPECT_TRUE(items.empty());
- EXPECT_EQ(items.size(), 0u);
- EXPECT_EQ(items.front(), nullptr);
- EXPECT_EQ(items.back(), nullptr);
- EXPECT_TRUE(items.begin() == items.end());
-
- EXPECT_TRUE(item1.is_valid());
- EXPECT_TRUE(item2.is_valid());
- EXPECT_TRUE(item3.is_valid());
- EXPECT_TRUE(item4.is_valid());
-}
-
-TEST(IntrusiveListTest, Contains) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
- Item item4(4);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items;
- items.push_back(&item1);
- items.push_back(&item2);
- items.push_back(&item3);
- // item4 omitted.
-
- EXPECT_TRUE(items.contains(&item1));
- EXPECT_TRUE(items.contains(&item2));
- EXPECT_TRUE(items.contains(&item3));
- EXPECT_FALSE(items.contains(&item4));
-
- EXPECT_FALSE(items.contains(nullptr));
-}
-
-TEST(IntrusiveListTest, MergeFrom) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
- Item item4(4);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items0;
- items0.push_back(&item1);
- items0.push_back(&item2);
- items0.push_back(&item3);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items1;
- items1.push_back(&item4);
-
- items0.merge_from(&items1);
- EXPECT_THAT(ExtractValues(items0), ElementsAre(1, 2, 3, 4));
- EXPECT_TRUE(items1.empty());
-}
-
-TEST(IntrusiveListTest, MergeFromEmpty) {
- IntrusiveList<Item, offsetof(Item, list_a)> items0;
- IntrusiveList<Item, offsetof(Item, list_a)> items1;
- items0.merge_from(&items1);
-}
-
-TEST(IntrusiveListTest, MergeFromAll) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
- Item item4(4);
- IntrusiveList<Item, offsetof(Item, list_a)> items0;
- items0.push_back(&item1);
- items0.push_back(&item2);
- items0.push_back(&item3);
- items0.push_back(&item4);
- IntrusiveList<Item, offsetof(Item, list_a)> items1;
-
- // Merge all items from items1 into items0. Shouldn't change anything.
- items0.merge_from(&items1);
- EXPECT_THAT(ExtractValues(items0), ElementsAre(1, 2, 3, 4));
- EXPECT_TRUE(items1.empty());
-
- // Merge all items from items0 into items1. Should move everything.
- items1.merge_from(&items0);
- EXPECT_TRUE(items0.empty());
- EXPECT_THAT(ExtractValues(items1), ElementsAre(1, 2, 3, 4));
-}
-
-TEST(IntrusiveListTest, Erase) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
- Item item4(4);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items;
- items.push_back(&item1);
- items.push_back(&item2);
- items.push_back(&item3);
- items.push_back(&item4);
-
- EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
- items.erase(&item3);
- EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 4));
- items.erase(&item1);
- EXPECT_THAT(ExtractValues(items), ElementsAre(2, 4));
- items.erase(&item4);
- EXPECT_THAT(ExtractValues(items), ElementsAre(2));
- items.erase(&item2);
- EXPECT_TRUE(items.empty());
-
- items.push_back(&item1);
- items.push_back(&item2);
- items.push_back(&item3);
- items.push_back(&item4);
-
- EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
- auto it = items.begin();
- items.erase(it);
- EXPECT_THAT(ExtractValues(items), ElementsAre(2, 3, 4));
- it = items.end();
- items.erase(it);
- EXPECT_THAT(ExtractValues(items), ElementsAre(2, 3, 4));
- it = items.begin();
- ++it;
- items.erase(it);
- EXPECT_THAT(ExtractValues(items), ElementsAre(2, 4));
-
- it = items.begin();
- it = items.erase(it);
- EXPECT_EQ(4, (*it)->value);
- EXPECT_THAT(ExtractValues(items), ElementsAre(4));
- it = items.erase(it);
- EXPECT_TRUE(items.empty());
- EXPECT_EQ(items.end(), it);
-}
-
-TEST(IntrusiveListTest, MultipleLists) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
- Item item4(4);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items_a;
- IntrusiveList<Item, offsetof(Item, list_b)> items_b;
- items_a.push_back(&item1);
- items_a.push_back(&item2);
- items_a.push_back(&item3);
- items_a.push_back(&item4);
- items_b.push_front(&item1);
- items_b.push_front(&item2);
- items_b.push_front(&item3);
- items_b.push_front(&item4);
- EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3, 4));
- EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 3, 2, 1));
- items_b.erase(&item3);
- EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3, 4));
- EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 2, 1));
- items_a.pop_back();
- EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3));
- EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 2, 1));
-}
-
-TEST(IntrusiveListTest, MutableIterator) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
- Item item4(4);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items;
- items.push_back(&item4);
- items.push_front(&item1);
- items.push_front(&item2);
- items.push_front(&item3);
-
- EXPECT_THAT(ExtractValuesMutable(items), ElementsAre(3, 2, 1, 4));
-}
-
-struct BaseType {
- explicit BaseType(int value) : value(value) {}
- int value;
- IntrusiveListLink base_link;
-};
-struct SubType : public BaseType {
- explicit SubType(int value) : BaseType(value) {}
- IntrusiveListLink sub_link;
-};
-TEST(IntrusiveListTest, SimpleType) {
- SubType item1(1);
- SubType item2(2);
- SubType item3(3);
- SubType item4(4);
-
- IntrusiveList<BaseType, offsetof(BaseType, base_link)> items_a;
- items_a.push_front(&item1);
- items_a.push_front(&item2);
- items_a.push_front(&item3);
- items_a.push_front(&item4);
- EXPECT_THAT(ExtractValues(items_a), ElementsAre(4, 3, 2, 1));
-
- IntrusiveList<SubType, offsetof(SubType, sub_link)> items_b;
- items_b.push_back(&item1);
- items_b.push_back(&item2);
- items_b.push_back(&item3);
- items_b.push_back(&item4);
- EXPECT_THAT(ExtractValues(items_b), ElementsAre(1, 2, 3, 4));
-}
-
-struct AbstractType {
- explicit AbstractType(int value) : value(value) {}
- virtual ~AbstractType() = default;
- virtual int DoSomething() = 0;
- int value;
- IntrusiveListLink base_link;
-};
-struct ImplType : public AbstractType {
- explicit ImplType(int value) : AbstractType(value) {}
- int DoSomething() override { return value; }
- IntrusiveListLink sub_link;
-};
-
-TEST(IntrusiveListTest, ComplexType) {
- ImplType item1(1);
- ImplType item2(2);
- ImplType item3(3);
- ImplType item4(4);
-
- IntrusiveList<AbstractType, offsetof(AbstractType, base_link)> items_a;
- items_a.push_front(&item1);
- items_a.push_front(&item2);
- items_a.push_front(&item3);
- items_a.push_front(&item4);
- EXPECT_THAT(ExtractValues(items_a), ElementsAre(4, 3, 2, 1));
-
- IntrusiveList<ImplType, offsetof(ImplType, sub_link)> items_b;
- items_b.push_back(&item1);
- items_b.push_back(&item2);
- items_b.push_back(&item3);
- items_b.push_back(&item4);
- EXPECT_THAT(ExtractValues(items_b), ElementsAre(1, 2, 3, 4));
-}
-
-bool Comparison(Item* a, Item* b) { return a->value < b->value; }
-
-TEST(IntrusiveListTest, Inserting) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
- Item item4(4);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items;
- items.insert(items.end(), &item3);
- items.insert(items.begin(), &item1);
- items.insert(items.end(), &item4);
-
- auto pos = std::upper_bound(items.begin(), items.end(), &item2, Comparison);
- items.insert(pos, &item2);
-
- EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
-}
-
-// TODO(benvanik): test reverse iteration.
-
-TEST(IntrusiveListTest, NextPrevious) {
- Item item1(1);
- Item item2(2);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items;
- EXPECT_EQ(nullptr, items.previous(nullptr));
- EXPECT_EQ(nullptr, items.next(nullptr));
-
- items.push_back(&item1);
- EXPECT_EQ(nullptr, items.previous(&item1));
- EXPECT_EQ(nullptr, items.next(&item1));
-
- items.push_back(&item2);
- EXPECT_EQ(nullptr, items.previous(&item1));
- EXPECT_EQ(&item2, items.next(&item1));
- EXPECT_EQ(&item1, items.previous(&item2));
- EXPECT_EQ(nullptr, items.next(&item2));
-}
-
-TEST(IntrusiveListTest, Clear) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
- Item item4(4);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items;
-
- // Empty clear.
- items.clear();
- EXPECT_TRUE(items.empty());
-
- // 1 item clear.
- items.push_back(&item1);
- items.clear();
- EXPECT_TRUE(items.empty());
-
- // Multi-item clear.
- items.push_back(&item1);
- items.push_back(&item2);
- items.push_back(&item3);
- items.push_back(&item4);
- items.clear();
- EXPECT_TRUE(items.empty());
-}
-
-TEST(IntrusiveListTest, ClearDeleter) {
- Item item1(1);
- Item item2(2);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items;
-
- // No-op first.
- int delete_count = 0;
- items.clear([&](Item* item) { ++delete_count; });
- EXPECT_EQ(0, delete_count);
-
- // Now with items.
- items.push_back(&item1);
- items.push_back(&item2);
- items.clear([&](Item* item) { ++delete_count; });
- EXPECT_EQ(2, delete_count);
- EXPECT_TRUE(items.empty());
-}
-
-TEST(IntrusiveListTest, Replace) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items;
- items.push_back(&item1);
- items.push_back(&item2);
-
- items.replace(&item1, &item3);
- EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2));
- EXPECT_FALSE(items.contains(&item1));
- items.replace(&item2, &item1);
- EXPECT_THAT(ExtractValues(items), ElementsAre(3, 1));
- EXPECT_FALSE(items.contains(&item2));
-}
-
-TEST(IntrusiveListTest, Sort) {
- Item item1(1);
- Item item2(2);
- Item item3(3);
- Item item4(4);
-
- IntrusiveList<Item, offsetof(Item, list_a)> items;
-
- // Empty sort.
- items.sort([](Item* a, Item* b) { return a->value < b->value; });
-
- // Single item sort.
- items.clear();
- items.push_back(&item1);
- items.sort([](Item* a, Item* b) { return a->value < b->value; });
- EXPECT_THAT(ExtractValues(items), ElementsAre(1));
-
- // Already sorted.
- items.clear();
- items.push_back(&item1);
- items.push_back(&item2);
- items.push_back(&item3);
- items.push_back(&item4);
- items.sort([](Item* a, Item* b) { return a->value < b->value; });
- EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
-
- // Reverse.
- items.clear();
- items.push_back(&item4);
- items.push_back(&item3);
- items.push_back(&item2);
- items.push_back(&item1);
- items.sort([](Item* a, Item* b) { return a->value < b->value; });
- EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
-
- // Random.
- items.clear();
- items.push_back(&item2);
- items.push_back(&item4);
- items.push_back(&item1);
- items.push_back(&item3);
- items.sort([](Item* a, Item* b) { return a->value < b->value; });
- EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4));
-
- // Stability.
- Item item1a(1);
- Item item2a(2);
- items.clear();
- items.push_back(&item2);
- items.push_back(&item4);
- items.push_back(&item1);
- items.push_back(&item3);
- items.push_back(&item1a);
- items.push_back(&item2a);
- items.sort([](Item* a, Item* b) { return a->value <= b->value; });
- EXPECT_THAT(ExtractValues(items), ElementsAre(1, 1, 2, 2, 3, 4));
- auto items_vector = ExtractItems(items);
- EXPECT_EQ(&item1, items_vector[0]);
- EXPECT_EQ(&item1a, items_vector[1]);
- EXPECT_EQ(&item2, items_vector[2]);
- EXPECT_EQ(&item2a, items_vector[3]);
- items.clear();
-}
-
-} // namespace
-} // namespace iree
diff --git a/iree/base/intrusive_list_unique_ptr.inc b/iree/base/intrusive_list_unique_ptr.inc
deleted file mode 100644
index 94c541d..0000000
--- a/iree/base/intrusive_list_unique_ptr.inc
+++ /dev/null
@@ -1,140 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// IWYU pragma: private, include "iree/base/intrusive_list.h"
-
-#ifndef IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_
-#define IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_
-
-#include <cstddef>
-#include <memory>
-
-#include "iree/base/intrusive_list.h"
-#include "iree/base/logging.h"
-
-namespace iree {
-
-// Specialized IntrusiveListBase for std::unique_ptr types.
-// This makes the list methods accept std::unique_ptrs and contains a special
-// take() method that takes ownership of a list item.
-template <typename T, size_t kOffset>
-class IntrusiveListUniquePtrBase
- : private IntrusiveListBase<T, IntrusiveListIterator<T, kOffset, true>,
- IntrusiveListIterator<T, kOffset, false>,
- kOffset> {
- public:
- using IteratorT = IntrusiveListIterator<T, kOffset, true>;
- using ReverseIteratorT = IntrusiveListIterator<T, kOffset, false>;
- using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>;
-
- IntrusiveListUniquePtrBase() = default;
-
- using base_list::empty;
- using base_list::size;
-
- using base_list::contains;
-
- using base_list::clear;
-
- using base_list::begin;
- using base_list::end;
- using base_list::rbegin;
- using base_list::rend;
-
- using base_list::next;
-
- using base_list::previous;
-
- using base_list::front;
-
- void push_front(std::unique_ptr<T> value) {
- base_list::push_front(value.release());
- }
-
- using base_list::pop_front;
-
- using base_list::back;
-
- void push_back(std::unique_ptr<T> value) {
- base_list::push_back(value.release());
- }
-
- using base_list::pop_back;
-
- void insert(const IteratorT& it, std::unique_ptr<T> value) {
- base_list::insert(it, value.release());
- }
-
- using base_list::erase;
-
- // Removes an item from the list at the given iterator and transfers ownership
- // to the caller.
- // Performance: O(1)
- std::unique_ptr<T> take(IteratorT& it) { // NOLINT(runtime/references)
- return take(*it);
- }
-
- // Removes an item from the list and transfers ownership to the caller.
- // Performance: O(1)
- std::unique_ptr<T> take(T* value) {
- if (!value) {
- return {nullptr};
- }
- auto* link = impl::TToLink<T, kOffset>(value);
- if (link->prev) {
- DCHECK_NE(link, head_);
- link->prev->next = link->next;
- } else {
- DCHECK_EQ(link, head_);
- head_ = link->next;
- }
- if (link->next) {
- DCHECK_NE(link, tail_);
- link->next->prev = link->prev;
- } else {
- DCHECK_EQ(link, tail_);
- tail_ = link->prev;
- }
- link->next = link->prev = nullptr;
- --count_;
- base_list::OnRemove(value);
- base_list::CheckCorrectness();
- return std::unique_ptr<T>(value);
- }
-
- void replace(T* old_value, std::unique_ptr<T> new_value) {
- base_list::replace(old_value, new_value.release());
- }
-
- using base_list::sort;
-
- private:
- void OnDeallocate(T* value) override { delete value; }
-
- using base_list::count_;
- using base_list::head_;
- using base_list::tail_;
-};
-
-template <typename U, size_t kOffset>
-class IntrusiveList<std::unique_ptr<U>, kOffset>
- : public IntrusiveListUniquePtrBase<U, kOffset> {};
-
-template <typename U>
-class IntrusiveList<std::unique_ptr<U>, kUseDefaultLinkOffset>
- : public IntrusiveListUniquePtrBase<U, offsetof(U, link)> {};
-
-} // namespace iree
-
-#endif // IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_
diff --git a/iree/base/intrusive_list_unique_ptr_test.cc b/iree/base/intrusive_list_unique_ptr_test.cc
deleted file mode 100644
index be52c61..0000000
--- a/iree/base/intrusive_list_unique_ptr_test.cc
+++ /dev/null
@@ -1,84 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "absl/memory/memory.h"
-#include "gtest/gtest.h"
-#include "iree/base/intrusive_list.h"
-
-namespace iree {
-namespace {
-
-struct AllocatedType : public IntrusiveLinkBase<void> {
- AllocatedType() { ++alloc_count; }
- ~AllocatedType() { --alloc_count; }
- static int alloc_count;
-};
-int AllocatedType::alloc_count = 0;
-
-TEST(IntrusiveListUniquePtrTest, UniquePtr) {
- AllocatedType::alloc_count = 0;
-
- // Push/clear.
- IntrusiveList<std::unique_ptr<AllocatedType>> list;
- EXPECT_EQ(0, AllocatedType::alloc_count);
- list.push_back(absl::make_unique<AllocatedType>());
- EXPECT_EQ(1, AllocatedType::alloc_count);
- EXPECT_NE(nullptr, list.front());
- list.clear();
- EXPECT_EQ(0, AllocatedType::alloc_count);
-
- // Push/pop.
- list.push_back(absl::make_unique<AllocatedType>());
- EXPECT_EQ(1, AllocatedType::alloc_count);
- EXPECT_NE(nullptr, list.front());
- for (auto item : list) {
- EXPECT_EQ(item, list.front());
- }
- list.pop_back();
- EXPECT_EQ(0, AllocatedType::alloc_count);
-
- // Push/take.
- list.push_back(absl::make_unique<AllocatedType>());
- EXPECT_EQ(1, AllocatedType::alloc_count);
- EXPECT_NE(nullptr, list.front());
- auto item = list.take(list.front());
- EXPECT_TRUE(list.empty());
- EXPECT_NE(nullptr, item.get());
- EXPECT_EQ(1, AllocatedType::alloc_count);
- item.reset();
- EXPECT_EQ(0, AllocatedType::alloc_count);
-
- // Push/replace.
- list.push_back(absl::make_unique<AllocatedType>());
- EXPECT_EQ(1, AllocatedType::alloc_count);
- list.replace(list.front(), absl::make_unique<AllocatedType>());
- EXPECT_EQ(1, AllocatedType::alloc_count);
- list.clear();
- EXPECT_EQ(0, AllocatedType::alloc_count);
-
- // Iteration.
- list.push_back(absl::make_unique<AllocatedType>());
- list.push_back(absl::make_unique<AllocatedType>());
- list.push_back(absl::make_unique<AllocatedType>());
- EXPECT_EQ(3, AllocatedType::alloc_count);
- for (auto item : list) {
- AllocatedType* item_ptr = item;
- EXPECT_NE(nullptr, item_ptr);
- }
- list.clear();
- EXPECT_EQ(0, AllocatedType::alloc_count);
-}
-
-} // namespace
-} // namespace iree
diff --git a/iree/base/logging.h b/iree/base/logging.h
deleted file mode 100644
index 8bbf57a..0000000
--- a/iree/base/logging.h
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_LOGGING_H_
-#define IREE_BASE_LOGGING_H_
-
-// Logging macros live in their own file so that we can use external versions
-// as required.
-//
-// LOG(severity) << ...;
-// Logs a message at the given severity.
-// Severity:
-// INFO Logs information text.
-// WARNING Logs a warning.
-// ERROR Logs an error.
-// FATAL Logs an error and exit(1).
-//
-// VLOG(level) << ...;
-// Logs a verbose message at the given verbosity level.
-//
-// DVLOG(level) << ...;
-// Behaves like `VLOG` in debug mode (i.e. `#ifndef NDEBUG`).
-// Otherwise, it compiles away and does nothing.
-//
-// CHECK(condition) << ...;
-// Runtime asserts that the given condition is true even in release builds.
-// It's recommended that DCHECK is used instead as too many CHECKs
-// can impact performance.
-//
-// CHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...;
-// Runtime assert the specified operation with the given values.
-//
-// DCHECK(condition) << ...;
-// Runtime asserts that the given condition is true only in non-opt builds.
-//
-// DCHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...;
-// Runtime assert the specified operation with the given values in non-opt
-// builds.
-//
-// QCHECK(condition) << ...;
-// QCHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...;
-// These behave like `CHECK` but do not print a full stack trace.
-// They are useful when problems are definitely unrelated to program flow,
-// e.g. when validating user input.
-
-#ifdef IREE_CONFIG_GOOGLE_INTERNAL
-#include "iree/base/google/logging_google.h"
-#else
-#include "iree/base/internal/logging.h"
-#endif // IREE_CONFIG_GOOGLE_INTERNAL
-
-#endif // IREE_BASE_LOGGING_H_
diff --git a/iree/base/ref_ptr.h b/iree/base/ref_ptr.h
deleted file mode 100644
index dd44034..0000000
--- a/iree/base/ref_ptr.h
+++ /dev/null
@@ -1,364 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_REF_PTR_H_
-#define IREE_BASE_REF_PTR_H_
-
-#include <atomic>
-#include <cstdint>
-#include <type_traits>
-#include <utility>
-
-#include "absl/base/attributes.h"
-#include "iree/base/logging.h"
-
-namespace iree {
-
-// Use this to get really verbose refptr logging:
-// #define IREE_VERBOSE_REF_PTR
-
-template <class T>
-class ref_ptr;
-
-// Allocates a new ref_ptr type.
-// Like make_unique, but for ref_ptr.
-//
-// Usage:
-// ref_ptr<MyType> p = make_ref<MyType>(1, 2, 3);
-template <typename T, typename... Args>
-ref_ptr<T> make_ref(Args&&... args) {
- return ref_ptr<T>(new T(std::forward<Args>(args)...));
-}
-
-// Assigns a raw pointer to a ref_ptr without adding a reference.
-//
-// Usage:
-// ref_ptr<MyType> p = assign_ref(new MyType());
-template <typename T>
-inline ref_ptr<T> assign_ref(T* value) {
- return ref_ptr<T>(value);
-}
-
-// Adds a reference to the given raw pointer.
-//
-// Usage:
-// MyType* raw_ptr = AcquirePointerFromSomewhere();
-// ref_ptr<MyType> p = add_ref(raw_ptr);
-template <typename T>
-inline ref_ptr<T> add_ref(T* value) {
- if (value) ref_ptr_add_ref(value);
- return ref_ptr<T>(value);
-}
-
-// Adds a reference to the given ref_ptr.
-//
-// Usage:
-// ref_ptr<MyType> a = make_ref<MyType>();
-// ref_ptr<MyType> p = add_ref(a);
-template <typename T>
-inline ref_ptr<T> add_ref(const ref_ptr<T>& value) {
- if (value.get()) ref_ptr_add_ref(value.get());
- return ref_ptr<T>(value.get());
-}
-
-// Reference counted pointer container.
-// This is modeled on boost::instrusive_ptr in that it requires no
-// extra storage over the pointer type and should compile to almost
-// no additional code. It also allows us to round-trip object pointers
-// through regular pointers, which is critical when having to round-trip
-// them through JNI/etc where we can't use things like unique_ptr/shared_ptr.
-//
-// ref_ptr<Foo> p1(new Foo()); // ref count 1
-// ref_ptr<Foo> p2(p1); // ref count 2
-// p1.reset(); // ref count 1
-// p2.reset(); // ref count 0, deleted
-//
-// When round-tripping the pointer through external APIs, use release():
-// ref_ptr<Foo> p1(new Foo()); // ref count 1
-// Foo* raw_p = p1.release(); // ref count 1
-// // pass to API
-// ref_ptr<Foo> p2(raw_p); // ref count 1 (don't add ref)
-// p2.reset(); // ref count 0, deleted
-//
-// See the boost intrusive_ptr docs for details of behavior:
-// http://www.boost.org/doc/libs/1_55_0/libs/smart_ptr/intrusive_ptr.html
-//
-// ref_ptr manages the target objects in a thread-safe way, though you'll want
-// to take care with objects that may have pinned threads for deallocation. If
-// you release the last reference to an object on a thread other than what it
-// was expecting you're gonna have a bad time.
-//
-// Compatible only with types that subclass RefObject or implement the following
-// methods:
-// ref_ptr_add_ref
-// ref_ptr_release_ref
-template <class T>
-class ref_ptr {
- private:
- typedef ref_ptr this_type;
- typedef T* this_type::*unspecified_bool_type;
-
- public:
- // Initializes with nullptr.
- ABSL_ATTRIBUTE_ALWAYS_INLINE ref_ptr() noexcept = default;
-
- // Initializes with nullptr so that there is no way to create an
- // uninitialized ref_ptr.
- ABSL_ATTRIBUTE_ALWAYS_INLINE ref_ptr(std::nullptr_t) noexcept {} // NOLINT
-
- // Initializes the pointer to the given value.
- // The value will not have its reference count incremented (as it is with
- // unique_ptr). Use Retain to add to the reference count.
- ABSL_ATTRIBUTE_ALWAYS_INLINE explicit ref_ptr(T* p) noexcept : px_(p) {}
-
- // Decrements the reference count of the owned pointer.
- ABSL_ATTRIBUTE_ALWAYS_INLINE ~ref_ptr() noexcept {
- if (px_) ref_ptr_release_ref(px_);
- }
-
- // No implicit ref_ptr copying allowed; use add_ref instead.
- ref_ptr(const ref_ptr&) noexcept = delete;
- ref_ptr& operator=(const ref_ptr&) noexcept = delete;
-
- // Move support to transfer ownership from one ref_ptr to another.
- ref_ptr(ref_ptr&& rhs) noexcept : px_(rhs.release()) {}
- ref_ptr& operator=(ref_ptr&& rhs) noexcept {
- if (px_ != rhs.px_) {
- if (px_) ref_ptr_release_ref(px_);
- px_ = rhs.release();
- }
- return *this;
- }
-
- // Move support from another compatible type.
- template <typename U>
- ref_ptr(ref_ptr<U>&& rhs) noexcept : px_(rhs.release()) {} // NOLINT
- template <typename U>
- ref_ptr& operator=(ref_ptr<U>&& rhs) noexcept {
- if (px_ != rhs.get()) {
- if (px_) ref_ptr_release_ref(px_);
- px_ = rhs.release();
- }
- return *this;
- }
-
- // Resets the object to nullptr and decrements the reference count, possibly
- // deleting it.
- void reset() noexcept {
- if (px_) {
- ref_ptr_release_ref(px_);
- px_ = nullptr;
- }
- }
-
- // Releases a pointer.
- // Returns the current pointer held by this object without having
- // its reference count decremented and resets the ref_ptr to empty.
- // Returns nullptr if the ref_ptr holds no value.
- // To re-wrap in a ref_ptr use either ref_ptr<T>(value) or assign().
- ABSL_ATTRIBUTE_ALWAYS_INLINE T* release() noexcept {
- T* p = px_;
- px_ = nullptr;
- return p;
- }
-
- // Assigns a pointer.
- // The pointer will be accepted by the ref_ptr and its reference count will
- // not be incremented.
- ABSL_ATTRIBUTE_ALWAYS_INLINE void assign(T* value) noexcept {
- reset();
- px_ = value;
- }
-
- // Gets the pointer referenced by this instance.
- // operator* and operator-> will assert() if there is no current object.
- constexpr T* get() const noexcept { return px_; }
- constexpr T& operator*() const noexcept { return *px_; }
- constexpr T* operator->() const noexcept { return px_; }
-
- // Support boolean expression evaluation ala unique_ptr/shared_ptr:
- // https://en.cppreference.com/w/cpp/memory/shared_ptr/operator_bool
- constexpr operator unspecified_bool_type() const noexcept {
- return px_ ? &this_type::px_ : nullptr;
- }
- // Supports unary expression evaluation.
- constexpr bool operator!() const noexcept { return !px_; }
-
- // Swap support.
- void swap(ref_ptr& rhs) { std::swap(px_, rhs.px_); }
-
- private:
- T* px_ = nullptr;
-};
-
-// Base class for reference counted objects.
-// Reference counted objects should be used with the ref_ptr pointer type.
-// As reference counting can be tricky always prefer to use unique_ptr and
-// avoid this type. Only use this when unique_ptr is not possible, such as
-// when round-tripping objects through marshaling boundaries (v8/Java) or
-// any objects that may have their lifetime tied to a garbage collected
-// object.
-//
-// Subclasses should protect their dtor so that reference counting must
-// be used.
-//
-// This is designed to avoid the need for extra vtable space or for adding
-// methods to the vtable of subclasses. This differs from the boost Pointable
-// version of this object.
-// Inspiration for this comes from Peter Weinert's Dr. Dobb's article:
-// http://www.drdobbs.com/cpp/a-base-class-for-intrusively-reference-c/229218807
-//
-// RefObjects are thread safe and may be used with ref_ptrs from multiple
-// threads.
-//
-// Subclasses may implement a custom Delete operator to handle their
-// deallocation. It should be thread safe as it may be called from any thread.
-//
-// Usage:
-// class MyRefObject : public RefObject<MyRefObject> {
-// public:
-// MyRefObject() = default;
-// // Optional; can be used to return to pool/etc - must be public:
-// static void Delete(MyRefObject* ptr) {
-// ::operator delete(ptr);
-// }
-// };
-template <class T>
-class RefObject {
- static_assert(!std::is_array<T>::value, "T must not be an array");
-
- // value is true if a static Delete(T*) function is present.
- struct has_custom_deleter {
- template <typename C>
- static auto Test(C* p) -> decltype(C::Delete(nullptr), std::true_type());
- template <typename>
- static std::false_type Test(...);
- static constexpr bool value =
- std::is_same<std::true_type, decltype(Test<T>(nullptr))>::value;
- };
-
- template <typename V, bool has_custom_deleter>
- struct delete_thunk {
- static void Delete(V* p) {
- auto ref_obj = static_cast<RefObject<V>*>(p);
- int previous_count = ref_obj->counter_.fetch_sub(1);
-#ifdef IREE_VERBOSE_REF_PTR
- LOG(INFO) << "ro-- " << typeid(V).name() << " " << p << " now "
- << previous_count - 1
- << (previous_count == 1 ? " DEAD (CUSTOM)" : "");
-#endif // IREE_VERBOSE_REF_PTR
- if (previous_count == 1) {
- // We delete type T pointer here to avoid the need for a virtual dtor.
- V::Delete(p);
- }
- }
- };
-
- template <typename V>
- struct delete_thunk<V, false> {
- static void Delete(V* p) {
- auto ref_obj = static_cast<RefObject<V>*>(p);
- int previous_count = ref_obj->counter_.fetch_sub(1);
-#ifdef IREE_VERBOSE_REF_PTR
- LOG(INFO) << "ro-- " << typeid(V).name() << " " << p << " now "
- << previous_count - 1 << (previous_count == 1 ? " DEAD" : "");
-#endif // IREE_VERBOSE_REF_PTR
- if (previous_count == 1) {
- // We delete type T pointer here to avoid the need for a virtual dtor.
- delete p;
- }
- }
- };
-
- public:
- // Adds a reference; used by ref_ptr.
- friend void ref_ptr_add_ref(T* p) {
- auto ref_obj = static_cast<RefObject*>(p);
- ++ref_obj->counter_;
-
-#ifdef IREE_VERBOSE_REF_PTR
- LOG(INFO) << "ro++ " << typeid(T).name() << " " << p << " now "
- << ref_obj->counter_;
-#endif // IREE_VERBOSE_REF_PTR
- }
-
- // Releases a reference, potentially deleting the object; used by ref_ptr.
- friend void ref_ptr_release_ref(T* p) {
- delete_thunk<T, has_custom_deleter::value>::Delete(p);
- }
-
- // Adds a reference.
- // ref_ptr should be used instead of this in most cases. This is required
- // for when interoperating with marshaling APIs.
- void AddReference() { ref_ptr_add_ref(static_cast<T*>(this)); }
-
- // Releases a reference, potentially deleting the object.
- // ref_ptr should be used instead of this in most cases. This is required
- // for when interoperating with marshaling APIs.
- void ReleaseReference() { ref_ptr_release_ref(static_cast<T*>(this)); }
-
- protected:
- RefObject() { ref_ptr_add_ref(static_cast<T*>(this)); }
- RefObject(const RefObject&) = default;
- RefObject& operator=(const RefObject&) { return *this; }
-
- std::atomic<intptr_t> counter_{0};
-};
-
-// Various comparison operator overloads.
-
-template <class T, class U>
-inline bool operator==(ref_ptr<T> const& a, ref_ptr<U> const& b) {
- return a.get() == b.get();
-}
-
-template <class T, class U>
-inline bool operator!=(ref_ptr<T> const& a, ref_ptr<U> const& b) {
- return a.get() != b.get();
-}
-
-template <class T, class U>
-inline bool operator==(ref_ptr<T> const& a, U* b) {
- return a.get() == b;
-}
-
-template <class T, class U>
-inline bool operator!=(ref_ptr<T> const& a, U* b) {
- return a.get() != b;
-}
-
-template <class T, class U>
-inline bool operator==(T* a, ref_ptr<U> const& b) {
- return a == b.get();
-}
-
-template <class T, class U>
-inline bool operator!=(T* a, ref_ptr<U> const& b) {
- return a != b.get();
-}
-
-template <class T>
-inline bool operator<(ref_ptr<T> const& a, ref_ptr<T> const& b) {
- return a.get() < b.get();
-}
-
-// Swaps the pointers of two ref_ptrs.
-template <class T>
-void swap(ref_ptr<T>& lhs, ref_ptr<T>& rhs) {
- lhs.swap(rhs);
-}
-
-} // namespace iree
-
-#endif // IREE_BASE_REF_PTR_H_
diff --git a/iree/base/ref_ptr_test.cc b/iree/base/ref_ptr_test.cc
deleted file mode 100644
index e518eb4..0000000
--- a/iree/base/ref_ptr_test.cc
+++ /dev/null
@@ -1,330 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/ref_ptr.h"
-
-#include "gtest/gtest.h"
-
-namespace iree {
-namespace {
-
-class MyType : public RefObject<MyType> {
- public:
- int x = 5;
-
- using RefObject<MyType>::counter_; // Expose for testing.
-};
-
-TEST(RefPtrTest, Construction) {
- // Empty.
- ref_ptr<MyType> n1;
- EXPECT_EQ(nullptr, n1.get());
- ref_ptr<MyType> n2(nullptr);
- EXPECT_EQ(nullptr, n2.get());
-
- // Assign a new ptr and add ref.
- MyType* a_ptr = new MyType();
- EXPECT_EQ(1, a_ptr->counter_);
- ref_ptr<MyType> a(a_ptr);
- EXPECT_EQ(1, a->counter_);
-
- // Assign existing ptr without adding a ref.
- ref_ptr<MyType> b(a_ptr);
- EXPECT_EQ(1, b->counter_);
-
- // Add a new ref.
- ref_ptr<MyType> c = add_ref(b);
- EXPECT_EQ(2, c->counter_);
-
- b.release();
-}
-
-TEST(RefPtrTest, Assign) {
- // Ok to assign nothing.
- ref_ptr<MyType> n1 = assign_ref<MyType>(nullptr);
- EXPECT_EQ(nullptr, n1.get());
-
- ref_ptr<MyType> mt = make_ref<MyType>();
- EXPECT_EQ(1, mt->counter_);
- ref_ptr<MyType> n2 = assign_ref(mt.get());
- EXPECT_EQ(1, mt->counter_);
- mt.release(); // must release, as we assigned to n2.
- EXPECT_EQ(1, n2->counter_);
- n2.reset();
-}
-
-TEST(RefPtrTest, Retain) {
- // Ok to retain nothing.
- ref_ptr<MyType> n1 = add_ref<MyType>(nullptr);
- EXPECT_EQ(nullptr, n1.get());
-
- ref_ptr<MyType> mt = make_ref<MyType>();
- EXPECT_EQ(1, mt->counter_);
- ref_ptr<MyType> n2 = add_ref(mt.get());
- EXPECT_EQ(2, mt->counter_);
- mt.reset();
- EXPECT_EQ(1, n2->counter_);
- n2.reset();
-}
-
-TEST(RefPtrTest, Reset) {
- ref_ptr<MyType> a(new MyType());
- ref_ptr<MyType> b(new MyType());
-
- // Reset to drop reference.
- ref_ptr<MyType> a_copy = add_ref(a);
- EXPECT_EQ(2, a_copy->counter_);
- a.reset();
- EXPECT_EQ(1, a_copy->counter_);
-
- // Reset via = operator.
- a = nullptr;
- EXPECT_EQ(1, a_copy->counter_);
- a = add_ref(a_copy);
- EXPECT_EQ(2, a_copy->counter_);
-
- // No-op on empty ptrs.
- ref_ptr<MyType> n;
- n.reset();
- n.assign(nullptr);
-}
-
-TEST(RefPtrTest, ReleaseAssign) {
- ref_ptr<MyType> a(new MyType());
-
- // Release a's pointer.
- MyType* a_raw_ptr = a.get();
- MyType* a_ptr = a.release();
- EXPECT_EQ(a_raw_ptr, a_ptr);
- EXPECT_EQ(nullptr, a.get());
- EXPECT_EQ(1, a_ptr->counter_);
-
- // Re-wrap in a ref_ptr.
- a.assign(a_ptr);
- EXPECT_EQ(1, a->counter_);
-
- // No-op on empty ptrs.
- ref_ptr<MyType> n;
- EXPECT_EQ(nullptr, n.release());
-}
-
-TEST(RefPtrTest, Accessors) {
- ref_ptr<MyType> a(new MyType());
- EXPECT_EQ(5, a->x);
- a->x = 100;
- EXPECT_EQ(100, a->x);
-
- MyType& ra = *a;
- ra.x = 200;
- EXPECT_EQ(200, ra.x);
-
- const MyType& cra = *a;
- EXPECT_EQ(200, cra.x);
-}
-
-TEST(RefPtrTest, BooleanExpressions) {
- ref_ptr<MyType> a(new MyType());
- ref_ptr<MyType> n;
-
- EXPECT_NE(nullptr, a.get());
- EXPECT_TRUE(a);
- EXPECT_FALSE(!a);
- EXPECT_EQ(true, static_cast<bool>(a));
-
- EXPECT_EQ(nullptr, n.get());
- EXPECT_FALSE(n);
- EXPECT_TRUE(!n);
- EXPECT_EQ(false, static_cast<bool>(n));
-}
-
-TEST(RefPtrTest, Comparisons) {
- ref_ptr<MyType> a(new MyType());
- ref_ptr<MyType> b(new MyType());
- ref_ptr<MyType> n;
-
- EXPECT_TRUE(a == a);
- EXPECT_TRUE(a == a.get());
- EXPECT_TRUE(a.get() == a);
- EXPECT_FALSE(a != a);
- EXPECT_FALSE(a != a.get());
- EXPECT_FALSE(a.get() != a);
-
- EXPECT_FALSE(a == b);
- EXPECT_FALSE(a == b.get());
- EXPECT_FALSE(a.get() == b);
- EXPECT_TRUE(a != b);
- EXPECT_TRUE(a != b.get());
- EXPECT_TRUE(a.get() != b);
-
- EXPECT_TRUE(n == n);
- EXPECT_TRUE(n == n.get());
- EXPECT_TRUE(n.get() == n);
- EXPECT_FALSE(n != n);
- EXPECT_FALSE(n != n.get());
- EXPECT_FALSE(n.get() != n);
-
- EXPECT_FALSE(a < a);
- EXPECT_TRUE(n < a);
-}
-
-TEST(RefPtrTest, Swap) {
- ref_ptr<MyType> a(new MyType());
- ref_ptr<MyType> b(new MyType());
- MyType* a_ptr = a.get();
- MyType* b_ptr = b.get();
-
- swap(a, a);
- EXPECT_EQ(a_ptr, a);
-
- swap(a, b);
- EXPECT_EQ(a_ptr, b.get());
- EXPECT_EQ(b_ptr, a.get());
-
- swap(a, b);
- EXPECT_EQ(a_ptr, a.get());
- EXPECT_EQ(b_ptr, b.get());
-
- ref_ptr<MyType> c;
- swap(a, c);
- EXPECT_EQ(a_ptr, c.get());
- EXPECT_EQ(nullptr, a.get());
-}
-
-TEST(RefPtrTest, Move) {
- auto a = make_ref<MyType>();
- auto b = make_ref<MyType>();
- ref_ptr<MyType> c;
- EXPECT_EQ(nullptr, c.get());
-
- c = std::move(a);
- EXPECT_NE(nullptr, c.get());
-
- b = std::move(c);
- EXPECT_NE(nullptr, b.get());
-}
-
-TEST(RefPtrTest, MoveCompatible) {
- struct MyBaseType : public RefObject<MyBaseType> {
- int x = 5;
- using RefObject<MyBaseType>::counter_; // Expose for testing.
- };
- struct MyTypeA : public MyBaseType {
- int a = 6;
- };
- struct MyTypeB : public MyBaseType {
- int b = 7;
- };
-
- ref_ptr<MyTypeA> a = make_ref<MyTypeA>();
- EXPECT_EQ(1, a->counter_);
- ref_ptr<MyBaseType> base = add_ref(a);
- EXPECT_EQ(a.get(), base.get());
- EXPECT_EQ(2, a->counter_);
-
- base = make_ref<MyTypeB>();
- EXPECT_EQ(1, a->counter_);
- EXPECT_EQ(1, base->counter_);
-}
-
-TEST(RefPtrTest, StackAllocation) {
- static int alloc_count = 0;
- class StackAllocationType : public RefObject<StackAllocationType> {
- public:
- StackAllocationType() { ++alloc_count; }
- ~StackAllocationType() { --alloc_count; }
- };
- {
- StackAllocationType a;
- EXPECT_EQ(1, alloc_count);
- }
- EXPECT_EQ(0, alloc_count);
-}
-
-TEST(RefPtrTest, DefaultDeleter) {
- static int alloc_count = 0;
- class DefaultDeleterType : public RefObject<DefaultDeleterType> {
- public:
- DefaultDeleterType() { ++alloc_count; }
- ~DefaultDeleterType() { --alloc_count; }
- };
-
- // Empty is ok.
- ref_ptr<DefaultDeleterType> n;
- n.reset();
-
- // Lifecycle.
- EXPECT_EQ(0, alloc_count);
- ref_ptr<DefaultDeleterType> a = make_ref<DefaultDeleterType>();
- EXPECT_EQ(1, alloc_count);
- a.reset();
- EXPECT_EQ(0, alloc_count);
-}
-
-TEST(RefPtrTest, InlineDeallocator) {
- static int alloc_count = 0;
- class CustomDeleterType : public RefObject<CustomDeleterType> {
- public:
- CustomDeleterType() { ++alloc_count; }
- static void Delete(CustomDeleterType* ptr) {
- --alloc_count;
- ::operator delete(ptr);
- }
- };
-
- // Empty is ok.
- ref_ptr<CustomDeleterType> n;
- n.reset();
-
- // Lifecycle.
- EXPECT_EQ(0, alloc_count);
- auto a = make_ref<CustomDeleterType>();
- EXPECT_EQ(1, alloc_count);
- a.reset();
- EXPECT_EQ(0, alloc_count);
-}
-
-class VirtualDtorTypeA : public RefObject<VirtualDtorTypeA> {
- public:
- VirtualDtorTypeA() { ++alloc_count_a; }
- virtual ~VirtualDtorTypeA() { --alloc_count_a; }
- static int alloc_count_a;
-};
-int VirtualDtorTypeA::alloc_count_a = 0;
-
-class VirtualDtorTypeB : public VirtualDtorTypeA {
- public:
- VirtualDtorTypeB() { ++alloc_count_b; }
- ~VirtualDtorTypeB() override { --alloc_count_b; }
- static int alloc_count_b;
-};
-int VirtualDtorTypeB::alloc_count_b = 0;
-
-TEST(RefPtrTest, VirtualDestructor) {
- // Empty is ok.
- ref_ptr<VirtualDtorTypeB> n;
- n.reset();
-
- // Lifecycle.
- EXPECT_EQ(0, VirtualDtorTypeA::alloc_count_a);
- EXPECT_EQ(0, VirtualDtorTypeB::alloc_count_b);
- ref_ptr<VirtualDtorTypeB> a = make_ref<VirtualDtorTypeB>();
- EXPECT_EQ(1, VirtualDtorTypeA::alloc_count_a);
- EXPECT_EQ(1, VirtualDtorTypeB::alloc_count_b);
- a.reset();
- EXPECT_EQ(0, VirtualDtorTypeA::alloc_count_a);
- EXPECT_EQ(0, VirtualDtorTypeB::alloc_count_b);
-}
-
-} // namespace
-} // namespace iree
diff --git a/iree/base/shape.cc b/iree/base/shape.cc
deleted file mode 100644
index 875d119..0000000
--- a/iree/base/shape.cc
+++ /dev/null
@@ -1,100 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/shape.h"
-
-#include <cstddef>
-
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-
-namespace iree {
-
-Shape::Shape(const int* values, int size) : rank_(size) {
- QCHECK_LE(size, kMaxRank)
- << "Max rank of " << kMaxRank << ", shape has " << size;
- std::memcpy(value_, values, size * sizeof(int));
-}
-
-std::string Shape::DebugString() const {
- return absl::StrCat("[", absl::StrJoin(subspan(), ","), "]");
-}
-
-absl::Span<const int> Shape::subspan(size_type pos, size_type len) const {
- if (len == npos) {
- len = rank_ - pos;
- }
- return absl::MakeConstSpan(&value_[pos], len);
-}
-
-void Shape::push_back(int dim) {
- DCHECK_LE(rank_ + 1, kMaxRank);
- value_[rank_++] = dim;
-}
-
-void Shape::insert(iterator pos, int dim) {
- int axis = static_cast<int>(pos - value_);
- DCHECK_GE(axis, 0);
- DCHECK_LE(axis, rank_);
- DCHECK_LE(rank_ + 1, kMaxRank);
- ++rank_;
- for (int i = rank_ - 1; i > axis; --i) {
- value_[i] = value_[i - 1];
- }
- value_[axis] = dim;
-}
-
-void Shape::erase(iterator pos) {
- int axis = static_cast<int>(pos - value_);
- DCHECK_GE(axis, 0);
- DCHECK_LE(axis, rank_);
- for (int i = axis; i < rank_ - 1; ++i) {
- value_[i] = value_[i + 1];
- }
- --rank_;
-}
-
-int Shape::element_count() const {
- size_t element_count = 1;
- for (int i = 0; i < rank_; ++i) {
- int dim = value_[i];
- if (dim == -1) {
- return 0;
- }
- element_count *= dim;
- }
- return element_count;
-}
-
-StatusOr<int> Shape::ResolveAxis(int axis) const {
- if (rank_ == 0 && (axis == -1 || axis == 0)) {
- // Scalar axes resolves to 0.
- return 0;
- }
-
- int new_axis = axis;
- if (new_axis < 0) {
- new_axis += rank_;
- }
- if (new_axis < 0 || new_axis >= rank_) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Axis " << new_axis << " (orig " << axis
- << ") out of bounds of rank " << rank_;
- }
- return new_axis;
-}
-
-} // namespace iree
diff --git a/iree/base/shape.h b/iree/base/shape.h
deleted file mode 100644
index 758d66f..0000000
--- a/iree/base/shape.h
+++ /dev/null
@@ -1,156 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_SHAPE_H_
-#define IREE_BASE_SHAPE_H_
-
-#include <array>
-#include <cstring>
-#include <initializer_list>
-#include <iterator>
-#include <string>
-#include <type_traits>
-#include <vector>
-
-#include "absl/meta/type_traits.h"
-#include "absl/types/span.h"
-#include "iree/base/logging.h"
-#include "iree/base/status.h"
-
-namespace iree {
-
-// For simplicity we limit our shapes to a max of rank-N (shape.size() == N) as
-// this prevents dynamic allocations and rarely are there greater ranks.
-constexpr int kMaxRank = 5;
-
-// Represent indices and lengths of tensors.
-using Index = std::array<int, kMaxRank>;
-using Length = std::array<int, kMaxRank>;
-
-// Represents the number of elements in multiple dimensions.
-// Can be rank-0 (scalar) to rank-kMaxRank. Tries to match the API of
-// std::vector and can be converted to a Span via subspan().
-//
-// https://www.tensorflow.org/guide/tensors#shape
-class Shape {
- public:
- using size_type = int;
- static constexpr size_type npos = ~(size_type(0)); // NOLINT
- using iterator = int*;
- using const_iterator = const int*;
-
- Shape() = default;
- Shape(const int* values, int size);
- Shape(std::initializer_list<int> values)
- : Shape(values.begin(), values.size()) {}
- explicit Shape(absl::Span<const int> values)
- : Shape(values.data(), values.size()) {}
-
- template <typename Iterator>
- using EnableIfForwardIterator = absl::enable_if_t<std::is_convertible<
- typename std::iterator_traits<Iterator>::iterator_category,
- std::forward_iterator_tag>::value>;
- template <typename Iterator, EnableIfForwardIterator<Iterator>* = nullptr>
- Shape(Iterator first, Iterator last) {
- rank_ = std::distance(first, last);
- QCHECK_LE(rank_, kMaxRank);
- for (int i = 0; first != last; ++i, static_cast<void>(++first)) {
- value_[i] = *first;
- }
- }
-
- // Returns a string representation of the given shape.
- std::string DebugString() const;
-
- // Size (aka 'rank') of the shape, counting the number of dimensions.
- constexpr size_type size() const noexcept { return rank_; }
-
- // Whether the shape is rank-0 (scalar).
- constexpr bool empty() const noexcept { return rank_ == 0; }
-
- // Returns the total elements in the tensor shape.
- // Returns 0 if the tensor shape is not complete and 1 if the shape is a
- // scalar value.
- int element_count() const;
-
- // Resolves an axis in [-R,R) to the real axis value and verifies the range.
- StatusOr<int> ResolveAxis(int axis) const;
-
- // Compares two shapes for equality.
- inline static bool Equal(const Shape& a, const Shape& b) {
- return a.rank_ == b.rank_ &&
- std::memcmp(a.value_, b.value_, a.rank_ * sizeof(value_[0])) == 0;
- }
-
- int& operator[](size_type i) noexcept {
- DCHECK_GE(i, 0);
- DCHECK_LT(i, rank_);
- return value_[i];
- }
-
- const int& operator[](size_type i) const noexcept {
- DCHECK_GE(i, 0);
- DCHECK_LT(i, rank_);
- return value_[i];
- }
-
- int front() const noexcept {
- DCHECK_GE(rank_, 1);
- return value_[0];
- }
-
- int back() const noexcept {
- DCHECK_GE(rank_, 1);
- return value_[rank_ - 1];
- }
-
- constexpr iterator begin() const noexcept {
- return const_cast<iterator>(&value_[0]);
- }
- constexpr iterator end() const noexcept {
- return const_cast<iterator>(&value_[rank_]);
- }
- constexpr const_iterator cbegin() const noexcept { return &value_[0]; }
- constexpr const_iterator cend() const noexcept { return &value_[rank_]; }
-
- absl::Span<const int> subspan(size_type pos = 0, size_type len = npos) const;
- absl::Span<const int> data() const { return subspan(); }
-
- void push_back(int dim);
-
- void insert(iterator pos, int dim);
-
- void erase(iterator pos);
-
- void clear() { rank_ = 0; }
-
- private:
- size_type rank_ = 0;
- int value_[kMaxRank];
-};
-
-inline bool operator==(const Shape& a, const Shape& b) {
- return Shape::Equal(a, b);
-}
-
-inline bool operator!=(const Shape& a, const Shape& b) { return !(a == b); }
-
-inline std::ostream& operator<<(std::ostream& stream, const Shape& shape) {
- stream << shape.DebugString();
- return stream;
-}
-
-} // namespace iree
-
-#endif // IREE_BASE_SHAPE_H_
diff --git a/iree/base/shape_test.cc b/iree/base/shape_test.cc
deleted file mode 100644
index ca163a3..0000000
--- a/iree/base/shape_test.cc
+++ /dev/null
@@ -1,222 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/shape.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status.h"
-#include "iree/base/status_matchers.h"
-
-namespace iree {
-namespace {
-
-using ::testing::ElementsAre;
-
-// Tests shapes that represent 0-D scalar values.
-TEST(ShapeTest, Scalar) {
- Shape shape;
- EXPECT_EQ(0, shape.size());
- EXPECT_TRUE(shape.empty());
- EXPECT_EQ(1, shape.element_count());
- EXPECT_EQ(shape, shape);
- EXPECT_EQ(0, shape.subspan().size());
- for (const int dim : shape) {
- FAIL() << "Should have no dimensions, have: " << dim;
- }
- EXPECT_EQ(shape.begin(), shape.end());
- EXPECT_EQ(shape.cbegin(), shape.cend());
- shape.clear();
- EXPECT_EQ(0, shape.size());
-}
-
-// Tests the various ways of constructing a 1+D shape.
-TEST(ShapeTest, NonScalarConstruction) {
- EXPECT_EQ(0, Shape().size());
- EXPECT_EQ(0, Shape({}).size());
- EXPECT_EQ(1, Shape({10}).size());
- EXPECT_EQ(4, Shape({10, 20, 30, 40}).size());
-
- std::vector<int> empty_data = {};
- EXPECT_EQ(0, Shape(empty_data.data(), empty_data.size()).size());
- EXPECT_EQ(0, Shape(empty_data.begin(), empty_data.end()).size());
- EXPECT_EQ(0, Shape(absl::MakeConstSpan(empty_data)).size());
-
- EXPECT_THAT(Shape({}).subspan(), ElementsAre());
- EXPECT_THAT(Shape({10}).subspan(), ElementsAre(10));
- EXPECT_THAT(Shape({10, 20, 30, 40}).subspan(), ElementsAre(10, 20, 30, 40));
-
- std::vector<int> valid_data = {10, 20, 30, 40};
- EXPECT_THAT(Shape(valid_data.begin(), valid_data.end()).subspan(),
- ElementsAre(10, 20, 30, 40));
- EXPECT_THAT(Shape(absl::MakeConstSpan(valid_data)).subspan(),
- ElementsAre(10, 20, 30, 40));
-}
-
-// Tests shapes that represent 1+D multidimensional values.
-TEST(ShapeTest, NonScalarAccess) {
- Shape shape = {1, 2, 3, 4};
- EXPECT_EQ(4, shape.size());
- EXPECT_FALSE(shape.empty());
- EXPECT_EQ(1 * 2 * 3 * 4, shape.element_count());
- EXPECT_EQ(shape, shape);
- EXPECT_NE(shape, Shape({4, 3, 2, 1}));
- EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4));
- std::vector<int> readout;
- for (const int dim : shape) {
- readout.push_back(dim);
- }
- EXPECT_THAT(readout, ElementsAre(1, 2, 3, 4));
- EXPECT_EQ(1, shape[0]);
- EXPECT_EQ(2, shape[1]);
- EXPECT_EQ(3, shape[2]);
- EXPECT_EQ(4, shape[3]);
- EXPECT_EQ(1, shape.front());
- EXPECT_EQ(4, shape.back());
-}
-
-TEST(ShapeTest, PushBack) {
- Shape shape;
- EXPECT_EQ(0, shape.size());
-
- shape.push_back(10);
- EXPECT_EQ(1, shape.size());
- EXPECT_EQ(10, shape.front());
- EXPECT_EQ(10, shape.back());
- EXPECT_EQ(10, shape[0]);
- EXPECT_THAT(shape.subspan(), ElementsAre(10));
-
- shape.push_back(20);
- EXPECT_EQ(2, shape.size());
- EXPECT_EQ(10, shape.front());
- EXPECT_EQ(20, shape.back());
- EXPECT_EQ(10, shape[0]);
- EXPECT_EQ(20, shape[1]);
- EXPECT_THAT(shape.subspan(), ElementsAre(10, 20));
-}
-
-TEST(ShapeTest, Insert) {
- Shape shape;
- EXPECT_EQ(0, shape.size());
-
- shape.insert(shape.begin(), 20);
- EXPECT_THAT(shape.subspan(), ElementsAre(20));
- shape.insert(shape.begin(), 10);
- EXPECT_THAT(shape.subspan(), ElementsAre(10, 20));
- shape.insert(shape.end(), 40);
- EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 40));
- shape.insert(shape.begin() + 2, 30);
- EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 30, 40));
-
- Shape ex_shape{72, 4};
- ex_shape.insert(ex_shape.begin(), 144);
- EXPECT_THAT(ex_shape.subspan(), ElementsAre(144, 72, 4));
-}
-
-TEST(ShapeTest, Erase) {
- Shape shape = {1, 2, 3, 4};
- EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4));
- shape.erase(shape.begin());
- EXPECT_THAT(shape.subspan(), ElementsAre(2, 3, 4));
- shape.erase(shape.end());
- EXPECT_THAT(shape.subspan(), ElementsAre(2, 3));
- shape.erase(shape.begin() + 1);
- EXPECT_THAT(shape.subspan(), ElementsAre(2));
- shape.erase(shape.end());
- EXPECT_THAT(shape.subspan(), ElementsAre());
-}
-
-TEST(ShapeTest, Clear) {
- Shape shape;
- EXPECT_EQ(0, shape.size());
- shape.clear();
- EXPECT_EQ(0, shape.size());
-
- shape = Shape({1});
- shape.clear();
- EXPECT_EQ(0, shape.size());
-
- shape = Shape({1, 2, 3, 4});
- shape.clear();
- EXPECT_EQ(0, shape.size());
-}
-
-TEST(ShapeTest, DebugString) {
- EXPECT_EQ("[]", Shape({}).DebugString());
- EXPECT_EQ("[1]", Shape({1}).DebugString());
- EXPECT_EQ("[1,2]", Shape({1, 2}).DebugString());
-}
-
-TEST(ShapeTest, ElementCount) {
- EXPECT_EQ(1, Shape({}).element_count());
- EXPECT_EQ(0, Shape({0}).element_count());
- EXPECT_EQ(1, Shape({1}).element_count());
- EXPECT_EQ(2, Shape({2, 1}).element_count());
- EXPECT_EQ(10, Shape({2, 5}).element_count());
- EXPECT_EQ(9216, Shape({72, 1, 128}).element_count());
- EXPECT_EQ(9216, Shape({1, 72, 128}).element_count());
-
- // Partial shaping should yield no elements.
- EXPECT_EQ(0, Shape({1, -1, 2, 3}).element_count());
-}
-
-TEST(ShapeTest, ResolveAxis) {
- int axis;
- ASSERT_OK_AND_ASSIGN(axis, Shape({0}).ResolveAxis(0));
- EXPECT_EQ(0, axis);
- ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(1));
- EXPECT_EQ(1, axis);
- ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(2));
- EXPECT_EQ(2, axis);
-
- EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(3).status()));
-}
-
-TEST(ShapeTest, ResolveAxisNegative) {
- int axis;
- ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-3));
- EXPECT_EQ(0, axis);
- ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-2));
- EXPECT_EQ(1, axis);
- ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-1));
- EXPECT_EQ(2, axis);
-
- EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(-4).status()));
-}
-
-TEST(ShapeTest, ResolveAxisScalar) {
- int axis;
- ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(0));
- EXPECT_EQ(0, axis);
- ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(-1));
- EXPECT_EQ(0, axis);
-
- EXPECT_TRUE(IsInvalidArgument(Shape({}).ResolveAxis(1).status()));
-}
-
-TEST(ShapeTest, Equality) {
- EXPECT_EQ(Shape({}), Shape({}));
- EXPECT_EQ(Shape({0}), Shape({0}));
- EXPECT_EQ(Shape({1}), Shape({1}));
- EXPECT_EQ(Shape({1, 2}), Shape({1, 2}));
-
- EXPECT_NE(Shape({}), Shape({1}));
- EXPECT_NE(Shape({-1}), Shape({1}));
- EXPECT_NE(Shape({1}), Shape({}));
- EXPECT_NE(Shape({1}), Shape({2}));
- EXPECT_NE(Shape({1, 2}), Shape({3, 4}));
-}
-
-} // namespace
-} // namespace iree
diff --git a/iree/base/source_location.h b/iree/base/source_location.h
deleted file mode 100644
index d4a10ca..0000000
--- a/iree/base/source_location.h
+++ /dev/null
@@ -1,24 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_SOURCE_LOCATION_H_
-#define IREE_BASE_SOURCE_LOCATION_H_
-
-#ifdef IREE_CONFIG_GOOGLE_INTERNAL
-#include "iree/base/google/source_location_google.h"
-#else
-#include "iree/base/internal/source_location.h"
-#endif // IREE_CONFIG_GOOGLE_INTERNAL
-
-#endif // IREE_BASE_SOURCE_LOCATION_H_
diff --git a/iree/base/status.h b/iree/base/status.h
deleted file mode 100644
index 4565101..0000000
--- a/iree/base/status.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_STATUS_H_
-#define IREE_BASE_STATUS_H_
-
-#ifdef IREE_CONFIG_GOOGLE_INTERNAL
-#include "iree/base/google/status_google.h"
-#else
-#include "iree/base/internal/status.h"
-#include "iree/base/internal/status_builder.h"
-#include "iree/base/internal/status_errno.h"
-#include "iree/base/internal/status_errors.h"
-#include "iree/base/internal/status_macros.h"
-#include "iree/base/internal/status_win32_errors.h"
-#include "iree/base/internal/statusor.h"
-#endif // IREE_CONFIG_GOOGLE_INTERNAL
-
-#include "iree/base/source_location.h" // IWYU pragma: export
-
-#endif // IREE_BASE_STATUS_H_
diff --git a/iree/base/status_matchers.h b/iree/base/status_matchers.h
deleted file mode 100644
index 7920c57..0000000
--- a/iree/base/status_matchers.h
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_STATUS_MATCHERS_H_
-#define IREE_BASE_STATUS_MATCHERS_H_
-
-#ifdef IREE_CONFIG_GOOGLE_INTERNAL
-
-#include "iree/base/google/status_matchers_google.h" // IWYU pragma: export
-
-#else
-
-#include "iree/base/internal/status_matchers.h" // IWYU pragma: export
-
-#endif // IREE_CONFIG_GOOGLE_INTERNAL
-
-#endif // IREE_BASE_STATUS_MATCHERS_H_
diff --git a/iree/base/tracing.cc b/iree/base/tracing.cc
deleted file mode 100644
index 51fa5d0..0000000
--- a/iree/base/tracing.cc
+++ /dev/null
@@ -1,188 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Force the header to detect WTF_ENABLE so that this library builds
-// (for when building recursively).
-#if !defined(WTF_ENABLE)
-#define WTF_ENABLE
-#endif
-
-#include "iree/base/tracing.h"
-
-#include <thread> // NOLINT: Fiber doesn't work during startup on Android.
-
-#include "absl/base/attributes.h"
-#include "absl/base/const_init.h"
-#include "absl/base/thread_annotations.h"
-#include "absl/flags/flag.h"
-#include "absl/strings/str_cat.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/time/clock.h"
-#include "absl/time/time.h"
-#include "iree/base/file_io.h"
-#include "iree/base/file_path.h"
-#include "iree/base/init.h"
-#include "iree/base/logging.h"
-#include "iree/base/status.h"
-
-ABSL_FLAG(int32_t, iree_trace_file_period, 5,
- "Seconds between automatic flushing of WTF trace files. 0 to "
- "disable auto-flush.");
-ABSL_FLAG(std::string, iree_trace_file, "/dev/null",
- "wtf-trace file to save if --define=GLOBAL_WTF_ENABLE=1 was used "
- "when building.");
-
-namespace iree {
-namespace {
-
-// Guards global WTF state (like the flush fiber and IO).
-ABSL_CONST_INIT absl::Mutex global_tracing_mutex(absl::kConstInit);
-
-// True when tracing has been enabled and initialized.
-bool global_tracing_initialized ABSL_GUARDED_BY(global_tracing_mutex) = false;
-
-// If there is an existing file at the given path back it up by moving it aside.
-// Only kMaxBackups will be kept to avoid unbounded growth.
-void RollTraceFiles(const std::string& path) {
- std::string path_stem = file_path::JoinPaths(file_path::DirectoryName(path),
- file_path::Stem(path));
- const int kMaxBackups = 5;
- for (int i = kMaxBackups; i >= 0; i--) {
- std::string source_name;
- if (i > 0) {
- source_name = absl::StrCat(path_stem, ".", i, ".wtf-trace");
- } else {
- source_name = path;
- }
- if (!file_io::FileExists(source_name).ok()) {
- continue;
- }
-
- Status status;
- if (i == kMaxBackups) {
- status = file_io::DeleteFile(source_name);
- } else {
- std::string backup_name =
- absl::StrCat(path_stem, ".", (i + 1), ".wtf-trace");
- status = file_io::MoveFile(source_name, backup_name);
- }
- if (!status.ok()) {
- LOG(WARNING) << "Could not remove backup trace file " << source_name
- << ": " << status;
- }
- }
-}
-
-// Flushes all recorded trace data since the last flush.
-void FlushTraceFile() ABSL_EXCLUSIVE_LOCKS_REQUIRED(global_tracing_mutex) {
- if (!global_tracing_initialized) return;
-
- const auto& trace_path = absl::GetFlag(FLAGS_iree_trace_file);
-
- static ::wtf::Runtime::SaveCheckpoint checkpoint;
- static bool is_first_flush = true;
-
- if (is_first_flush && trace_path != "/dev/null") {
- // Backup existing any existing trace files at the specified path.
- RollTraceFiles(trace_path);
- }
-
- auto save_options =
- ::wtf::Runtime::SaveOptions::ForStreamingFile(&checkpoint);
- if (is_first_flush) {
- // On the first time, truncate the file. All subsequent flushes append.
- save_options.open_mode = std::ios_base::trunc;
- }
-
- is_first_flush = false;
-
- auto* runtime = ::wtf::Runtime::GetInstance();
- if (!runtime->SaveToFile(trace_path, save_options)) {
- LOG(ERROR) << "Error saving WTF file: " << trace_path;
- return;
- }
-
- VLOG(1) << "Flushed WTF trace to: " << trace_path;
-}
-
-} // namespace
-
-void InitializeTracing() {
- if (!::wtf::kMasterEnable) {
- if (!absl::GetFlag(FLAGS_iree_trace_file).empty()) {
- LOG(WARNING) << "WTF trace save requested but WTF is not compiled in. "
- << "Enable by building with --define=GLOBAL_WTF_ENABLE=1.";
- }
- return;
- }
-
- absl::MutexLock lock(&global_tracing_mutex);
- if (global_tracing_initialized) return;
- global_tracing_initialized = true;
-
- LOG(INFO) << "Tracing enabled and streaming to: "
- << absl::GetFlag(FLAGS_iree_trace_file);
-
- // Enable tracing on this thread, which we know is main.
- IREE_TRACE_THREAD_ENABLE("main");
-
- // Register atexit callback to stop tracking.
- atexit(StopTracing);
-
- // Launch a thread to periodically flush the trace.
- if (absl::GetFlag(FLAGS_iree_trace_file_period) > 0) {
- auto flush_thread = std::thread(+[]() {
- absl::Duration period =
- absl::Seconds(absl::GetFlag(FLAGS_iree_trace_file_period));
- while (true) {
- absl::SleepFor(period);
- absl::MutexLock lock(&global_tracing_mutex);
- if (!global_tracing_initialized) {
- return;
- }
- FlushTraceFile();
- }
- });
- flush_thread.detach();
- }
-}
-
-// Stops tracing if currently initialized.
-void StopTracing() {
- if (!::wtf::kMasterEnable) return;
- absl::MutexLock lock(&global_tracing_mutex);
- if (!global_tracing_initialized) return;
-
- // Flush any pending trace data.
- FlushTraceFile();
-
- // Mark WTF as uninitialized to kill the flush thread.
- global_tracing_initialized = false;
-
- LOG(INFO) << "Tracing stopped and flushed to file: "
- << absl::GetFlag(FLAGS_iree_trace_file);
-}
-
-void FlushTrace() {
- if (!::wtf::kMasterEnable) return;
- absl::MutexLock lock(&global_tracing_mutex);
- if (!global_tracing_initialized) return;
- FlushTraceFile();
-}
-
-} // namespace iree
-
-IREE_DECLARE_MODULE_INITIALIZER(iree_tracing);
-
-IREE_REGISTER_MODULE_INITIALIZER(iree_tracing, ::iree::InitializeTracing());
diff --git a/iree/base/tracing_disabled.cc b/iree/base/tracing_disabled.cc
deleted file mode 100644
index 96ea115..0000000
--- a/iree/base/tracing_disabled.cc
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// This file is linked in only when WTF is not enabled. It allows us to keep the
-// same flags and functions without needing to do a bunch of ifdef hackery or
-// undefok mangling.
-
-#include <cstdint>
-#include <string>
-
-#include "absl/flags/flag.h"
-#include "iree/base/tracing.h"
-
-// TODO(benvanik): remove this when disabled so that we don't dep on flags.
-ABSL_FLAG(int32_t, iree_trace_file_period, 0,
- "Flag for tracing. Use --define=GLOBAL_WTF_ENABLE=1 to enable WTF.");
-ABSL_FLAG(std::string, iree_trace_file, "",
- "Flag for tracing. Use --define=GLOBAL_WTF_ENABLE=1 to enable WTF.");
diff --git a/iree/base/wait_handle.cc b/iree/base/wait_handle.cc
deleted file mode 100644
index 39e5ca9..0000000
--- a/iree/base/wait_handle.cc
+++ /dev/null
@@ -1,532 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/wait_handle.h"
-
-#include <errno.h>
-#include <fcntl.h>
-#include <poll.h>
-#include <time.h>
-#include <unistd.h>
-
-#include <type_traits>
-#include <utility>
-
-#include "absl/container/fixed_array.h"
-#include "absl/strings/str_cat.h"
-#include "absl/time/clock.h"
-#include "absl/time/time.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-
-// TODO(benvanik): organize these macros - they are terrible.
-
-#if !defined(__ANDROID__) && !defined(OS_IOS) && !defined(__EMSCRIPTEN__)
-#define IREE_HAS_PPOLL 1
-#endif // !__ANDROID__ && !__EMSCRIPTEN__
-#define IREE_HAS_POLL 1
-
-#if !defined(OS_IOS) && !defined(OS_MACOSX) && !defined(__EMSCRIPTEN__)
-#define IREE_HAS_EVENTFD 1
-#endif
-#define IREE_HAS_PIPE 1
-// #define IREE_HAS_SYNC_FILE 1
-
-#if defined(IREE_HAS_EVENTFD)
-#include <sys/eventfd.h>
-#endif // IREE_HAS_EVENTFD
-
-namespace iree {
-
-namespace {
-
-constexpr int kInvalidFd = WaitableObject::kInvalidFd;
-constexpr int kSignaledFd = WaitableObject::kSignaledFd;
-
-// Retries a syscall until it succeeds or fails for a real reason.
-template <typename SyscallT, typename... ParamsT>
-StatusOr<typename std::result_of<SyscallT(ParamsT...)>::type> Syscall(
- SyscallT syscall, ParamsT&&... params) {
- while (true) {
- const auto rv = syscall(std::forward<ParamsT>(params)...);
- if (rv >= 0) return rv;
- if (errno == EINTR) {
- // Retry on EINTR.
- continue;
- } else {
- return ErrnoToCanonicalStatus(errno, "");
- }
- }
-}
-
-#if defined(IREE_HAS_PPOLL)
-
-// ppoll(), present on Linux.
-// ppoll is preferred as it has a much better timing mechanism; poll can have a
-// large slop on the deadline.
-// Documentation: https://linux.die.net/man/2/poll
-StatusOr<int> SystemPoll(absl::Span<pollfd> poll_fds, absl::Time deadline) {
- // Convert the deadline into a tmo_p struct for ppoll that controls whether
- // the call is blocking or non-blocking. Note that we must do this every
- // iteration of the loop as a previous ppoll may have taken some of the
- // time.
- //
- // See the ppoll docs for more information as to what the expected value is:
- // http://man7.org/linux/man-pages/man2/poll.2.html
- timespec timeout_spec;
- timespec* tmo_p;
- if (deadline == absl::InfinitePast()) {
- // 0 for non-blocking.
- timeout_spec = {0};
- tmo_p = &timeout_spec;
- } else if (deadline == absl::InfiniteFuture()) {
- // nullptr to ppoll() to block forever.
- tmo_p = nullptr;
- } else {
- // Wait only for as much time as we have before the deadline is exceeded.
- absl::Duration remaining_time = deadline - absl::Now();
- if (remaining_time < absl::ZeroDuration()) {
- // Note: we likely have already bailed before getting here with a negative
- // duration.
- return DeadlineExceededErrorBuilder(IREE_LOC);
- }
- timeout_spec = absl::ToTimespec(remaining_time);
- tmo_p = &timeout_spec;
- }
- return Syscall(::ppoll, poll_fds.data(), poll_fds.size(), tmo_p, nullptr);
-}
-
-#elif defined(IREE_HAS_POLL)
-
-// poll(), present pretty much everywhere.
-// Documentation: https://linux.die.net/man/2/poll
-StatusOr<int> SystemPoll(absl::Span<pollfd> poll_fds, absl::Time deadline) {
- int timeout;
- if (deadline == absl::InfinitePast()) {
- // Don't block.
- timeout = 0;
- } else if (deadline == absl::InfiniteFuture()) {
- // Block forever.
- timeout = -1;
- } else {
- absl::Duration remaining_time = deadline - absl::Now();
- if (remaining_time < absl::ZeroDuration()) {
- return DeadlineExceededErrorBuilder(IREE_LOC);
- }
- timeout = static_cast<int>(absl::ToInt64Milliseconds(remaining_time));
- }
- return Syscall(::poll, poll_fds.data(), poll_fds.size(), timeout);
-}
-
-#else
-#error "No SystemPoll implementation"
-#endif // IREE_HAS_PPOLL / IREE_HAS_POLL / etc
-
-// Builds the list of pollfds to for ppoll wait on and will perform any
-// required wait handle callbacks.
-//
-// The provided deadline will be observed if any of the wait handles needs to
-// block for acquiring an fd.
-StatusOr<absl::FixedArray<pollfd>> AcquireWaitHandles(
- WaitHandle::WaitHandleSpan wait_handles, absl::Time deadline) {
- absl::FixedArray<pollfd> poll_fds{wait_handles.size()};
- for (int i = 0; i < wait_handles.size(); ++i) {
- poll_fds[i].events = POLLIN | POLLPRI | POLLERR | POLLHUP | POLLNVAL;
- poll_fds[i].revents = 0;
- // NOTE: poll will ignore any negative fds and our kInvalidFd == -1 so we
- // can still put them in the list and it'll just skip them.
- if (!wait_handles[i] || !wait_handles[i]->object()) {
- poll_fds[i].fd = kInvalidFd;
- continue;
- }
-
- // Acquire the file descriptor for waiting.
- // This may block (if |deadline| allows it) if the fd is not yet available.
- // This is like a pre-wait for the actual poll operation. It can be bad with
- // WaitAny, though we could handle that better here.
- ASSIGN_OR_RETURN(auto fd_info,
- wait_handles[i]->object()->AcquireFdForWait(deadline));
- poll_fds[i].fd = fd_info.second;
-
- // Abort if deadline exceeded.
- if (deadline != absl::InfinitePast() && deadline < absl::Now()) {
- return DeadlineExceededErrorBuilder(IREE_LOC)
- << "Deadline exceeded acquiring for fds";
- }
- }
- return poll_fds;
-}
-
-Status ClearFd(WaitableObject::FdType fd_type, int fd) {
- // Read in a loop until the read would block.
- // Depending on how the users setup the fd the act of reading may reset the
- // entire handle (such as with the default eventfd mode) or multiple reads
- // may be required (such as with semaphores).
- while (true) {
-#if defined(IREE_HAS_EVENTFD)
- eventfd_t val = 0;
- int rv = ::eventfd_read(fd, &val);
-#elif defined(IREE_HAS_PIPE)
- char buf;
- int rv = ::read(fd, &buf, 1);
-#else
- return UnimplementedErrorBuilder(IREE_LOC) << "fd_type cannot be cleared";
-#endif // IREE_HAS_EVENTFD
- if (rv != -1) {
- // Success! Keep going.
- continue;
- } else {
- if (errno == EWOULDBLOCK) {
- // The read would have blocked meaning that we've hit the end and
- // successfully cleared the fd.
- return OkStatus();
- } else if (errno == EINTR) {
- // Retry.
- continue;
- } else {
- return ErrnoToCanonicalStatus(errno, "ClearFd failed");
- }
- }
- }
-}
-
-// Performs a single poll on multiple fds and returns information about the
-// signaled fds, if any.
-Status MultiPoll(WaitHandle::WaitHandleSpan wait_handles,
- absl::Span<pollfd> poll_fds, absl::Time deadline,
- int* out_any_signaled_index, int* out_unsignaled_count) {
- *out_any_signaled_index = -1;
- *out_unsignaled_count = 0;
-
- // poll has a nasty behavior where it allows -1 for fds... except for at [0].
- // To keep the rest of the code sane we correct for that here as epoll doesn't
- // have that behavior and we may want to special case this later.
- bool any_valid_fds = true;
- int swapped_zero_index = -1;
- if (poll_fds[0].fd < 0) {
- // Find a valid handle.
- for (int i = 1; i < poll_fds.size(); ++i) {
- if (poll_fds[i].fd > 0) {
- swapped_zero_index = i;
- std::swap(poll_fds[0], poll_fds[i]);
- break;
- }
- }
- if (swapped_zero_index == -1) {
- // No valid handles found, meaning that all handles are invalid.
- // We'll skip the wait below so we can share the processing code for any
- // fds that may be kSignaledFd.
- any_valid_fds = false;
- }
- }
-
- // Pass handles to ppoll.
- // http://man7.org/linux/man-pages/man2/poll.2.html
- if (any_valid_fds) {
- ASSIGN_OR_RETURN(int rv, SystemPoll(poll_fds, deadline));
- if (rv == 0) {
- // Call timed out and no descriptors were ready.
- // If this was just a poll then that's fine.
- return DeadlineExceededErrorBuilder(IREE_LOC);
- }
- }
-
- // If we had swapped fds[0] above we need to correct for that now.
- if (swapped_zero_index != -1) {
- std::swap(poll_fds[0], poll_fds[swapped_zero_index]);
- }
-
- // |rv| denotes the number of fds that were ready. Run through the list and
- // find the ones that were ready and mark them as completed.
- for (int i = 0; i < poll_fds.size(); ++i) {
- if (poll_fds[i].fd == kSignaledFd || poll_fds[i].revents == POLLIN) {
- // First attempt any resolve actions. If these fail we can't consider the
- // fd as having been signaled.
- ASSIGN_OR_RETURN(
- bool resolved,
- wait_handles[i]->object()->TryResolveWakeOnFd(poll_fds[i].fd));
- if (!resolved) {
- ++(*out_unsignaled_count);
- continue;
- }
-
- // Successful wait. Kill the fd so it is ignored on the next poll.
- poll_fds[i].fd = kInvalidFd;
- *out_any_signaled_index = i;
- } else if (poll_fds[i].revents) {
- if (poll_fds[i].revents & POLLERR) {
- return InternalErrorBuilder(IREE_LOC);
- } else if (poll_fds[i].revents & POLLHUP) {
- return CancelledErrorBuilder(IREE_LOC);
- } else if (poll_fds[i].revents & POLLNVAL) {
- return InvalidArgumentErrorBuilder(IREE_LOC);
- } else {
- return UnknownErrorBuilder(IREE_LOC);
- }
- } else if (poll_fds[i].fd != kInvalidFd) {
- ++(*out_unsignaled_count);
- }
- }
-
- return OkStatus();
-}
-
-} // namespace
-
-// static
-std::atomic<uint64_t> WaitHandle::next_unique_id_{1};
-
-// static
-WaitHandle WaitHandle::AlwaysSignaling() {
- class AlwaysSignalingObject : public WaitableObject {
- public:
- std::string DebugString() const override { return "signal"; }
- StatusOr<std::pair<FdType, int>> AcquireFdForWait(
- absl::Time deadline) override {
- return std::make_pair(FdType::kPermanent, kSignaledFd);
- }
- StatusOr<bool> TryResolveWakeOnFd(int fd) override { return true; }
- };
- static auto* obj = new AlwaysSignalingObject();
- return WaitHandle(add_ref(obj));
-}
-
-// static
-WaitHandle WaitHandle::AlwaysFailing() {
- class AlwaysFailingObject : public WaitableObject {
- public:
- std::string DebugString() const override { return "fail"; }
- StatusOr<std::pair<FdType, int>> AcquireFdForWait(
- absl::Time deadline) override {
- return InternalErrorBuilder(IREE_LOC) << "AlwaysFailingObject";
- }
- StatusOr<bool> TryResolveWakeOnFd(int fd) override {
- return InternalErrorBuilder(IREE_LOC) << "AlwaysFailingObject";
- }
- };
- static auto* obj = new AlwaysFailingObject();
- return WaitHandle(add_ref(obj));
-}
-
-// static
-Status WaitHandle::WaitAll(WaitHandleSpan wait_handles, absl::Time deadline) {
- if (wait_handles.empty()) return OkStatus();
-
- // Build the list of pollfds to wait on.
- ASSIGN_OR_RETURN(auto poll_fds, AcquireWaitHandles(wait_handles, deadline));
-
- // Loop until all handles have been signaled or the deadline is exceeded.
- int unsignaled_count = 0;
- do {
- int any_signaled_index = 0;
- RETURN_IF_ERROR(MultiPoll(wait_handles, absl::MakeSpan(poll_fds), deadline,
- &any_signaled_index, &unsignaled_count));
- } while (unsignaled_count > 0 && absl::Now() < deadline);
-
- if (unsignaled_count == 0) {
- // All waits resolved.
- return OkStatus();
- } else {
- // One or more were unsignaled.
- return DeadlineExceededErrorBuilder(IREE_LOC);
- }
-}
-
-// static
-StatusOr<bool> WaitHandle::TryWaitAll(WaitHandleSpan wait_handles) {
- auto status = WaitAll(wait_handles, absl::InfinitePast());
- if (status.ok()) {
- return true;
- } else if (IsDeadlineExceeded(status)) {
- return false;
- }
- return status;
-}
-
-// static
-StatusOr<int> WaitHandle::WaitAny(WaitHandleSpan wait_handles,
- absl::Time deadline) {
- if (wait_handles.empty()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "At least one wait handle is required for WaitAny";
- }
-
- // Build the list of pollfds to wait on.
- ASSIGN_OR_RETURN(auto poll_fds, AcquireWaitHandles(wait_handles, deadline));
-
- // Poll once; this makes a WaitAny just a WaitMulti that doesn't loop.
- int any_signaled_index = -1;
- int unsignaled_count = 0;
- RETURN_IF_ERROR(MultiPoll(wait_handles, absl::MakeSpan(poll_fds), deadline,
- &any_signaled_index, &unsignaled_count));
- if (any_signaled_index == -1) {
- // No wait handles were valid. Pretend 0 was signaled.
- return 0;
- }
- return any_signaled_index;
-}
-
-// static
-StatusOr<int> WaitHandle::TryWaitAny(WaitHandleSpan wait_handles) {
- auto status_or = WaitAny(wait_handles, absl::InfinitePast());
- return IsDeadlineExceeded(status_or.status()) ? -1 : status_or;
-}
-
-// Storage for static class variables; these won't be needed when we can use
-// c++17 everywhere.
-constexpr int WaitableObject::kInvalidFd;
-constexpr int WaitableObject::kSignaledFd;
-
-WaitHandle::WaitHandle(ref_ptr<WaitableObject> object)
- : unique_id_(++next_unique_id_), object_(std::move(object)) {}
-
-WaitHandle::~WaitHandle() { Dispose(); }
-
-void WaitHandle::Dispose() { object_.reset(); }
-
-WaitHandle::WaitHandle(WaitHandle&& other)
- : unique_id_(other.unique_id_), object_(std::move(other.object_)) {
- other.unique_id_ = 0;
-}
-
-WaitHandle& WaitHandle::operator=(WaitHandle&& other) {
- if (this != std::addressof(other)) {
- // Close current handle.
- Dispose();
-
- // Take ownership of handle and resources.
- object_ = std::move(other.object_);
-
- other.unique_id_ = ++next_unique_id_;
- }
- return *this;
-}
-
-std::string WaitHandle::DebugString() const {
- return object_ ? object_->DebugString() : absl::StrCat("wh_", unique_id_);
-}
-
-StatusOr<bool> WaitHandle::TryWait() {
- auto status = WaitAll({this}, absl::InfinitePast());
- if (status.ok()) {
- return true;
- } else if (IsDeadlineExceeded(status)) {
- return false;
- }
- return status;
-}
-
-ManualResetEvent::ManualResetEvent(const char* debug_name)
- : debug_name_(debug_name) {
- Initialize();
-}
-
-ManualResetEvent::~ManualResetEvent() { Dispose(); }
-
-void ManualResetEvent::Initialize() {
-#if defined(IREE_HAS_EVENTFD)
- // Create with an eventfd by default when we support it.
- // eventfd has lower overhead than pipes (the syscalls are cheap).
- // This usually will only fail if the system is completely out of handles.
- //
- // Docs: http://man7.org/linux/man-pages/man2/eventfd.2.html
- fd_type_ = FdType::kEventFd;
- fd_ = Syscall(::eventfd, 0, EFD_CLOEXEC | EFD_NONBLOCK).ValueOrDie();
-#elif defined(IREE_HAS_PIPE)
- // Android/Linux/iOS-compatible POSIX pipe handle.
- // Two handles are generated: one for transmitting and one for receiving.
- //
- // Docs: http://man7.org/linux/man-pages/man2/pipe.2.html
- fd_type_ = FdType::kPipe;
- int pipefd[2];
- Syscall(::pipe, pipefd).ValueOrDie();
- Syscall(::fcntl, pipefd[0], F_SETFL, O_NONBLOCK).ValueOrDie();
- fd_ = pipefd[0];
- write_fd_ = pipefd[1];
-#else
-// NOTE: sync_file does not use Notifier as they come from the kernel.
-#error "No fd-based sync primitive on this platform"
-#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE / etc
-}
-
-void ManualResetEvent::Dispose() {
- if (fd_ != kInvalidFd) {
- // Always signal, as we need to ensure waiters are woken.
- CHECK_OK(Set());
- Syscall(::close, fd_).ValueOrDie();
- fd_ = kInvalidFd;
- }
- if (write_fd_ != kInvalidFd) {
- Syscall(::close, write_fd_).ValueOrDie();
- write_fd_ = kInvalidFd;
- }
-}
-
-ManualResetEvent::ManualResetEvent(ManualResetEvent&& other)
- : fd_type_(other.fd_type_),
- fd_(other.fd_),
- write_fd_(other.write_fd_),
- debug_name_(other.debug_name_) {
- other.fd_type_ = FdType::kPermanent;
- other.fd_ = kInvalidFd;
- other.write_fd_ = kInvalidFd;
- other.debug_name_ = nullptr;
-}
-
-ManualResetEvent& ManualResetEvent::operator=(ManualResetEvent&& other) {
- if (this != std::addressof(other)) {
- Dispose();
- fd_type_ = other.fd_type_;
- fd_ = other.fd_;
- write_fd_ = other.write_fd_;
- debug_name_ = other.debug_name_;
- other.fd_type_ = FdType::kPermanent;
- other.fd_ = kInvalidFd;
- other.write_fd_ = kInvalidFd;
- other.debug_name_ = nullptr;
- other.Initialize();
- }
- return *this;
-}
-
-std::string ManualResetEvent::DebugString() const {
- if (debug_name_) {
- return debug_name_;
- }
-#if defined(IREE_HAS_EVENTFD)
- return absl::StrCat("eventfd_", fd_);
-#elif defined(IREE_HAS_PIPE)
- return absl::StrCat("pipe_", fd_, "_", write_fd_);
-#else
- return absl::StrCat("unknown_", fd_, "_", write_fd_);
-#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE
-}
-
-Status ManualResetEvent::Set() {
-#if defined(IREE_HAS_EVENTFD)
- return Syscall(::eventfd_write, fd_, 1ull).status();
-#elif defined(IREE_HAS_PIPE)
- char buf = '\n';
- return Syscall(::write, write_fd_, &buf, 1).status();
-#else
- return UnimplementedErrorBuilder(IREE_LOC)
- << "No fd-based sync primitive on this platform";
-#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE
-}
-
-Status ManualResetEvent::Reset() { return ClearFd(fd_type_, fd_); }
-
-WaitHandle ManualResetEvent::OnSet() { return WaitHandle(add_ref(this)); }
-
-} // namespace iree
diff --git a/iree/base/wait_handle.h b/iree/base/wait_handle.h
deleted file mode 100644
index d051e51..0000000
--- a/iree/base/wait_handle.h
+++ /dev/null
@@ -1,321 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_WAIT_HANDLE_H_
-#define IREE_BASE_WAIT_HANDLE_H_
-
-#include <atomic>
-#include <cstdint>
-#include <string>
-#include <utility>
-
-#include "absl/time/clock.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-#include "iree/base/time.h"
-
-namespace iree {
-
-// Interfaces for waitable objects that can produce WaitHandles.
-// WaitableObjects are much like ::thread::Selectable, only they support both
-// the classic locking style as well as file descriptors for use with select().
-//
-// Usage:
-// class MyWaitableObject : public WaitableObject {
-// public:
-// std::string DebugString() const override { return "something useful"; }
-// WaitHandle OnAsyncTask() {
-// return WaitHandle(retain_ref(this));
-// }
-// private:
-// StatusOr<std::pair<FdType, int>> AcquireFdForWait(
-// absl::Time deadline) override {
-// // If blocking traditionally do so now and then return this:
-// return std::make_pair(FdType::kPermanent, kSignaledFd);
-// // Otherwise, see ManualResetEvent for an example using fds.
-// }
-// StatusOr<bool> TryResolveWakeOnFd(int fd) override {
-// // Return true iff the object is really acquired, such as the semaphore
-// // being decremented.
-// return true;
-// }
-// };
-class WaitableObject : public RefObject<WaitableObject> {
- public:
- // Indicates that a file descriptor is invalid. It will not block when waited
- // upon.
- constexpr static int kInvalidFd = -1;
- // Indicates that a file descriptor should be treated as signaled.
- // Waiting on this fd should return as if it has already been signaled.
- constexpr static int kSignaledFd = -2;
-
- // Defines the type of the native handle used for synchronization.
- enum class FdType : uint16_t {
- // Event has no handle and should be treated as permanently signaled.
- kPermanent,
-
- // Android/Linux/iOS-compatible POSIX pipe handle.
- // Two handles are generated: one for transmitting and one for receiving.
- //
- // More information:
- // http://man7.org/linux/man-pages/man2/pipe.2.html
- kPipe,
-
- // Android/Linux eventfd handle.
- // These are akin to pipe() but require only a single handle and have
- // significantly lower overhead (equivalent if not slightly better than
- // pthreads condvars).
- //
- // eventfds support acting as both semaphores and auto reset events.
- //
- // More information:
- // http://man7.org/linux/man-pages/man2/eventfd.2.html
- kEventFd,
-
- // Android/Linux sync_file handle (aka 'sync fence').
- // The handle is allocated indirectly by the device driver via the
- // <linux/sync_file.h> API. It may be waited upon with poll(), select(), or
- // epoll() and must be closed with close() when no longer required. If
- // waiting on multiple sync_files the caller should first merge them
- // together.
- //
- // A sync_file must only be used as fences (one-shot manual reset events).
- //
- // More information:
- // https://www.kernel.org/doc/Documentation/sync_file.txt
- // https://lwn.net/Articles/702339/
- // https://source.android.com/devices/graphics/implement-vsync#explicit_synchronization
- kSyncFile,
- };
-
- virtual ~WaitableObject() = default;
-
- // Returns a string representing the object, either specified as a debug_name
- // or a unique ID.
- virtual std::string DebugString() const = 0;
-
- // Attempts to acquire a file descriptor for the waitable objects by the given
- // |deadline|. In many cases this will return immediately with a valid fd.
- //
- // In cases where the file descriptor may not be available the call may block
- // until either it is available or the |deadline| has elapsed. Use
- // absl::InfinitePast() to prevent blocking.
- //
- // Returns a valid file descriptor or kInvalidFd as an indication that the
- // object should not be waited on (already signaled, etc). Can return
- // kSignaledFd to indicate that it's already known that the handle has been
- // signaled and the caller should resolve as if it caused a wake normally.
- virtual StatusOr<std::pair<FdType, int>> AcquireFdForWait(
- absl::Time deadline) = 0;
-
- // Tries to resolve the object with the given |fd|.
- // In many cases this will no-op, however some types may require additional
- // checks to ensure that the wait operation succeeded (such as semaphores
- // that may need to query a count). If resolution fails the waitable object
- // must not be considered signaled. This call will never block.
- virtual StatusOr<bool> TryResolveWakeOnFd(int fd) = 0;
-};
-
-// Handle to waitable objects.
-// WaitHandles are created by a particular synchronization primitive, such as
-// Fence, as a way for one or more observers to poll or wait for notification.
-//
-// External synchronization primitives can be wrapped in WaitHandles to enable
-// other libraries or languages to be waited on alongside WaitHandles created
-// by the IREE primitives like Fence. See the notes on WaitHandleType for a list
-// of handle types that are supported.
-//
-// Wait handles are thread-safe in that multiple threads may be waiting on them
-// concurrently.
-class WaitHandle {
- public:
- // Returns a WaitHandle that when waited on will never block.
- static WaitHandle AlwaysSignaling();
-
- // Returns a WaitHandle that when waited on will always fail.
- static WaitHandle AlwaysFailing();
-
- using WaitHandleSpan = absl::Span<WaitHandle* const>;
-
- // Blocks the caller until all passed |wait_handles| are signaled or the
- // |deadline| elapses.
- //
- // Returns success if the wait is successful and all events have been
- // signaled.
- //
- // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all handles
- // having been signaled. Note that a subset of the |wait_handles| may have
- // been signaled and each can be queried to see which one.
- static Status WaitAll(WaitHandleSpan wait_handles, absl::Time deadline);
- static Status WaitAll(WaitHandleSpan wait_handles, absl::Duration timeout) {
- return WaitAll(wait_handles, RelativeTimeoutToDeadline(timeout));
- }
- static Status WaitAll(WaitHandleSpan wait_handles) {
- return WaitAll(wait_handles, absl::InfiniteFuture());
- }
-
- // Tries waiting on the handles and returns immediately if it would have
- // blocked. The caller will not be blocked even if a handle has not yet been
- // signaled.
- //
- // Returns true if all handles have been signaled.
- static StatusOr<bool> TryWaitAll(WaitHandleSpan wait_handles);
-
- // Blocks the caller until at least one of the |wait_handles| is signaled or
- // the |deadline| elapses.
- //
- // Returns the index into |wait_handles| of a handle that was signaled. Note
- // that more than one handle may have been signaled and all of the other
- // |wait_handles| should be queried or waited on again until waits for them
- // succeed.
- //
- // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any handles
- // having been signaled.
- static StatusOr<int> WaitAny(WaitHandleSpan wait_handles,
- absl::Time deadline);
- static StatusOr<int> WaitAny(WaitHandleSpan wait_handles,
- absl::Duration timeout) {
- return WaitAny(wait_handles, RelativeTimeoutToDeadline(timeout));
- }
- static StatusOr<int> WaitAny(WaitHandleSpan wait_handles) {
- return WaitAny(wait_handles, absl::InfiniteFuture());
- }
-
- // Tries waiting for at least one handle to complete and returns immediately
- // if none have been. The caller will not be blocked even if a handle has not
- // yet been signaled.
- //
- // Returns the index into |wait_handles| of a handle that was signaled. Note
- // that more than one handle may have been signaled and all of the other
- // |wait_handles| should be queried or waited on again until waits for them
- // succeed.
- //
- // Returns -1 if no handles were signaled.
- static StatusOr<int> TryWaitAny(WaitHandleSpan wait_handles);
-
- // Default constructor creates a permanently signaled handle.
- // Waiting on this handle will never block.
- WaitHandle() = default;
-
- // Wraps an existing sync file descriptor.
- // Ownership of the file descriptor is transferred to the WaitHandle and must
- // be duplicated by the caller if they want to continue using it.
- explicit WaitHandle(ref_ptr<WaitableObject> object);
-
- ~WaitHandle();
-
- // Copying not supported. Create a new WaitHandle from the source.
- WaitHandle(const WaitHandle&) = delete;
- WaitHandle& operator=(const WaitHandle&) = delete;
-
- // Moving supported; sync primitive ownership is transferred.
- WaitHandle(WaitHandle&& other);
- WaitHandle& operator=(WaitHandle&& other);
-
- // Unique ID for the WaitHandle instance.
- // Two wait handles, even if waiting on the same underlying primitive, will
- // have differing unique_ids. This can be used for deduping the handles or
- // storing handles in a map.
- uint64_t unique_id() const { return unique_id_; }
-
- // Returns a unique string representing the handle.
- std::string DebugString() const;
-
- // Blocks the caller until the handle is signaled or the |deadline| elapses.
- //
- // If waiting on multiple wait handles use WaitAll or WaitAny instead of
- // multiple calls to Wait as they can significantly reduce overhead.
- //
- // Returns success if the wait is successful and the |wait_handle| was
- // signaled. Returns DEADLINE_EXCEEDED if the timeout elapses without the
- // handle having been signaled.
- Status Wait(absl::Time deadline) { return WaitAll({this}, deadline); }
- Status Wait(absl::Duration timeout) {
- return WaitAll({this}, RelativeTimeoutToDeadline(timeout));
- }
- Status Wait() { return WaitAll({this}, absl::InfiniteFuture()); }
-
- // Tries waiting on the handle and returns immediately if it would have
- // waited. The caller will not be blocked even if the handle has not yet been
- // signaled.
- //
- // Returns true if the handle has been signaled.
- StatusOr<bool> TryWait();
-
- // These accessors should generally be considered opaque but may be useful to
- // code trying to interop with other runtimes.
- const ref_ptr<WaitableObject>& object() const { return object_; }
-
- private:
- // Disposes the handle by closing the fd and issuing callbacks.
- void Dispose();
-
- static std::atomic<uint64_t> next_unique_id_;
-
- uint64_t unique_id_ = 0;
- ref_ptr<WaitableObject> object_;
-};
-
-// A manually-resettable event primitive.
-// Effectively a binary semaphore with a maximum_count of 1 when running in
-// auto-reset mode but also provides a sticky manual reset mode.
-class ManualResetEvent : public WaitableObject {
- public:
- explicit ManualResetEvent(const char* debug_name = nullptr);
-
- ~ManualResetEvent() override;
-
- // Copying not supported.
- ManualResetEvent(const ManualResetEvent&) = delete;
- ManualResetEvent& operator=(const ManualResetEvent&) = delete;
-
- // Moving supported; sync primitive ownership is transferred.
- ManualResetEvent(ManualResetEvent&& other);
- ManualResetEvent& operator=(ManualResetEvent&& other);
-
- std::string DebugString() const override;
-
- // Sets the specified event object to the signaled state.
- // The event stays signaled until Reset is called. Multiple waiters will be
- // woken.
- Status Set();
-
- // Resets the specified event object to the nonsignaled state.
- // Resetting an event that is already reset has no effect.
- Status Reset();
-
- // Returns a WaitHandle that will be signaled when the event is set.
- WaitHandle OnSet();
-
- protected:
- void Initialize();
- void Dispose();
-
- StatusOr<std::pair<FdType, int>> AcquireFdForWait(
- absl::Time deadline) override {
- return std::make_pair(fd_type_, fd_);
- }
- StatusOr<bool> TryResolveWakeOnFd(int fd) override { return true; }
-
- FdType fd_type_ = FdType::kPermanent;
- int fd_ = kInvalidFd;
- int write_fd_ = kInvalidFd; // Used only for fd_type_ == kPipe.
- const char* debug_name_ = nullptr;
-};
-
-} // namespace iree
-
-#endif // IREE_BASE_WAIT_HANDLE_H_
diff --git a/iree/base/wait_handle_test.cc b/iree/base/wait_handle_test.cc
deleted file mode 100644
index 600a63f..0000000
--- a/iree/base/wait_handle_test.cc
+++ /dev/null
@@ -1,555 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/wait_handle.h"
-
-#include <unistd.h>
-
-#include <string>
-#include <thread> // NOLINT
-#include <type_traits>
-
-#include "absl/time/time.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status.h"
-#include "iree/base/status_matchers.h"
-
-// StatusOr<bool> will be true if the status is ok, which is bad.
-#define ASSERT_STATUSOR_TRUE(x) ASSERT_TRUE(x.ValueOrDie())
-#define ASSERT_STATUSOR_FALSE(x) ASSERT_FALSE(x.ValueOrDie())
-
-namespace iree {
-namespace {
-
-using ::testing::_;
-using ::testing::Return;
-
-// Tests the AlwaysSignaling helper.
-TEST(WaitHandleTest, AlwaysSignaling) {
- ASSERT_OK(WaitHandle::AlwaysSignaling().Wait());
- EXPECT_FALSE(WaitHandle::AlwaysSignaling().DebugString().empty());
-}
-
-// Tests the AlwaysFailing helper.
-TEST(WaitHandleTest, AlwaysFailing) {
- ASSERT_FALSE(WaitHandle::AlwaysFailing().Wait().ok());
- EXPECT_FALSE(WaitHandle::AlwaysFailing().DebugString().empty());
-}
-
-// Tests the basic lifecycle of a permanently signaled wait handle.
-TEST(WaitHandleTest, LifecyclePermanentSignaled) {
- // Just to be sure it's ok to safely no-op a WaitHandle value.
- WaitHandle wh_never_used;
- (void)wh_never_used;
-
- // Try waiting; should return immediately.
- WaitHandle wh0;
- ASSERT_OK(wh0.Wait());
-
- // Waits on multiple permanent handles should be ok.
- WaitHandle wh1;
- ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}));
-}
-
-// Tests moving permanent WaitHandles around.
-TEST(WaitHandleTest, MovePermanent) {
- WaitHandle wh0;
- WaitHandle wh1{std::move(wh0)};
- WaitHandle wh2 = std::move(wh1);
- wh1 = std::move(wh2);
-}
-
-// Tests moving around real handles (that may require closing).
-TEST(WaitHandleTest, MoveRealHandle) {
- ManualResetEvent fence0;
- WaitHandle wh0 = fence0.OnSet();
- WaitHandle wh1{std::move(wh0)};
- WaitHandle wh2 = std::move(wh1);
- wh1 = std::move(wh2);
-
- // Now overwrite the handle value to force a close.
- ManualResetEvent fence1;
- WaitHandle wh3 = fence1.OnSet();
- wh1 = std::move(wh3);
- wh1 = WaitHandle(); // Ensure handle dies first.
-}
-
-// Tests the various forms of waiting on a single WaitHandle.
-// Since these just call WaitAll we leave the involved testing to those.
-TEST(WaitHandleTest, SingleWait) {
- WaitHandle wh;
- ASSERT_OK(wh.Wait());
- ASSERT_OK(wh.Wait(absl::Now() + absl::Seconds(1)));
- ASSERT_OK(wh.Wait(absl::Seconds(1)));
- ASSERT_STATUSOR_TRUE(wh.TryWait());
-}
-
-// Tests using WaitAll with no valid handles. This should no-op.
-TEST(WaitHandleTest, WaitAllNop) {
- ASSERT_OK(WaitHandle::WaitAll({}));
- ASSERT_OK(WaitHandle::WaitAll({nullptr}));
- ASSERT_OK(WaitHandle::WaitAll({nullptr, nullptr}));
-}
-
-// Tests polling with WaitAll with multiple wait handles.
-TEST(WaitHandleTest, WaitAllPoll) {
- ManualResetEvent fence0;
- WaitHandle wh0 = fence0.OnSet();
- ManualResetEvent fence1;
- WaitHandle wh1 = fence1.OnSet();
-
- // Poll; should return immediately with timeout.
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast())));
-
- // Notify fence1.
- ASSERT_OK(fence1.Set());
-
- // Poll; should return immediately with timeout as fence1 is not signaled.
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast())));
-
- // Notify fence0.
- ASSERT_OK(fence0.Set());
-
- // Poll again; should return immediately with success.
- ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast()));
-}
-
-// Tests waiting when the first file handle is invalid. This is to verify a
-// workaround for bad poll() behavior with fds[0] == -1.
-TEST(WaitHandleTest, WaitAllWithInvalid0) {
- ManualResetEvent fence;
- WaitHandle wh = fence.OnSet();
-
- // Poll; should return immediately with timeout as fence is not signaled.
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAll({nullptr, &wh}, absl::InfinitePast())));
-
- // Notify fence.
- ASSERT_OK(fence.Set());
-
- // Poll again; should return immediately with success.
- ASSERT_OK(WaitHandle::WaitAll({nullptr, &wh}, absl::InfinitePast()));
-}
-
-// Tests exceeding the timeout deadline with WaitAll.
-TEST(WaitHandleTest, WaitAllTimeout) {
- ManualResetEvent fence;
- WaitHandle wh = fence.OnSet();
-
- // Wait with timeout on the unsignaled fence:
- // Via polling (should never block):
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::InfinitePast())));
- ASSERT_STATUSOR_FALSE(WaitHandle::TryWaitAll({&wh}));
- // Via time in the near future (should block):
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(250))));
- // Via time in the past, should exceed deadline.
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(-250))));
-
- // Notify and ensure no more timeouts.
- ASSERT_OK(fence.Set());
- ASSERT_OK(WaitHandle::WaitAll({&wh}, absl::InfinitePast()));
- ASSERT_STATUSOR_TRUE(WaitHandle::TryWaitAll({&wh}));
- ASSERT_OK(WaitHandle::WaitAll({&wh}, absl::Milliseconds(250)));
-
- // Via time in the past, should exceed deadline even if signaled.
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(-250))));
-}
-
-// Tests using WaitAll to wait on other threads.
-TEST(WaitHandleTest, WaitAllThreaded) {
- // Spin up two threads.
- ManualResetEvent fence0;
- std::thread t0{[&]() {
- ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250)));
- ASSERT_OK(fence0.Set());
- }};
- ManualResetEvent fence1;
- std::thread t1{[&]() {
- ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250)));
- ASSERT_OK(fence1.Set());
- }};
-
- // Wait on both threads to complete.
- WaitHandle wh0 = fence0.OnSet();
- WaitHandle wh1 = fence1.OnSet();
- ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}));
-
- t0.join();
- t1.join();
-}
-
-// Tests using WaitAll with multiple wait handles from the same fence.
-TEST(WaitHandleTest, WaitAllSameSource) {
- ManualResetEvent fence;
- WaitHandle wh0 = fence.OnSet();
- WaitHandle wh1 = fence.OnSet();
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast())));
- ASSERT_OK(fence.Set());
- ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}));
-}
-
-// Tests using WaitAll with literally the same wait handles.
-TEST(WaitHandleTest, WaitAllSameHandle) {
- ManualResetEvent fence;
- WaitHandle wh = fence.OnSet();
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAll({&wh, &wh}, absl::InfinitePast())));
- ASSERT_OK(fence.Set());
- ASSERT_OK(WaitHandle::WaitAll({&wh, &wh}));
-}
-
-// Tests WaitAll when a wait handle fails.
-TEST(WaitHandleTest, WaitAllFailure) {
- WaitHandle good_wh;
- // Create a purposefully bad handle to induce an error.
- WaitHandle bad_wh = WaitHandle::AlwaysFailing();
- // Should fail with some posixy error.
- ASSERT_FALSE(WaitHandle::WaitAll({&good_wh, &bad_wh}).ok());
-}
-
-// Tests using WaitAny with no valid handles. This should no-op.
-TEST(WaitHandleTest, WaitAnyNop) {
- ASSERT_TRUE(IsInvalidArgument(WaitHandle::WaitAny({}).status()));
- ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({nullptr}));
- ASSERT_EQ(0, index);
- ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({nullptr, nullptr}));
- ASSERT_EQ(0, index);
-}
-
-// Tests polling with WaitAny with multiple wait handles.
-TEST(WaitHandleTest, WaitAnyPoll) {
- ManualResetEvent fence0;
- WaitHandle wh0 = fence0.OnSet();
- ManualResetEvent fence1;
- WaitHandle wh1 = fence1.OnSet();
-
- // Poll; should return immediately with timeout.
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status()));
-
- // Notify fence1.
- ASSERT_OK(fence1.Set());
-
- // Poll; should return immediately with fence1 signaled.
- ASSERT_OK_AND_ASSIGN(int index,
- WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()));
- EXPECT_EQ(1, index);
-
- // Notify fence0.
- ASSERT_OK(fence0.Set());
-
- // Poll again; should return immediately; which one is signaled is undefined.
- ASSERT_OK_AND_ASSIGN(index,
- WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()));
- ASSERT_TRUE(index == 0 || index == 1);
-}
-
-// Tests exceeding the timeout deadline with WaitAny.
-TEST(WaitHandleTest, WaitAnyTimeout) {
- ManualResetEvent fence0;
- WaitHandle wh0 = fence0.OnSet();
- ManualResetEvent fence1;
- WaitHandle wh1 = fence1.OnSet();
-
- // Wait with timeout on the unsignaled fences:
- // Via polling (should never block):
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status()));
- ASSERT_OK_AND_ASSIGN(int index, WaitHandle::TryWaitAny({&wh0, &wh1}));
- ASSERT_EQ(-1, index);
- // Via time in the near future (should block):
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0, &wh1}, absl::Milliseconds(250)).status()));
-
- // Notify one of the fences. Should return immediately.
- ASSERT_OK(fence1.Set());
- ASSERT_OK_AND_ASSIGN(index,
- WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()));
- ASSERT_EQ(1, index);
- ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0, &wh1}));
- ASSERT_EQ(1, index);
- ASSERT_OK_AND_ASSIGN(
- index, WaitHandle::WaitAny({&wh0, &wh1}, absl::Milliseconds(250)));
- ASSERT_EQ(1, index);
-
- // The unnotified fence should still timeout.
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0}, absl::InfinitePast()).status()));
- ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0}));
- ASSERT_EQ(-1, index);
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0}, absl::Milliseconds(250)).status()));
-
- // Notify last fence and ensure complete.
- ASSERT_OK(fence0.Set());
- ASSERT_OK_AND_ASSIGN(index,
- WaitHandle::WaitAny({&wh0}, absl::InfinitePast()));
- ASSERT_EQ(0, index);
- ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0}));
- ASSERT_EQ(0, index);
- ASSERT_OK_AND_ASSIGN(index,
- WaitHandle::WaitAny({&wh0}, absl::Milliseconds(250)));
- ASSERT_EQ(0, index);
-}
-
-// Tests using WaitAny to wait on other threads.
-TEST(WaitHandleTest, WaitAnyThreaded) {
- // Spin up two threads.
- // t1 will wait on t0 such that they will act in sequence.
- ManualResetEvent fence0;
- std::thread t0{[&]() {
- ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250)));
- ASSERT_OK(fence0.Set());
- }};
- ManualResetEvent fence1;
- std::thread t1{[&]() {
- ASSERT_OK(fence0.OnSet().Wait());
- ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250)));
- ASSERT_OK(fence1.Set());
- }};
-
- // Wait on both threads. We expect 0 to complete first.
- WaitHandle wh0 = fence0.OnSet();
- WaitHandle wh1 = fence1.OnSet();
- ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh0, &wh1}));
- ASSERT_EQ(0, index);
-
- // Now wait for thread 1.
- ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({&wh1}));
- ASSERT_EQ(0, index);
-
- t0.join();
- t1.join();
-}
-
-// Tests using WaitAny with multiple wait handles from the same fence.
-TEST(WaitHandleTest, WaitAnySameSource) {
- ManualResetEvent fence;
- WaitHandle wh0 = fence.OnSet();
- WaitHandle wh1 = fence.OnSet();
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status()));
- ASSERT_OK(fence.Set());
- ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh0, &wh1}));
- ASSERT_TRUE(index == 0 || index == 1);
-}
-
-// Tests using WaitAny with literally the same wait handles.
-TEST(WaitHandleTest, WaitAnySameHandle) {
- ManualResetEvent fence;
- WaitHandle wh = fence.OnSet();
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh, &wh}, absl::InfinitePast()).status()));
- ASSERT_OK(fence.Set());
- ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh, &wh}));
- ASSERT_TRUE(index == 0 || index == 1);
-}
-
-// Tests WaitAny when a wait handle fails.
-TEST(WaitHandleTest, WaitAnyFailure) {
- WaitHandle good_wh;
- // Create a purposefully bad handle to induce an error.
- WaitHandle bad_wh = WaitHandle::AlwaysFailing();
- // Should fail with some posixy error.
- ASSERT_FALSE(WaitHandle::WaitAny({&good_wh, &bad_wh}).ok());
-}
-
-// ManualResetEvent with innards exposed. Meh.
-class ExposedManualResetEvent : public ManualResetEvent {
- public:
- using ManualResetEvent::AcquireFdForWait;
- using ManualResetEvent::TryResolveWakeOnFd;
-};
-
-// Mock type for the WaitableObject methods.
-class MockWaitableObject : public ::testing::StrictMock<WaitableObject> {
- public:
- MockWaitableObject() : ::testing::StrictMock<WaitableObject>() {}
-
- MOCK_CONST_METHOD0(DebugString, std::string());
- MOCK_METHOD1(AcquireFdForWait,
- StatusOr<std::pair<FdType, int>>(absl::Time deadline));
- MOCK_METHOD1(TryResolveWakeOnFd, StatusOr<bool>(int fd));
-
- WaitHandle OnSomething() { return WaitHandle(add_ref(this)); }
-};
-
-// Tests normal AcquireFdForWait + TryResolveWakeOnFd use.
-TEST(WaitableObjectTest, AcquireAndResolve) {
- MockWaitableObject mwo;
- WaitHandle wh = mwo.OnSomething();
-
- // Use a MRE for testing, as we can just use its fd.
- ExposedManualResetEvent mre;
-
- // Try waiting; we should see the AcquireFdForWait and then return because
- // the fd has not been resolved.
- EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([&](absl::Time deadline) {
- // Return the valid FD from the MRE.
- return mre.AcquireFdForWait(deadline);
- });
- ASSERT_STATUSOR_FALSE(wh.TryWait());
-
- // Signal the MRE.
- ASSERT_OK(mre.Set());
-
- // Try waiting again; we should get the AcquireFdForWait and then also get
- // the TryResolveWakeOnFd.
- EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([&](absl::Time deadline) {
- // Return the valid (and now signaled) FD from the MRE.
- return mre.AcquireFdForWait(deadline);
- });
- EXPECT_CALL(mwo, TryResolveWakeOnFd(_)).WillOnce(Return(true));
- ASSERT_STATUSOR_TRUE(wh.TryWait());
-}
-
-// Tests timing out in AcquireFdForWait.
-TEST(WaitableObjectTest, AcquireFdForWaitTimeout) {
- ManualResetEvent mre;
- WaitHandle always_wait = mre.OnSet();
- WaitHandle always_signal = WaitHandle::AlwaysSignaling();
- MockWaitableObject mwo;
- WaitHandle wh = mwo.OnSomething();
-
- // Make the AcquireFdForWait take longer than the timeout. We should hit
- // deadline exceeded even though always_wait hasn't be signaled.
- EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([](absl::Time deadline) {
- ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(10)));
- return std::make_pair(WaitableObject::FdType::kPermanent,
- WaitableObject::kInvalidFd);
- });
- ASSERT_TRUE(IsDeadlineExceeded(WaitHandle::WaitAll(
- {&wh, &always_signal}, absl::Now() - absl::Milliseconds(250))));
-}
-
-// Tests TryResolveWakeOnFd when a handle is a permanent kSignaledFd.
-TEST(WaitableObjectTest, SignaledFd) {
- MockWaitableObject mwo;
- WaitHandle wh = mwo.OnSomething();
-
- // Return the kSignaledFd handle and expect that we still get our notify call.
- // We can do this multiple times.
- for (int i = 0; i < 4; ++i) {
- EXPECT_CALL(mwo, AcquireFdForWait(_))
- .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent,
- WaitableObject::kSignaledFd)));
- EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd))
- .WillOnce(Return(true));
- ASSERT_STATUSOR_TRUE(wh.TryWait());
- }
-}
-
-// Tests that waiting will not resolve if TryResolveWakeOnFd returns false.
-TEST(WaitableObjectTest, UnresolvedWake) {
- MockWaitableObject mwo;
- WaitHandle wh = mwo.OnSomething();
-
- // Fail to resolve the first time.
- // Since we are only trying to wait it should bail.
- EXPECT_CALL(mwo, AcquireFdForWait(_))
- .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent,
- WaitableObject::kSignaledFd)));
- EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd))
- .WillOnce(Return(false));
- ASSERT_STATUSOR_FALSE(wh.TryWait());
-
- // Resolve on the next try.
- EXPECT_CALL(mwo, AcquireFdForWait(_))
- .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent,
- WaitableObject::kSignaledFd)));
- EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd))
- .WillOnce(Return(true));
- ASSERT_STATUSOR_TRUE(wh.TryWait());
-}
-
-// Tests the normal lifecycle of a ManualResetEvent.
-TEST(ManualResetEventTest, Lifecycle) {
- ManualResetEvent ev;
- EXPECT_FALSE(ev.DebugString().empty());
- WaitHandle wh0 = ev.OnSet();
- EXPECT_EQ(ev.DebugString(), wh0.DebugString());
- WaitHandle wh1 = ev.OnSet();
- EXPECT_EQ(ev.DebugString(), wh1.DebugString());
- // Should not be set.
- ASSERT_STATUSOR_FALSE(wh0.TryWait());
- ASSERT_STATUSOR_FALSE(wh1.TryWait());
- // Set should be sticky.
- ASSERT_OK(ev.Set());
- ASSERT_STATUSOR_TRUE(wh0.TryWait());
- ASSERT_STATUSOR_TRUE(wh1.TryWait());
- // Reset should clear.
- ASSERT_OK(ev.Reset());
- ASSERT_STATUSOR_FALSE(wh0.TryWait());
- ASSERT_STATUSOR_FALSE(wh1.TryWait());
- // Setting again should enable the previous WaitHandles to be signaled.
- ASSERT_OK(ev.Set());
- ASSERT_STATUSOR_TRUE(wh0.TryWait());
- ASSERT_STATUSOR_TRUE(wh1.TryWait());
-}
-
-// Tests moving ManualResetEvents around.
-TEST(ManualResetEventTest, Move) {
- ManualResetEvent ev0;
- WaitHandle wh = ev0.OnSet();
- ManualResetEvent ev1{std::move(ev0)};
- ManualResetEvent ev2 = std::move(ev1);
- ev1 = std::move(ev2);
- ASSERT_OK(ev1.Set());
- ASSERT_STATUSOR_TRUE(wh.TryWait());
-}
-
-// Tests redundantly setting and resetting ManualResetEvents.
-TEST(ManualResetEventTest, RedundantUse) {
- ManualResetEvent ev;
- ASSERT_OK(ev.Reset());
- ASSERT_OK(ev.Reset());
- ASSERT_FALSE(ev.OnSet().TryWait().ValueOrDie());
- ASSERT_OK(ev.Set());
- ASSERT_OK(ev.Set());
- ASSERT_TRUE(ev.OnSet().TryWait().ValueOrDie());
- ASSERT_OK(ev.Reset());
- ASSERT_FALSE(ev.OnSet().TryWait().ValueOrDie());
-}
-
-// Tests waiting on an initially-set ManualResetEvent;
-TEST(ManualResetEventTest, SetThenWait) {
- ManualResetEvent ev;
- ASSERT_OK(ev.Set());
- ASSERT_TRUE(ev.OnSet().TryWait().ValueOrDie());
-}
-
-// Tests that dangling an event will not wake waiters.
-// This is intentional (for now); we could with a bit of wrangling make it so
-// that WaitableObjects tracked their waiters and ensured they were all cleaned
-// up, but that seems hard. Don't drop your objects.
-TEST(ManualResetEventTest, NeverSet) {
- ManualResetEvent ev;
- WaitHandle wh = ev.OnSet();
- ASSERT_STATUSOR_FALSE(wh.TryWait());
- // Kill event to unblock waiters.
- ev = ManualResetEvent();
- // Waiter should not have woken.
- ASSERT_STATUSOR_FALSE(wh.TryWait());
-}
-
-} // namespace
-} // namespace iree
diff --git a/iree/bindings/python/BUILD b/iree/bindings/python/BUILD
deleted file mode 100644
index 0ac8ef1..0000000
--- a/iree/bindings/python/BUILD
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:build_defs.bzl", "iree_py_library")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_py_library(
- name = "pathsetup",
- imports = ["."],
-)
diff --git a/iree/bindings/python/pyiree/BUILD b/iree/bindings/python/pyiree/BUILD
deleted file mode 100644
index 558df05..0000000
--- a/iree/bindings/python/pyiree/BUILD
+++ /dev/null
@@ -1,103 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:build_defs.bzl", "NUMPY_DEPS", "iree_py_extension")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-COMPILER_DEPS = [
- "//iree/compiler/Translation/Sequencer",
- "//iree/compiler/Translation/Interpreter",
- "//iree/compiler/Translation/SPIRV",
-]
-
-DRIVER_DEPS = [
- "//iree/hal/interpreter:interpreter_driver_module",
- "//iree/hal/vulkan:vulkan_driver_module",
-]
-
-iree_py_extension(
- name = "binding",
- srcs = [
- "binding.cc",
- "binding.h",
- "compiler.cc",
- "compiler.h",
- "hal.cc",
- "hal.h",
- "initialize.cc",
- "initialize.h",
- "rt.cc",
- "rt.h",
- "status_utils.cc",
- "status_utils.h",
- "vm.cc",
- "vm.h",
- ],
- copts = [
- "-fexceptions",
- ],
- features = ["-use_header_modules"],
- deps = [
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- "//iree/base:api",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/hal:api",
- "//iree/rt:api",
- "//iree/schemas",
- "//iree/vm:api",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:Parser",
- "@iree_pybind11//:pybind11",
- ] + COMPILER_DEPS + DRIVER_DEPS,
-)
-
-py_test(
- name = "compiler_test",
- srcs = ["compiler_test.py"],
- python_version = "PY3",
- deps = [
- ":binding",
- "//iree/bindings/python:pathsetup",
- "@absl_py//absl/testing:absltest",
- ],
-)
-
-py_test(
- name = "hal_test",
- srcs = ["hal_test.py"],
- python_version = "PY3",
- deps = [
- ":binding",
- "//iree/bindings/python:pathsetup",
- "@absl_py//absl/testing:absltest",
- ],
-)
-
-py_test(
- name = "runtime_test",
- srcs = ["runtime_test.py"],
- python_version = "PY3",
- deps = NUMPY_DEPS + [
- ":binding",
- "@absl_py//absl/testing:absltest",
- ],
-)
diff --git a/iree/bindings/python/pyiree/binding.cc b/iree/bindings/python/pyiree/binding.cc
deleted file mode 100644
index 3fc4b59..0000000
--- a/iree/bindings/python/pyiree/binding.cc
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/bindings/python/pyiree/binding.h"
-
-#include "iree/bindings/python/pyiree/compiler.h"
-#include "iree/bindings/python/pyiree/hal.h"
-#include "iree/bindings/python/pyiree/initialize.h"
-#include "iree/bindings/python/pyiree/rt.h"
-#include "iree/bindings/python/pyiree/status_utils.h"
-#include "iree/bindings/python/pyiree/vm.h"
-
-namespace iree {
-namespace python {
-
-PYBIND11_MODULE(binding, m) {
- m.doc() = "IREE Binding Backend Helpers";
- py::class_<OpaqueBlob, std::shared_ptr<OpaqueBlob>>(m, "OpaqueBlob");
- m.def("initialize_extension", &InitializeExtension);
-
- auto compiler_m = m.def_submodule("compiler", "IREE compiler support");
- SetupCompilerBindings(compiler_m);
-
- auto hal_m = m.def_submodule("hal", "IREE HAL support");
- SetupHalBindings(hal_m);
-
- auto rt_m = m.def_submodule("rt", "IREE RT api");
- SetupRtBindings(rt_m);
-
- auto vm_m = m.def_submodule("vm", "IREE VM api");
- SetupVmBindings(vm_m);
-}
-
-} // namespace python
-} // namespace iree
diff --git a/iree/bindings/python/pyiree/binding.h b/iree/bindings/python/pyiree/binding.h
deleted file mode 100644
index 47ee323..0000000
--- a/iree/bindings/python/pyiree/binding.h
+++ /dev/null
@@ -1,147 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
-
-#include <vector>
-
-#include "absl/types/optional.h"
-#include "iree/base/api.h"
-#include "pybind11/pybind11.h"
-#include "pybind11/stl.h"
-
-namespace pybind11 {
-namespace detail {
-#if !defined(ABSL_HAVE_STD_OPTIONAL)
-// Make absl::optional act like the future C++17 optional for pybind11.
-// If ABSL_HAVE_STD_OPTIONAL is defined then absl::optional == std::optional
-// and the default type caster is sufficient.
-template <typename T>
-struct type_caster<absl::optional<T>> : optional_caster<absl::optional<T>> {};
-#endif
-} // namespace detail
-} // namespace pybind11
-
-namespace iree {
-namespace python {
-
-namespace py = pybind11;
-
-// Wrapper around a blob of memory.
-// Used to transport blobs back and forth between C++ and Python.
-class OpaqueBlob {
- public:
- OpaqueBlob() : data_(nullptr), size_(0) {}
- OpaqueBlob(void* data, size_t size) : data_(data), size_(size) {}
- virtual ~OpaqueBlob() = default;
-
- void* data() { return data_; }
- const void* data() const { return data_; }
- size_t size() const { return size_; }
-
- // Create a free function from the OpaqueBlob shared pointer.
- using BufferFreeFn = void (*)(void* self, iree_byte_span_t);
- static std::pair<BufferFreeFn, void*> CreateFreeFn(
- std::shared_ptr<OpaqueBlob> blob) {
- // Note that there are more efficient ways to write this which
- // don't bounce through an extra heap alloc, but this is not
- // intended to be a high impact code path.
- struct Holder {
- std::shared_ptr<OpaqueBlob> blob;
- };
- Holder* holder = new Holder{std::move(blob)};
- auto free_fn = +([](void* self, iree_byte_span_t) {
- Holder* self_holder = static_cast<Holder*>(self);
- delete self_holder;
- });
- return {free_fn, holder};
- }
-
- protected:
- void* data_;
- size_t size_;
-};
-
-// Opaque blob that owns a vector.
-class OpaqueByteVectorBlob : public OpaqueBlob {
- public:
- OpaqueByteVectorBlob(std::vector<uint8_t> v)
- : OpaqueBlob(), v_(std::move(v)) {
- data_ = v_.data();
- size_ = v_.size();
- }
-
- private:
- std::vector<uint8_t> v_;
-};
-
-template <typename T>
-struct ApiPtrAdapter {};
-
-template <typename Self, typename T>
-class ApiRefCounted {
- public:
- ApiRefCounted() : instance_(nullptr) {}
- ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) {
- other.instance_ = nullptr;
- }
- void operator=(const ApiRefCounted&) = delete;
-
- ~ApiRefCounted() { Release(); }
-
- // Creates an instance of the ref counted wrapper based on an instance
- // that has already been retained. Ownership is transferred to the
- // wrapper.
- static Self CreateRetained(T* retained_inst) {
- auto self = Self();
- self.instance_ = retained_inst;
- return self;
- }
-
- // Creates a new instance, retaining the underlying object.
- static Self RetainAndCreate(T* non_retained_inst) {
- auto self = Self();
- self.instance_ = non_retained_inst;
- if (non_retained_inst) {
- ApiPtrAdapter<T>::Retain(non_retained_inst);
- }
- return self;
- }
-
- T* raw_ptr() {
- if (!instance_) {
- throw std::invalid_argument("API object is null");
- }
- return instance_;
- }
- void Retain() {
- if (instance_) {
- ApiPtrAdapter<T>::Retain(instance_);
- }
- }
- void Release() {
- if (instance_) {
- ApiPtrAdapter<T>::Release(instance_);
- }
- }
-
- private:
- T* instance_;
-};
-
-} // namespace python
-} // namespace iree
-
-#endif // IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
diff --git a/iree/bindings/python/pyiree/compiler.cc b/iree/bindings/python/pyiree/compiler.cc
deleted file mode 100644
index c093d56..0000000
--- a/iree/bindings/python/pyiree/compiler.cc
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/bindings/python/pyiree/compiler.h"
-
-#include <stdexcept>
-
-#include "iree/bindings/python/pyiree/binding.h"
-#include "iree/bindings/python/pyiree/initialize.h"
-#include "iree/bindings/python/pyiree/status_utils.h"
-#include "iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h"
-#include "iree/schemas/module_def_generated.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
-#include "mlir/Parser.h"
-
-namespace py = pybind11;
-
-using namespace mlir;
-using namespace mlir::iree_compiler;
-
-using llvm::MemoryBuffer;
-using llvm::MemoryBufferRef;
-using llvm::StringRef;
-
-namespace iree {
-namespace python {
-
-namespace {
-
-OwningModuleRef parseMLIRModuleFromString(StringRef contents,
- MLIRContext* context) {
- std::unique_ptr<MemoryBuffer> contents_buffer;
- if (contents.back() == 0) {
- // If it has a nul terminator, just use as-is.
- contents_buffer = MemoryBuffer::getMemBuffer(contents.drop_back());
- } else {
- // Otherwise, make a copy.
- contents_buffer = MemoryBuffer::getMemBufferCopy(contents, "EMBED");
- }
-
- llvm::SourceMgr source_mgr;
- source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc());
- OwningModuleRef mlir_module = parseSourceFile(source_mgr, context);
- return mlir_module;
-}
-
-} // namespace
-
-std::shared_ptr<OpaqueBlob> CompileModuleFromAsm(const std::string& moduleAsm) {
- InitializeExtension({});
-
- MLIRContext context;
-
- // Arrange to get a view that includes a terminating null to avoid additional
- // copy.
- const char* moduleAsmChars = moduleAsm.c_str();
- StringRef moduleAsmSr(moduleAsmChars, moduleAsm.size() + 1);
-
- // TODO(laurenzo): This error handling is super hoaky. Hook into the MLIR
- // error reporter and plumb through properly.
- OwningModuleRef mlirModule = parseMLIRModuleFromString(moduleAsmSr, &context);
- if (!mlirModule) {
- throw std::runtime_error("Failed to parse MLIR asm");
- }
-
- auto moduleBlob =
- mlir::iree_compiler::translateMlirToIreeSequencerModule(mlirModule.get());
- if (moduleBlob.empty()) {
- throw std::runtime_error("Failed to translate MLIR module");
- }
- return std::make_shared<OpaqueByteVectorBlob>(std::move(moduleBlob));
-}
-
-void SetupCompilerBindings(pybind11::module m) {
- m.def("compile_module_from_asm", CompileModuleFromAsm);
-}
-
-} // namespace python
-} // namespace iree
diff --git a/iree/bindings/python/pyiree/compiler.h b/iree/bindings/python/pyiree/compiler.h
deleted file mode 100644
index faa65b7..0000000
--- a/iree/bindings/python/pyiree/compiler.h
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
-
-#include <string>
-
-#include "iree/bindings/python/pyiree/binding.h"
-
-namespace iree {
-namespace python {
-
-void SetupCompilerBindings(pybind11::module m);
-
-} // namespace python
-} // namespace iree
-
-#endif // IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
diff --git a/iree/bindings/python/pyiree/hal.cc b/iree/bindings/python/pyiree/hal.cc
deleted file mode 100644
index 648822b..0000000
--- a/iree/bindings/python/pyiree/hal.cc
+++ /dev/null
@@ -1,135 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/bindings/python/pyiree/hal.h"
-
-#include "iree/hal/api.h"
-
-namespace iree {
-namespace python {
-
-namespace {
-
-class HalMappedMemory {
- public:
- HalMappedMemory(iree_hal_mapped_memory_t mapped_memory,
- iree_hal_buffer_view_t* bv)
- : mapped_memory_(mapped_memory), bv_(bv) {
- iree_hal_buffer_view_retain(bv_);
- }
- ~HalMappedMemory() {
- if (bv_) {
- iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_);
- CHECK_EQ(iree_hal_buffer_unmap(buffer, &mapped_memory_), IREE_STATUS_OK);
- iree_hal_buffer_view_release(bv_);
- }
- }
- HalMappedMemory(HalMappedMemory&& other)
- : mapped_memory_(other.mapped_memory_), bv_(other.bv_) {
- other.bv_ = nullptr;
- }
-
- static HalMappedMemory Create(HalBufferView& bv) {
- iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv.raw_ptr());
- iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer);
- iree_hal_mapped_memory_t mapped_memory;
- CheckApiStatus(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ,
- 0 /* element_offset */, byte_length,
- &mapped_memory),
- "Could not map memory");
- return HalMappedMemory(mapped_memory, bv.raw_ptr());
- }
-
- py::buffer_info ToBufferInfo() {
- iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_);
- iree_shape_t shape;
- CheckApiStatus(iree_hal_buffer_view_shape(bv_, &shape),
- "Error getting buffer view shape");
- int8_t element_size = iree_hal_buffer_view_element_size(bv_);
- iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer);
- absl::InlinedVector<ssize_t, IREE_SHAPE_MAX_RANK> dims;
- dims.resize(shape.rank);
- for (int i = 0; i < shape.rank; ++i) {
- dims[i] = shape.dims[i];
- }
- absl::InlinedVector<ssize_t, IREE_SHAPE_MAX_RANK> strides;
- strides.resize(shape.rank);
- for (int i = 1; i < shape.rank; ++i) {
- strides[i - 1] = shape.dims[i] * element_size;
- }
- if (!strides.empty()) {
- strides.back() = 1 * element_size;
- }
-
- // TODO(laurenzo): We need to figure out how to propagate dtype in the
- // buffer view.
- return py::buffer_info(
- mapped_memory_.contents.data, element_size,
- py::format_descriptor<float>::format(), // TODO(laurenzo): DTYPE!
- shape.rank, dims, strides);
- }
-
- private:
- iree_hal_mapped_memory_t mapped_memory_;
- iree_hal_buffer_view_t* bv_;
-};
-
-} // namespace
-
-void SetupHalBindings(pybind11::module m) {
- // Enums.
- py::enum_<iree_hal_memory_type_t>(m, "MemoryType")
- .value("NONE", IREE_HAL_MEMORY_TYPE_NONE)
- .value("TRANSIENT", IREE_HAL_MEMORY_TYPE_TRANSIENT)
- .value("HOST_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)
- .value("HOST_COHERENT", IREE_HAL_MEMORY_TYPE_HOST_COHERENT)
- .value("HOST_CACHED", IREE_HAL_MEMORY_TYPE_HOST_CACHED)
- .value("HOST_LOCAL", IREE_HAL_MEMORY_TYPE_HOST_LOCAL)
- .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)
- .value("DEVICE_LOCAL", IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)
- .export_values();
- py::enum_<iree_hal_buffer_usage_t>(m, "BufferUsage")
- .value("NONE", IREE_HAL_BUFFER_USAGE_NONE)
- .value("CONSTANT", IREE_HAL_BUFFER_USAGE_CONSTANT)
- .value("TRANSFER", IREE_HAL_BUFFER_USAGE_TRANSFER)
- .value("MAPPING", IREE_HAL_BUFFER_USAGE_MAPPING)
- .value("DISPATCH", IREE_HAL_BUFFER_USAGE_DISPATCH)
- .value("ALL", IREE_HAL_BUFFER_USAGE_ALL)
- .export_values();
- py::enum_<iree_hal_memory_access_t>(m, "MemoryAccess")
- .value("NONE", IREE_HAL_MEMORY_ACCESS_NONE)
- .value("READ", IREE_HAL_MEMORY_ACCESS_READ)
- .value("WRITE", IREE_HAL_MEMORY_ACCESS_WRITE)
- .value("DISCARD", IREE_HAL_MEMORY_ACCESS_DISCARD)
- .value("DISCARD_WRITE", IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE)
- .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL)
- .export_values();
-
- py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector));
- py::class_<HalBufferView>(m, "BufferView")
- .def("map", HalMappedMemory::Create);
- py::class_<HalMappedMemory>(m, "MappedMemory", py::buffer_protocol())
- .def_buffer(&HalMappedMemory::ToBufferInfo);
- py::class_<HalBuffer>(m, "Buffer")
- .def_static("allocate_heap", &HalBuffer::AllocateHeapBuffer,
- py::arg("memory_type"), py::arg("usage"),
- py::arg("allocation_size"))
- .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
- py::arg("byte_length"))
- .def("create_view", &HalBuffer::CreateView, py::arg("shape"),
- py::arg("element_size"));
-}
-
-} // namespace python
-} // namespace iree
diff --git a/iree/bindings/python/pyiree/hal.h b/iree/bindings/python/pyiree/hal.h
deleted file mode 100644
index eab644e..0000000
--- a/iree/bindings/python/pyiree/hal.h
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
-
-#include "iree/bindings/python/pyiree/binding.h"
-#include "iree/bindings/python/pyiree/status_utils.h"
-#include "iree/hal/api.h"
-
-namespace iree {
-namespace python {
-
-template <>
-struct ApiPtrAdapter<iree_hal_buffer_t> {
- static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); }
- static void Release(iree_hal_buffer_t* b) { iree_hal_buffer_release(b); }
-};
-
-template <>
-struct ApiPtrAdapter<iree_hal_buffer_view_t> {
- static void Retain(iree_hal_buffer_view_t* bv) {
- iree_hal_buffer_view_retain(bv);
- }
- static void Release(iree_hal_buffer_view_t* bv) {
- iree_hal_buffer_view_release(bv);
- }
-};
-
-struct HalShape {
- public:
- static HalShape FromIntVector(std::vector<int32_t> indices) {
- if (indices.size() > IREE_SHAPE_MAX_RANK) {
- throw RaiseValueError("Shape exceeded maximum rank");
- }
- HalShape s;
- s.s.rank = indices.size();
- for (size_t i = 0, e = indices.size(); i < e; ++i) {
- s.s.dims[i] = indices[i];
- }
- return s;
- }
-
- iree_shape_t s;
-};
-
-class HalBufferView
- : public ApiRefCounted<HalBufferView, iree_hal_buffer_view_t> {
- public:
-};
-
-class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> {
- public:
- static HalBuffer AllocateHeapBuffer(int32_t memory_type, int32_t usage,
- iree_host_size_t allocation_size) {
- iree_hal_buffer_t* buffer = nullptr;
- CheckApiStatus(
- iree_hal_heap_buffer_allocate(
- static_cast<iree_hal_memory_type_t>(memory_type),
- static_cast<iree_hal_buffer_usage_t>(usage), allocation_size,
- IREE_ALLOCATOR_DEFAULT, IREE_ALLOCATOR_DEFAULT, &buffer),
- "Error allocating heap buffer");
- return HalBuffer::CreateRetained(buffer);
- }
-
- void FillZero(iree_device_size_t byte_offset,
- iree_device_size_t byte_length) {
- CheckApiStatus(iree_hal_buffer_zero(raw_ptr(), byte_offset, byte_length),
- "Error zero filling buffer");
- }
-
- HalBufferView CreateView(HalShape& shape, size_t element_size) {
- iree_hal_buffer_view_t* bv;
- CheckApiStatus(iree_hal_buffer_view_create(raw_ptr(), shape.s, element_size,
- IREE_ALLOCATOR_DEFAULT, &bv),
- "Error creating buffer view");
- return HalBufferView::CreateRetained(bv);
- }
-};
-
-void SetupHalBindings(pybind11::module m);
-
-} // namespace python
-} // namespace iree
-
-#endif // IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
diff --git a/iree/bindings/python/pyiree/initialize.cc b/iree/bindings/python/pyiree/initialize.cc
deleted file mode 100644
index 2d4f48e..0000000
--- a/iree/bindings/python/pyiree/initialize.cc
+++ /dev/null
@@ -1,53 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/bindings/python/pyiree/initialize.h"
-
-#include <string.h>
-
-#include <mutex> // NOLINT
-
-#include "iree/base/init.h"
-
-namespace iree {
-namespace python {
-
-namespace {
-
-void InternalInitialize(const std::vector<std::string>& arguments) {
- int argc = arguments.size() + 1; // plus one for program name.
- char** argv = static_cast<char**>(
- malloc(sizeof(char*) * (argc + 1))); // plus one for null terminator.
- char** orig_argv = argv;
- argv[0] = strdup("<python_extension>");
- for (int i = 1; i < argc; ++i) {
- argv[i] = strdup(arguments[i - 1].c_str());
- }
- argv[argc] = nullptr;
- InitializeEnvironment(&argc, &argv);
- for (int i = 0; i < argc; ++i) {
- free(argv[i]);
- }
- free(orig_argv);
-}
-
-} // namespace
-
-void InitializeExtension(const std::vector<std::string>& arguments) {
- static std::once_flag init_once;
- std::call_once(init_once, InternalInitialize, arguments);
-}
-
-} // namespace python
-} // namespace iree
diff --git a/iree/bindings/python/pyiree/rt.cc b/iree/bindings/python/pyiree/rt.cc
deleted file mode 100644
index f09f20d..0000000
--- a/iree/bindings/python/pyiree/rt.cc
+++ /dev/null
@@ -1,150 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/bindings/python/pyiree/rt.h"
-
-#include "iree/base/api.h"
-#include "iree/bindings/python/pyiree/status_utils.h"
-#include "iree/hal/api.h"
-
-namespace iree {
-namespace python {
-
-HalBufferView RtContext::WrapPyBufferForInput(py::buffer py_buffer) {
- auto py_buffer_info = py_buffer.request(false /* writable */);
- if (py_buffer_info.ndim > IREE_SHAPE_MAX_RANK || py_buffer_info.ndim < 0) {
- RaiseValueError("Unsupported buffer rank");
- }
- if (py_buffer_info.size < 0) {
- RaiseValueError("Illegal buffer size");
- }
-
- // For the moment, allocate a device visible buffer of equivalent size and
- // copy into it.
- // TODO(laurenzo): Once sequencer is in place, switch to HeapBuffer, wrap
- // and retain the original buffer.
- iree_host_size_t byte_size = py_buffer_info.size * py_buffer_info.itemsize;
- HalBuffer buffer =
- AllocateDeviceVisible(byte_size, IREE_HAL_BUFFER_USAGE_CONSTANT |
- IREE_HAL_BUFFER_USAGE_TRANSFER |
- IREE_HAL_BUFFER_USAGE_DISPATCH);
- CheckApiStatus(iree_hal_buffer_write_data(buffer.raw_ptr(), 0,
- py_buffer_info.ptr, byte_size),
- "Error writing to input buffer");
-
- // Create the buffer view.
- // TODO(laurenzo): This does no validation on dtype and only cares if the
- // elementsize matches. Figure out where to enforce actual dtype.
- iree_shape_t shape;
- shape.rank = py_buffer_info.ndim;
-
- // Verify strides are row-major.
- // TODO(laurenzo): Test this with rank>1.
- for (int i = 1; i < shape.rank; ++i) {
- if ((py_buffer_info.strides[i - 1] * py_buffer_info.itemsize) !=
- py_buffer_info.shape[i]) {
- RaiseValueError("Expected row-major layout");
- }
- }
- if (!py_buffer_info.strides.empty()) {
- if (py_buffer_info.strides.back() != 1) {
- RaiseValueError("Expected row-major layout");
- }
- }
-
- // Populate shape.
- for (int i = 0; i < shape.rank; ++i) {
- ssize_t dim = py_buffer_info.shape[i];
- if (dim < 0) {
- RaiseValueError("Unsupported negative dim");
- }
- shape.dims[i] = dim;
- }
-
- iree_hal_buffer_view_t* bv;
- CheckApiStatus(iree_hal_buffer_view_create(buffer.raw_ptr(), shape,
- py_buffer_info.itemsize,
- IREE_ALLOCATOR_DEFAULT, &bv),
- "Error allocating buffer view");
-
- return HalBufferView::CreateRetained(bv);
-}
-
-void SetupRtBindings(pybind11::module m) {
- // BufferPlacement.
- py::enum_<BufferPlacement>(m, "BufferPlacement")
- .value("HEAP", BufferPlacement::kHeap)
- .value("DEVICE_VISIBLE", BufferPlacement::kDeviceVisible)
- .value("DEVICE_LOCAL", BufferPlacement::kDeviceLocal)
- .export_values();
-
- // RtModule.
- py::class_<RtModule>(m, "Module")
- .def_property_readonly("name", &RtModule::name)
- .def("lookup_function_by_ordinal", &RtModule::lookup_function_by_ordinal)
- .def("lookup_function_by_name", &RtModule::lookup_function_by_name);
- // RtFunction.
- py::class_<RtFunction>(m, "Function")
- .def_property_readonly("name", &RtFunction::name)
- .def_property_readonly("signature", &RtFunction::signature);
- py::class_<iree_rt_function_signature_t>(m, "FunctionSignature")
- .def_readonly("argument_count",
- &iree_rt_function_signature_t::argument_count)
- .def_readonly("result_count",
- &iree_rt_function_signature_t::result_count);
-
- // RtPolicy.
- py::class_<RtPolicy>(m, "Policy").def(py::init(&RtPolicy::Create));
-
- // RtInstance.
- py::class_<RtInstance>(m, "Instance")
- .def(py::init(&RtInstance::Create),
- py::arg_v("driver_name", absl::optional<std::string>()));
-
- // RtContext.
- py::class_<RtContext>(m, "Context")
- .def(py::init(&RtContext::Create), py::arg("instance"), py::arg("policy"))
- .def_property_readonly("context_id", &RtContext::context_id)
- .def("register_modules", &RtContext::RegisterModules, py::arg("modules"))
- .def("register_module", &RtContext::RegisterModule, py::arg("module"))
- .def("lookup_module_by_name", &RtContext::LookupModuleByName,
- py::arg("name"))
- .def("resolve_function", &RtContext::ResolveFunction,
- py::arg("full_name"))
- .def("allocate", &RtContext::Allocate, py::arg("allocation_size"),
- py::arg("placement") = BufferPlacement::kHeap,
- py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL)
- .def("allocate_device_local", &RtContext::AllocateDeviceLocal,
- py::arg("allocation_size"),
- py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL)
- .def("allocate_device_visible", &RtContext::AllocateDeviceVisible,
- py::arg("allocation_size"),
- py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL)
- .def("wrap_for_input", &RtContext::WrapPyBufferForInput, py::arg("v"))
- .def("invoke", &RtContext::Invoke, py::arg("f"), py::arg("policy"),
- py::arg("arguments"),
- py::arg("results") = absl::optional<std::vector<HalBufferView*>>());
-
- // RtInvocation.
- py::class_<RtInvocation>(m, "Invocation")
- .def("query_status", &RtInvocation::QueryStatus)
- .def("await", &RtInvocation::Await,
- py::arg("deadline") = IREE_TIME_INFINITE_FUTURE)
- .def("await_optional", &RtInvocation::AwaitOptional,
- py::arg("deadline") = IREE_TIME_INFINITE_FUTURE)
- .def_property_readonly("results", &RtInvocation::ConsumeResults);
-}
-
-} // namespace python
-} // namespace iree
diff --git a/iree/bindings/python/pyiree/rt.h b/iree/bindings/python/pyiree/rt.h
deleted file mode 100644
index 2f1dbd0..0000000
--- a/iree/bindings/python/pyiree/rt.h
+++ /dev/null
@@ -1,390 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_RT_H_
-
-#include "absl/container/inlined_vector.h"
-#include "iree/base/api.h"
-#include "iree/bindings/python/pyiree/binding.h"
-#include "iree/bindings/python/pyiree/hal.h"
-#include "iree/bindings/python/pyiree/initialize.h"
-#include "iree/bindings/python/pyiree/status_utils.h"
-#include "iree/hal/api.h"
-#include "iree/rt/api.h"
-
-namespace iree {
-namespace python {
-
-// When creating a buffer via the context, switch between the different
-// allocation entry-points via an enum (these are separate functions in the
-// C API).
-enum class BufferPlacement {
- kHeap,
- kDeviceVisible,
- kDeviceLocal,
-};
-
-// Adapts API pointer access to retain/release API calls.
-template <>
-struct ApiPtrAdapter<iree_rt_module_t> {
- static void Retain(iree_rt_module_t* m) { iree_rt_module_retain(m); }
- static void Release(iree_rt_module_t* m) { iree_rt_module_release(m); }
-};
-
-template <>
-struct ApiPtrAdapter<iree_rt_instance_t> {
- static void Retain(iree_rt_instance_t* inst) {
- iree_rt_instance_retain(inst);
- }
- static void Release(iree_rt_instance_t* inst) {
- iree_rt_instance_release(inst);
- }
-};
-
-template <>
-struct ApiPtrAdapter<iree_rt_policy_t> {
- static void Retain(iree_rt_policy_t* p) { iree_rt_policy_retain(p); }
- static void Release(iree_rt_policy_t* p) { iree_rt_policy_release(p); }
-};
-
-template <>
-struct ApiPtrAdapter<iree_rt_context_t> {
- static void Retain(iree_rt_context_t* c) { iree_rt_context_retain(c); }
- static void Release(iree_rt_context_t* c) { iree_rt_context_release(c); }
-};
-
-template <>
-struct ApiPtrAdapter<iree_rt_invocation_t> {
- static void Retain(iree_rt_invocation_t* inv) {
- iree_rt_invocation_retain(inv);
- }
- static void Release(iree_rt_invocation_t* inv) {
- iree_rt_invocation_release(inv);
- }
-};
-
-// Wrapper classes. These mirror the Python declarations.
-class RtFunction {
- public:
- // Note that this will retain the module.
- RtFunction(iree_rt_function_t function) : function_(function) {
- iree_rt_module_retain(function_.module);
- }
- ~RtFunction() {
- if (function_.module) iree_rt_module_release(function_.module);
- }
- RtFunction(RtFunction&& other) : function_(other.function_) {
- other.function_.module = nullptr;
- }
- void operator=(const RtFunction&) = delete;
-
- std::string name() {
- auto sv = iree_rt_function_name(&function_);
- return std::string(sv.data, sv.size);
- }
-
- iree_rt_function_signature_t signature() {
- iree_rt_function_signature_t sig;
- CheckApiStatus(iree_rt_function_signature(&function_, &sig),
- "Error getting function signature");
- return sig;
- }
-
- iree_rt_function_t& raw_function() { return function_; }
-
- private:
- iree_rt_function_t function_;
-};
-
-class RtModule : public ApiRefCounted<RtModule, iree_rt_module_t> {
- public:
- std::string name() {
- auto sv = iree_rt_module_name(raw_ptr());
- return std::string(sv.data, sv.size);
- }
-
- absl::optional<RtFunction> lookup_function_by_ordinal(int32_t ordinal) {
- iree_rt_function_t f;
- // TODO(laurenzo): Support an optional linkage argument.
- auto module = raw_ptr();
- auto status = iree_rt_module_lookup_function_by_ordinal(
- module, IREE_RT_FUNCTION_LINKAGE_EXPORT, ordinal, &f);
- if (status == IREE_STATUS_NOT_FOUND) {
- return absl::optional<RtFunction>();
- }
- CheckApiStatus(status, "Error looking up function");
- return RtFunction(f);
- }
-
- absl::optional<RtFunction> lookup_function_by_name(const std::string& name) {
- iree_rt_function_t f;
- // TODO(laurenzo): Support an optional linkage argument.
- auto module = raw_ptr();
- iree_string_view_t name_sv{name.data(), name.size()};
- auto status = iree_rt_module_lookup_function_by_name(
- module, IREE_RT_FUNCTION_LINKAGE_EXPORT, name_sv, &f);
- if (status == IREE_STATUS_NOT_FOUND) {
- return absl::optional<RtFunction>();
- }
- CheckApiStatus(status, "Error looking up function");
- return RtFunction(f);
- }
-};
-
-class RtInstance : public ApiRefCounted<RtInstance, iree_rt_instance_t> {
- public:
- // TODO(laurenzo): Support optional allocator argument.
- static RtInstance Create(absl::optional<std::string> driver_name) {
- InitializeExtension({});
- iree_rt_instance_t* raw_inst;
- CheckApiStatus(iree_rt_instance_create(IREE_ALLOCATOR_DEFAULT, &raw_inst),
- "Error creating instance");
- RtInstance inst = RtInstance::CreateRetained(raw_inst);
-
- if (!driver_name) {
- driver_name = "interpreter";
- }
- CheckApiStatus(iree_rt_instance_register_driver_ex(
- raw_inst, iree_string_view_t{driver_name->c_str(),
- driver_name->size()}),
- "Error registering drivers");
-
- return inst;
- }
-};
-
-class RtPolicy : public ApiRefCounted<RtPolicy, iree_rt_policy_t> {
- public:
- // TODO(laurenzo): Support optional allocator argument.
- static RtPolicy Create() {
- iree_rt_policy_t* policy;
- CheckApiStatus(iree_rt_policy_create(IREE_ALLOCATOR_DEFAULT, &policy),
- "Error creating policy");
- return RtPolicy::CreateRetained(policy);
- }
-};
-
-class RtInvocation : public ApiRefCounted<RtInvocation, iree_rt_invocation_t> {
- public:
- // Returns whether ready.
- // Raises exception on error.
- bool QueryStatus() {
- auto status = iree_rt_invocation_query_status(raw_ptr());
- if (status == IREE_STATUS_OK) {
- return true;
- } else if (status == IREE_STATUS_UNAVAILABLE) {
- return false;
- } else {
- CheckApiStatus(status, "Error in function invocation");
- return false;
- }
- }
-
- // TODO(laurenzo): Convert to the pybind chrono support.
- // Returns whether the invocation is ready.
- bool AwaitOptional(iree_time_t epoch_nanos_deadline) {
- auto status = iree_rt_invocation_await(raw_ptr(), epoch_nanos_deadline);
- if (status == IREE_STATUS_OK) {
- return true;
- } else if (status == IREE_STATUS_DEADLINE_EXCEEDED) {
- return false;
- } else {
- CheckApiStatus(status, "Error in invocation");
- return false;
- }
- }
-
- // Similar to AwaitOptional but will raise an error unless if the status
- // is ready.
- void Await(iree_time_t epoch_nanos_deadline) {
- if (!AwaitOptional(epoch_nanos_deadline)) {
- RaiseValueError("Deadline expired");
- }
- }
-
- std::vector<HalBufferView> ConsumeResults() {
- static constexpr size_t kInlineSize = 8;
- iree_host_size_t result_count;
- absl::InlinedVector<iree_hal_buffer_view_t*, kInlineSize> result_bvs;
- result_bvs.resize(kInlineSize);
- auto status = iree_rt_invocation_consume_results(
- raw_ptr(), kInlineSize, IREE_ALLOCATOR_DEFAULT, &result_bvs[0],
- &result_count);
- if (status == IREE_STATUS_OUT_OF_RANGE) {
- // Resize/retry.
- result_bvs.resize(result_count);
- status = iree_rt_invocation_consume_results(
- raw_ptr(), result_count, IREE_ALLOCATOR_DEFAULT, &result_bvs[0],
- &result_count);
- }
- CheckApiStatus(status, "Error consuming invocation results");
- result_bvs.resize(result_count);
- std::vector<HalBufferView> results;
- for (auto* raw_bv : result_bvs) {
- results.push_back(HalBufferView::CreateRetained(raw_bv));
- }
- return results;
- }
-};
-
-class RtContext : public ApiRefCounted<RtContext, iree_rt_context_t> {
- public:
- static RtContext Create(RtInstance* instance, RtPolicy* policy) {
- iree_rt_context_t* context;
- // TODO(laurenzo): Support optional allocator argument.
- CheckApiStatus(
- iree_rt_context_create(instance->raw_ptr(), policy->raw_ptr(),
- IREE_ALLOCATOR_DEFAULT, &context),
- "Error creating instance");
- return RtContext::CreateRetained(context);
- }
-
- int context_id() { return iree_rt_context_id(raw_ptr()); }
-
- void RegisterModules(std::vector<RtModule*> modules) {
- std::vector<iree_rt_module_t*> module_raw_ptrs;
- module_raw_ptrs.resize(modules.size());
- for (size_t i = 0, e = modules.size(); i < e; ++i) {
- auto module_raw_ptr = modules[i]->raw_ptr();
- module_raw_ptrs[i] = module_raw_ptr;
- }
- CheckApiStatus(
- iree_rt_context_register_modules(raw_ptr(), module_raw_ptrs.data(),
- module_raw_ptrs.size()),
- "Error registering modules");
- }
-
- void RegisterModule(RtModule* module) {
- iree_rt_module_t* module_raw_ptr = module->raw_ptr();
- CheckApiStatus(
- iree_rt_context_register_modules(raw_ptr(), &module_raw_ptr, 1),
- "Error registering module");
- }
-
- absl::optional<RtModule> LookupModuleByName(const std::string& name) {
- iree_rt_module_t* module = iree_rt_context_lookup_module_by_name(
- raw_ptr(), {name.data(), name.size()});
- if (!module) {
- return absl::optional<RtModule>();
- }
- return RtModule::RetainAndCreate(module);
- }
-
- absl::optional<RtFunction> ResolveFunction(const std::string& full_name) {
- iree_rt_function_t f;
- auto status = iree_rt_context_resolve_function(
- raw_ptr(), {full_name.data(), full_name.size()}, &f);
- if (status == IREE_STATUS_NOT_FOUND) {
- return absl::optional<RtFunction>();
- }
- CheckApiStatus(status, "Error resolving function");
- return RtFunction(f);
- }
-
- // Convenience method to allocate host, device-visible or device-local
- // buffers.
- HalBuffer Allocate(iree_host_size_t allocation_size,
- BufferPlacement placement, int32_t usage) {
- iree_hal_buffer_t* raw_buffer = nullptr;
- switch (placement) {
- case BufferPlacement::kHeap:
- // Even though allocating a heap buffer does not require the context,
- // provide it here to make the API easier to navigate.
- CheckApiStatus(
- iree_hal_heap_buffer_allocate(
- IREE_HAL_MEMORY_TYPE_HOST_LOCAL,
- static_cast<iree_hal_buffer_usage_t>(usage), allocation_size,
- IREE_ALLOCATOR_DEFAULT, IREE_ALLOCATOR_DEFAULT, &raw_buffer),
- "Error allocating heap buffer");
- break;
- case BufferPlacement::kDeviceLocal:
- CheckApiStatus(
- iree_rt_context_allocate_device_local_buffer(
- raw_ptr(), static_cast<iree_hal_buffer_usage_t>(usage),
- allocation_size, IREE_ALLOCATOR_DEFAULT, &raw_buffer),
- "Error allocating device local buffer");
- break;
- case BufferPlacement::kDeviceVisible:
- CheckApiStatus(
- iree_rt_context_allocate_device_visible_buffer(
- raw_ptr(), static_cast<iree_hal_buffer_usage_t>(usage),
- allocation_size, IREE_ALLOCATOR_DEFAULT, &raw_buffer),
- "Error allocating device visible buffer");
- break;
- default:
- throw RaiseValueError("Unknown BufferPlacement");
- }
-
- return HalBuffer::CreateRetained(raw_buffer);
- }
-
- HalBuffer AllocateHeap(iree_host_size_t allocation_size, int32_t usage) {
- return Allocate(allocation_size, BufferPlacement::kHeap, usage);
- }
-
- HalBuffer AllocateDeviceLocal(iree_host_size_t allocation_size,
- int32_t usage) {
- return Allocate(allocation_size, BufferPlacement::kDeviceLocal, usage);
- }
-
- HalBuffer AllocateDeviceVisible(iree_host_size_t allocation_size,
- int32_t usage) {
- return Allocate(allocation_size, BufferPlacement::kDeviceVisible, usage);
- }
-
- // One stop convenience method for wrapping a python buffer protocol buffer
- // for input to a function. At the runtime's discretion, this may make a copy
- // or do something smarter, meaning the data in the backing python buffer
- // will either be accessed immediately or at some future point.
- HalBufferView WrapPyBufferForInput(py::buffer py_buffer);
-
- RtInvocation Invoke(RtFunction& f, RtPolicy& policy,
- std::vector<HalBufferView*> arguments,
- absl::optional<std::vector<HalBufferView*>> results) {
- absl::InlinedVector<iree_hal_buffer_view_t*, 8> raw_arguments;
- raw_arguments.resize(arguments.size());
- for (size_t i = 0, e = arguments.size(); i < e; ++i) {
- auto inst = arguments[i];
- CheckApiNotNull(inst, "Argument buffer view cannot be None");
- raw_arguments[i] = inst->raw_ptr();
- }
- absl::InlinedVector<iree_hal_buffer_view_t*, 8> raw_results;
- if (results) {
- raw_results.resize(results->size());
- for (size_t i = 0, e = results->size(); i < e; ++i) {
- auto inst = (*results)[i];
- CheckApiNotNull(inst, "Result buffer view cannot be None");
- raw_results[i] = inst->raw_ptr();
- }
- }
-
- iree_rt_invocation_t* invocation;
- CheckApiStatus(iree_rt_invocation_create(
- raw_ptr(), &f.raw_function(), policy.raw_ptr(),
- nullptr /* dependencies */, raw_arguments.data(),
- raw_arguments.size(), raw_results.data(),
- raw_results.size(), IREE_ALLOCATOR_DEFAULT, &invocation),
- "Error invoking function");
-
- return RtInvocation::CreateRetained(invocation);
- }
-};
-
-void SetupRtBindings(pybind11::module m);
-
-} // namespace python
-} // namespace iree
-
-#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_H_
diff --git a/iree/bindings/python/pyiree/status_utils.cc b/iree/bindings/python/pyiree/status_utils.cc
deleted file mode 100644
index a8c5008..0000000
--- a/iree/bindings/python/pyiree/status_utils.cc
+++ /dev/null
@@ -1,72 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/bindings/python/pyiree/status_utils.h"
-
-#include "absl/strings/str_cat.h"
-
-namespace iree {
-namespace python {
-
-namespace {
-
-PyObject* StatusToPyExcClass(const Status& status) {
- switch (status.code()) {
- case StatusCode::kInvalidArgument:
- return PyExc_ValueError;
- case StatusCode::kOutOfRange:
- return PyExc_IndexError;
- case StatusCode::kUnimplemented:
- return PyExc_NotImplementedError;
- default:
- return PyExc_RuntimeError;
- }
-}
-
-PyObject* ApiStatusToPyExcClass(iree_status_t status) {
- switch (status) {
- case IREE_STATUS_INVALID_ARGUMENT:
- return PyExc_ValueError;
- case IREE_STATUS_OUT_OF_RANGE:
- return PyExc_IndexError;
- case IREE_STATUS_UNIMPLEMENTED:
- return PyExc_NotImplementedError;
- default:
- return PyExc_RuntimeError;
- }
-}
-
-} // namespace
-
-pybind11::error_already_set StatusToPyExc(const Status& status) {
- assert(!status.ok());
- PyErr_SetString(StatusToPyExcClass(status), status.error_message().c_str());
- return pybind11::error_already_set();
-}
-
-pybind11::error_already_set ApiStatusToPyExc(iree_status_t status,
- const char* message) {
- assert(status != IREE_STATUS_OK);
- auto full_message = absl::StrCat(message, ": ", static_cast<int>(status));
- PyErr_SetString(ApiStatusToPyExcClass(status), full_message.c_str());
- return pybind11::error_already_set();
-}
-
-pybind11::error_already_set RaiseValueError(const char* message) {
- PyErr_SetString(PyExc_ValueError, message);
- return pybind11::error_already_set();
-}
-
-} // namespace python
-} // namespace iree
diff --git a/iree/bindings/python/pyiree/status_utils.h b/iree/bindings/python/pyiree/status_utils.h
deleted file mode 100644
index ec12b01..0000000
--- a/iree/bindings/python/pyiree/status_utils.h
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
-
-#include "iree/base/api.h"
-#include "iree/base/status.h"
-#include "pybind11/pytypes.h"
-
-namespace iree {
-namespace python {
-
-// Converts a failing status to a throwable exception, setting Python
-// error information.
-// Correct usage is something like:
-// if (!status.ok()) {
-// throw StatusToPyExc(status);
-// }
-pybind11::error_already_set StatusToPyExc(const Status& status);
-
-// Raises a value error with the given message.
-// Correct usage:
-// throw RaiseValueError("Foobar'd");
-pybind11::error_already_set RaiseValueError(const char* message);
-
-// Consumes a StatusOr<T>, returning an rvalue reference to the T if the
-// status is ok(). Otherwise, throws an exception.
-template <typename T>
-T&& PyConsumeStatusOr(iree::StatusOr<T>&& sor) {
- if (sor.ok()) {
- return std::move(*sor);
- }
- throw StatusToPyExc(sor.status());
-}
-
-pybind11::error_already_set ApiStatusToPyExc(iree_status_t status,
- const char* message);
-
-static void CheckApiStatus(iree_status_t status, const char* message) {
- if (status == IREE_STATUS_OK) {
- return;
- }
- throw ApiStatusToPyExc(status, message);
-}
-
-static void CheckApiNotNull(const void* p, const char* message) {
- if (!p) {
- throw RaiseValueError(message);
- }
-}
-
-} // namespace python
-} // namespace iree
-
-#endif // IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
diff --git a/iree/bindings/python/pyiree/vm.cc b/iree/bindings/python/pyiree/vm.cc
deleted file mode 100644
index c97ebce..0000000
--- a/iree/bindings/python/pyiree/vm.cc
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/bindings/python/pyiree/vm.h"
-
-#include "iree/bindings/python/pyiree/status_utils.h"
-
-namespace iree {
-namespace python {
-
-RtModule CreateModuleFromBlob(std::shared_ptr<OpaqueBlob> blob) {
- iree_rt_module_t* module;
- auto free_fn = OpaqueBlob::CreateFreeFn(blob);
- auto status = iree_vm_bytecode_module_create_from_buffer(
- {static_cast<const uint8_t*>(blob->data()), blob->size()}, free_fn.first,
- free_fn.second, IREE_ALLOCATOR_DEFAULT, &module);
- CheckApiStatus(status, "Error creating vm module from blob");
- return RtModule::CreateRetained(module);
-}
-
-void SetupVmBindings(pybind11::module m) {
- m.def("create_module_from_blob", CreateModuleFromBlob);
-}
-
-} // namespace python
-} // namespace iree
diff --git a/iree/bindings/python/pyiree/vm.h b/iree/bindings/python/pyiree/vm.h
deleted file mode 100644
index fcb7982..0000000
--- a/iree/bindings/python/pyiree/vm.h
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_VM_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_VM_H_
-
-#include "iree/bindings/python/pyiree/binding.h"
-#include "iree/bindings/python/pyiree/rt.h"
-#include "iree/vm/api.h"
-
-namespace iree {
-namespace python {
-
-void SetupVmBindings(pybind11::module m);
-
-} // namespace python
-} // namespace iree
-
-#endif // IREE_BINDINGS_PYTHON_PYIREE_VM_H_
diff --git a/iree/build_defs.bzl b/iree/build_defs.bzl
deleted file mode 100644
index d42ebcf..0000000
--- a/iree/build_defs.bzl
+++ /dev/null
@@ -1,130 +0,0 @@
-"""Common Bazel definitions for IREE."""
-
-load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
-load("@iree_native_python//:build_defs.bzl", "py_extension")
-load("@iree_core//build_tools/third_party/glslang:build_defs.bzl", "glsl_vulkan")
-load("@rules_python//python:defs.bzl", "py_library")
-
-NUMPY_DEPS = []
-
-def platform_trampoline_deps(basename):
- """Produce a list of deps for the given `basename` platform target.
-
- Example:
- "file_mapping" -> ["//iree/base/internal/file_mapping_internal"]
-
- This is used for compatibility with various methods of including the
- library in foreign source control systems.
-
- Args:
- basename: Library name prefix for a library in base/internal.
- Returns:
- A list of dependencies for depending on the library in a platform
- sensitive way.
- """
- return [
- "//iree/base/internal:%s_internal" % basename,
- ]
-
-# A platform-sensitive list of copts for the Vulkan loader.
-PLATFORM_VULKAN_LOADER_COPTS = select({
- "//iree/hal/vulkan:native_vk": [],
- "//iree/hal/vulkan:swiftshader_vk": [],
- "//conditions:default": [],
-})
-
-# A platform-sensitive list of dependencies for non-test targets using Vulkan.
-PLATFORM_VULKAN_DEPS = select({
- "//iree/hal/vulkan:native_vk": [],
- "//iree/hal/vulkan:swiftshader_vk": [],
- "//conditions:default": [],
-})
-
-# A platform-sensitive list of dependencies for tests using Vulkan.
-PLATFORM_VULKAN_TEST_DEPS = [
- "@com_google_googletest//:gtest_main",
-]
-
-def iree_py_library(**kwargs):
- """Compatibility py_library which has bazel compatible args."""
-
- # This is used when args are needed that are incompatible with upstream.
- # Presently, this includes:
- # imports
- py_library(**kwargs)
-
-def iree_py_extension(deps = [], **kwargs):
- """Delegates to the real py_extension."""
- py_extension(
- deps = ["@iree_native_python//:python_headers"] + deps,
- **kwargs
- )
-
-def iree_build_test(name, targets):
- """Dummy rule to ensure that targets build.
-
- This is currently undefined in bazel and is preserved for compatibility.
- """
- pass
-
-def iree_setup_lit_package(data):
- """Should be called once per test package that contains globbed lit tests.
-
- Args:
- data: Additional, project specific data deps to add.
- """
-
- # Bundle together all of the test utilities that are used by tests.
- native.filegroup(
- name = "lit_test_utilities",
- testonly = True,
- data = data + [
- "@llvm//:FileCheck",
- ],
- )
-
-def iree_glob_lit_tests(
- data = [":lit_test_utilities"],
- driver = "//iree/tools:run_lit.sh",
- test_file_exts = ["mlir"]):
- """Globs lit test files into tests for a package.
-
- For most packages, the defaults suffice. Packages that include this must
- also include a call to iree_setup_lit_package().
-
- Args:
- data: Data files to include/build.
- driver: Test driver.
- test_file_exts: File extensions to glob.
- """
- for test_file_ext in test_file_exts:
- test_files = native.glob([
- "*.%s" % (test_file_ext,),
- "**/*.%s" % (test_file_ext,),
- ])
- for test_file in test_files:
- test_file_location = "$(location %s)" % (test_file,)
- native.sh_test(
- name = "%s.test" % (test_file,),
- size = "small",
- srcs = [driver],
- data = data + [test_file],
- args = [test_file_location],
- )
-
-# The OSS build currently has issues with generating flatbuffer reflections.
-# It is hard-coded to disabled here (and in iree_flatbuffer_cc_library) until triaged/fixed.
-FLATBUFFER_SUPPORTS_REFLECTIONS = False
-
-def iree_flatbuffer_cc_library(**kwargs):
- """Wrapper for the flatbuffer_cc_library."""
-
- # TODO(laurenzo): The bazel rule for reflections seems broken in OSS
- # builds. Fix it and enable by default.
- flatbuffer_cc_library(
- gen_reflections = False,
- **kwargs
- )
-
-def iree_glsl_vulkan(**kwargs):
- glsl_vulkan(**kwargs)
diff --git a/iree/compiler/IR/ConfigOps.cpp b/iree/compiler/IR/ConfigOps.cpp
deleted file mode 100644
index 3ec129a..0000000
--- a/iree/compiler/IR/ConfigOps.cpp
+++ /dev/null
@@ -1,110 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/ConfigOps.h"
-
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/IR/Types.h"
-#include "llvm/ADT/SmallString.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/STLExtras.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-
-//===----------------------------------------------------------------------===//
-// Generic printers and parsers.
-//===----------------------------------------------------------------------===//
-
-// Parses an op that has no inputs and no outputs.
-static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) {
- if (failed(parser.parseOptionalAttributeDict(state.attributes))) {
- return failure();
- }
- return success();
-}
-
-// Prints an op that has no inputs and no outputs.
-static void printNoIOOp(Operation *op, OpAsmPrinter &printer) {
- printer << op->getName();
- printer.printOptionalAttrDict(op->getAttrs());
-}
-
-//===----------------------------------------------------------------------===//
-// iree.target_config
-//===----------------------------------------------------------------------===//
-
-void ExecutableTargetConfigOp::build(Builder *builder, OperationState &state,
- std::string backend) {
- state.addAttribute("backend", builder->getStringAttr(backend));
- ensureTerminator(*state.addRegion(), *builder, state.location);
-}
-
-static ParseResult parseExecutableTargetConfigOp(OpAsmParser &parser,
- OperationState &state) {
- llvm::SMLoc backendLoc;
- StringAttr backendAttr;
- if (failed(parser.parseLParen()) ||
- failed(parser.getCurrentLocation(&backendLoc)) ||
- failed(parser.parseAttribute(backendAttr, "backend", state.attributes))) {
- return failure();
- }
-
- Region *body = state.addRegion();
- if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) {
- return failure();
- }
- if (succeeded(parser.parseOptionalKeyword("attributes"))) {
- if (failed(parser.parseOptionalAttributeDict(state.attributes))) {
- return failure();
- }
- }
-
- ExecutableTargetConfigOp::ensureTerminator(*body, parser.getBuilder(),
- state.location);
-
- return success();
-}
-
-static void printExecutableTargetConfigOp(OpAsmPrinter &printer,
- ExecutableTargetConfigOp op) {
- printer << op.getOperationName() << "(" << op.backend() << ")";
-
- printer.printRegion(op.body(), /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/false);
-
- // Print out executable attributes, if present.
- SmallVector<StringRef, 1> ignoredAttrs = {
- "backend",
- };
- if (op.getAttrs().size() > ignoredAttrs.size()) {
- printer << "\n attributes ";
- printer.printOptionalAttrDict(op.getAttrs(), ignoredAttrs);
- }
-}
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/ConfigOps.cpp.inc"
-
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/ConfigOps.h b/iree/compiler/IR/ConfigOps.h
deleted file mode 100644
index 1cef1f6..0000000
--- a/iree/compiler/IR/ConfigOps.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_IR_CONFIGOPS_H_
-#define IREE_COMPILER_IR_CONFIGOPS_H_
-
-#include <cstdint>
-
-#include "iree/compiler/IR/Types.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/ConfigOps.h.inc"
-
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_IR_CONFIGOPS_H_
diff --git a/iree/compiler/IR/ConfigOps.td b/iree/compiler/IR/ConfigOps.td
deleted file mode 100644
index 21681c5..0000000
--- a/iree/compiler/IR/ConfigOps.td
+++ /dev/null
@@ -1,44 +0,0 @@
-// Ops used to declare configuration used by the IREE compiler.
-// These allow inline config that follows along the IR they are associated with.
-// Multiple config ops are allowed within a single scope to indicate that the
-// parent IR node should be processed for multiple targets.
-
-#ifdef IREE_CONFIG_OPS
-#else
-#define IREE_CONFIG_OPS
-
-include "iree/compiler/IR/OpBase.td"
-
-class IREE_ConfigOp<string mnemonic, list<OpTrait> traits = []> :
- Op<IREE_Dialect, mnemonic, traits> {
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ print$cppClass(p, *this); }];
-}
-
-//===----------------------------------------------------------------------===//
-// iree.executable configuration
-//===----------------------------------------------------------------------===//
-
-def IREE_ExecutableTargetConfigOp : IREE_ConfigOp<"target_config", [
- IREE_ExecutableOnly,
- SingleBlockImplicitTerminator<"ExecutableTargetConfigEndOp">
-]> {
- let arguments = (ins
- StrAttr:$backend
- );
-
- let regions = (region SizedRegion<1>:$body);
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilder<"Builder *builder, OperationState &state, std::string backend">,
- ];
-}
-
-def IREE_ExecutableTargetConfigEndOp :
- IREE_ConfigOp<"_target_config_end", [Terminator, IREE_ExecutableTargetConfigOnly]> {
- let parser = [{ return parseNoIOOp(parser, result); }];
- let printer = [{ printNoIOOp(getOperation(), p); }];
-}
-
-#endif // IREE_CONFIG_OPS
diff --git a/iree/compiler/IR/Dialect.cpp b/iree/compiler/IR/Dialect.cpp
deleted file mode 100644
index 0946319..0000000
--- a/iree/compiler/IR/Dialect.cpp
+++ /dev/null
@@ -1,90 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-
-#include "iree/compiler/IR/ConfigOps.h"
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/IR/Types.h"
-#include "llvm/Support/SourceMgr.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-static DialectRegistration<IREEDialect> iree_dialect;
-
-IREEDialect::IREEDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
-#define IREE_ADD_TYPE(NAME, KIND, TYPE) addTypes<TYPE>();
- IREE_TYPE_TABLE(IREE_ADD_TYPE);
-
-#define GET_OP_LIST
- addOperations<
-#include "iree/compiler/IR/Ops.cpp.inc"
- >();
-#define GET_OP_LIST
- addOperations<
-#include "iree/compiler/IR/ConfigOps.cpp.inc"
- >();
-#define GET_OP_LIST
- addOperations<
-#include "iree/compiler/IR/StructureOps.cpp.inc"
- >();
-}
-
-//===----------------------------------------------------------------------===//
-// Type Parsing
-//===----------------------------------------------------------------------===//
-
-#define IREE_TYPE_PARSER(NAME, KIND, TYPE) \
- static Type parse##TYPE(IREEDialect const &dialect, StringRef spec, \
- Location loc) { \
- spec.consume_front(NAME); \
- return TYPE::get(dialect.getContext()); \
- }
-IREE_TYPE_TABLE(IREE_TYPE_PARSER);
-
-#define IREE_PARSE_TYPE(NAME, KIND, TYPE) \
- if (spec.startswith(NAME)) { \
- return parse##TYPE(*this, spec, loc); \
- }
-Type IREEDialect::parseType(StringRef spec, Location loc) const {
- IREE_TYPE_TABLE(IREE_PARSE_TYPE);
- emitError(loc, "unknown IREE type: ") << spec;
- return Type();
-}
-
-//===----------------------------------------------------------------------===//
-// Type Printing
-//===----------------------------------------------------------------------===//
-
-#define IREE_TYPE_PRINTER(NAME, KIND, TYPE) \
- static void print##TYPE(TYPE type, llvm::raw_ostream &os) { os << NAME; }
-IREE_TYPE_TABLE(IREE_TYPE_PRINTER);
-
-#define IREE_PRINT_TYPE(NAME, KIND, TYPE) \
- case KIND: \
- print##TYPE(type.cast<TYPE>(), os); \
- return;
-void IREEDialect::printType(Type type, llvm::raw_ostream &os) const {
- switch (type.getKind()) {
- IREE_TYPE_TABLE(IREE_PRINT_TYPE);
- default:
- llvm_unreachable("unhandled IREE type");
- }
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/BUILD b/iree/compiler/IR/Interpreter/BUILD
deleted file mode 100644
index 63a063d..0000000
--- a/iree/compiler/IR/Interpreter/BUILD
+++ /dev/null
@@ -1,75 +0,0 @@
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-load("@local_config_mlir//:tblgen.bzl", "gentbl")
-
-filegroup(
- name = "td_files",
- srcs = glob(["*.td"]),
-)
-
-cc_library(
- name = "Interpreter",
- srcs = [
- "HLDialect.cpp",
- "HLOps.cpp",
- "HLOps.cpp.inc",
- "LLDialect.cpp",
- "LLOps.cpp",
- "LLOps.cpp.inc",
- "OpWriters.cpp",
- ],
- hdrs = [
- "HLDialect.h",
- "HLOps.h",
- "HLOps.h.inc",
- "LLDialect.h",
- "LLOps.h",
- "LLOps.h.inc",
- "OpWriters.h",
- ],
- deps = [
- ":HLOpsGen",
- ":LLOpsGen",
- "//iree/compiler/IR",
- "//iree/compiler/Serialization",
- "//iree/compiler/Utils",
- "//iree/schemas/bytecode:interpreter_bytecode_v0",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:StandardOps",
- ],
- alwayslink = 1,
-)
-
-gentbl(
- name = "HLOpsGen",
- tbl_outs = [
- ("-gen-op-decls", "HLOps.h.inc"),
- ("-gen-op-defs", "HLOps.cpp.inc"),
- ],
- tblgen = "@local_config_mlir//:mlir-tblgen",
- td_file = "HLOps.td",
- td_srcs = [
- ":td_files",
- "@local_config_mlir//:include/mlir/IR/OpBase.td",
- "//iree/compiler/IR:OpBase.td",
- ],
-)
-
-gentbl(
- name = "LLOpsGen",
- tbl_outs = [
- ("-gen-op-decls", "LLOps.h.inc"),
- ("-gen-op-defs", "LLOps.cpp.inc"),
- ],
- tblgen = "@local_config_mlir//:mlir-tblgen",
- td_file = "LLOps.td",
- td_srcs = [
- ":td_files",
- "@local_config_mlir//:include/mlir/IR/OpBase.td",
- "//iree/compiler/IR:OpBase.td",
- ],
-)
diff --git a/iree/compiler/IR/Interpreter/HLDialect.cpp b/iree/compiler/IR/Interpreter/HLDialect.cpp
deleted file mode 100644
index 55b3fbb..0000000
--- a/iree/compiler/IR/Interpreter/HLDialect.cpp
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Interpreter/HLDialect.h"
-
-#include "iree/compiler/IR/Interpreter/HLOps.h"
-#include "llvm/Support/SourceMgr.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-IREEHLInterpreterDialect::IREEHLInterpreterDialect(MLIRContext* context)
- : Dialect(getDialectNamespace(), context) {
-#define GET_OP_LIST
- addOperations<
-#include "iree/compiler/IR/Interpreter/HLOps.cpp.inc"
- >();
-}
-
-static DialectRegistration<IREEHLInterpreterDialect> iree_hl_interp_dialect;
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/HLOps.cpp b/iree/compiler/IR/Interpreter/HLOps.cpp
deleted file mode 100644
index c708787..0000000
--- a/iree/compiler/IR/Interpreter/HLOps.cpp
+++ /dev/null
@@ -1,246 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Interpreter/HLOps.h"
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Utils/OpCreationUtils.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREEInterp {
-namespace HL {
-
-//===----------------------------------------------------------------------===//
-// iree_hl_interp.call
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) {
- SymbolRefAttr calleeAttr;
- FunctionType calleeType;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- auto calleeLoc = parser.getNameLoc();
- if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttributeDict(state.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.addTypesToList(calleeType.getResults(), state.types) ||
- parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
- state.operands)) {
- return failure();
- }
- return success();
-}
-
-static void printCallOp(OpAsmPrinter &p, CallOp op) {
- p << "iree_hl_interp.call " << op.getAttr("callee") << '(';
- p.printOperands(op.getOperands());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : ";
- p.printType(op.getCalleeType());
-}
-
-FunctionType CallOp::getCalleeType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
- SmallVector<Type, 8> argTypes(getOperandTypes());
- return FunctionType::get(argTypes, resultTypes, getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_interp.call_indirect
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCallIndirectOp(OpAsmParser &parser,
- OperationState &result) {
- FunctionType calleeType;
- OpAsmParser::OperandType callee;
- llvm::SMLoc operandsLoc;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- return failure(
- parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttributeDict(result.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.resolveOperand(callee, calleeType, result.operands) ||
- parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
- result.operands) ||
- parser.addTypesToList(calleeType.getResults(), result.types));
-}
-
-static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) {
- p << "iree_hl_interp.call_indirect ";
- p.printOperand(op.getCallee());
- p << '(';
- auto operandRange = op.getOperands();
- p.printOperands(++operandRange.begin(), operandRange.end());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : " << op.getCallee()->getType();
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_interp.return
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) {
- SmallVector<OpAsmParser::OperandType, 2> opInfo;
- SmallVector<Type, 2> types;
- llvm::SMLoc loc = parser.getCurrentLocation();
- return failure(parser.parseOperandList(opInfo) ||
- (!opInfo.empty() && parser.parseColonTypeList(types)) ||
- parser.resolveOperands(opInfo, types, loc, state.operands));
-}
-
-static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
- p << "iree_hl_interp.return";
- if (op.getNumOperands() > 0) {
- p << ' ';
- p.printOperands(op.operand_begin(), op.operand_end());
- p << " : ";
- interleaveComma(op.getOperandTypes(), p);
- }
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_interp.br
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
- Block *dest;
- SmallVector<Value *, 4> destOperands;
- if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
- result.addSuccessor(dest, destOperands);
- return success();
-}
-
-static void printBranchOp(OpAsmPrinter &p, BranchOp op) {
- p << "iree_hl_interp.br ";
- p.printSuccessorAndUseList(op.getOperation(), 0);
-}
-
-Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
-
-void BranchOp::setDest(Block *block) {
- return getOperation()->setSuccessor(block, 0);
-}
-
-void BranchOp::eraseOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(0, index);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_interp.cond_br
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCondBranchOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<Value *, 4> destOperands;
- Block *dest;
- OpAsmParser::OperandType condInfo;
-
- // Parse the condition.
- Type int1Ty = parser.getBuilder().getI1Type();
- if (parser.parseOperand(condInfo) || parser.parseComma() ||
- parser.resolveOperand(condInfo, int1Ty, result.operands)) {
- return parser.emitError(parser.getNameLoc(),
- "expected condition type was boolean (i1)");
- }
-
- // Parse the true successor.
- if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
- result.addSuccessor(dest, destOperands);
-
- // Parse the false successor.
- destOperands.clear();
- if (parser.parseComma() ||
- parser.parseSuccessorAndUseList(dest, destOperands))
- return failure();
- result.addSuccessor(dest, destOperands);
-
- return success();
-}
-
-static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) {
- p << "iree_hl_interp.cond_br ";
- p.printOperand(op.getCondition());
- p << ", ";
- p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
- p << ", ";
- p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_interp.clone
-//===----------------------------------------------------------------------===//
-
-OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
- // If this is the only usage, we know the clone is unnecessary.
- // TODO(b/135053584) More sophisticated analysis.
- if (src()->hasOneUse()) return src();
- return {};
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_interp.concat
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct ConcatToCopies : public OpRewritePattern<ConcatOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(ConcatOp concatOp,
- PatternRewriter &rewriter) const override {
- auto finalType = concatOp.getResult()->getType().cast<ShapedType>();
- auto loc = concatOp.getLoc();
- std::vector<Value *> dimPieces;
- auto dst =
- rewriter.create<IREEInterp::HL::AllocHeapOp>(loc, finalType, dimPieces);
-
- llvm::SmallVector<int64_t, 4> zeroOffset(finalType.getRank(), 0);
- auto srcIndices = createArrayConstant(rewriter, loc, zeroOffset);
-
- auto concatDimension = concatOp.dimension().getZExtValue();
- llvm::SmallVector<int64_t, 4> dstIndices(finalType.getRank(), 0);
- for (auto *src : concatOp.srcs()) {
- auto srcShape = src->getType().cast<ShapedType>().getShape();
- auto lengths = createArrayConstant(rewriter, loc, srcShape);
- auto dstIndicesOp = createArrayConstant(rewriter, loc, dstIndices);
- rewriter.create<IREEInterp::HL::CopyOp>(loc, src, srcIndices, dst,
- dstIndicesOp, lengths);
- dstIndices[concatDimension] += srcShape[concatDimension];
- }
-
- concatOp.replaceAllUsesWith(dst.getResult());
-
- return matchSuccess();
- }
-};
-} // namespace
-
-void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.insert<ConcatToCopies>(context);
-}
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/Interpreter/HLOps.cpp.inc"
-
-} // namespace HL
-} // namespace IREEInterp
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/HLOps.h b/iree/compiler/IR/Interpreter/HLOps.h
deleted file mode 100644
index e3a4d23..0000000
--- a/iree/compiler/IR/Interpreter/HLOps.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_IR_INTERPRETER_HLOPS_H_
-#define IREE_COMPILER_IR_INTERPRETER_HLOPS_H_
-
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREEInterp {
-namespace HL {
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/Interpreter/HLOps.h.inc"
-
-} // namespace HL
-} // namespace IREEInterp
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_IR_INTERPRETER_HLOPS_H_
diff --git a/iree/compiler/IR/Interpreter/HLOps.td b/iree/compiler/IR/Interpreter/HLOps.td
deleted file mode 100644
index 3458b45..0000000
--- a/iree/compiler/IR/Interpreter/HLOps.td
+++ /dev/null
@@ -1,658 +0,0 @@
-// IREE high-level interpreter op definitions.
-// This op set contains pseudo ops, ops that accept non-MemRef types, and ops in
-// normal SSA form.
-//
-// Through lowering these high-level ops are converted to low-level ops in the
-// LLOps.td (iree_ll_interp.*). These map 1:1 with the bytecode,
-// accept only MemRef types, and generally use output parameters instead of
-// return types.
-//
-// The source of truth for bytecode opcodes is:
-// https://github.com/google/iree/tree/master/iree/schemas/bytecode/interpreter_bytecode_v0.h
-
-#ifdef IREE_INTERPRETER_HL_OPS
-#else
-#define IREE_INTERPRETER_HL_OPS
-
-#ifdef IREE_OP_BASE
-#else
-include "iree/compiler/IR/OpBase.td"
-#endif // IREE_OP_BASE
-
-def IREEInterpHL_Dialect : Dialect {
- let name = "iree_hl_interp";
- let cppNamespace = "IREEInterp::HL";
-}
-
-//===----------------------------------------------------------------------===//
-// Base op classes
-//===----------------------------------------------------------------------===//
-
-class IREEInterpHL_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<IREEInterpHL_Dialect, mnemonic, traits>;
-
-class IREEInterpHL_PureOp<string mnemonic, list<OpTrait> traits = []> :
- IREEInterpHL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
-
-//===----------------------------------------------------------------------===//
-// High-level interpreter ops
-//===----------------------------------------------------------------------===//
-
-def IREEInterpHL_CallOp : IREEInterpHL_Op<"call"> {
- let arguments = (ins SymbolRefAttr:$callee, Variadic<IREEHL_MemRef>);
- let results = (outs Variadic<IREEHL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, FuncOp callee,"
- "ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(callee.getType().getResults());
- }]>, OpBuilder<
- "Builder *builder, OperationState &result, StringRef callee,"
- "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(results);
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- StringRef getCallee() { return callee(); }
- FunctionType getCalleeType();
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
- operand_iterator arg_operand_begin() { return operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREEInterpHL_CallIndirectOp : IREEInterpHL_Op<"call_indirect"> {
- let arguments = (ins FunctionType:$callee, Variadic<IREEHL_MemRef>:$operands);
- let results = (outs Variadic<IREEHL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *callee,"
- "ArrayRef<Value *> operands = {}", [{
- result.operands.push_back(callee);
- result.addOperands(operands);
- result.addTypes(callee->getType().cast<FunctionType>().getResults());
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- Value *getCallee() { return getOperand(0); }
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
- operand_iterator arg_operand_begin() { return ++operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREEInterpHL_ReturnOp : IREEInterpHL_Op<"return", [Terminator]> {
- let arguments = (ins Variadic<IREEHL_MemRef>:$operands);
-
- let builders = [OpBuilder<
- "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
- >];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREEInterpHL_BranchOp : IREEInterpHL_Op<"br", [Terminator]> {
- let arguments = (ins Variadic<IREEHL_MemRef>:$operands);
-
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Block *dest,"
- "ArrayRef<Value *> operands = {}", [{
- result.addSuccessor(dest, operands);
- }]>];
-
- let extraClassDeclaration = [{
- Block *getDest();
- void setDest(Block *block);
-
- /// Erase the operand at 'index' from the operand list.
- void eraseOperand(unsigned index);
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREEInterpHL_CondBranchOp : IREEInterpHL_Op<"cond_br", [Terminator]> {
- let arguments = (ins
- IREEHL_BoolScalar:$condition,
- Variadic<IREEHL_MemRef>:$branchOperands
- );
-
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *condition,"
- "Block *trueDest, ArrayRef<Value *> trueOperands,"
- "Block *falseDest, ArrayRef<Value *> falseOperands", [{
- result.addOperands(condition);
- result.addSuccessor(trueDest, trueOperands);
- result.addSuccessor(falseDest, falseOperands);
- }]>];
-
- let extraClassDeclaration = [{
- // These are the indices into the dests list.
- enum { trueIndex = 0, falseIndex = 1 };
-
- // The condition operand is the first operand in the list.
- Value *getCondition() { return getOperand(0); }
-
- /// Return the destination if the condition is true.
- Block *getTrueDest() {
- return getOperation()->getSuccessor(trueIndex);
- }
-
- /// Return the destination if the condition is false.
- Block *getFalseDest() {
- return getOperation()->getSuccessor(falseIndex);
- }
-
- // Accessors for operands to the 'true' destination.
- Value *getTrueOperand(unsigned idx) {
- assert(idx < getNumTrueOperands());
- return getOperand(getTrueDestOperandIndex() + idx);
- }
-
- void setTrueOperand(unsigned idx, Value *value) {
- assert(idx < getNumTrueOperands());
- setOperand(getTrueDestOperandIndex() + idx, value);
- }
-
- operand_iterator true_operand_begin() {
- return operand_begin() + getTrueDestOperandIndex();
- }
- operand_iterator true_operand_end() {
- return true_operand_begin() + getNumTrueOperands();
- }
- operand_range getTrueOperands() {
- return {true_operand_begin(), true_operand_end()};
- }
-
- unsigned getNumTrueOperands() {
- return getOperation()->getNumSuccessorOperands(trueIndex);
- }
-
- /// Erase the operand at 'index' from the true operand list.
- void eraseTrueOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(trueIndex, index);
- }
-
- // Accessors for operands to the 'false' destination.
- Value *getFalseOperand(unsigned idx) {
- assert(idx < getNumFalseOperands());
- return getOperand(getFalseDestOperandIndex() + idx);
- }
- void setFalseOperand(unsigned idx, Value *value) {
- assert(idx < getNumFalseOperands());
- setOperand(getFalseDestOperandIndex() + idx, value);
- }
-
- operand_iterator false_operand_begin() { return true_operand_end(); }
- operand_iterator false_operand_end() {
- return false_operand_begin() + getNumFalseOperands();
- }
- operand_range getFalseOperands() {
- return {false_operand_begin(), false_operand_end()};
- }
-
- unsigned getNumFalseOperands() {
- return getOperation()->getNumSuccessorOperands(falseIndex);
- }
-
- /// Erase the operand at 'index' from the false operand list.
- void eraseFalseOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(falseIndex, index);
- }
-
- private:
- /// Get the index of the first true destination operand.
- unsigned getTrueDestOperandIndex() { return 1; }
-
- /// Get the index of the first false destination operand.
- unsigned getFalseDestOperandIndex() {
- return getTrueDestOperandIndex() + getNumTrueOperands();
- }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREEInterpHL_CmpIOp :
- IREEInterpHL_PureOp<"cmp_i", [SameOperandsAndResultShape,
- AllTypesMatch<["lhs", "rhs"]>]> {
- let arguments = (ins
- I32Attr:$predicate,
- IREEHL_IntMemRef:$lhs,
- IREEHL_IntMemRef:$rhs
- );
- let results = (outs IREEHL_BoolMemRef);
-}
-
-def IREEInterpHL_CmpFOp :
- IREEInterpHL_PureOp<"cmp_f", [SameOperandsAndResultShape,
- AllTypesMatch<["lhs", "rhs"]>]> {
- let arguments = (ins
- I32Attr:$predicate,
- IREEHL_FloatMemRef:$lhs,
- IREEHL_FloatMemRef:$rhs
- );
- let results = (outs IREEHL_BoolMemRef);
-}
-
-// TODO(b/142012496): Add trait that enables DCE but not CSE.
-def IREEInterpHL_AllocHeapOp : IREEInterpHL_Op<"alloc_heap"> {
- // TODO(benvanik): attributes and args.
- let arguments = (ins
- Variadic<IREEHL_MemRef>:$dim_pieces
- );
- let results = (outs
- IREEHL_MemRef
- );
-}
-
-def IREEInterpHL_DiscardOp : IREEInterpHL_Op<"discard"> {
- let arguments = (ins IREEHL_MemRef);
-}
-
-def IREEInterpHL_RankOp : IREEInterpHL_PureOp<"rank"> {
- let arguments = (ins IREEHL_MemRef);
- let results = (outs IREEHL_IntScalar);
-}
-
-def IREEInterpHL_DimOp : IREEInterpHL_PureOp<"dim"> {
- // TODO(benvanik) add dim attr (I32Attr:$dim)
- let arguments = (ins IREEHL_MemRef);
- let results = (outs IREEHL_IntScalar);
-}
-
-def IREEInterpHL_ShapeOp : IREEInterpHL_PureOp<"shape"> {
- let arguments = (ins IREEHL_MemRef);
- let results = (outs IREEHL_1DIntMemRef);
-}
-
-def IREEInterpHL_LengthOp : IREEInterpHL_PureOp<"length"> {
- let arguments = (ins IREEHL_MemRef);
- let results = (outs IREEHL_IndexScalar);
-}
-
-def IREEInterpHL_SliceOp :
- IREEInterpHL_PureOp<"slice", [AllElementTypesMatch<["src", "result"]>,
- AllTypesMatch<["srcIndices", "lengths"]>]> {
- let arguments = (ins
- IREEHL_MemRef:$src,
- IREEHL_1DIndexMemRef:$srcIndices,
- IREEHL_1DIndexMemRef:$lengths
- );
- let results = (outs IREEHL_MemRef:$result);
-}
-
-def IREEInterpHL_CopyOp : IREEInterpHL_Op<"copy", [
- AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>,
- AllRanksMatch<["src", "dst"]>,
- // The checks above are redundant with this one, but they give more specific
- // error messages.
- AllMatch<[
- Rank<"src">.result,
- Rank<"dst">.result,
- ElementCount<"srcIndices">.result,
- ElementCount<"dstIndices">.result,
- ElementCount<"lengths">.result
- ], "src/dst rank is the same as srcIndices/dstIndices/lengths size">,
- AllElementTypesMatch<["src", "dst"]>
-]> {
- let arguments = (ins
- IREEHL_MemRef:$src,
- IREEHL_1DIndexMemRef:$srcIndices,
- IREEHL_MemRef:$dst,
- IREEHL_1DIndexMemRef:$dstIndices,
- IREEHL_1DIndexMemRef:$lengths
- );
-}
-
-def IREEInterpHL_CloneOp :
- IREEInterpHL_PureOp<"clone", [SameOperandsAndResultType]> {
- let arguments = (ins IREEHL_MemRef:$src);
- let results = (outs IREEHL_MemRef);
-
- let hasFolder = 1;
-}
-
-// A pseudo op provided for convenience. This gets canonicalized to a series of
-// copies.
-def IREEInterpHL_ConcatOp : IREEInterpHL_PureOp<"concat"> {
- let arguments = (ins
- Variadic<IREEHL_MemRef>:$srcs,
- I32Attr:$dimension
- );
- let results = (outs IREEHL_MemRef);
-
- let hasCanonicalizer = 1;
-}
-
-// TODO(benvanik): add split dim/size/etc. Maybe make multiple ops?
-def IREEInterpHL_SplitOp :
- IREEInterpHL_PureOp<"split", [SameOperandsAndResultElementType]> {
- let arguments = (ins IREEHL_MemRef:$src);
- let results = (outs Variadic<IREEHL_MemRef>);
-}
-
-def IREEInterpHL_AssignOp :
- IREEInterpHL_PureOp<"assign", [SameOperandsAndResultType]> {
- let arguments = (ins IREEHL_MemRef:$src);
- let results = (outs IREEHL_MemRef:$result);
-}
-
-def IREEInterpHL_CondAssignOp :
- IREEInterpHL_PureOp<"cond_assign",
- [AllTypesMatch<["lhs", "rhs", "result"]>]> {
- let arguments = (ins
- IREEHL_BoolScalar:$cond,
- IREEHL_MemRef:$lhs,
- IREEHL_MemRef:$rhs
- );
- let results = (outs IREEHL_MemRef:$result);
-}
-
-def IREEInterpHL_ReshapeOp : IREEInterpHL_PureOp<"reshape"> {
- let arguments = (ins IREEHL_MemRef:$src, IREEHL_MemRef:$shape);
- let results = (outs IREEHL_MemRef);
-}
-
-def IREEInterpHL_SelectOp :
- IREEInterpHL_PureOp<"select", [AllTypesMatch<["lhs", "rhs", "result"]>]> {
- let arguments = (ins
- IREEHL_BoolMemRef:$cond,
- IREEHL_MemRef:$lhs,
- IREEHL_MemRef:$rhs
- );
- let results = (outs IREEHL_MemRef:$result);
-}
-
-def IREEInterpHL_BroadcastOp :
- IREEInterpHL_PureOp<"broadcast",
- [AllElementTypesMatch<["operand", "result"]>]> {
- let arguments = (ins
- IREE_ScalarMemRefOf<[AnyType]>:$operand,
- IREEHL_1DIntMemRef:$shape
- );
- let results = (outs IREEHL_MemRef:$result);
-}
-
-def IREEInterpHL_PadOp :
- IREEInterpHL_PureOp<
- "pad", [AllElementTypesMatch<["src", "result", "padding_value"]>]> {
- let arguments = (ins
- IREEHL_MemRef:$src,
- IREEHL_AnyScalar:$padding_value,
- IREEHL_1DIndexMemRef:$edge_padding_low,
- IREEHL_1DIndexMemRef:$edge_padding_high,
- IREEHL_1DIndexMemRef:$interior_padding
- );
-
- let results = (outs IREEHL_MemRef:$result);
-}
-
-def IREEInterpHL_TileOp :
- IREEInterpHL_PureOp<"tile", [AllElementTypesMatch<["operand", "result"]>]> {
- let arguments = (ins
- IREEHL_MemRef:$operand,
- IREEHL_1DIntMemRef:$shape
- );
- let results = (outs IREEHL_MemRef:$result);
-}
-
-def IREEInterpHL_TransposeOp :
- IREEInterpHL_PureOp<"transpose", [
- AllElementTypesMatch<["operand", "result"]>,
- AllRanksMatch<["operand", "result"]>,
- AllElementCountsMatch<["operand", "result"]>
- ]> {
- let arguments = (ins
- IREEHL_MemRef:$operand,
- IREEHL_1DIntMemRef:$permutation
- );
- let results = (outs IREEHL_MemRef:$result);
-}
-
-def IREEInterpHL_ReverseOp :
- IREEInterpHL_PureOp<"reverse", [AllTypesMatch<["operand", "result"]>]> {
- let arguments = (ins
- IREEHL_MemRef:$operand,
- IREEHL_1DIntMemRef:$dims
- );
- let results = (outs IREEHL_MemRef:$result);
-}
-
-class IREEInterpHL_UnaryElementwiseOp<string mnemonic, Type type,
- list<OpTrait> traits = []> :
- IREEInterpHL_PureOp<mnemonic,
- !listconcat(traits, [SameOperandsAndResultType])> {
- let arguments = (ins type);
- let results = (outs type);
-}
-
-class IREEInterpHL_UnaryElementwiseFloatOp<string mnemonic,
- list<OpTrait> traits = []> :
- IREEInterpHL_UnaryElementwiseOp<mnemonic, IREEHL_FloatMemRef, traits>;
-
-class IREEInterpHL_UnaryElementwiseIntOp<string mnemonic,
- list<OpTrait> traits = []> :
- IREEInterpHL_UnaryElementwiseOp<mnemonic, IREEHL_IntMemRef, traits>;
-
-class IREEInterpHL_BinaryElementwiseOp<string mnemonic, Type type,
- list<OpTrait> traits> :
- IREEInterpHL_PureOp<mnemonic,
- !listconcat(traits, [SameOperandsAndResultType])> {
- let arguments = (ins type:$lhs, type:$rhs);
- let results = (outs type);
-}
-
-class IREEInterpHL_BinaryElementwiseFloatOp<string mnemonic,
- list<OpTrait> traits = []> :
- IREEInterpHL_BinaryElementwiseOp<mnemonic, IREEHL_FloatMemRef,
- traits>;
-
-class IREEInterpHL_BinaryElementwiseIntOp<string mnemonic,
- list<OpTrait> traits = []> :
- IREEInterpHL_BinaryElementwiseOp<mnemonic, IREEHL_IntMemRef,
- traits>;
-
-class IREEInterpHL_TernaryOp<string mnemonic,
- Type type = IREEHL_MemRef,
- list<OpTrait> traits = []> :
- IREEInterpHL_PureOp<mnemonic, traits> {
- let arguments = (ins type:$a, type:$b, type:$c);
- let results = (outs type);
-}
-
-// TODO(benvanik): add traits for broadcasting support.
-
-def IREEInterpHL_NotOp : IREEInterpHL_UnaryElementwiseIntOp<"not">;
-def IREEInterpHL_AndOp : IREEInterpHL_BinaryElementwiseIntOp<"and">;
-def IREEInterpHL_OrOp : IREEInterpHL_BinaryElementwiseIntOp<"or">;
-def IREEInterpHL_XorOp : IREEInterpHL_BinaryElementwiseIntOp<"xor">;
-def IREEInterpHL_ShiftLeftOp : IREEInterpHL_BinaryElementwiseIntOp<"sll">;
-def IREEInterpHL_ShiftRightLogicalOp : IREEInterpHL_BinaryElementwiseIntOp<"srl">;
-def IREEInterpHL_ShiftRightArithmeticOp : IREEInterpHL_BinaryElementwiseIntOp<"sra">;
-
-def IREEInterpHL_AddIOp : IREEInterpHL_BinaryElementwiseIntOp<"add_i">;
-def IREEInterpHL_AddFOp : IREEInterpHL_BinaryElementwiseFloatOp<"add_f">;
-def IREEInterpHL_SubIOp : IREEInterpHL_BinaryElementwiseIntOp<"sub_i">;
-def IREEInterpHL_SubFOp : IREEInterpHL_BinaryElementwiseFloatOp<"sub_f">;
-def IREEInterpHL_AbsIOp : IREEInterpHL_UnaryElementwiseIntOp<"abs_i">;
-def IREEInterpHL_AbsFOp : IREEInterpHL_UnaryElementwiseFloatOp<"abs_f">;
-def IREEInterpHL_MulIOp : IREEInterpHL_BinaryElementwiseIntOp<"mul_i">;
-def IREEInterpHL_MulFOp : IREEInterpHL_BinaryElementwiseFloatOp<"mul_f">;
-def IREEInterpHL_DivISOp : IREEInterpHL_BinaryElementwiseIntOp<"div_i_s">;
-def IREEInterpHL_DivIUOp : IREEInterpHL_BinaryElementwiseIntOp<"div_i_u">;
-def IREEInterpHL_DivFOp : IREEInterpHL_BinaryElementwiseFloatOp<"div_f">;
-def IREEInterpHL_MulAddIOp : IREEInterpHL_TernaryOp<"madd_i", IREEHL_IntMemRef>;
-def IREEInterpHL_MulAddFOp : IREEInterpHL_TernaryOp<"madd_f", IREEHL_FloatMemRef>;
-def IREEInterpHL_ExpFOp : IREEInterpHL_UnaryElementwiseFloatOp<"exp_f">;
-def IREEInterpHL_LogFOp : IREEInterpHL_UnaryElementwiseFloatOp<"log_f">;
-def IREEInterpHL_RsqrtFOp : IREEInterpHL_UnaryElementwiseFloatOp<"rsqrt_f">;
-def IREEInterpHL_CosFOp : IREEInterpHL_UnaryElementwiseFloatOp<"cos_f">;
-def IREEInterpHL_SinFOp : IREEInterpHL_UnaryElementwiseFloatOp<"sin_f">;
-def IREEInterpHL_TanhFOp : IREEInterpHL_UnaryElementwiseFloatOp<"tanh_f">;
-def IREEInterpHL_Atan2FOp : IREEInterpHL_UnaryElementwiseFloatOp<"atan2_f">;
-
-def IREEInterpHL_MinISOp : IREEInterpHL_BinaryElementwiseIntOp<"min_i_s">;
-def IREEInterpHL_MinIUOp : IREEInterpHL_BinaryElementwiseIntOp<"min_i_u">;
-def IREEInterpHL_MinFOp : IREEInterpHL_BinaryElementwiseFloatOp<"min_f">;
-def IREEInterpHL_MaxISOp : IREEInterpHL_BinaryElementwiseIntOp<"max_i_s">;
-def IREEInterpHL_MaxIUOp : IREEInterpHL_BinaryElementwiseIntOp<"max_i_u">;
-def IREEInterpHL_MaxFOp : IREEInterpHL_BinaryElementwiseFloatOp<"max_f">;
-def IREEInterpHL_ClampFOp : IREEInterpHL_TernaryOp<"clamp_f", IREEHL_FloatMemRef>;
-def IREEInterpHL_FloorFOp : IREEInterpHL_UnaryElementwiseFloatOp<"floor_f">;
-def IREEInterpHL_CeilFOp : IREEInterpHL_UnaryElementwiseFloatOp<"ceil_f">;
-
-class IREEInterpHL_ConversionOp<string mnemonic, Type inputType,
- Type outputType> :
- IREEInterpHL_PureOp<mnemonic, [SameOperandsAndResultShape]> {
- let arguments = (ins inputType);
- let results = (outs outputType);
-}
-
-def IREEInterpHL_ConvertSSOp :
- IREEInterpHL_ConversionOp<"convert_s_s", IREEHL_IntMemRef,
- IREEHL_IntMemRef>;
-def IREEInterpHL_ConvertSUOp :
- IREEInterpHL_ConversionOp<"convert_s_u", IREEHL_IntMemRef,
- IREEHL_IntMemRef>;
-def IREEInterpHL_ConvertSFOp :
- IREEInterpHL_ConversionOp<"convert_s_f", IREEHL_IntMemRef,
- IREEHL_FloatMemRef>;
-
-def IREEInterpHL_ConvertUSOp :
- IREEInterpHL_ConversionOp<"convert_u_s", IREEHL_IntMemRef,
- IREEHL_IntMemRef>;
-def IREEInterpHL_ConvertUUOp :
- IREEInterpHL_ConversionOp<"convert_u_u", IREEHL_IntMemRef,
- IREEHL_IntMemRef>;
-def IREEInterpHL_ConvertUFOp :
- IREEInterpHL_ConversionOp<"convert_u_f", IREEHL_IntMemRef,
- IREEHL_FloatMemRef>;
-
-def IREEInterpHL_ConvertFSOp :
- IREEInterpHL_ConversionOp<"convert_f_s", IREEHL_FloatMemRef,
- IREEHL_IntMemRef>;
-def IREEInterpHL_ConvertFUOp :
- IREEInterpHL_ConversionOp<"convert_f_u", IREEHL_FloatMemRef,
- IREEHL_IntMemRef>;
-def IREEInterpHL_ConvertFFOp :
- IREEInterpHL_ConversionOp<"convert_f_f", IREEHL_FloatMemRef,
- IREEHL_FloatMemRef>;
-
-def IREEInterpHL_MatMulIOp :
- IREEInterpHL_PureOp<"matmul_i",
- [AllElementTypesMatch<["lhs", "rhs", "result"]>]> {
- let arguments = (ins
- IREEHL_IntMemRef:$lhs,
- IREEHL_IntMemRef:$rhs,
- IREEHL_IntMemRef:$multiplier_mantissa,
- IREEHL_IntMemRef:$multiplier_exponent
- );
- let results = (outs IREEHL_IntMemRef:$result);
-}
-def IREEInterpHL_MatMulFOp :
- IREEInterpHL_PureOp<"matmul_f", [SameOperandsAndResultElementType]> {
- let arguments = (ins
- IREEHL_FloatMemRef:$lhs,
- IREEHL_FloatMemRef:$rhs
- );
- let results = (outs IREEHL_FloatMemRef);
-}
-
-def IREEInterpHL_ReduceSumIOp :
- IREEInterpHL_PureOp<"reduce_sum_i",
- [AllElementTypesMatch<["src", "result", "init"]>]> {
- let arguments = (ins
- IREEHL_IntMemRef:$src,
- IREEHL_IntMemRef:$init,
- I32Attr:$dimension
- );
- let results = (outs IREEHL_IntMemRef:$result);
-}
-def IREEInterpHL_ReduceSumFOp :
- IREEInterpHL_PureOp<"reduce_sum_f",
- [AllElementTypesMatch<["src", "result", "init"]>]> {
- let arguments = (ins
- IREEHL_FloatMemRef:$src,
- IREEHL_FloatMemRef:$init,
- I32Attr:$dimension
- );
- let results = (outs IREEHL_FloatMemRef:$result);
-}
-def IREEInterpHL_ReduceMinIOp :
- IREEInterpHL_PureOp<"reduce_min_i",
- [AllElementTypesMatch<["src", "result", "init"]>]> {
- let arguments = (ins
- IREEHL_IntMemRef:$src,
- IREEHL_IntMemRef:$init,
- I32Attr:$dimension
- );
- let results = (outs IREEHL_IntMemRef:$result);
-}
-def IREEInterpHL_ReduceMinFOp :
- IREEInterpHL_PureOp<"reduce_min_f",
- [AllElementTypesMatch<["src", "result", "init"]>]> {
- let arguments = (ins
- IREEHL_FloatMemRef:$src,
- IREEHL_FloatMemRef:$init,
- I32Attr:$dimension
- );
- let results = (outs IREEHL_FloatMemRef:$result);
-}
-def IREEInterpHL_ReduceMaxIOp :
- IREEInterpHL_PureOp<"reduce_max_i",
- [AllElementTypesMatch<["src", "result", "init"]>]> {
- let arguments = (ins
- IREEHL_IntMemRef:$src,
- IREEHL_IntMemRef:$init,
- I32Attr:$dimension
- );
- let results = (outs IREEHL_IntMemRef:$result);
-}
-def IREEInterpHL_ReduceMaxFOp :
- IREEInterpHL_PureOp<"reduce_max_f",
- [AllElementTypesMatch<["src", "result", "init"]>]> {
- let arguments = (ins
- IREEHL_FloatMemRef:$src,
- IREEHL_FloatMemRef:$init,
- I32Attr:$dimension
- );
- let results = (outs IREEHL_FloatMemRef:$result);
-}
-
-def IREEInterpHL_TraceOp : IREEInterpHL_Op<"trace"> {
- let arguments = (ins Variadic<IREEHL_MemRef>:$srcs);
-}
-
-def IREEInterpHL_CondBreakOp : IREEInterpHL_Op<"cond_break"> {
- let arguments = (ins IREEHL_BoolScalar:$cond);
-}
-
-def IREEInterpHL_BreakOp : IREEInterpHL_Op<"break">;
-
-#endif // IREE_INTERPRETER_HL_OPS
diff --git a/iree/compiler/IR/Interpreter/LLDialect.cpp b/iree/compiler/IR/Interpreter/LLDialect.cpp
deleted file mode 100644
index ffe4289..0000000
--- a/iree/compiler/IR/Interpreter/LLDialect.cpp
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Interpreter/LLDialect.h"
-
-#include "iree/compiler/IR/Interpreter/LLOps.h"
-#include "llvm/Support/SourceMgr.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-IREELLInterpreterDialect::IREELLInterpreterDialect(MLIRContext* context)
- : Dialect(getDialectNamespace(), context) {
-#define GET_OP_LIST
- addOperations<
-#include "iree/compiler/IR/Interpreter/LLOps.cpp.inc"
- >();
-}
-
-static DialectRegistration<IREELLInterpreterDialect> iree_ll_interp_dialect;
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/LLOps.cpp b/iree/compiler/IR/Interpreter/LLOps.cpp
deleted file mode 100644
index 2fbbb64..0000000
--- a/iree/compiler/IR/Interpreter/LLOps.cpp
+++ /dev/null
@@ -1,228 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Interpreter/LLOps.h"
-
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/OpImplementation.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREEInterp {
-namespace LL {
-
-//===----------------------------------------------------------------------===//
-// iree_ll_interp.call
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) {
- SymbolRefAttr calleeAttr;
- FunctionType calleeType;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- auto calleeLoc = parser.getNameLoc();
- if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttributeDict(state.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.addTypesToList(calleeType.getResults(), state.types) ||
- parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
- state.operands)) {
- return failure();
- }
- return success();
-}
-
-static void printCallOp(OpAsmPrinter &p, CallOp op) {
- p << "iree_ll_interp.call " << op.getAttr("callee") << '(';
- p.printOperands(op.getOperands());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : ";
- p.printType(op.getCalleeType());
-}
-
-FunctionType CallOp::getCalleeType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
- SmallVector<Type, 8> argTypes(getOperandTypes());
- return FunctionType::get(argTypes, resultTypes, getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_interp.call_import
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCallImportOp(OpAsmParser &parser,
- OperationState &state) {
- SymbolRefAttr calleeAttr;
- FunctionType calleeType;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- auto calleeLoc = parser.getNameLoc();
- if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttributeDict(state.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.addTypesToList(calleeType.getResults(), state.types) ||
- parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
- state.operands)) {
- return failure();
- }
- return success();
-}
-
-static void printCallImportOp(OpAsmPrinter &p, CallImportOp op) {
- p << "iree_ll_interp.call_import " << op.getAttr("callee") << '(';
- p.printOperands(op.getOperands());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : ";
- p.printType(op.getCalleeType());
-}
-
-FunctionType CallImportOp::getCalleeType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
- SmallVector<Type, 8> argTypes(getOperandTypes());
- return FunctionType::get(argTypes, resultTypes, getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_interp.call_indirect
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCallIndirectOp(OpAsmParser &parser,
- OperationState &result) {
- FunctionType calleeType;
- OpAsmParser::OperandType callee;
- llvm::SMLoc operandsLoc;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- return failure(
- parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttributeDict(result.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.resolveOperand(callee, calleeType, result.operands) ||
- parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
- result.operands) ||
- parser.addTypesToList(calleeType.getResults(), result.types));
-}
-
-static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) {
- p << "iree_ll_interp.call_indirect ";
- p.printOperand(op.getCallee());
- p << '(';
- auto operandRange = op.getOperands();
- p.printOperands(++operandRange.begin(), operandRange.end());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : " << op.getCallee()->getType();
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_interp.return
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) {
- SmallVector<OpAsmParser::OperandType, 2> opInfo;
- SmallVector<Type, 2> types;
- llvm::SMLoc loc = parser.getCurrentLocation();
- return failure(parser.parseOperandList(opInfo) ||
- (!opInfo.empty() && parser.parseColonTypeList(types)) ||
- parser.resolveOperands(opInfo, types, loc, state.operands));
-}
-
-static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
- p << "iree_ll_interp.return";
- if (op.getNumOperands() > 0) {
- p << ' ';
- p.printOperands(op.operand_begin(), op.operand_end());
- p << " : ";
- interleaveComma(op.getOperandTypes(), p);
- }
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_interp.br
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
- Block *dest;
- SmallVector<Value *, 4> destOperands;
- if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
- result.addSuccessor(dest, destOperands);
- return success();
-}
-
-static void printBranchOp(OpAsmPrinter &p, BranchOp op) {
- p << "iree_ll_interp.br ";
- p.printSuccessorAndUseList(op.getOperation(), 0);
-}
-
-Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
-
-void BranchOp::setDest(Block *block) {
- return getOperation()->setSuccessor(block, 0);
-}
-
-void BranchOp::eraseOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(0, index);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_interp.cond_br
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCondBranchOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<Value *, 4> destOperands;
- Block *dest;
- OpAsmParser::OperandType condInfo;
-
- // Parse the condition.
- Type int1Ty = parser.getBuilder().getI1Type();
- if (parser.parseOperand(condInfo) || parser.parseComma() ||
- parser.resolveOperand(condInfo, int1Ty, result.operands)) {
- return parser.emitError(parser.getNameLoc(),
- "expected condition type was boolean (i1)");
- }
-
- // Parse the true successor.
- if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
- result.addSuccessor(dest, destOperands);
-
- // Parse the false successor.
- destOperands.clear();
- if (parser.parseComma() ||
- parser.parseSuccessorAndUseList(dest, destOperands))
- return failure();
- result.addSuccessor(dest, destOperands);
-
- return success();
-}
-
-static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) {
- p << "iree_ll_interp.cond_br ";
- p.printOperand(op.getCondition());
- p << ", ";
- p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
- p << ", ";
- p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
-}
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/Interpreter/LLOps.cpp.inc"
-
-} // namespace LL
-} // namespace IREEInterp
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/LLOps.h b/iree/compiler/IR/Interpreter/LLOps.h
deleted file mode 100644
index 578b33f..0000000
--- a/iree/compiler/IR/Interpreter/LLOps.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_IR_INTERPRETER_LLOPS_H_
-#define IREE_COMPILER_IR_INTERPRETER_LLOPS_H_
-
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREEInterp {
-namespace LL {
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/Interpreter/LLOps.h.inc"
-
-} // namespace LL
-} // namespace IREEInterp
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_IR_INTERPRETER_LLOPS_H_
diff --git a/iree/compiler/IR/Interpreter/LLOps.td b/iree/compiler/IR/Interpreter/LLOps.td
deleted file mode 100644
index b4ec592..0000000
--- a/iree/compiler/IR/Interpreter/LLOps.td
+++ /dev/null
@@ -1,633 +0,0 @@
-// IREE low-level interpreter op definitions.
-// These map 1:1 with the bytecode, accept only MemRef types and generally use
-// output parameters instead of return types.
-//
-// The source of truth for bytecode opcodes is:
-// https://github.com/google/iree/tree/master/iree/schemas/bytecode/interpreter_bytecode_v0.h
-
-#ifdef IREE_INTERPRETER_LL_OPS
-#else
-#define IREE_INTERPRETER_LL_OPS
-
-#ifdef IREE_OP_BASE
-#else
-include "iree/compiler/IR/OpBase.td"
-#endif // IREE_OP_BASE
-
-def IREEInterpLL_Dialect : Dialect {
- let name = "iree_ll_interp";
- let cppNamespace = "IREEInterp::LL";
-}
-
-//===----------------------------------------------------------------------===//
-// Base op classes
-//===----------------------------------------------------------------------===//
-
-class IREEInterpLL_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<IREEInterpLL_Dialect, mnemonic, traits>;
-
-class IREEInterpLL_PureOp<string mnemonic, list<OpTrait> traits = []> :
- IREEInterpLL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
-
-class IREEInterpLL_UnaryOp<string mnemonic, Type type = IREELL_MemRef,
- list<OpTrait> traits = []> : IREEInterpLL_Op<mnemonic, traits> {
- let arguments = (ins type:$input, type:$dst);
-}
-
-class IREEInterpLL_BinaryOp<string mnemonic, Type type = IREELL_MemRef,
- list<OpTrait> traits = []> : IREEInterpLL_Op<mnemonic, traits> {
- let arguments = (ins type:$lhs, type:$rhs, type:$dst);
-}
-
-class IREEInterpLL_TernaryOp<string mnemonic, Type type = IREELL_MemRef,
- list<OpTrait> traits = []>
- : IREEInterpLL_Op<mnemonic, traits> {
- let arguments = (ins type : $a, type : $b, type : $c, type : $dst);
-}
-
-//===----------------------------------------------------------------------===//
-// Low-level interpreter ops
-//===----------------------------------------------------------------------===//
-
-// TODO(benvanik): value attribute.
-def IREEInterpLL_ConstantOp : IREEInterpLL_PureOp<"constant"> {
- let results = (outs IREELL_MemRef);
-}
-
-def IREEInterpLL_CallOp : IREEInterpLL_Op<"call"> {
- let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>);
- let results = (outs Variadic<IREELL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, FuncOp callee,"
- "ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(callee.getType().getResults());
- }]>, OpBuilder<
- "Builder *builder, OperationState &result, StringRef callee,"
- "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(results);
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- StringRef getCallee() { return callee(); }
- FunctionType getCalleeType();
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
-
- operand_iterator arg_operand_begin() { return operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-// TODO(benvanik): add verifier that target isExternal.
-def IREEInterpLL_CallImportOp : IREEInterpLL_Op<"call_import"> {
- let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>);
- let results = (outs Variadic<IREELL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, FuncOp callee,"
- "ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(callee.getType().getResults());
- }]>, OpBuilder<
- "Builder *builder, OperationState &result, StringRef callee,"
- "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(results);
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- StringRef getCallee() { return callee(); }
- FunctionType getCalleeType();
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
-
- operand_iterator arg_operand_begin() { return operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREEInterpLL_CallIndirectOp : IREEInterpLL_Op<"call_indirect"> {
- let arguments = (ins FunctionType:$callee, Variadic<IREELL_MemRef>:$operands);
- let results = (outs Variadic<IREELL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *callee,"
- "ArrayRef<Value *> operands = {}", [{
- result.operands.push_back(callee);
- result.addOperands(operands);
- result.addTypes(callee->getType().cast<FunctionType>().getResults());
- }]>];
-
- let extraClassDeclaration = [{
- Value *getCallee() { return getOperand(0); }
-
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
-
- operand_iterator arg_operand_begin() { return ++operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREEInterpLL_ReturnOp : IREEInterpLL_Op<"return", [Terminator]> {
- let arguments = (ins Variadic<IREELL_MemRef>:$operands);
-
- let builders = [OpBuilder<
- "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
- >];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREEInterpLL_BranchOp : IREEInterpLL_Op<"br", [Terminator]> {
- let arguments = (ins Variadic<IREELL_MemRef>:$operands);
-
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Block *dest,"
- "ArrayRef<Value *> operands = {}", [{
- result.addSuccessor(dest, operands);
- }]>];
-
- let extraClassDeclaration = [{
- Block *getDest();
- void setDest(Block *block);
-
- /// Erase the operand at 'index' from the operand list.
- void eraseOperand(unsigned index);
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREEInterpLL_CondBranchOp : IREEInterpLL_Op<"cond_br", [Terminator]> {
- let arguments = (ins
- IREELL_BoolScalar:$condition,
- Variadic<IREELL_MemRef>:$branchOperands
- );
-
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *condition,"
- "Block *trueDest, ArrayRef<Value *> trueOperands,"
- "Block *falseDest, ArrayRef<Value *> falseOperands", [{
- result.addOperands(condition);
- result.addSuccessor(trueDest, trueOperands);
- result.addSuccessor(falseDest, falseOperands);
- }]>];
-
- let extraClassDeclaration = [{
- // These are the indices into the dests list.
- enum { trueIndex = 0, falseIndex = 1 };
-
- // The condition operand is the first operand in the list.
- Value *getCondition() { return getOperand(0); }
-
- /// Return the destination if the condition is true.
- Block *getTrueDest() {
- return getOperation()->getSuccessor(trueIndex);
- }
-
- /// Return the destination if the condition is false.
- Block *getFalseDest() {
- return getOperation()->getSuccessor(falseIndex);
- }
-
- // Accessors for operands to the 'true' destination.
- Value *getTrueOperand(unsigned idx) {
- assert(idx < getNumTrueOperands());
- return getOperand(getTrueDestOperandIndex() + idx);
- }
-
- void setTrueOperand(unsigned idx, Value *value) {
- assert(idx < getNumTrueOperands());
- setOperand(getTrueDestOperandIndex() + idx, value);
- }
-
- operand_iterator true_operand_begin() {
- return operand_begin() + getTrueDestOperandIndex();
- }
- operand_iterator true_operand_end() {
- return true_operand_begin() + getNumTrueOperands();
- }
- operand_range getTrueOperands() {
- return {true_operand_begin(), true_operand_end()};
- }
-
- unsigned getNumTrueOperands() {
- return getOperation()->getNumSuccessorOperands(trueIndex);
- }
-
- /// Erase the operand at 'index' from the true operand list.
- void eraseTrueOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(trueIndex, index);
- }
-
- // Accessors for operands to the 'false' destination.
- Value *getFalseOperand(unsigned idx) {
- assert(idx < getNumFalseOperands());
- return getOperand(getFalseDestOperandIndex() + idx);
- }
- void setFalseOperand(unsigned idx, Value *value) {
- assert(idx < getNumFalseOperands());
- setOperand(getFalseDestOperandIndex() + idx, value);
- }
-
- operand_iterator false_operand_begin() { return true_operand_end(); }
- operand_iterator false_operand_end() {
- return false_operand_begin() + getNumFalseOperands();
- }
- operand_range getFalseOperands() {
- return {false_operand_begin(), false_operand_end()};
- }
-
- unsigned getNumFalseOperands() {
- return getOperation()->getNumSuccessorOperands(falseIndex);
- }
-
- /// Erase the operand at 'index' from the false operand list.
- void eraseFalseOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(falseIndex, index);
- }
-
- private:
- /// Get the index of the first true destination operand.
- unsigned getTrueDestOperandIndex() { return 1; }
-
- /// Get the index of the first false destination operand.
- unsigned getFalseDestOperandIndex() {
- return getTrueDestOperandIndex() + getNumTrueOperands();
- }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREEInterpLL_CmpIOp : IREEInterpLL_Op<"cmp_i"> {
- let arguments = (ins
- I32Attr:$predicate,
- IREELL_IntMemRef:$lhs,
- IREELL_IntMemRef:$rhs,
- IREELL_BoolMemRef:$dst
- );
-}
-
-def IREEInterpLL_CmpFOp : IREEInterpLL_Op<"cmp_f"> {
- let arguments = (ins
- I32Attr:$predicate,
- IREELL_FloatMemRef:$lhs,
- IREELL_FloatMemRef:$rhs,
- IREELL_BoolMemRef:$dst
- );
-}
-
-def IREEInterpLL_AllocStaticOp : IREEInterpLL_PureOp<"alloc_static"> {
- // TODO(benvanik): attributes and args.
- let results = (outs IREELL_MemRef);
-}
-
-def IREEInterpLL_AllocStackOp : IREEInterpLL_PureOp<"alloc_stack"> {
- // TODO(benvanik): atributes and args.
- let arguments = (ins
- Variadic<IREELL_MemRef>:$dim_pieces
- );
- let results = (outs
- IREELL_MemRef
- );
-}
-
-def IREEInterpLL_AllocStackInitOp : IREEInterpLL_PureOp<"alloc_stack_init"> {
- // TODO(benvanik): attributes and args.
- let arguments = (ins
- Variadic<IREELL_MemRef>:$dim_pieces
- );
- let results = (outs
- IREELL_MemRef
- );
-}
-
-// TODO(b/142012496): Add trait that enables DCE but not CSE.
-def IREEInterpLL_AllocHeapOp : IREEInterpLL_Op<"alloc_heap"> {
- // TODO(benvanik): attributes and args.
- let arguments = (ins
- Variadic<IREELL_MemRef>:$dim_pieces
- );
- let results = (outs
- IREELL_MemRef
- );
-}
-
-def IREEInterpLL_DiscardOp : IREEInterpLL_Op<"discard"> {
- let arguments = (ins IREELL_MemRef);
-}
-
-def IREEInterpLL_RankOp : IREEInterpLL_Op<"rank"> {
- let arguments = (ins
- IREELL_MemRef:$input,
- IREELL_I32Scalar:$dst
- );
-}
-
-def IREEInterpLL_DimOp : IREEInterpLL_Op<"dim"> {
- // TODO(benvanik) add dim attr (I32Attr:$dim)
- let arguments = (ins
- IREELL_MemRef:$input,
- IREELL_I32Scalar:$dst
- );
-}
-
-def IREEInterpLL_ShapeOp : IREEInterpLL_Op<"shape"> {
- let arguments = (ins
- IREELL_MemRef:$input,
- IREELL_I32MemRef:$dst
- );
-}
-
-def IREEInterpLL_LengthOp : IREEInterpLL_Op<"length"> {
- let arguments = (ins
- IREELL_MemRef:$input,
- IREELL_I32Scalar:$dst
- );
-}
-
-
-def IREEInterpLL_DynamicSliceOp : IREEInterpLL_PureOp<"dynamic_slice"> {
- let arguments = (ins
- IREELL_MemRef:$src,
- IREELL_1DIndexMemRef:$srcIndices,
- IREELL_1DIndexMemRef:$lengths
- );
- let results = (outs
- IREELL_MemRef
- );
-}
-
-// TODO(benvanik): add attribute requirements/types.
-def IREEInterpLL_StaticSliceOp :
- IREEInterpLL_PureOp<"static_slice", [SameOperandsAndResultElementType]> {
- let arguments = (ins IREELL_MemRef:$src);
- let results = (outs IREELL_MemRef);
-}
-
-def IREEInterpLL_DynamicCopyOp : IREEInterpLL_Op<"dynamic_copy", [
- AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>,
-]> {
- let arguments = (ins
- IREELL_MemRef:$src,
- IREELL_1DIndexMemRef:$srcIndices,
- IREELL_MemRef:$dst,
- IREELL_1DIndexMemRef:$dstIndices,
- IREELL_1DIndexMemRef:$lengths
- );
-}
-
-def IREEInterpLL_StaticCopyOp : IREEInterpLL_Op<"static_copy", [
- AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>,
-]> {
- let arguments = (ins
- IREELL_MemRef:$src,
- I32ElementsAttr:$srcIndices,
- IREELL_MemRef:$dst,
- I32ElementsAttr:$dstIndices,
- I32ElementsAttr:$lengths
- );
-}
-
-def IREEInterpLL_CloneOp :
- IREEInterpLL_PureOp<"clone", [SameOperandsAndResultType]> {
- let arguments = (ins IREELL_MemRef:$src);
- let results = (outs IREELL_MemRef);
-}
-
-// TODO(benvanik): add split dim/size/etc. Maybe make multiple ops?
-def IREEInterpLL_SplitOp : IREEInterpLL_PureOp<"split"> {
- let arguments = (ins
- IREELL_MemRef:$src
- );
- let results = (outs
- Variadic<IREELL_MemRef>
- );
-}
-
-def IREEInterpLL_AssignOp :
- IREEInterpLL_Op<"assign", [SameOperandsAndResultType]> {
- let arguments = (ins IREELL_MemRef:$src);
- let results = (outs IREELL_MemRef);
-}
-
-def IREEInterpLL_CondAssignOp : IREEInterpLL_Op<"cond_assign"> {
- let arguments = (ins
- IREELL_BoolScalar:$cond,
- IREELL_MemRef:$lhs,
- IREELL_MemRef:$rhs
- );
- let results = (outs
- IREELL_MemRef
- );
-}
-
-def IREEInterpLL_ReshapeOp : IREEInterpLL_Op<"reshape"> {
- let arguments = (ins
- IREELL_MemRef:$input,
- IREELL_1DIntMemRef:$shape
- );
- let results = (outs
- IREELL_MemRef
- );
-}
-
-def IREEInterpLL_SelectOp : IREEInterpLL_Op<"select"> {
- let arguments = (ins
- IREELL_MemRef:$cond,
- IREELL_MemRef:$lhs,
- IREELL_MemRef:$rhs,
- IREELL_MemRef:$dst
- );
-}
-
-def IREEInterpLL_PadOp :
- IREEInterpLL_Op<
- "pad", [AllElementTypesMatch<["src", "dst", "padding_value"]>]> {
- let arguments = (ins
- IREELL_MemRef:$src,
- IREELL_ElementScalar:$padding_value,
- IREELL_1DIndexMemRef:$edge_padding_low,
- IREELL_1DIndexMemRef:$edge_padding_high,
- IREELL_1DIndexMemRef:$interior_padding,
- IREELL_MemRef:$dst
- );
-}
-
-def IREEInterpLL_TransposeOp : IREEInterpLL_BinaryOp<"transpose">;
-
-def IREEInterPLL_ReverseOp : IREEInterpLL_BinaryOp<"reverse">;
-
-def IREEInterpLL_BroadcastOp : IREEInterpLL_BinaryOp<"broadcast">;
-
-def IREEInterpLL_TileOp : IREEInterpLL_BinaryOp<"tile">;
-
-// TODO(benvanik): add traits for broadcasting support.
-
-def IREEInterpLL_NotOp : IREEInterpLL_UnaryOp<"not">;
-def IREEInterpLL_AndOp : IREEInterpLL_BinaryOp<"and">;
-def IREEInterpLL_OrOp : IREEInterpLL_BinaryOp<"or">;
-def IREEInterpLL_XorOp : IREEInterpLL_BinaryOp<"xor">;
-def IREEInterpLL_ShiftLeftOp : IREEInterpLL_BinaryOp<"sll">;
-def IREEInterpLL_ShiftRightLogicalOp : IREEInterpLL_BinaryOp<"srl">;
-def IREEInterpLL_ShiftRightArithmeticOp : IREEInterpLL_BinaryOp<"sra">;
-
-def IREEInterpLL_AddIOp : IREEInterpLL_BinaryOp<"add_i", IREELL_IntMemRef>;
-def IREEInterpLL_AddFOp : IREEInterpLL_BinaryOp<"add_f", IREELL_FloatMemRef>;
-def IREEInterpLL_SubIOp : IREEInterpLL_BinaryOp<"sub_i", IREELL_IntMemRef>;
-def IREEInterpLL_SubFOp : IREEInterpLL_BinaryOp<"sub_f", IREELL_FloatMemRef>;
-def IREEInterpLL_AbsIOp : IREEInterpLL_UnaryOp<"abs_i", IREELL_IntMemRef>;
-def IREEInterpLL_AbsFOp : IREEInterpLL_UnaryOp<"abs_f", IREELL_FloatMemRef>;
-def IREEInterpLL_MulIOp : IREEInterpLL_BinaryOp<"mul_i", IREELL_IntMemRef>;
-def IREEInterpLL_MulFOp : IREEInterpLL_BinaryOp<"mul_f", IREELL_FloatMemRef>;
-def IREEInterpLL_DivISOp : IREEInterpLL_BinaryOp<"div_i_s", IREELL_IntMemRef>;
-def IREEInterpLL_DivIUOp : IREEInterpLL_BinaryOp<"div_i_u", IREELL_IntMemRef>;
-def IREEInterpLL_DivFOp : IREEInterpLL_BinaryOp<"div_f", IREELL_FloatMemRef>;
-def IREEInterpLL_MulAddIOp : IREEInterpLL_BinaryOp<"madd_i", IREELL_IntMemRef>;
-def IREEInterpLL_MulAddFOp : IREEInterpLL_BinaryOp<"madd_f", IREELL_FloatMemRef>;
-def IREEInterpLL_ExpFOp : IREEInterpLL_UnaryOp<"exp_f", IREELL_FloatMemRef>;
-def IREEInterpLL_LogFOp : IREEInterpLL_UnaryOp<"log_f", IREELL_FloatMemRef>;
-def IREEInterpLL_RsqrtFOp : IREEInterpLL_UnaryOp<"rsqrt_f", IREELL_FloatMemRef>;
-def IREEInterpLL_CosFOp : IREEInterpLL_UnaryOp<"cos_f", IREELL_FloatMemRef>;
-def IREEInterpLL_SinFOp : IREEInterpLL_UnaryOp<"sin_f", IREELL_FloatMemRef>;
-def IREEInterpLL_TanhFOp : IREEInterpLL_UnaryOp<"tanh_f", IREELL_FloatMemRef>;
-def IREEInterpLL_Atan2FOp : IREEInterpLL_UnaryOp<"atan2_f", IREELL_FloatMemRef>;
-
-def IREEInterpLL_MinISOp : IREEInterpLL_BinaryOp<"min_i_s", IREELL_IntMemRef>;
-def IREEInterpLL_MinIUOp : IREEInterpLL_BinaryOp<"min_i_u", IREELL_IntMemRef>;
-def IREEInterpLL_MinFOp : IREEInterpLL_BinaryOp<"min_f", IREELL_FloatMemRef>;
-def IREEInterpLL_MaxISOp : IREEInterpLL_BinaryOp<"max_i_s", IREELL_IntMemRef>;
-def IREEInterpLL_MaxIUOp : IREEInterpLL_BinaryOp<"max_i_u", IREELL_IntMemRef>;
-def IREEInterpLL_MaxFOp : IREEInterpLL_BinaryOp<"max_f", IREELL_FloatMemRef>;
-def IREEInterpLL_ClampFOp : IREEInterpLL_TernaryOp<"clamp_f", IREELL_FloatMemRef>;
-def IREEInterpLL_FloorFOp : IREEInterpLL_UnaryOp<"floor_f", IREELL_FloatMemRef>;
-def IREEInterpLL_CeilFOp : IREEInterpLL_UnaryOp<"ceil_f", IREELL_FloatMemRef>;
-
-def IREEInterpLL_ConvertSSOp : IREEInterpLL_UnaryOp<"convert_s_s", IREELL_MemRef>;
-def IREEInterpLL_ConvertSUOp : IREEInterpLL_UnaryOp<"convert_s_u", IREELL_MemRef>;
-def IREEInterpLL_ConvertSFOp : IREEInterpLL_UnaryOp<"convert_s_f", IREELL_MemRef>;
-
-def IREEInterpLL_ConvertUSOp : IREEInterpLL_UnaryOp<"convert_u_s", IREELL_MemRef>;
-def IREEInterpLL_ConvertUUOp : IREEInterpLL_UnaryOp<"convert_u_u", IREELL_MemRef>;
-def IREEInterpLL_ConvertUFOp : IREEInterpLL_UnaryOp<"convert_u_f", IREELL_MemRef>;
-
-def IREEInterpLL_ConvertFSOp : IREEInterpLL_UnaryOp<"convert_f_s", IREELL_MemRef>;
-def IREEInterpLL_ConvertFUOp : IREEInterpLL_UnaryOp<"convert_f_u", IREELL_MemRef>;
-def IREEInterpLL_ConvertFFOp : IREEInterpLL_UnaryOp<"convert_f_f", IREELL_MemRef>;
-
-def IREEInterpLL_MatMulIOp : IREEInterpLL_Op<"matmul_i"> {
- let arguments = (ins
- IREELL_IntMemRef:$lhs,
- IREELL_IntMemRef:$rhs,
- IREELL_IntMemRef:$multiplier_mantissa,
- IREELL_IntMemRef:$multiplier_exponent,
- IREELL_IntMemRef:$dst
- );
-}
-def IREEInterpLL_MatMulFOp : IREEInterpLL_Op<"matmul_f"> {
- let arguments = (ins
- IREELL_FloatMemRef:$lhs,
- IREELL_FloatMemRef:$rhs,
- IREELL_FloatMemRef:$dst
- );
-}
-
-def IREEInterpLL_ReduceSumIOp : IREEInterpLL_Op<"reduce_sum_i"> {
- let arguments = (ins
- IREELL_IntMemRef:$src,
- IREELL_IntMemRef:$init,
- I32Attr:$dimension,
- IREELL_IntMemRef:$dst
- );
-}
-def IREEInterpLL_ReduceSumFOp : IREEInterpLL_Op<"reduce_sum_f"> {
- let arguments = (ins
- IREELL_FloatMemRef:$src,
- IREELL_FloatMemRef:$init,
- I32Attr:$dimension,
- IREELL_FloatMemRef:$dst
- );
-}
-
-def IREEInterpLL_ReduceMinIOp : IREEInterpLL_Op<"reduce_min_i"> {
- let arguments = (ins
- IREELL_IntMemRef:$src,
- IREELL_IntMemRef:$init,
- I32Attr:$dimension,
- IREELL_IntMemRef:$dst
- );
-}
-def IREEInterpLL_ReduceMinFOp : IREEInterpLL_Op<"reduce_min_f"> {
- let arguments = (ins
- IREELL_FloatMemRef:$src,
- IREELL_FloatMemRef:$init,
- I32Attr:$dimension,
- IREELL_FloatMemRef:$dst
- );
-}
-
-def IREEInterpLL_ReduceMaxIOp : IREEInterpLL_Op<"reduce_max_i"> {
- let arguments = (ins
- IREELL_IntMemRef:$src,
- IREELL_IntMemRef:$init,
- I32Attr:$dimension,
- IREELL_IntMemRef:$dst
- );
-}
-def IREEInterpLL_ReduceMaxFOp : IREEInterpLL_Op<"reduce_max_f"> {
- let arguments = (ins
- IREELL_FloatMemRef:$src,
- IREELL_FloatMemRef:$init,
- I32Attr:$dimension,
- IREELL_FloatMemRef:$dst
- );
-}
-
-def IREEInterpLL_TraceOp : IREEInterpLL_Op<"trace"> {
- let arguments = (ins
- Variadic<IREELL_MemRef>:$srcs
- );
-}
-
-def IREEInterpLL_CondBreakOp : IREEInterpLL_Op<"cond_break"> {
- let arguments = (ins
- IREELL_BoolScalar:$cond
- );
-}
-
-def IREEInterpLL_BreakOp : IREEInterpLL_Op<"break">;
-
-#endif // IREE_INTERPRETER_LL_OPS
diff --git a/iree/compiler/IR/Interpreter/OpWriters.cpp b/iree/compiler/IR/Interpreter/OpWriters.cpp
deleted file mode 100644
index 2e07793..0000000
--- a/iree/compiler/IR/Interpreter/OpWriters.cpp
+++ /dev/null
@@ -1,261 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Interpreter/OpWriters.h"
-
-#include "iree/compiler/IR/Interpreter/LLOps.h"
-#include "iree/compiler/Serialization/BytecodeWriter.h"
-#include "iree/compiler/Utils/Macros.h"
-#include "iree/schemas/bytecode/interpreter_bytecode_v0.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/TypeUtilities.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Sequencer ops
-//===----------------------------------------------------------------------===//
-
-LogicalResult writeOp(IREEInterp::LL::ConstantOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConstant));
- auto memrefType = op.getType().dyn_cast<MemRefType>();
- if (!memrefType) {
- return op.emitOpError()
- << "Constant has an unsupported type; must be a memref: "
- << op.getType();
- }
- RETURN_IF_FAILURE(writer->WriteConstant(memrefType, op.getAttr("value")));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
- return success();
-}
-
-LogicalResult writeOp(IREEInterp::LL::CallOp op, BytecodeWriter *writer) {
- auto module = op.getOperation()->getParentOfType<ModuleOp>();
- auto callee = module.lookupSymbol<FuncOp>(op.getCallee());
- // TODO(benvanik): transforms to convert Call->CallImport.
- // TODO(benvanik): switch with kCallTail if attr exists.
- if (callee.isExternal()) {
- RETURN_IF_FAILURE(
- writer->WriteOpcode(iree::InterpreterOpcode::kCallImport));
- } else {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCall));
- }
- RETURN_IF_FAILURE(writer->WriteFunctionOrdinal(callee));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
- return success();
-}
-
-LogicalResult writeOp(IREEInterp::LL::CallImportOp op, BytecodeWriter *writer) {
- auto module = op.getOperation()->getParentOfType<ModuleOp>();
- auto callee = module.lookupSymbol<FuncOp>(op.getCallee());
- // TODO(benvanik): switch with kCallTail if attr exists.
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCallImport));
- RETURN_IF_FAILURE(writer->WriteImportOrdinal(callee));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
- return success();
-}
-
-LogicalResult writeOp(IREEInterp::LL::CallIndirectOp op,
- BytecodeWriter *writer) {
- RETURN_IF_FAILURE(
- writer->WriteOpcode(iree::InterpreterOpcode::kCallIndirect));
- RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getCallee()->getType()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getCallee()));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
- return success();
-}
-
-LogicalResult WriteConvertOperands(Operation *op, BytecodeWriter *writer) {
- auto *src = op->getOperand(0);
- RETURN_IF_FAILURE(
- writer->WriteTypeIndex(getElementTypeOrSelf(src->getType())));
- RETURN_IF_FAILURE(writer->WriteLocal(src));
- auto *dst = op->getOperand(1);
- RETURN_IF_FAILURE(
- writer->WriteTypeIndex(getElementTypeOrSelf(dst->getType())));
- RETURN_IF_FAILURE(writer->WriteLocal(dst));
- return success();
-}
-
-LogicalResult writeOp(IREEInterp::LL::ConvertSSOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertSS));
- return WriteConvertOperands(op, writer);
-}
-
-LogicalResult writeOp(IREEInterp::LL::ConvertUUOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertUU));
- return WriteConvertOperands(op, writer);
-}
-
-LogicalResult writeOp(IREEInterp::LL::ConvertSUOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertSU));
- return WriteConvertOperands(op, writer);
-}
-
-LogicalResult writeOp(IREEInterp::LL::ConvertUSOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertUS));
- return WriteConvertOperands(op, writer);
-}
-
-LogicalResult writeOp(IREEInterp::LL::BranchOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kBranch));
- RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getDest()));
- RETURN_IF_FAILURE(writer->WriteCount(op.getNumOperands()));
- for (int i = 0; i < op.getNumOperands(); ++i) {
- // Copy src->dst.
- RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(i)));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getDest()->getArgument(i)));
- }
- return success();
-}
-
-LogicalResult writeOp(IREEInterp::LL::CondBranchOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCondBranch));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getCondition()));
- RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getTrueDest()));
- RETURN_IF_FAILURE(writer->WriteCount(op.getNumTrueOperands()));
- for (int i = 0; i < op.getNumTrueOperands(); ++i) {
- // Copy src->dst.
- RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueOperand(i)));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueDest()->getArgument(i)));
- }
- RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getFalseDest()));
- RETURN_IF_FAILURE(writer->WriteCount(op.getNumFalseOperands()));
- for (int i = 0; i < op.getNumFalseOperands(); ++i) {
- // Copy src->dst.
- RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseOperand(i)));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseDest()->getArgument(i)));
- }
- return success();
-}
-
-LogicalResult writeOp(IREEInterp::LL::CmpIOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCmpI));
- RETURN_IF_FAILURE(
- writer->WriteUint8(static_cast<uint8_t>(op.predicate().getZExtValue())));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(0)));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(1)));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(2)));
- return success();
-}
-
-LogicalResult writeOp(IREEInterp::LL::CmpFOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCmpF));
- RETURN_IF_FAILURE(
- writer->WriteUint8(static_cast<uint8_t>(op.predicate().getZExtValue())));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(0)));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(1)));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(2)));
- return success();
-}
-
-LogicalResult writeOp(IREEInterp::LL::AllocHeapOp op, BytecodeWriter *writer) {
- auto memrefType = op.getType().cast<MemRefType>();
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kAllocHeap));
- RETURN_IF_FAILURE(writer->WriteInt32(0));
- RETURN_IF_FAILURE(writer->WriteTypeIndex(memrefType.getElementType()));
- RETURN_IF_FAILURE(writer->WriteShapePieces(memrefType));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getOperands()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
- return success();
-}
-
-LogicalResult writeOp(IREEInterp::LL::StaticCopyOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kStaticCopy));
- RETURN_IF_FAILURE(writer->WriteLocal(op.src()));
- RETURN_IF_FAILURE(writer->WriteShapePieces(op.srcIndices()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.dst()));
- RETURN_IF_FAILURE(writer->WriteShapePieces(op.dstIndices()));
- RETURN_IF_FAILURE(writer->WriteShapePieces(op.lengths()));
- return success();
-}
-
-LogicalResult writeReduceOperands(Operation *op, BytecodeWriter *writer,
- APInt dimension) {
- RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(0)));
- RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(1)));
- RETURN_IF_FAILURE(writer->WriteInt32(dimension.getZExtValue()));
- RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(2)));
- return success();
-}
-
-LogicalResult writeOp(IREEInterp::LL::ReduceSumIOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceSumI));
- return writeReduceOperands(op, writer, op.dimension());
-}
-
-LogicalResult writeOp(IREEInterp::LL::ReduceSumFOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceSumF));
- return writeReduceOperands(op, writer, op.dimension());
-}
-
-LogicalResult writeOp(IREEInterp::LL::ReduceMinIOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMinI));
- return writeReduceOperands(op, writer, op.dimension());
-}
-
-LogicalResult writeOp(IREEInterp::LL::ReduceMinFOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMinF));
- return writeReduceOperands(op, writer, op.dimension());
-}
-
-LogicalResult writeOp(IREEInterp::LL::ReduceMaxIOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMaxI));
- return writeReduceOperands(op, writer, op.dimension());
-}
-
-LogicalResult writeOp(IREEInterp::LL::ReduceMaxFOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMaxF));
- return writeReduceOperands(op, writer, op.dimension());
-}
-
-} // namespace
-
-void registerInterpreterCustomWriters(VMFunctionBuilder *builder) {
-#define REGISTER_CUSTOM_WRITER_IMPL(op_type) \
- builder->RegisterCustomWriter( \
- op_type::getOperationName(), \
- +[](Operation *op, BytecodeWriter *writer) { \
- return writeOp(cast<op_type>(op), writer); \
- });
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConstantOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallImportOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallIndirectOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::BranchOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CondBranchOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertSSOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertUUOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertSUOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertUSOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CmpIOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CmpFOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::AllocHeapOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::StaticCopyOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceSumIOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceSumFOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMinIOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMinFOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMaxIOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMaxFOp);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/OpWriters.h b/iree/compiler/IR/Interpreter/OpWriters.h
deleted file mode 100644
index 302fc88..0000000
--- a/iree/compiler/IR/Interpreter/OpWriters.h
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_
-#define IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_
-
-#include "iree/compiler/Serialization/VMFunctionBuilder.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Registers custom op writers with the builder.
-// Ops not registered will use the generic writer.
-void registerInterpreterCustomWriters(VMFunctionBuilder *builder);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_
diff --git a/iree/compiler/IR/Interpreter/test/BUILD b/iree/compiler/IR/Interpreter/test/BUILD
deleted file mode 100644
index 4df56ac..0000000
--- a/iree/compiler/IR/Interpreter/test/BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_setup_lit_package(
- data = [
- "//iree/tools:iree-opt",
- "//iree/tools:iree-run-mlir",
- ],
-)
-
-iree_glob_lit_tests()
diff --git a/iree/compiler/IR/Ops.cpp b/iree/compiler/IR/Ops.cpp
deleted file mode 100644
index ada40ab..0000000
--- a/iree/compiler/IR/Ops.cpp
+++ /dev/null
@@ -1,639 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Ops.h"
-
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/SMLoc.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/STLExtras.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-
-//===----------------------------------------------------------------------===//
-// iree.constant
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseConstantOp(OpAsmParser &parser,
- OperationState &result) {
- Attribute valueAttr;
- Type type;
- if (parser.parseLSquare() ||
- parser.parseAttribute(valueAttr, "value", result.attributes) ||
- parser.parseRSquare() ||
- parser.parseOptionalAttributeDict(result.attributes) ||
- parser.parseColonType(type))
- return failure();
-
- return parser.addTypeToList(type, result.types);
-}
-
-static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) {
- p << "iree.constant[";
- p.printAttribute(op.getValue());
- p << "] ";
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
-
- p << " : ";
- p.printType(op.getType());
-}
-
-namespace {
-
-// TODO(gcmn) this is duplicated from MemRefUtils to avoid a circular
-// dependency. Extract op-dependent parts of memref utils to allow reuse.
-MemRefType convertTypeToMemRef(Type type) {
- if (type.isIntOrIndexOrFloat()) {
- return MemRefType::get({}, type, {}, 0);
- } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
- return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- } else if (auto memRefType = type.dyn_cast<MemRefType>()) {
- return MemRefType::get(memRefType.getShape(), memRefType.getElementType());
- } else {
- llvm_unreachable("Unconvertable type");
- }
-}
-
-} // namespace
-
-void ConstantOp::build(Builder *builder, OperationState &state,
- ElementsAttr value) {
- auto type = convertTypeToMemRef(value.getType());
- return build(builder, state, type, value);
-}
-
-// TODO(b/134575149): enable folder when we store the correct type.
-// OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
-// assert(operands.empty() && "constant has no operands");
-// return getValue();
-// }
-
-//===----------------------------------------------------------------------===//
-// iree.tensor_to_memref
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseTensorToMemRefOp(OpAsmParser &parser,
- OperationState &state) {
- OpAsmParser::OperandType operand;
- Type operandType;
- Type resultType;
- if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
- failed(parser.parseColonType(operandType)) ||
- failed(parser.resolveOperand(operand, operandType, state.operands)) ||
- failed(parser.parseRParen()) ||
- failed(parser.parseColonType(resultType)) ||
- failed(parser.addTypeToList(resultType, state.types))) {
- return failure();
- }
- return success();
-}
-
-static void printTensorToMemRefOp(OpAsmPrinter &p, TensorToMemRefOp &op) {
- p << "iree.tensor_to_memref(";
- p.printOperand(op.getOperand());
- p << " : ";
- p.printType(op.getOperand()->getType());
- p << ") : ";
- p.printType(op.getType());
-}
-
-OpFoldResult TensorToMemRefOp::fold(ArrayRef<Attribute> operands) {
- if (auto memrefToTensorOp = dyn_cast_or_null<IREE::MemRefToTensorOp>(
- getOperand()->getDefiningOp())) {
- return memrefToTensorOp.getOperand();
- }
-
- return {};
-}
-
-void TensorToMemRefOp::build(Builder *builder, OperationState &state,
- Value *arg) {
- build(builder, state, convertTypeToMemRef(arg->getType()), arg);
-}
-
-//===----------------------------------------------------------------------===//
-// iree.memref_to_tensor
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseMemRefToTensorOp(OpAsmParser &parser,
- OperationState &state) {
- OpAsmParser::OperandType operand;
- Type operandType;
- Type resultType;
- if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
- failed(parser.parseColonType(operandType)) ||
- failed(parser.resolveOperand(operand, operandType, state.operands)) ||
- failed(parser.parseRParen()) ||
- failed(parser.parseColonType(resultType)) ||
- failed(parser.addTypeToList(resultType, state.types))) {
- return failure();
- }
- return success();
-}
-
-static void printMemRefToTensorOp(OpAsmPrinter &p, MemRefToTensorOp &op) {
- p << "iree.memref_to_tensor(";
- p.printOperand(op.getOperand());
- p << " : ";
- p.printType(op.getOperand()->getType());
- p << ") : ";
- p.printType(op.getType());
-}
-
-OpFoldResult MemRefToTensorOp::fold(ArrayRef<Attribute> operands) {
- if (auto tensorToMemRefOp = dyn_cast_or_null<IREE::TensorToMemRefOp>(
- getOperand()->getDefiningOp())) {
- return tensorToMemRefOp.getOperand();
- }
-
- return {};
-}
-
-void MemRefToTensorOp::build(Builder *builder, OperationState &state,
- Value *arg) {
- // TODO(gcmn) Use getTensorType from MemRefUtils when circular dependency can
- // be avoided.
- auto memRefType = arg->getType().cast<MemRefType>();
- auto tensorType =
- RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
- build(builder, state, tensorType, arg);
-}
-
-//===----------------------------------------------------------------------===//
-// iree.scalar_to_memref
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseScalarToMemRefOp(OpAsmParser &parser,
- OperationState &state) {
- OpAsmParser::OperandType operand;
- Type operandType;
- Type resultType;
- if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
- failed(parser.parseColonType(operandType)) ||
- failed(parser.resolveOperand(operand, operandType, state.operands)) ||
- failed(parser.parseRParen()) ||
- failed(parser.parseColonType(resultType)) ||
- failed(parser.addTypeToList(resultType, state.types))) {
- return failure();
- }
- return success();
-}
-
-static void printScalarToMemRefOp(OpAsmPrinter &p, ScalarToMemRefOp &op) {
- p << "iree.scalar_to_memref(";
- p.printOperand(op.getOperand());
- p << " : ";
- p.printType(op.getOperand()->getType());
- p << ") : ";
- p.printType(op.getType());
-}
-
-OpFoldResult ScalarToMemRefOp::fold(ArrayRef<Attribute> operands) {
- if (auto memrefToScalarOp = dyn_cast_or_null<IREE::MemRefToScalarOp>(
- getOperand()->getDefiningOp())) {
- return memrefToScalarOp.getOperand();
- }
-
- return {};
-}
-
-void ScalarToMemRefOp::build(Builder *builder, OperationState &state,
- Value *arg) {
- build(builder, state, convertTypeToMemRef(arg->getType()), arg);
-}
-
-//===----------------------------------------------------------------------===//
-// iree.memref_to_scalar
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseMemRefToScalarOp(OpAsmParser &parser,
- OperationState &state) {
- OpAsmParser::OperandType operand;
- Type operandType;
- Type resultType;
- if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
- failed(parser.parseColonType(operandType)) ||
- failed(parser.resolveOperand(operand, operandType, state.operands)) ||
- failed(parser.parseRParen()) ||
- failed(parser.parseColonType(resultType)) ||
- failed(parser.addTypeToList(resultType, state.types))) {
- return failure();
- }
- return success();
-}
-
-static void printMemRefToScalarOp(OpAsmPrinter &p, MemRefToScalarOp &op) {
- p << "iree.memref_to_scalar(";
- p.printOperand(op.getOperand());
- p << " : ";
- p.printType(op.getOperand()->getType());
- p << ") : ";
- p.printType(op.getType());
-}
-
-OpFoldResult MemRefToScalarOp::fold(ArrayRef<Attribute> operands) {
- if (auto scalarToMemRefOp = dyn_cast_or_null<IREE::ScalarToMemRefOp>(
- getOperand()->getDefiningOp())) {
- return scalarToMemRefOp.getOperand();
- }
-
- return {};
-}
-
-void MemRefToScalarOp::build(Builder *builder, OperationState &state,
- Value *arg) {
- build(builder, state, getElementTypeOrSelf(arg), arg);
-}
-
-//===----------------------------------------------------------------------===//
-// iree.dispatch_region
-//===----------------------------------------------------------------------===//
-
-void DispatchRegionOp::build(Builder *builder, OperationState &state,
- ArrayRef<Type> resultTypes, Value *workload,
- ArrayRef<Value *> operands,
- ArrayRef<NamedAttribute> attributes) {
- state.addTypes(resultTypes);
- state.addOperands({workload});
- state.addOperands(operands);
- state.addAttributes(attributes);
- state.addRegion();
- state.setOperandListToResizable();
-}
-
-ParseResult parseDispatchRegionOp(OpAsmParser &parser, OperationState &state) {
- // Parse required workload.
- OpAsmParser::OperandType workloadArg;
- Type workloadArgType;
- if (failed(parser.parseLSquare()) ||
- failed(parser.parseOperand(workloadArg)) ||
- failed(parser.parseColonType(workloadArgType)) ||
- failed(parser.parseRSquare()) ||
- failed(parser.resolveOperand(workloadArg, workloadArgType,
- state.operands))) {
- return failure();
- }
-
- // Parse (optional) args.
- SmallVector<OpAsmParser::OperandType, 16> regionArgs;
- SmallVector<Type, 16> regionArgTypes;
- if (failed(parser.parseLParen())) {
- return failure();
- }
- if (failed(parser.parseOptionalRParen())) {
- SmallVector<OpAsmParser::OperandType, 16> regionOperands;
- auto argsLoc = parser.getCurrentLocation();
- do {
- // Reserve entries in the lists.
- regionArgs.emplace_back();
- regionOperands.emplace_back();
- regionArgTypes.emplace_back();
- if (failed(parser.parseRegionArgument(regionArgs.back())) ||
- failed(parser.parseEqual()) ||
- failed(parser.parseOperand(regionOperands.back())) ||
- failed(parser.parseColonType(regionArgTypes.back()))) {
- return failure();
- }
- } while (succeeded(parser.parseOptionalComma()));
- if (failed(parser.parseRParen()) ||
- failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc,
- state.operands))) {
- return failure();
- }
- }
- state.setOperandListToResizable();
-
- // Parse (optional) results.
- if (failed(parser.parseOptionalColonTypeList(state.types))) {
- return failure();
- }
-
- // Parse region body.
- Region *body = state.addRegion();
- if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) ||
- failed(parser.parseOptionalAttributeDict(state.attributes))) {
- return failure();
- }
- return success();
-}
-
-void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) {
- p << "iree.dispatch_region";
-
- // Print the workload argument.
- p << "[";
- p.printOperand(op.getWorkload());
- p << " : ";
- p.printType(op.getWorkload()->getType());
- p << "]";
-
- // Print the data argument remapping.
- p << "(";
- interleaveComma(
- llvm::zip(op.getBody().front().getArguments(), op.getArgOperands()), p,
- [&](std::tuple<BlockArgument *, Value *> it) {
- p << *std::get<0>(it) << " = " << *std::get<1>(it);
- p << " : ";
- p << std::get<1>(it)->getType();
- });
- p << ")";
-
- // Print the result types, if any.
- if (op.getNumResults() > 0) {
- p << " : ";
- interleaveComma(op.getResultTypes(), p);
- }
-
- p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
- p.printOptionalAttrDict(op.getAttrs(),
- /*elidedAttrs=*/{});
-}
-
-//===----------------------------------------------------------------------===//
-// iree.reduction_region
-//===----------------------------------------------------------------------===//
-
-void ReductionRegionOp::build(Builder *builder, OperationState &state,
- ArrayRef<Type> resultTypes, Value *workload,
- ArrayRef<Value *> operands,
- ArrayRef<Value *> initialValues,
- ArrayRef<int64_t> dimensions,
- ArrayRef<NamedAttribute> attributes) {
- state.addTypes(resultTypes);
- state.addOperands({workload});
- state.addOperands(operands);
- state.addOperands(initialValues);
- state.addAttribute(
- "dimensions",
- DenseIntElementsAttr::get(
- RankedTensorType::get({static_cast<int64_t>(dimensions.size())},
- builder->getIntegerType(64)),
- dimensions));
- state.addAttributes(attributes);
- state.addRegion();
- state.setOperandListToResizable();
-}
-
-void ReductionRegionOp::build(
- Builder *builder, OperationState &state, ArrayRef<Type> resultTypes,
- Value *workload, ArrayRef<Value *> operands,
- ArrayRef<Value *> initialValues, ArrayRef<int64_t> windowDimensions,
- ArrayRef<int64_t> windowStrides, ArrayRef<int64_t> baseDilations,
- ArrayRef<int64_t> windowDilations, PaddingMode paddingMode,
- ArrayRef<NamedAttribute> attributes) {
- state.addTypes(resultTypes);
- state.addOperands({workload});
- state.addOperands(operands);
- state.addOperands(initialValues);
- state.addAttribute(
- "window_dimensions",
- DenseIntElementsAttr::get(
- RankedTensorType::get({static_cast<int64_t>(windowDimensions.size())},
- builder->getIntegerType(64)),
- windowDimensions));
- state.addAttribute(
- "window_strides",
- DenseIntElementsAttr::get(
- RankedTensorType::get({static_cast<int64_t>(windowStrides.size())},
- builder->getIntegerType(64)),
- windowStrides));
- state.addAttribute(
- "base_dilations",
- DenseIntElementsAttr::get(
- RankedTensorType::get({static_cast<int64_t>(baseDilations.size())},
- builder->getIntegerType(64)),
- baseDilations));
- state.addAttribute(
- "window_dilations",
- DenseIntElementsAttr::get(
- RankedTensorType::get({static_cast<int64_t>(windowDilations.size())},
- builder->getIntegerType(64)),
- windowDilations));
- state.addAttribute("padding_mode", builder->getI32IntegerAttr(
- static_cast<int32_t>(paddingMode)));
- state.addAttributes(attributes);
- state.addRegion();
- state.setOperandListToResizable();
-}
-
-ParseResult parseReductionRegionOp(OpAsmParser &parser, OperationState &state) {
- OpAsmParser::OperandType workloadArg;
- Type workloadArgType;
- if (failed(parser.parseLSquare()) ||
- failed(parser.parseOperand(workloadArg)) ||
- failed(parser.parseColonType(workloadArgType)) ||
- failed(parser.parseRSquare()) ||
- failed(parser.resolveOperand(workloadArg, workloadArgType,
- state.operands))) {
- return failure();
- }
-
- SmallVector<OpAsmParser::OperandType, 8> reductionOperands;
- Type reductionType;
- auto operandsLoc = parser.getCurrentLocation();
- if (failed(parser.parseLParen()) ||
- failed(parser.parseOperandList(reductionOperands)) ||
- failed(parser.parseRParen()) ||
- failed(parser.parseColonType(reductionType)) ||
- failed(parser.resolveOperands(
- reductionOperands, reductionType.cast<FunctionType>().getInputs(),
- operandsLoc, state.operands))) {
- return failure();
- }
- for (auto type : reductionType.cast<FunctionType>().getResults()) {
- state.types.push_back(type);
- }
- state.setOperandListToResizable();
-
- SmallVector<OpAsmParser::OperandType, 8> regionArgs;
- SmallVector<Type, 8> regionArgTypes;
- if (failed(parser.parseKeyword("invocation")) ||
- failed(parser.parseLParen())) {
- return failure();
- }
- do {
- Type argType;
- SmallVector<OpAsmParser::OperandType, 2> reductionRegionArgs;
- OpAsmParser::OperandType initialValue;
- if (failed(parser.parseLParen()) ||
- failed(parser.parseOperandList(reductionRegionArgs, 2)) ||
- failed(parser.parseRParen()) || failed(parser.parseEqual()) ||
- failed(parser.parseOperand(initialValue)) ||
- failed(parser.parseColonType(argType)) ||
- failed(parser.resolveOperand(initialValue, argType, state.operands))) {
- return failure();
- }
- regionArgs.push_back(reductionRegionArgs[0]);
- regionArgTypes.push_back(argType);
- regionArgs.push_back(reductionRegionArgs[1]);
- regionArgTypes.push_back(argType);
- } while (succeeded(parser.parseOptionalComma()));
- if (failed(parser.parseRParen())) {
- return failure();
- }
-
- // Parse region body.
- Region *body = state.addRegion();
- if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) ||
- failed(parser.parseOptionalAttributeDict(state.attributes))) {
- return failure();
- }
-
- return success();
-}
-
-void printReductionRegionOp(OpAsmPrinter &p, ReductionRegionOp op) {
- p << "iree.reduction_region";
-
- // Print the workload argument.
- p << "[";
- p.printOperand(op.getWorkload());
- p << " : ";
- p.printType(op.getWorkload()->getType());
- p << "]";
-
- p << "(";
- p.printOperands(op.getODSOperands(1));
- p << ")";
- if (op.getNumResults() > 0) {
- p << " : (";
- interleaveComma(op.getODSOperands(1), p,
- [&](Value *operand) { p.printType(operand->getType()); });
- p << ")";
- p << " -> (";
- interleaveComma(op.getResultTypes(), p);
- p << ")";
- }
- p << "\n";
-
- p << " invocation(";
- auto &entryBlock = op.getBody().getBlocks().front();
- int regionArgIndex = 0;
- interleaveComma(op.getODSOperands(2), p, [&](Value *operand) {
- p << "(";
- p.printOperand(entryBlock.getArgument(regionArgIndex++));
- p << ", ";
- p.printOperand(entryBlock.getArgument(regionArgIndex++));
- p << ") = ";
- p.printOperand(operand);
- p << " : ";
- p.printType(operand->getType());
- });
- p << ") ";
-
- p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
- p.printOptionalAttrDict(op.getAttrs(),
- /*elidedAttrs=*/{});
-}
-
-//===----------------------------------------------------------------------===//
-// iree.return
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) {
- SmallVector<OpAsmParser::OperandType, 2> opInfo;
- SmallVector<Type, 2> types;
- llvm::SMLoc loc = parser.getCurrentLocation();
- return failure(parser.parseOperandList(opInfo) ||
- (!opInfo.empty() && parser.parseColonTypeList(types)) ||
- parser.resolveOperands(opInfo, types, loc, state.operands));
-}
-
-static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
- p << "iree.return";
- if (op.getNumOperands() > 0) {
- p << ' ';
- p.printOperands(op.operand_begin(), op.operand_end());
- p << " : ";
- interleaveComma(op.getOperandTypes(), p);
- }
-}
-
-//===----------------------------------------------------------------------===//
-// iree.load_input
-//===----------------------------------------------------------------------===//
-
-ParseResult parseLoadInputOp(OpAsmParser &parser, OperationState &state) {
- OpAsmParser::OperandType operand;
- Type argType;
- if (parser.parseLParen() || parser.parseOperand(operand) ||
- parser.parseColonType(argType) || parser.parseRParen() ||
- parser.resolveOperand(operand, argType, state.operands)) {
- return failure();
- }
- Type outputType;
- if (parser.parseColonType(outputType) ||
- parser.addTypeToList(outputType, state.types)) {
- return failure();
- }
- return success();
-}
-
-void printLoadInputOp(OpAsmPrinter &printer, Operation *op) {
- auto *inputValue = op->getOperand(0);
- auto *outputValue = op->getResult(0);
- printer << op->getName() << '(';
- printer.printOperand(inputValue);
- printer << " : ";
- printer.printType(inputValue->getType());
- printer << ") : ";
- printer.printType(outputValue->getType());
-}
-
-//===----------------------------------------------------------------------===//
-// iree.store_output
-//===----------------------------------------------------------------------===//
-
-ParseResult parseStoreOutputOp(OpAsmParser &parser, OperationState &state) {
- OpAsmParser::OperandType op0, op1;
- Type argType0, argType1;
- if (parser.parseLParen() || parser.parseOperand(op0) ||
- parser.parseColonType(argType0) || parser.parseComma() ||
- parser.resolveOperand(op0, argType0, state.operands) ||
- parser.parseOperand(op1) || parser.parseColonType(argType1) ||
- parser.parseRParen() ||
- parser.resolveOperand(op1, argType1, state.operands)) {
- return failure();
- }
- return success();
-}
-
-void printStoreOutputOp(OpAsmPrinter &printer, Operation *op) {
- auto *inputValue = op->getOperand(0);
- auto *outputValue = op->getOperand(1);
- printer << op->getName() << '(';
- printer.printOperand(inputValue);
- printer << " : ";
- printer.printType(inputValue->getType());
- printer << ", ";
- printer.printOperand(outputValue);
- printer << " : ";
- printer.printType(outputValue->getType());
- printer << ")";
-}
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/Ops.cpp.inc"
-
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Ops.h b/iree/compiler/IR/Ops.h
deleted file mode 100644
index 34f0bc3..0000000
--- a/iree/compiler/IR/Ops.h
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_IR_OPS_H_
-#define IREE_COMPILER_IR_OPS_H_
-
-#include "iree/compiler/IR/Types.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/Ops.h.inc"
-
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_IR_OPS_H_
diff --git a/iree/compiler/IR/Ops.td b/iree/compiler/IR/Ops.td
deleted file mode 100644
index cb54101..0000000
--- a/iree/compiler/IR/Ops.td
+++ /dev/null
@@ -1,200 +0,0 @@
-// IREE ops for working with buffers and buffer views.
-// These are used by common transforms between the sequencer and interpreter and
-// allow us to share some of the common lowering passes from other dialects.
-
-#ifdef IREE_OPS
-#else
-#define IREE_OPS
-
-#ifdef IREE_OP_BASE
-#else
-include "iree/compiler/IR/OpBase.td"
-#endif // IREE_OP_BASE
-
-class IREE_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<IREE_Dialect, mnemonic, traits> {
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ print$cppClass(p, *this); }];
-}
-
-class IREE_PureOp<string mnemonic, list<OpTrait> traits = []> :
- IREE_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
-
-// TODO(b/134575149): determine if we want multiple constant op types.
-def IREE_ConstantOp : IREE_PureOp<"constant", [
- AllShapesMatch<["value", "result"]>,
- AllElementTypesMatch<["value", "result"]>
-]> {
- let arguments = (ins ElementsAttr:$value);
- let results = (outs IREEHL_MemRef:$result);
-
- // TODO(b/132296600): make tablegen follow the style guide.
- let extraClassDeclaration = [{
- Attribute getValue() { return value(); }
- }];
-
- let builders = [OpBuilder<"Builder*, OperationState&, ElementsAttr">];
-
- // TODO(b/134575149): enable folder when we store the correct type.
- // let hasFolder = 1;
-}
-
-// TODO(b/134671482): remove/move tensor_to_memref/memref_to_tensor.
-def IREE_TensorToMemRefOp : IREE_PureOp<"tensor_to_memref", [
- SameOperandsAndResultShape, SameOperandsAndResultElementType
-]> {
- let arguments = (ins AnyTensor);
- let results = (outs IREEHL_MemRef);
-
- let builders = [OpBuilder<"Builder*, OperationState&, Value*">];
-
- let hasFolder = 1;
-}
-
-// TODO(b/134671482): remove/move tensor_to_memref/memref_to_tensor.
-def IREE_MemRefToTensorOp : IREE_PureOp<"memref_to_tensor", [
- SameOperandsAndResultShape, SameOperandsAndResultElementType
-]> {
- let arguments = (ins IREEHL_MemRef);
- let results = (outs AnyTensor);
- let builders = [OpBuilder<"Builder*, OperationState&, Value*">];
-
- let hasFolder = 1;
-}
-
-def IREE_ScalarToMemRefOp : IREE_PureOp<"scalar_to_memref", [
- SameOperandsAndResultElementType
-]> {
- let arguments = (ins IREEHL_Element);
- let results = (outs IREEHL_AnyScalar);
-
- let builders = [OpBuilder<"Builder*, OperationState&, Value*">];
-
- let hasFolder = 1;
-}
-
-def IREE_MemRefToScalarOp : IREE_PureOp<"memref_to_scalar", [
- SameOperandsAndResultElementType
-]> {
- let arguments = (ins IREEHL_AnyScalar);
- let results = (outs IREEHL_Element);
-
- let builders = [OpBuilder<"Builder*, OperationState&, Value*">];
-
- let hasFolder = 1;
-}
-
-def IREE_Workload : TensorOf<[AnyInteger]>;
-
-def IREE_DispatchRegionOp : IREE_PureOp<"dispatch_region"> {
- let arguments = (ins
- IREE_Workload:$workload,
- Variadic<AnyType>:$args
- );
- let results = (outs Variadic<AnyType>);
- let regions = (region AnyRegion:$body);
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- Value *getWorkload() { return workload(); }
- Region& getBody() { return body(); }
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
- unsigned mapArgOperandToOpOperand(unsigned i) { return i + 1; }
- unsigned getNumArgOperands() { return getNumOperands() - 1; }
- Value *getArgOperand(unsigned i) {
- return getOperand(mapArgOperandToOpOperand(i));
- }
- void setArgOperand(unsigned i, Value *arg) {
- setOperand(mapArgOperandToOpOperand(i), arg);
- }
-
- operand_iterator arg_operand_begin() {
- return operand_begin() + mapArgOperandToOpOperand(0);
- }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilder<"Builder *builder, OperationState &state,"
- "ArrayRef<Type> resultTypes, Value *workload,"
- "ArrayRef<Value *> args,"
- "ArrayRef<NamedAttribute> attributes = {}">,
- ];
-}
-
-def IREE_ReductionRegionOp : IREE_PureOp<"reduction_region", [
- SameVariadicOperandSize,
-]> {
- let arguments = (ins
- IREE_Workload:$workload,
- Variadic<AnyType>:$operands,
- Variadic<AnyType>:$initial_values,
- OptionalAttr<I64ElementsAttr>:$dimensions,
- OptionalAttr<I64ElementsAttr>:$window_dimensions,
- OptionalAttr<I64ElementsAttr>:$window_strides,
- OptionalAttr<I64ElementsAttr>:$base_dilations,
- OptionalAttr<I64ElementsAttr>:$window_dilations,
- OptionalAttr<IREE_PaddingModeAttr>:$padding_mode
- );
- let results = (outs Variadic<AnyType>);
- let regions = (region AnyRegion:$body);
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- Value *getWorkload() { return workload(); }
- Region& getBody() { return body(); }
-
- bool isWindowed() {
- return window_dimensions().hasValue();
- }
-
- PaddingMode getPaddingMode() {
- return static_cast<PaddingMode>(padding_mode().getValue());
- }
-
- unsigned getNumReductionOperands() { return (getNumOperands() - 1) / 2; }
- operand_range getReductionOperands() { return getODSOperands(1); }
- operand_range getInitialValueOperands() { return getODSOperands(2); }
- }];
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilder<"Builder *builder, OperationState &state,"
- "ArrayRef<Type> resultTypes, Value *workload, ArrayRef<Value *> operands,"
- "ArrayRef<Value *> initialValues,"
- "ArrayRef<int64_t> dimensions,"
- "ArrayRef<NamedAttribute> attributes = {}">,
- OpBuilder<"Builder *builder, OperationState &state,"
- "ArrayRef<Type> resultTypes, Value *workload, ArrayRef<Value *> operands,"
- "ArrayRef<Value *> initialValues,"
- "ArrayRef<int64_t> windowDimensions, ArrayRef<int64_t> windowStrides,"
- "ArrayRef<int64_t> baseDilations, ArrayRef<int64_t> windowDilations,"
- "PaddingMode paddingMode,"
- "ArrayRef<NamedAttribute> attributes = {}">,
- ];
-}
-
-def IREE_ReturnOp : IREE_Op<"return", [Terminator]> {
- let arguments = (ins Variadic<AnyType>:$operands);
-
- let builders = [OpBuilder<
- "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
- >];
-}
-
-def IREE_LoadInputOp : IREE_PureOp<"load_input"> {
- let arguments = (ins IREEHL_MemRef:$src);
- let results = (outs AnyType);
-}
-
-def IREE_StoreOutputOp : IREE_Op<"store_output"> {
- let arguments = (ins AnyType:$src, IREEHL_MemRef:$dst);
-}
-
-#endif // IREE_OPS
diff --git a/iree/compiler/IR/Sequencer/BUILD b/iree/compiler/IR/Sequencer/BUILD
deleted file mode 100644
index ba722cb..0000000
--- a/iree/compiler/IR/Sequencer/BUILD
+++ /dev/null
@@ -1,76 +0,0 @@
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-load("@local_config_mlir//:tblgen.bzl", "gentbl")
-
-filegroup(
- name = "td_files",
- srcs = glob(["*.td"]),
-)
-
-cc_library(
- name = "Sequencer",
- srcs = [
- "HLDialect.cpp",
- "HLOps.cpp",
- "HLOps.cpp.inc",
- "LLDialect.cpp",
- "LLOps.cpp",
- "LLOps.cpp.inc",
- "OpWriters.cpp",
- ],
- hdrs = [
- "HLDialect.h",
- "HLOps.h",
- "HLOps.h.inc",
- "LLDialect.h",
- "LLOps.h",
- "LLOps.h.inc",
- "OpWriters.h",
- ],
- deps = [
- ":HLOpsGen",
- ":LLOpsGen",
- "//iree/compiler/IR",
- "//iree/compiler/Serialization",
- "//iree/compiler/Utils",
- "//iree/schemas/bytecode:sequencer_bytecode_v0",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:StandardOps",
- "@local_config_mlir//:Support",
- ],
- alwayslink = 1,
-)
-
-gentbl(
- name = "HLOpsGen",
- tbl_outs = [
- ("-gen-op-decls", "HLOps.h.inc"),
- ("-gen-op-defs", "HLOps.cpp.inc"),
- ],
- tblgen = "@local_config_mlir//:mlir-tblgen",
- td_file = "HLOps.td",
- td_srcs = [
- ":td_files",
- "@local_config_mlir//:include/mlir/IR/OpBase.td",
- "//iree/compiler/IR:OpBase.td",
- ],
-)
-
-gentbl(
- name = "LLOpsGen",
- tbl_outs = [
- ("-gen-op-decls", "LLOps.h.inc"),
- ("-gen-op-defs", "LLOps.cpp.inc"),
- ],
- tblgen = "@local_config_mlir//:mlir-tblgen",
- td_file = "LLOps.td",
- td_srcs = [
- ":td_files",
- "@local_config_mlir//:include/mlir/IR/OpBase.td",
- "//iree/compiler/IR:OpBase.td",
- ],
-)
diff --git a/iree/compiler/IR/Sequencer/HLDialect.cpp b/iree/compiler/IR/Sequencer/HLDialect.cpp
deleted file mode 100644
index 7b96908..0000000
--- a/iree/compiler/IR/Sequencer/HLDialect.cpp
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Sequencer/HLDialect.h"
-
-#include "iree/compiler/IR/Sequencer/HLOps.h"
-#include "llvm/Support/SourceMgr.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-IREEHLSequencerDialect::IREEHLSequencerDialect(MLIRContext* context)
- : Dialect(getDialectNamespace(), context) {
-#define GET_OP_LIST
- addOperations<
-#include "iree/compiler/IR/Sequencer/HLOps.cpp.inc"
- >();
-}
-
-static DialectRegistration<IREEHLSequencerDialect> iree_hl_seq_dialect;
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/HLOps.cpp b/iree/compiler/IR/Sequencer/HLOps.cpp
deleted file mode 100644
index 63a5d6f..0000000
--- a/iree/compiler/IR/Sequencer/HLOps.cpp
+++ /dev/null
@@ -1,379 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Sequencer/HLOps.h"
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/Types.h"
-#include "iree/compiler/Utils/OpCreationUtils.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREESeq {
-namespace HL {
-
-namespace {
-
-static LogicalResult verifyWorkload(Operation *op, Value *workload) {
- if (auto workloadType = workload->getType().dyn_cast<MemRefType>()) {
- if (workloadType.getNumElements() != 3) {
- return op->emitOpError("workload must be specified as (x,y,z) but has ")
- << workloadType.getNumElements()
- << " elements (type=" << workload->getType() << ")";
- }
- return success();
- }
- return op->emitOpError(
- "workload must be specified as an (x,y,z) memref but has type ")
- << workload->getType();
-}
-
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// iree_hl_seq.call
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) {
- SymbolRefAttr calleeAttr;
- FunctionType calleeType;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- auto calleeLoc = parser.getNameLoc();
- if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttributeDict(state.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.addTypesToList(calleeType.getResults(), state.types) ||
- parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
- state.operands)) {
- return failure();
- }
- return success();
-}
-
-static void printCallOp(OpAsmPrinter &p, CallOp op) {
- p << "iree_hl_seq.call " << op.getAttr("callee") << '(';
- p.printOperands(op.getOperands());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : ";
- p.printType(op.getCalleeType());
-}
-
-FunctionType CallOp::getCalleeType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
- SmallVector<Type, 8> argTypes(getOperandTypes());
- return FunctionType::get(argTypes, resultTypes, getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_seq.call_indirect
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCallIndirectOp(OpAsmParser &parser,
- OperationState &result) {
- FunctionType calleeType;
- OpAsmParser::OperandType callee;
- llvm::SMLoc operandsLoc;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- return failure(
- parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttributeDict(result.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.resolveOperand(callee, calleeType, result.operands) ||
- parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
- result.operands) ||
- parser.addTypesToList(calleeType.getResults(), result.types));
-}
-
-static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) {
- p << "iree_hl_seq.call_indirect ";
- p.printOperand(op.getCallee());
- p << '(';
- auto operandRange = op.getOperands();
- p.printOperands(++operandRange.begin(), operandRange.end());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : " << op.getCallee()->getType();
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_seq.return
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) {
- SmallVector<OpAsmParser::OperandType, 2> opInfo;
- SmallVector<Type, 2> types;
- llvm::SMLoc loc = parser.getCurrentLocation();
- return failure(parser.parseOperandList(opInfo) ||
- (!opInfo.empty() && parser.parseColonTypeList(types)) ||
- parser.resolveOperands(opInfo, types, loc, state.operands));
-}
-
-static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
- p << "iree_hl_seq.return";
- if (op.getNumOperands() > 0) {
- p << ' ';
- p.printOperands(op.operand_begin(), op.operand_end());
- p << " : ";
- interleaveComma(op.getOperandTypes(), p);
- }
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_seq.br
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
- Block *dest;
- SmallVector<Value *, 4> destOperands;
- if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
- result.addSuccessor(dest, destOperands);
- return success();
-}
-
-static void printBranchOp(OpAsmPrinter &p, BranchOp op) {
- p << "iree_hl_seq.br ";
- p.printSuccessorAndUseList(op.getOperation(), 0);
-}
-
-Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
-
-void BranchOp::setDest(Block *block) {
- return getOperation()->setSuccessor(block, 0);
-}
-
-void BranchOp::eraseOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(0, index);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_seq.cond_br
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCondBranchOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<Value *, 4> destOperands;
- Block *dest;
- OpAsmParser::OperandType condInfo;
-
- // Parse the condition.
- Type int1Ty = parser.getBuilder().getI1Type();
- if (parser.parseOperand(condInfo) || parser.parseComma() ||
- parser.resolveOperand(condInfo, int1Ty, result.operands)) {
- return parser.emitError(parser.getNameLoc(),
- "expected condition type was boolean (i1)");
- }
-
- // Parse the true successor.
- if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
- result.addSuccessor(dest, destOperands);
-
- // Parse the false successor.
- destOperands.clear();
- if (parser.parseComma() ||
- parser.parseSuccessorAndUseList(dest, destOperands))
- return failure();
- result.addSuccessor(dest, destOperands);
-
- return success();
-}
-
-static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) {
- p << "iree_hl_seq.cond_br ";
- p.printOperand(op.getCondition());
- p << ", ";
- p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
- p << ", ";
- p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_seq.dispatch
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseDispatchOp(OpAsmParser &parser, OperationState &state) {
- auto executableLoc = parser.getNameLoc();
-
- SymbolRefAttr executableAttr;
- SymbolRefAttr entryPointAttr;
- FunctionType entryPointType;
- if (failed(parser.parseAttribute(executableAttr, "executable",
- state.attributes)) ||
- failed(parser.parseColon()) || failed(parser.parseColon()) ||
- failed(parser.parseAttribute(entryPointAttr, "entry_point",
- state.attributes))) {
- return failure();
- }
-
- OpAsmParser::OperandType workloadArg;
- Type workloadArgType;
- if (failed(parser.parseLSquare()) ||
- failed(parser.parseOperand(workloadArg)) ||
- failed(parser.parseColonType(workloadArgType)) ||
- failed(parser.parseRSquare()) ||
- failed(parser.resolveOperand(workloadArg, workloadArgType,
- state.operands))) {
- return failure();
- }
-
- SmallVector<OpAsmParser::OperandType, 4> operands;
- if (failed(
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
- failed(parser.parseOptionalAttributeDict(state.attributes)) ||
- failed(parser.parseColonType(entryPointType)) ||
- failed(parser.addTypesToList(entryPointType.getResults(), state.types)) ||
- failed(parser.resolveOperands(operands, entryPointType.getInputs(),
- executableLoc, state.operands))) {
- return failure();
- }
- return success();
-}
-
-static void printDispatchOp(OpAsmPrinter &p, DispatchOp op) {
- p << "iree_hl_seq.dispatch " << op.getExecutable()
- << "::" << op.getEntryPoint();
- p << "[";
- p.printOperand(op.getWorkload());
- p << " : ";
- p.printType(op.getWorkload()->getType());
- p << "](";
- p.printOperands(op.getArgOperands());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
- "executable",
- "entry_point",
- });
- p << " : ";
- p.printType(op.getEntryPointType());
-}
-
-static LogicalResult verifyDispatchOp(DispatchOp op) {
- if (failed(verifyWorkload(op, op.getWorkload()))) {
- return failure();
- }
- return success();
-}
-
-FunctionType DispatchOp::getEntryPointType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
- SmallVector<Type, 8> argTypes(getArgOperandTypes());
- return FunctionType::get(argTypes, resultTypes, getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_seq.rank
-//===----------------------------------------------------------------------===//
-
-OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
- Builder builder(getContext());
- if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) {
- return builder.getIntegerAttr(builder.getIntegerType(32),
- op0.getType().getRank());
- }
- return {};
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_seq.shape
-//===----------------------------------------------------------------------===//
-
-void ShapeOp::build(Builder *builder, OperationState &state, Value *operand) {
- state.addOperands(operand);
- int64_t rank = 0;
- if (auto shapedType = operand->getType().dyn_cast<ShapedType>()) {
- rank = shapedType.getRank();
- }
- state.addTypes(MemRefType::get({rank}, builder->getIntegerType(32)));
-}
-
-OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
- Builder builder(getContext());
- if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) {
- return DenseIntElementsAttr::get(
- RankedTensorType::get({op0.getType().getRank()},
- builder.getIntegerType(32)),
- op0.getType().getShape());
- }
- return {};
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_seq.length
-//===----------------------------------------------------------------------===//
-
-OpFoldResult LengthOp::fold(ArrayRef<Attribute> operands) {
- Builder builder(getContext());
- if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) {
- return builder.getIntegerAttr(builder.getIntegerType(32),
- op0.getNumElements());
- }
- return {};
-}
-
-//===----------------------------------------------------------------------===//
-// iree_hl_seq.concat
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct ConcatToCopies : public OpRewritePattern<ConcatOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(ConcatOp concatOp,
- PatternRewriter &rewriter) const override {
- auto finalType = concatOp.getResult()->getType().cast<ShapedType>();
- auto loc = concatOp.getLoc();
- std::vector<Value *> dimPieces;
- auto dst =
- rewriter.create<IREESeq::HL::AllocHeapOp>(loc, finalType, dimPieces);
-
- llvm::SmallVector<int64_t, 4> zeroOffset(finalType.getRank(), 0);
- auto srcIndices = createArrayConstant(rewriter, loc, zeroOffset);
-
- auto concatDimension = concatOp.dimension().getZExtValue();
- llvm::SmallVector<int64_t, 4> dstIndices(finalType.getRank(), 0);
- for (auto *src : concatOp.srcs()) {
- auto srcShape = src->getType().cast<ShapedType>().getShape();
- auto lengths = createArrayConstant(rewriter, loc, srcShape);
- auto dstIndicesOp = createArrayConstant(rewriter, loc, dstIndices);
- rewriter.create<IREESeq::HL::CopyOp>(loc, src, srcIndices, dst,
- dstIndicesOp, lengths);
- dstIndices[concatDimension] += srcShape[concatDimension];
- }
-
- concatOp.replaceAllUsesWith(dst.getResult());
-
- return matchSuccess();
- }
-};
-} // namespace
-
-void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.insert<ConcatToCopies>(context);
-}
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/Sequencer/HLOps.cpp.inc"
-
-} // namespace HL
-} // namespace IREESeq
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/HLOps.h b/iree/compiler/IR/Sequencer/HLOps.h
deleted file mode 100644
index 9fb6a4d..0000000
--- a/iree/compiler/IR/Sequencer/HLOps.h
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_IR_SEQUENCER_HLOPS_H_
-#define IREE_COMPILER_IR_SEQUENCER_HLOPS_H_
-
-#include "iree/compiler/IR/Types.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREESeq {
-namespace HL {
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/Sequencer/HLOps.h.inc"
-
-} // namespace HL
-} // namespace IREESeq
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_IR_SEQUENCER_HLOPS_H_
diff --git a/iree/compiler/IR/Sequencer/HLOps.td b/iree/compiler/IR/Sequencer/HLOps.td
deleted file mode 100644
index 1a32f76..0000000
--- a/iree/compiler/IR/Sequencer/HLOps.td
+++ /dev/null
@@ -1,429 +0,0 @@
-// IREE high-level sequencer op definitions.
-// This op set contains pseudo ops, ops that accept non-MemRef types, and ops in
-// normal SSA form.
-//
-// Through lowering these high-level ops are converted to low-level ops in the
-// LLOps.td (iree_ll_seq.*). These map 1:1 with the bytecode, accept
-// only MemRef types, and generally use output parameters instead of return
-// types.
-//
-// The source of truth for bytecode opcodes is:
-// https://github.com/google/iree/tree/master/iree/schemas/bytecode/sequencer_bytecode_v0.h
-
-#ifdef IREE_SEQUENCER_HL_OPS
-#else
-#define IREE_SEQUENCER_HL_OPS
-
-#ifdef IREE_OP_BASE
-#else
-include "iree/compiler/IR/OpBase.td"
-#endif // IREE_OP_BASE
-
-def IREESeqHL_Dialect : Dialect {
- let name = "iree_hl_seq";
- let cppNamespace = "IREESeq::HL";
-}
-
-//===----------------------------------------------------------------------===//
-// Base op classes
-//===----------------------------------------------------------------------===//
-
-class IREESeqHL_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<IREESeqHL_Dialect, mnemonic, traits>;
-
-class IREESeqHL_PureOp<string mnemonic, list<OpTrait> traits = []> :
- IREESeqHL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
-
-//===----------------------------------------------------------------------===//
-// High-level sequencer ops
-//===----------------------------------------------------------------------===//
-
-def IREESeqHL_CallOp : IREESeqHL_PureOp<"call"> {
- let arguments = (ins SymbolRefAttr:$callee, Variadic<IREEHL_MemRef>);
- let results = (outs Variadic<IREEHL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, FuncOp callee,"
- "ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(callee.getType().getResults());
- }]>, OpBuilder<
- "Builder *builder, OperationState &result, StringRef callee,"
- "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(results);
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- StringRef getCallee() { return callee(); }
- FunctionType getCalleeType();
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
-
- operand_iterator arg_operand_begin() { return operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREESeqHL_CallIndirectOp : IREESeqHL_Op<"call_indirect"> {
- let arguments = (ins FunctionType:$callee, Variadic<IREEHL_MemRef>:$operands);
- let results = (outs Variadic<IREEHL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *callee,"
- "ArrayRef<Value *> operands = {}", [{
- result.operands.push_back(callee);
- result.addOperands(operands);
- result.addTypes(callee->getType().cast<FunctionType>().getResults());
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- Value *getCallee() { return getOperand(0); }
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
- operand_iterator arg_operand_begin() { return ++operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREESeqHL_ReturnOp : IREESeqHL_Op<"return", [Terminator]> {
- let arguments = (ins Variadic<IREEHL_MemRef>:$operands);
-
- let builders = [OpBuilder<
- "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
- >];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREESeqHL_BranchOp : IREESeqHL_Op<"br", [Terminator]> {
- let arguments = (ins Variadic<IREEHL_MemRef>:$operands);
-
- let skipDefaultBuilders = 1;
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Block *dest,"
- "ArrayRef<Value *> operands = {}", [{
- result.addSuccessor(dest, operands);
- }]>];
-
- let extraClassDeclaration = [{
- Block *getDest();
- void setDest(Block *block);
-
- /// Erase the operand at 'index' from the operand list.
- void eraseOperand(unsigned index);
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREESeqHL_CondBranchOp : IREESeqHL_Op<"cond_br", [Terminator]> {
- let arguments = (ins
- IREEHL_BoolScalar:$condition,
- Variadic<IREEHL_MemRef>:$branchOperands
- );
-
- let skipDefaultBuilders = 1;
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *condition, "
- "Block *trueDest, ArrayRef<Value *> trueOperands, "
- "Block *falseDest, ArrayRef<Value *> falseOperands", [{
- result.addOperands(condition);
- result.addSuccessor(trueDest, trueOperands);
- result.addSuccessor(falseDest, falseOperands);
- }]>];
-
- let extraClassDeclaration = [{
- // These are the indices into the dests list.
- enum { trueIndex = 0, falseIndex = 1 };
-
- // The condition operand is the first operand in the list.
- Value *getCondition() { return getOperand(0); }
-
- /// Return the destination if the condition is true.
- Block *getTrueDest() {
- return getOperation()->getSuccessor(trueIndex);
- }
-
- /// Return the destination if the condition is false.
- Block *getFalseDest() {
- return getOperation()->getSuccessor(falseIndex);
- }
-
- // Accessors for operands to the 'true' destination.
- Value *getTrueOperand(unsigned idx) {
- assert(idx < getNumTrueOperands());
- return getOperand(getTrueDestOperandIndex() + idx);
- }
-
- void setTrueOperand(unsigned idx, Value *value) {
- assert(idx < getNumTrueOperands());
- setOperand(getTrueDestOperandIndex() + idx, value);
- }
-
- operand_iterator true_operand_begin() {
- return operand_begin() + getTrueDestOperandIndex();
- }
- operand_iterator true_operand_end() {
- return true_operand_begin() + getNumTrueOperands();
- }
- operand_range getTrueOperands() {
- return {true_operand_begin(), true_operand_end()};
- }
-
- unsigned getNumTrueOperands() {
- return getOperation()->getNumSuccessorOperands(trueIndex);
- }
-
- /// Erase the operand at 'index' from the true operand list.
- void eraseTrueOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(trueIndex, index);
- }
-
- // Accessors for operands to the 'false' destination.
- Value *getFalseOperand(unsigned idx) {
- assert(idx < getNumFalseOperands());
- return getOperand(getFalseDestOperandIndex() + idx);
- }
- void setFalseOperand(unsigned idx, Value *value) {
- assert(idx < getNumFalseOperands());
- setOperand(getFalseDestOperandIndex() + idx, value);
- }
-
- operand_iterator false_operand_begin() { return true_operand_end(); }
- operand_iterator false_operand_end() {
- return false_operand_begin() + getNumFalseOperands();
- }
- operand_range getFalseOperands() {
- return {false_operand_begin(), false_operand_end()};
- }
-
- unsigned getNumFalseOperands() {
- return getOperation()->getNumSuccessorOperands(falseIndex);
- }
-
- /// Erase the operand at 'index' from the false operand list.
- void eraseFalseOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(falseIndex, index);
- }
-
- private:
- /// Get the index of the first true destination operand.
- unsigned getTrueDestOperandIndex() { return 1; }
-
- /// Get the index of the first false destination operand.
- unsigned getFalseDestOperandIndex() {
- return getTrueDestOperandIndex() + getNumTrueOperands();
- }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREESeqHL_DispatchOp : IREESeqHL_Op<"dispatch"> {
- let arguments = (ins
- SymbolRefAttr:$executable,
- SymbolRefAttr:$entry_point,
- IREEHL_IntMemRef:$workload,
- Variadic<IREEHL_MemRef>:$operands
- );
- let results = (outs Variadic<IREEHL_MemRef>);
-
- let skipDefaultBuilders = 1;
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, StringRef executable,"
- "StringRef entry_point, Value *workload,"
- "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
- result.addOperands({workload});
- result.addOperands(operands);
- result.addAttribute("executable", builder->getSymbolRefAttr(executable));
- result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point));
- result.addTypes(results);
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- StringRef getExecutable() { return executable(); }
- StringRef getEntryPoint() { return entry_point(); }
- FunctionType getEntryPointType();
-
- Value *getWorkload() { return getOperand(0); }
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
- operand_iterator arg_operand_begin() { return operand_begin() + 1; }
- operand_iterator arg_operand_end() { return operand_end(); }
-
- operand_type_range getArgOperandTypes() {
- return {arg_operand_type_begin(), arg_operand_type_end()};
- }
- operand_type_iterator arg_operand_type_begin() {
- return operand_type_iterator(arg_operand_begin());
- }
- operand_type_iterator arg_operand_type_end() {
- return operand_type_iterator(arg_operand_end());
- }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
- let verifier = [{ return verify$cppClass(*this); }];
-}
-
-// TODO(b/142012496): Add trait that enables DCE but not CSE.
-def IREESeqHL_AllocHeapOp : IREESeqHL_Op<"alloc_heap"> {
- // TODO(benvanik): attributes and args.
- let arguments = (ins Variadic<IREEHL_IntMemRef>:$dim_pieces);
- let results = (outs IREEHL_MemRef);
-}
-
-def IREESeqHL_DiscardOp : IREESeqHL_Op<"discard"> {
- let arguments = (ins IREEHL_MemRef);
-}
-
-def IREESeqHL_RankOp : IREESeqHL_PureOp<"rank"> {
- let arguments = (ins IREEHL_MemRef);
- let results = (outs IREEHL_IntScalar);
-
- let hasFolder = 1;
-}
-
-def IREESeqHL_DimOp : IREESeqHL_PureOp<"dim"> {
- // TODO(benvanik) add dim attr (I32Attr:$dim)
- let arguments = (ins IREEHL_MemRef);
- let results = (outs IREEHL_IntScalar);
-}
-
-def IREESeqHL_ShapeOp : IREESeqHL_PureOp<"shape"> {
- let arguments = (ins IREEHL_MemRef);
- let results = (outs IREEHL_1DIntMemRef);
-
- let skipDefaultBuilders = 1;
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *operand">];
-
- let hasFolder = 1;
-}
-
-def IREESeqHL_LengthOp : IREESeqHL_PureOp<"length"> {
- let arguments = (ins IREEHL_MemRef);
- let results = (outs IREEHL_IndexScalar);
-
- let hasFolder = 1;
-}
-
-def IREESeqHL_SliceOp :
- IREESeqHL_PureOp<"slice", [AllElementTypesMatch<["src", "result"]>,
- AllTypesMatch<["indices", "lengths"]>]> {
- let arguments = (ins
- IREEHL_MemRef:$src,
- IREEHL_1DIndexMemRef:$indices,
- IREEHL_1DIndexMemRef:$lengths
- );
- let results = (outs IREEHL_MemRef:$result);
-}
-
-def IREESeqHL_CopyOp : IREESeqHL_Op<"copy", [
- AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>,
- AllRanksMatch<["src", "dst"]>,
- // The checks above are redundant with this one, but they give more specific
- // error messages.
- AllMatch<[
- Rank<"src">.result,
- Rank<"dst">.result,
- ElementCount<"srcIndices">.result,
- ElementCount<"dstIndices">.result,
- ElementCount<"lengths">.result
- ], "src/dst rank is the same as srcIndices/dstIndices/lengths size">,
- AllElementTypesMatch<["src", "dst"]>
-]> {
- let arguments = (ins
- IREEHL_MemRef:$src,
- IREEHL_1DIndexMemRef:$srcIndices,
- IREEHL_MemRef:$dst,
- IREEHL_1DIndexMemRef:$dstIndices,
- IREEHL_1DIndexMemRef:$lengths
- );
-}
-
-def IREESeqHL_FillOp : IREESeqHL_Op<"fill"> {
- let arguments = (ins
- IREEHL_I32Scalar:$value,
- IREEHL_MemRef:$dst,
- IREEHL_1DIndexMemRef:$dstIndices,
- IREEHL_1DIndexMemRef:$lengths
- );
-}
-
-def IREESeqHL_CloneOp : IREESeqHL_PureOp<"clone", [SameOperandsAndResultType]> {
- let arguments = (ins IREEHL_MemRef:$src);
- let results = (outs IREEHL_MemRef);
-}
-
-// A pseudo op provided for convenience. This gets canonicalized to a series of
-// copies.
-def IREESeqHL_ConcatOp : IREESeqHL_PureOp<"concat"> {
- // TODO(b/135032064) Add type constraints when they support variadic
- let arguments = (ins
- Variadic<IREEHL_MemRef>:$srcs,
- I32Attr:$dimension
- );
- let results = (outs IREEHL_MemRef);
-
- let hasCanonicalizer = 1;
-}
-
-def IREESeqHL_AssignOp :
- IREESeqHL_PureOp<"assign", [SameOperandsAndResultType]> {
- let arguments = (ins IREEHL_MemRef:$src);
- let results = (outs IREEHL_MemRef);
-}
-
-def IREESeqHL_CondAssignOp : IREESeqHL_PureOp<"cond_assign"> {
- let arguments = (ins
- IREEHL_BoolScalar:$cond,
- IREEHL_MemRef:$lhs,
- IREEHL_MemRef:$rhs
- );
- let results = (outs IREEHL_MemRef);
-}
-
-def IREESeqHL_ReshapeOp : IREESeqHL_PureOp<"reshape"> {
- let arguments = (ins IREEHL_MemRef:$src, IREEHL_MemRef:$shape);
- let results = (outs IREEHL_MemRef);
-}
-
-def IREESeqHL_TraceOp : IREESeqHL_Op<"trace"> {
- let arguments = (ins Variadic<IREEHL_MemRef>:$srcs);
-}
-
-def IREESeqHL_CondBreakOp : IREESeqHL_Op<"cond_break"> {
- let arguments = (ins IREEHL_BoolScalar:$cond);
-}
-
-def IREESeqHL_BreakOp : IREESeqHL_Op<"break">;
-
-#endif // IREE_SEQUENCER_HL_OPS
diff --git a/iree/compiler/IR/Sequencer/LLDialect.cpp b/iree/compiler/IR/Sequencer/LLDialect.cpp
deleted file mode 100644
index 872e27f..0000000
--- a/iree/compiler/IR/Sequencer/LLDialect.cpp
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Sequencer/LLDialect.h"
-
-#include "iree/compiler/IR/Sequencer/LLOps.h"
-#include "llvm/Support/SourceMgr.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-IREELLSequencerDialect::IREELLSequencerDialect(MLIRContext* context)
- : Dialect(getDialectNamespace(), context) {
-#define GET_OP_LIST
- addOperations<
-#include "iree/compiler/IR/Sequencer/LLOps.cpp.inc"
- >();
-}
-
-static DialectRegistration<IREELLSequencerDialect> iree_ll_seq_dialect;
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/LLOps.cpp b/iree/compiler/IR/Sequencer/LLOps.cpp
deleted file mode 100644
index c5b132b..0000000
--- a/iree/compiler/IR/Sequencer/LLOps.cpp
+++ /dev/null
@@ -1,671 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Sequencer/LLOps.h"
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Utils/OpUtils.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Support/STLExtras.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREESeq {
-namespace LL {
-
-namespace {
-
-static LogicalResult verifyWorkload(Operation *op, Value *workload) {
- if (auto workloadType = workload->getType().dyn_cast<MemRefType>()) {
- if (workloadType.getNumElements() != 3) {
- return op->emitOpError("workload must be specified as (x,y,z) but has ")
- << workloadType.getNumElements()
- << " elements (type=" << workload->getType() << ")";
- }
- return success();
- }
- return op->emitOpError(
- "workload must be specified as an (x,y,z) memref but has type ")
- << workload->getType();
-}
-
-static LogicalResult verifyWorkload(Operation *op, ElementsAttr workload) {
- if (workload.getNumElements() != 3) {
- return op->emitOpError("workload must be specified as (x,y,z) but has ")
- << workload.getNumElements() << " elements (value=" << workload
- << ")";
- }
- return success();
-}
-
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.constant
-//===----------------------------------------------------------------------===//
-
-OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
- return getValue();
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.call
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) {
- SymbolRefAttr calleeAttr;
- FunctionType calleeType;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- auto calleeLoc = parser.getNameLoc();
- if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttributeDict(state.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.addTypesToList(calleeType.getResults(), state.types) ||
- parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
- state.operands)) {
- return failure();
- }
- return success();
-}
-
-static void printCallOp(OpAsmPrinter &p, CallOp op) {
- p << "iree_ll_seq.call " << op.getAttr("callee") << '(';
- p.printOperands(op.getOperands());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : ";
- p.printType(op.getCalleeType());
-}
-
-FunctionType CallOp::getCalleeType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
- SmallVector<Type, 8> argTypes(getOperandTypes());
- return FunctionType::get(argTypes, resultTypes, getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.call_import
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCallImportOp(OpAsmParser &parser,
- OperationState &state) {
- SymbolRefAttr calleeAttr;
- FunctionType calleeType;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- auto calleeLoc = parser.getNameLoc();
- if (parser.parseAttribute(calleeAttr, "callee", state.attributes) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttributeDict(state.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.addTypesToList(calleeType.getResults(), state.types) ||
- parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
- state.operands)) {
- return failure();
- }
- return success();
-}
-
-static void printCallImportOp(OpAsmPrinter &p, CallImportOp op) {
- p << "iree_ll_seq.call_import " << op.getAttr("callee") << '(';
- p.printOperands(op.getOperands());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : ";
- p.printType(op.getCalleeType());
-}
-
-FunctionType CallImportOp::getCalleeType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
- SmallVector<Type, 8> argTypes(getOperandTypes());
- return FunctionType::get(argTypes, resultTypes, getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.call_indirect
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCallIndirectOp(OpAsmParser &parser,
- OperationState &result) {
- FunctionType calleeType;
- OpAsmParser::OperandType callee;
- llvm::SMLoc operandsLoc;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- return failure(
- parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttributeDict(result.attributes) ||
- parser.parseColonType(calleeType) ||
- parser.resolveOperand(callee, calleeType, result.operands) ||
- parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
- result.operands) ||
- parser.addTypesToList(calleeType.getResults(), result.types));
-}
-
-static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) {
- p << "iree_ll_seq.call_indirect ";
- p.printOperand(op.getCallee());
- p << '(';
- auto operandRange = op.getOperands();
- p.printOperands(++operandRange.begin(), operandRange.end());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- p << " : " << op.getCallee()->getType();
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.return
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) {
- SmallVector<OpAsmParser::OperandType, 2> opInfo;
- SmallVector<Type, 2> types;
- llvm::SMLoc loc = parser.getCurrentLocation();
- return failure(parser.parseOperandList(opInfo) ||
- (!opInfo.empty() && parser.parseColonTypeList(types)) ||
- parser.resolveOperands(opInfo, types, loc, state.operands));
-}
-
-static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
- p << "iree_ll_seq.return";
- if (op.getNumOperands() > 0) {
- p << ' ';
- p.printOperands(op.operand_begin(), op.operand_end());
- p << " : ";
- interleaveComma(op.getOperandTypes(), p);
- }
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.br
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
- Block *dest;
- SmallVector<Value *, 4> destOperands;
- if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
- result.addSuccessor(dest, destOperands);
- return success();
-}
-
-static void printBranchOp(OpAsmPrinter &p, BranchOp op) {
- p << "iree_ll_seq.br ";
- p.printSuccessorAndUseList(op.getOperation(), 0);
-}
-
-Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
-
-void BranchOp::setDest(Block *block) {
- return getOperation()->setSuccessor(block, 0);
-}
-
-void BranchOp::eraseOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(0, index);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.cond_br
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCondBranchOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<Value *, 4> destOperands;
- Block *dest;
- OpAsmParser::OperandType condInfo;
-
- // Parse the condition.
- Type int1Ty = parser.getBuilder().getI1Type();
- if (parser.parseOperand(condInfo) || parser.parseComma() ||
- parser.resolveOperand(condInfo, int1Ty, result.operands)) {
- return parser.emitError(parser.getNameLoc(),
- "expected condition type was boolean (i1)");
- }
-
- // Parse the true successor.
- if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure();
- result.addSuccessor(dest, destOperands);
-
- // Parse the false successor.
- destOperands.clear();
- if (parser.parseComma() ||
- parser.parseSuccessorAndUseList(dest, destOperands))
- return failure();
- result.addSuccessor(dest, destOperands);
-
- return success();
-}
-
-static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) {
- p << "iree_ll_interp.cond_br ";
- p.printOperand(op.getCondition());
- p << ", ";
- p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
- p << ", ";
- p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.dynamic_dispatch
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseDynamicDispatchOp(OpAsmParser &parser,
- OperationState &state) {
- auto executableLoc = parser.getNameLoc();
-
- SymbolRefAttr executableAttr;
- SymbolRefAttr entryPointAttr;
- FunctionType entryPointType;
- if (failed(parser.parseAttribute(executableAttr, "executable",
- state.attributes)) ||
- failed(parser.parseColon()) || failed(parser.parseColon()) ||
- failed(parser.parseAttribute(entryPointAttr, "entry_point",
- state.attributes))) {
- return failure();
- }
-
- OpAsmParser::OperandType workloadArg;
- Type workloadArgType;
- if (failed(parser.parseLSquare()) ||
- failed(parser.parseOperand(workloadArg)) ||
- failed(parser.parseColonType(workloadArgType)) ||
- failed(parser.parseRSquare()) ||
- failed(parser.resolveOperand(workloadArg, workloadArgType,
- state.operands))) {
- return failure();
- }
-
- SmallVector<OpAsmParser::OperandType, 4> operands;
- if (failed(
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
- failed(parser.parseOptionalAttributeDict(state.attributes)) ||
- failed(parser.parseColonType(entryPointType)) ||
- failed(parser.addTypesToList(entryPointType.getResults(), state.types)) ||
- failed(parser.resolveOperands(operands, entryPointType.getInputs(),
- executableLoc, state.operands))) {
- return failure();
- }
- return success();
-}
-
-static void printDynamicDispatchOp(OpAsmPrinter &p, DynamicDispatchOp op) {
- p << "iree_ll_seq.dynamic_dispatch " << op.getExecutable()
- << "::" << op.getEntryPoint();
- p << "[";
- p.printOperand(op.getWorkload());
- p << " : ";
- p.printType(op.getWorkload()->getType());
- p << "](";
- p.printOperands(op.getArgOperands());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
- "executable",
- "entry_point",
- });
- p << " : ";
- p.printType(op.getEntryPointType());
-}
-
-static LogicalResult verifyDynamicDispatchOp(DynamicDispatchOp op) {
- if (failed(verifyWorkload(op, op.getWorkload()))) {
- return failure();
- }
- return success();
-}
-
-FunctionType DynamicDispatchOp::getEntryPointType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
- SmallVector<Type, 8> argTypes(getArgOperandTypes());
- return FunctionType::get(argTypes, resultTypes, getContext());
-}
-
-namespace {
-struct MakeDynamicDispatchOpStatic
- : public OpRewritePattern<DynamicDispatchOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(DynamicDispatchOp dynamicDispatchOp,
- PatternRewriter &rewriter) const override {
- ElementsAttr workloadAttr;
- if (!matchPattern(dynamicDispatchOp.getWorkload(),
- m_Constant(&workloadAttr))) {
- return matchFailure();
- }
-
- SmallVector<Type, 8> resultTypes{dynamicDispatchOp.getResultTypes()};
- SmallVector<Value *, 8> operands{dynamicDispatchOp.getArgOperands()};
- rewriter.replaceOpWithNewOp<IREESeq::LL::StaticDispatchOp>(
- dynamicDispatchOp, dynamicDispatchOp.getExecutable(),
- dynamicDispatchOp.getEntryPoint(), workloadAttr, resultTypes, operands);
- return matchSuccess();
- }
-};
-} // namespace
-
-void DynamicDispatchOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<MakeDynamicDispatchOpStatic>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.static_dispatch
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseStaticDispatchOp(OpAsmParser &parser,
- OperationState &state) {
- auto executableLoc = parser.getNameLoc();
-
- SymbolRefAttr executableAttr;
- SymbolRefAttr entryPointAttr;
- FunctionType entryPointType;
- if (failed(parser.parseAttribute(executableAttr, "executable",
- state.attributes)) ||
- failed(parser.parseColon()) || failed(parser.parseColon()) ||
- failed(parser.parseAttribute(entryPointAttr, "entry_point",
- state.attributes))) {
- return failure();
- }
-
- ElementsAttr workloadAttr;
- if (failed(parser.parseLSquare()) ||
- failed(
- parser.parseAttribute(workloadAttr, "workload", state.attributes)) ||
- failed(parser.parseRSquare())) {
- return failure();
- }
-
- SmallVector<OpAsmParser::OperandType, 4> operands;
- if (failed(
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
- failed(parser.parseOptionalAttributeDict(state.attributes)) ||
- failed(parser.parseColonType(entryPointType)) ||
- failed(parser.addTypesToList(entryPointType.getResults(), state.types)) ||
- failed(parser.resolveOperands(operands, entryPointType.getInputs(),
- executableLoc, state.operands))) {
- return failure();
- }
- return success();
-}
-
-static void printStaticDispatchOp(OpAsmPrinter &p, StaticDispatchOp op) {
- p << "iree_ll_seq.static_dispatch " << op.getExecutable()
- << "::" << op.getEntryPoint();
- p << "[";
- p.printAttribute(op.getWorkload());
- p << "](";
- p.printOperands(op.getArgOperands());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
- "executable",
- "entry_point",
- "workload",
- });
- p << " : ";
- p.printType(op.getEntryPointType());
-}
-
-static LogicalResult verifyStaticDispatchOp(StaticDispatchOp op) {
- if (failed(verifyWorkload(op, op.getWorkload()))) {
- return failure();
- }
- return success();
-}
-
-FunctionType StaticDispatchOp::getEntryPointType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
- SmallVector<Type, 8> argTypes(getArgOperandTypes());
- return FunctionType::get(argTypes, resultTypes, getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.shape
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct FoldShapeOp : public OpRewritePattern<ShapeOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(ShapeOp shapeOp,
- PatternRewriter &rewriter) const override {
- auto memRefType = shapeOp.input()->getType().cast<MemRefType>();
- if (memRefType.hasStaticShape()) {
- auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>(
- shapeOp.getLoc(),
- MemRefType::get({memRefType.getRank()}, rewriter.getIntegerType(64)),
- DenseIntElementsAttr::get(
- RankedTensorType::get({memRefType.getRank()},
- rewriter.getIntegerType(64)),
- memRefType.getShape()));
- replaceSubsequentUses(shapeOp, shapeOp.dst(), constantOp.getResult());
- rewriter.eraseOp(shapeOp);
- return matchSuccess();
- }
- return matchFailure();
- }
-};
-} // namespace
-
-void ShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.insert<FoldShapeOp>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.length
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct FoldLengthOp : public OpRewritePattern<LengthOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(LengthOp lengthOp,
- PatternRewriter &rewriter) const override {
- auto memRefType = lengthOp.input()->getType().cast<MemRefType>();
- if (memRefType.hasStaticShape()) {
- auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>(
- lengthOp.getLoc(), MemRefType::get({}, rewriter.getIntegerType(64)),
- DenseIntElementsAttr::get(
- RankedTensorType::get({}, rewriter.getIntegerType(64)),
- {memRefType.getNumElements()}));
- replaceSubsequentUses(lengthOp, lengthOp.dst(), constantOp.getResult());
- rewriter.eraseOp(lengthOp);
- return matchSuccess();
- }
- return matchFailure();
- }
-};
-} // namespace
-
-void LengthOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.insert<FoldLengthOp>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.compute_offset
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct FoldComputeOffsetOp : public OpRewritePattern<ComputeOffsetOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(ComputeOffsetOp computeOffsetOp,
- PatternRewriter &rewriter) const override {
- ElementsAttr shapeAttr;
- ElementsAttr indicesAttr;
- if (!matchPattern(computeOffsetOp.shape(), m_Constant(&shapeAttr)) ||
- !matchPattern(computeOffsetOp.indices(), m_Constant(&indicesAttr))) {
- return matchFailure();
- }
-
- int64_t offset = 0;
- for (unsigned i = 0; i < indicesAttr.getNumElements(); ++i) {
- int64_t axisOffset =
- indicesAttr.getValue({i}).cast<IntegerAttr>().getInt();
- for (unsigned j = i + 1; j < shapeAttr.getNumElements(); ++j) {
- axisOffset *= shapeAttr.getValue({j}).cast<IntegerAttr>().getInt();
- }
- offset += axisOffset;
- }
- offset *= computeOffsetOp.elementSize().getZExtValue();
-
- auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>(
- computeOffsetOp.getLoc(),
- MemRefType::get({}, rewriter.getIntegerType(64)),
- DenseIntElementsAttr::get(
- RankedTensorType::get({}, rewriter.getIntegerType(64)), {offset}));
- replaceSubsequentUses(computeOffsetOp, computeOffsetOp.dst(),
- constantOp.getResult());
- rewriter.eraseOp(computeOffsetOp);
- return matchSuccess();
- }
-};
-} // namespace
-
-void ComputeOffsetOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<FoldComputeOffsetOp>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.compute_range
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct FoldComputeRangeOp : public OpRewritePattern<ComputeRangeOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(ComputeRangeOp computeRangeOp,
- PatternRewriter &rewriter) const override {
- ElementsAttr shapeAttr;
- ElementsAttr indicesAttr;
- ElementsAttr lengthsAttr;
- if (!matchPattern(computeRangeOp.shape(), m_Constant(&shapeAttr)) ||
- !matchPattern(computeRangeOp.indices(), m_Constant(&indicesAttr)) ||
- !matchPattern(computeRangeOp.lengths(), m_Constant(&lengthsAttr))) {
- return matchFailure();
- }
-
- int64_t offset = 0;
- int64_t length = computeRangeOp.elementSize().getZExtValue();
- for (unsigned i = 0; i < indicesAttr.getNumElements(); ++i) {
- int64_t axisOffset =
- indicesAttr.getValue({i}).cast<IntegerAttr>().getInt();
- for (unsigned j = i + 1; j < shapeAttr.getNumElements(); ++j) {
- axisOffset *= shapeAttr.getValue({j}).cast<IntegerAttr>().getInt();
- }
- offset += axisOffset;
- length *= lengthsAttr.getValue({i}).cast<IntegerAttr>().getInt();
- }
- offset *= computeRangeOp.elementSize().getZExtValue();
-
- auto offsetConstantOp = rewriter.create<IREESeq::LL::ConstantOp>(
- computeRangeOp.getLoc(),
- MemRefType::get({}, rewriter.getIntegerType(64)),
- DenseIntElementsAttr::get(
- RankedTensorType::get({}, rewriter.getIntegerType(64)), {offset}));
- replaceSubsequentUses(computeRangeOp, computeRangeOp.dstOffset(),
- offsetConstantOp.getResult());
- auto lengthConstantOp = rewriter.create<IREESeq::LL::ConstantOp>(
- computeRangeOp.getLoc(),
- MemRefType::get({}, rewriter.getIntegerType(64)),
- DenseIntElementsAttr::get(
- RankedTensorType::get({}, rewriter.getIntegerType(64)), {length}));
- replaceSubsequentUses(computeRangeOp, computeRangeOp.dstLength(),
- lengthConstantOp.getResult());
- rewriter.eraseOp(computeRangeOp);
- return matchSuccess();
- }
-};
-} // namespace
-
-void ComputeRangeOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<FoldComputeRangeOp>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.dynamic_copy
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct MakeDynamicCopyOpStatic : public OpRewritePattern<DynamicCopyOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(DynamicCopyOp dynamicCopyOp,
- PatternRewriter &rewriter) const override {
- ElementsAttr srcOffsetAttr;
- ElementsAttr dstOffsetAttr;
- ElementsAttr lengthAttr;
- if (!matchPattern(dynamicCopyOp.srcOffset(), m_Constant(&srcOffsetAttr)) ||
- !matchPattern(dynamicCopyOp.dstOffset(), m_Constant(&dstOffsetAttr)) ||
- !matchPattern(dynamicCopyOp.length(), m_Constant(&lengthAttr))) {
- return matchFailure();
- }
-
- rewriter.replaceOpWithNewOp<IREESeq::LL::StaticCopyOp>(
- dynamicCopyOp, dynamicCopyOp.src(),
- srcOffsetAttr.getValue({}).cast<IntegerAttr>(), dynamicCopyOp.dst(),
- dstOffsetAttr.getValue({}).cast<IntegerAttr>(),
- lengthAttr.getValue({}).cast<IntegerAttr>());
- return matchSuccess();
- }
-};
-} // namespace
-
-void DynamicCopyOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<MakeDynamicCopyOpStatic>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// iree_ll_seq.dynamic_fill
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct MakeDynamicFillOpStatic : public OpRewritePattern<DynamicFillOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(DynamicFillOp dynamicFillOp,
- PatternRewriter &rewriter) const override {
- ElementsAttr valueAttr;
- ElementsAttr dstOffsetAttr;
- ElementsAttr lengthAttr;
- if (!matchPattern(dynamicFillOp.value(), m_Constant(&valueAttr)) ||
- !matchPattern(dynamicFillOp.dstOffset(), m_Constant(&dstOffsetAttr)) ||
- !matchPattern(dynamicFillOp.length(), m_Constant(&lengthAttr))) {
- return matchFailure();
- }
-
- rewriter.replaceOpWithNewOp<IREESeq::LL::StaticFillOp>(
- dynamicFillOp, valueAttr.getValue({}).cast<IntegerAttr>(),
- dynamicFillOp.dst(), dstOffsetAttr.getValue({}).cast<IntegerAttr>(),
- lengthAttr.getValue({}).cast<IntegerAttr>());
- return matchSuccess();
- }
-};
-} // namespace
-
-void DynamicFillOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<MakeDynamicFillOpStatic>(context);
-}
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/Sequencer/LLOps.cpp.inc"
-
-} // namespace LL
-} // namespace IREESeq
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/LLOps.h b/iree/compiler/IR/Sequencer/LLOps.h
deleted file mode 100644
index 9a29ba0..0000000
--- a/iree/compiler/IR/Sequencer/LLOps.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_IR_SEQUENCER_LLOPS_H_
-#define IREE_COMPILER_IR_SEQUENCER_LLOPS_H_
-
-#include "iree/compiler/IR/Types.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREESeq {
-namespace LL {
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/Sequencer/LLOps.h.inc"
-
-} // namespace LL
-} // namespace IREESeq
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_IR_SEQUENCER_LLOPS_H_
diff --git a/iree/compiler/IR/Sequencer/LLOps.td b/iree/compiler/IR/Sequencer/LLOps.td
deleted file mode 100644
index 6f352fb..0000000
--- a/iree/compiler/IR/Sequencer/LLOps.td
+++ /dev/null
@@ -1,579 +0,0 @@
-// IREE low-level sequencer op definitions.
-// These map 1:1 with the bytecode, accept only MemRef types and generally use
-// output parameters instead of return types.
-//
-// The source of truth for bytecode opcodes is:
-// https://github.com/google/iree/tree/master/iree/schemas/bytecode/sequencer_bytecode_v0.h
-//
-// Note that in this dialect we cannot use folders: they require that all
-// operands are possible to make constants where we use output arguments that
-// will never be constant. Instead we can use canonicalization patterns to
-// match constant input operands and do the folding by replacing output operands
-// with the new values.
-
-#ifdef IREE_SEQUENCER_LL_OPS
-#else
-#define IREE_SEQUENCER_LL_OPS
-
-#ifdef IREE_OP_BASE
-#else
-include "iree/compiler/IR/OpBase.td"
-#endif // IREE_OP_BASE
-
-def IREESeqLL_Dialect : Dialect {
- let name = "iree_ll_seq";
- let cppNamespace = "IREESeq::LL";
-}
-
-//===----------------------------------------------------------------------===//
-// Base op classes
-//===----------------------------------------------------------------------===//
-
-class IREESeqLL_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<IREESeqLL_Dialect, mnemonic, traits> {
- bit hasCustomSerializer = 0;
-}
-
-class IREESeqLL_PureOp<string mnemonic, list<OpTrait> traits = []> :
- IREESeqLL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
-
-class IREESeqLL_UnaryOp<string mnemonic, Type type = IREELL_MemRef,
- list<OpTrait> traits = []> : IREESeqLL_Op<mnemonic, traits> {
- let arguments = (ins type:$input, type:$dst);
-}
-
-class IREESeqLL_BinaryOp<string mnemonic, Type type = IREELL_MemRef,
- list<OpTrait> traits = []> : IREESeqLL_Op<mnemonic, traits> {
- let arguments = (ins type:$lhs, type:$rhs, type:$dst);
-}
-
-class IREESeqLL_TernaryOp<string mnemonic, Type type = IREELL_MemRef,
- list<OpTrait> traits = []>
- : IREESeqLL_Op<mnemonic, traits> {
- let arguments = (ins type : $a, type : $b, type : $c, type : $dst);
-}
-
-//===----------------------------------------------------------------------===//
-// Low-level sequencer ops
-//===----------------------------------------------------------------------===//
-
-def IREESeqLL_ConstantOp : IREESeqLL_PureOp<"constant"> {
- let arguments = (ins ElementsAttr:$value);
- let results = (outs IREELL_MemRef);
-
- // TODO(b/132296600): make tablegen follow the style guide.
- let extraClassDeclaration = [{
- Attribute getValue() { return value(); }
- }];
-
- let hasFolder = 1;
-}
-
-def IREESeqLL_CallOp : IREESeqLL_Op<"call"> {
- let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>);
- let results = (outs Variadic<IREELL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, FuncOp callee,"
- "ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(callee.getType().getResults());
- }]>, OpBuilder<
- "Builder *builder, OperationState &result, StringRef callee,"
- "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(results);
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- StringRef getCallee() { return callee(); }
- FunctionType getCalleeType();
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
-
- operand_iterator arg_operand_begin() { return operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-// TODO(benvanik): add verifier that target isExternal.
-def IREESeqLL_CallImportOp : IREESeqLL_Op<"call_import"> {
- let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>);
- let results = (outs Variadic<IREELL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, FuncOp callee,"
- "ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(callee.getType().getResults());
- }]>, OpBuilder<
- "Builder *builder, OperationState &result, StringRef callee,"
- "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
- result.addOperands(operands);
- result.addAttribute("callee", builder->getSymbolRefAttr(callee));
- result.addTypes(results);
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- StringRef getCallee() { return callee(); }
- FunctionType getCalleeType();
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
-
- operand_iterator arg_operand_begin() { return operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREESeqLL_CallIndirectOp : IREESeqLL_Op<"call_indirect"> {
- let arguments = (ins FunctionType:$callee, Variadic<IREELL_MemRef>:$operands);
- let results = (outs Variadic<IREELL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *callee,"
- "ArrayRef<Value *> operands = {}", [{
- result.operands.push_back(callee);
- result.addOperands(operands);
- result.addTypes(callee->getType().cast<FunctionType>().getResults());
- }]>];
-
- let extraClassDeclaration = [{
- Value *getCallee() { return getOperand(0); }
-
- /// Get the argument operands to the called function.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
-
- operand_iterator arg_operand_begin() { return ++operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREESeqLL_ReturnOp : IREESeqLL_Op<"return", [Terminator]> {
- let arguments = (ins Variadic<IREELL_MemRef>:$operands);
-
- let builders = [OpBuilder<
- "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
- >];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREESeqLL_BranchOp : IREESeqLL_Op<"br", [Terminator]> {
- let arguments = (ins Variadic<IREELL_MemRef>:$operands);
-
- let skipDefaultBuilders = 1;
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Block *dest, "
- "ArrayRef<Value *> operands = {}", [{
- result.addSuccessor(dest, operands);
- }]>];
-
- let extraClassDeclaration = [{
- Block *getDest();
- void setDest(Block *block);
-
- /// Erase the operand at 'index' from the operand list.
- void eraseOperand(unsigned index);
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREESeqLL_CondBranchOp : IREESeqLL_Op<"cond_br", [Terminator]> {
- let arguments = (ins
- IREELL_BoolScalar:$condition,
- Variadic<IREELL_MemRef>:$branchOperands
- );
-
- let skipDefaultBuilders = 1;
- let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *condition,"
- "Block *trueDest, ArrayRef<Value *> trueOperands,"
- "Block *falseDest, ArrayRef<Value *> falseOperands", [{
- result.addOperands(condition);
- result.addSuccessor(trueDest, trueOperands);
- result.addSuccessor(falseDest, falseOperands);
- }]>];
-
- let extraClassDeclaration = [{
- // These are the indices into the dests list.
- enum { trueIndex = 0, falseIndex = 1 };
-
- // The condition operand is the first operand in the list.
- Value *getCondition() { return getOperand(0); }
-
- /// Return the destination if the condition is true.
- Block *getTrueDest() {
- return getOperation()->getSuccessor(trueIndex);
- }
-
- /// Return the destination if the condition is false.
- Block *getFalseDest() {
- return getOperation()->getSuccessor(falseIndex);
- }
-
- // Accessors for operands to the 'true' destination.
- Value *getTrueOperand(unsigned idx) {
- assert(idx < getNumTrueOperands());
- return getOperand(getTrueDestOperandIndex() + idx);
- }
-
- void setTrueOperand(unsigned idx, Value *value) {
- assert(idx < getNumTrueOperands());
- setOperand(getTrueDestOperandIndex() + idx, value);
- }
-
- operand_iterator true_operand_begin() {
- return operand_begin() + getTrueDestOperandIndex();
- }
- operand_iterator true_operand_end() {
- return true_operand_begin() + getNumTrueOperands();
- }
- operand_range getTrueOperands() {
- return {true_operand_begin(), true_operand_end()};
- }
-
- unsigned getNumTrueOperands() {
- return getOperation()->getNumSuccessorOperands(trueIndex);
- }
-
- /// Erase the operand at 'index' from the true operand list.
- void eraseTrueOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(trueIndex, index);
- }
-
- // Accessors for operands to the 'false' destination.
- Value *getFalseOperand(unsigned idx) {
- assert(idx < getNumFalseOperands());
- return getOperand(getFalseDestOperandIndex() + idx);
- }
- void setFalseOperand(unsigned idx, Value *value) {
- assert(idx < getNumFalseOperands());
- setOperand(getFalseDestOperandIndex() + idx, value);
- }
-
- operand_iterator false_operand_begin() { return true_operand_end(); }
- operand_iterator false_operand_end() {
- return false_operand_begin() + getNumFalseOperands();
- }
- operand_range getFalseOperands() {
- return {false_operand_begin(), false_operand_end()};
- }
-
- unsigned getNumFalseOperands() {
- return getOperation()->getNumSuccessorOperands(falseIndex);
- }
-
- /// Erase the operand at 'index' from the false operand list.
- void eraseFalseOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(falseIndex, index);
- }
-
- private:
- /// Get the index of the first true destination operand.
- unsigned getTrueDestOperandIndex() { return 1; }
-
- /// Get the index of the first false destination operand.
- unsigned getFalseDestOperandIndex() {
- return getTrueDestOperandIndex() + getNumTrueOperands();
- }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
-}
-
-def IREESeqLL_DynamicDispatchOp : IREESeqLL_Op<"dynamic_dispatch"> {
- let arguments = (ins
- SymbolRefAttr:$executable,
- SymbolRefAttr:$entry_point,
- IREELL_IntMemRef:$workload,
- Variadic<IREELL_MemRef>:$operands
- );
- let results = (outs Variadic<IREELL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, StringRef executable,"
- "StringRef entry_point, Value *workload,"
- "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
- result.addOperands({workload});
- result.addOperands(operands);
- result.addAttribute("executable", builder->getSymbolRefAttr(executable));
- result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point));
- result.addTypes(results);
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- StringRef getExecutable() { return executable(); }
- StringRef getEntryPoint() { return entry_point(); }
- FunctionType getEntryPointType();
-
- Value *getWorkload() { return getOperand(0); }
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
- operand_iterator arg_operand_begin() { return operand_begin() + 1; }
- operand_iterator arg_operand_end() { return operand_end(); }
-
- operand_type_range getArgOperandTypes() {
- return {arg_operand_type_begin(), arg_operand_type_end()};
- }
- operand_type_iterator arg_operand_type_begin() {
- return operand_type_iterator(arg_operand_begin());
- }
- operand_type_iterator arg_operand_type_end() {
- return operand_type_iterator(arg_operand_end());
- }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
- let verifier = [{ return verify$cppClass(*this); }];
- let hasCanonicalizer = 1;
-}
-
-def IREESeqLL_StaticDispatchOp : IREESeqLL_Op<"static_dispatch"> {
- let arguments = (ins
- SymbolRefAttr:$executable,
- SymbolRefAttr:$entry_point,
- I32ElementsAttr:$workload,
- Variadic<IREELL_MemRef>:$operands
- );
- let results = (outs Variadic<IREELL_MemRef>);
-
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, StringRef executable,"
- "StringRef entry_point, ElementsAttr workload,"
- "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
- result.addAttribute("workload", workload);
- result.addOperands(operands);
- result.addAttribute("executable", builder->getSymbolRefAttr(executable));
- result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point));
- result.addTypes(results);
- }]>];
-
- let extraClassDeclaration = [{
- // TODO(b/132296600): make tablegen follow the style guide.
- StringRef getExecutable() { return executable(); }
- StringRef getEntryPoint() { return entry_point(); }
- FunctionType getEntryPointType();
-
- ElementsAttr getWorkload() { return workload(); }
-
- // TODO(b/133879130): make tablegen support variadic operand accessors.
- operand_range getArgOperands() {
- return {arg_operand_begin(), arg_operand_end()};
- }
- operand_iterator arg_operand_begin() { return operand_begin(); }
- operand_iterator arg_operand_end() { return operand_end(); }
-
- operand_type_range getArgOperandTypes() {
- return {arg_operand_type_begin(), arg_operand_type_end()};
- }
- operand_type_iterator arg_operand_type_begin() {
- return operand_type_iterator(arg_operand_begin());
- }
- operand_type_iterator arg_operand_type_end() {
- return operand_type_iterator(arg_operand_end());
- }
- }];
-
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ return print$cppClass(p, *this); }];
- let verifier = [{ return verify$cppClass(*this); }];
-}
-
-def IREESeqLL_AllocStaticOp : IREESeqLL_PureOp<"alloc_static"> {
- // TODO(benvanik): attributes and args.
- let results = (outs IREELL_MemRef);
-}
-
-def IREESeqLL_AllocStackOp : IREESeqLL_PureOp<"alloc_stack"> {
- // TODO(benvanik): attributes and args.
- let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces);
- let results = (outs IREELL_MemRef);
-}
-
-def IREESeqLL_AllocStackInitOp : IREESeqLL_PureOp<"alloc_stack_init"> {
- // TODO(benvanik): attributes and args.
- let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces);
- let results = (outs IREELL_MemRef);
-}
-
-// TODO(b/142012496): Add trait that enables DCE but not CSE.
-def IREESeqLL_AllocHeapOp : IREESeqLL_Op<"alloc_heap"> {
- // TODO(benvanik): attributes and args.
- let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces);
- let results = (outs IREELL_MemRef);
-}
-
-def IREESeqLL_DiscardOp : IREESeqLL_Op<"discard"> {
- let arguments = (ins IREELL_MemRef);
-}
-
-def IREESeqLL_ShapeOp : IREESeqLL_Op<"shape"> {
- let arguments = (ins IREELL_MemRef:$input, IREELL_I32MemRef:$dst);
-
- let hasCanonicalizer = 1;
-}
-
-def IREESeqLL_LengthOp : IREESeqLL_Op<"length"> {
- let arguments = (ins IREELL_MemRef:$input, IREELL_I32Scalar:$dst);
-
- let hasCanonicalizer = 1;
-}
-
-def IREESeqLL_ComputeOffsetOp : IREESeqLL_Op<"compute_offset"> {
- let arguments = (ins
- IREELL_1DIntMemRef:$shape,
- I8Attr:$elementSize,
- IREELL_1DIntMemRef:$indices,
- IREELL_I32Scalar:$dst
- );
-
- let hasCanonicalizer = 1;
-}
-
-def IREESeqLL_ComputeRangeOp : IREESeqLL_Op<"compute_range"> {
- let arguments = (ins
- IREELL_1DIntMemRef:$shape,
- I8Attr:$elementSize,
- IREELL_1DIntMemRef:$indices,
- IREELL_1DIntMemRef:$lengths,
- IREELL_I32Scalar:$dstOffset,
- IREELL_I32Scalar:$dstLength
- );
-
- let hasCanonicalizer = 1;
-}
-
-def IREESeqLL_DynamicSliceOp : IREESeqLL_PureOp<"dynamic_slice", [
- AllElementTypesMatch<["src", "result"]>
-]> {
- let arguments = (ins
- IREELL_MemRef:$src,
- IREELL_IntScalar:$offset,
- IREELL_IntScalar:$length
- );
- let results = (outs IREELL_MemRef:$result);
-}
-
-def IREESeqLL_StaticSliceOp : IREESeqLL_PureOp<"static_slice", [
- AllElementTypesMatch<["src", "result"]>
-]> {
- let arguments = (ins
- IREELL_MemRef:$src,
- I64Attr:$offset,
- I64Attr:$length
- );
- let results = (outs IREELL_MemRef:$result);
-}
-
-def IREESeqLL_DynamicCopyOp : IREESeqLL_Op<"dynamic_copy"> {
- let arguments = (ins
- IREELL_MemRef:$src,
- IREELL_IndexScalar:$srcOffset,
- IREELL_MemRef:$dst,
- IREELL_IndexScalar:$dstOffset,
- IREELL_IndexScalar:$length
- );
-
- let hasCanonicalizer = 1;
-}
-
-def IREESeqLL_StaticCopyOp : IREESeqLL_Op<"static_copy"> {
- let arguments = (ins
- IREELL_MemRef:$src,
- I64Attr:$srcOffset,
- IREELL_MemRef:$dst,
- I64Attr:$dstOffset,
- I64Attr:$length
- );
-}
-
-def IREESeqLL_DynamicFillOp : IREESeqLL_Op<"dynamic_fill"> {
- let arguments = (ins
- IREELL_I32Scalar:$value,
- IREELL_MemRef:$dst,
- IREELL_IndexScalar:$dstOffset,
- IREELL_IndexScalar:$length
- );
-
- let hasCanonicalizer = 1;
-}
-
-def IREESeqLL_StaticFillOp : IREESeqLL_Op<"static_fill"> {
- let arguments = (ins
- I32Attr:$value,
- IREELL_MemRef:$dst,
- I64Attr:$dstOffset,
- I64Attr:$length
- );
-}
-
-def IREESeqLL_CloneOp :
- IREESeqLL_PureOp<"clone", [SameOperandsAndResultType]> {
- let arguments = (ins IREELL_MemRef:$src);
- let results = (outs IREELL_MemRef);
-}
-
-def IREESeqLL_AssignOp :
- IREESeqLL_Op<"assign", [SameOperandsAndResultType]> {
- let arguments = (ins IREELL_MemRef:$src);
- let results = (outs IREELL_MemRef);
-}
-
-def IREESeqLL_CondAssignOp : IREESeqLL_Op<"cond_assign"> {
- let arguments = (ins
- IREELL_BoolScalar:$cond,
- IREELL_MemRef:$lhs,
- IREELL_MemRef:$rhs
- );
- let results = (outs IREELL_MemRef);
-}
-
-def IREESeqLL_ReshapeOp : IREESeqLL_Op<"reshape"> {
- let arguments = (ins IREELL_MemRef:$input, IREELL_1DIntMemRef:$shape);
- let results = (outs IREELL_MemRef);
-}
-
-def IREESeqLL_TraceOp : IREESeqLL_Op<"trace"> {
- let arguments = (ins Variadic<IREELL_MemRef>:$srcs);
-}
-
-def IREESeqLL_CondBreakOp : IREESeqLL_Op<"cond_break"> {
- let arguments = (ins IREELL_BoolScalar:$cond);
-}
-
-def IREESeqLL_BreakOp : IREESeqLL_Op<"break">;
-
-#endif // IREE_SEQUENCER_LL_OPS
diff --git a/iree/compiler/IR/Sequencer/OpWriters.cpp b/iree/compiler/IR/Sequencer/OpWriters.cpp
deleted file mode 100644
index 592c2d7..0000000
--- a/iree/compiler/IR/Sequencer/OpWriters.cpp
+++ /dev/null
@@ -1,266 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Sequencer/OpWriters.h"
-
-#include "iree/compiler/IR/Sequencer/LLOps.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Serialization/BytecodeWriter.h"
-#include "iree/compiler/Utils/Macros.h"
-#include "iree/schemas/bytecode/sequencer_bytecode_v0.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/OpImplementation.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Sequencer ops
-//===----------------------------------------------------------------------===//
-
-LogicalResult writeOp(IREESeq::LL::ConstantOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kConstant));
- auto memRefType = op.getType().dyn_cast<MemRefType>();
- if (!memRefType) {
- return op.emitError()
- << "Constant has an unsupported type; must be a memref: "
- << op.getType();
- }
- RETURN_IF_FAILURE(writer->WriteConstant(memRefType, op.getAttr("value")));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::CallOp op, BytecodeWriter *writer) {
- auto module = op.getOperation()->getParentOfType<ModuleOp>();
- auto callee = module.lookupSymbol<FuncOp>(op.getCallee());
- // TODO(benvanik): switch with kCallTail if attr exists.
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCall));
- RETURN_IF_FAILURE(writer->WriteFunctionOrdinal(callee));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::CallImportOp op, BytecodeWriter *writer) {
- auto module = op.getOperation()->getParentOfType<ModuleOp>();
- auto callee = module.lookupSymbol<FuncOp>(op.getCallee());
- // TODO(benvanik): transforms to convert Call->CallImport.
- // TODO(benvanik): switch with kCallTail if attr exists.
- if (callee.isExternal()) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCallImport));
- } else {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCall));
- }
- RETURN_IF_FAILURE(writer->WriteImportOrdinal(callee));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::CallIndirectOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCallIndirect));
- RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getCallee()->getType()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getCallee()));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::BranchOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kBranch));
- RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getDest()));
- RETURN_IF_FAILURE(writer->WriteCount(op.getNumOperands()));
- for (int i = 0; i < op.getNumOperands(); ++i) {
- // Copy src->dst.
- RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(i)));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getDest()->getArgument(i)));
- }
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::CondBranchOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCondBranch));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getCondition()));
- RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getTrueDest()));
- RETURN_IF_FAILURE(writer->WriteCount(op.getNumTrueOperands()));
- for (int i = 0; i < op.getNumTrueOperands(); ++i) {
- // Copy src->dst.
- RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueOperand(i)));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueDest()->getArgument(i)));
- }
- RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getFalseDest()));
- RETURN_IF_FAILURE(writer->WriteCount(op.getNumFalseOperands()));
- for (int i = 0; i < op.getNumFalseOperands(); ++i) {
- // Copy src->dst.
- RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseOperand(i)));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseDest()->getArgument(i)));
- }
- return success();
-}
-
-LogicalResult writeDispatchOpExecutableRef(Operation *op, StringRef executable,
- StringRef entryPoint,
- BytecodeWriter *writer) {
- auto module = op->getParentOfType<ModuleOp>();
- auto multiArchExecutableOp =
- module.lookupSymbol<IREE::MultiArchExecutableOp>(executable);
- if (!multiArchExecutableOp) {
- return op->emitError() << "Executable @" << executable.str()
- << " not found in module";
- }
-
- auto executableOrdinalAttr = multiArchExecutableOp.getAttr("iree.ordinal")
- .dyn_cast_or_null<IntegerAttr>();
- if (!executableOrdinalAttr) {
- return op->emitError() << "No ordinal assigned to executable";
- }
- int executableOrdinal = executableOrdinalAttr.getInt();
-
- // TODO(benvanik): move an export table to the MAE to make this cleaner.
- auto executableOp =
- cast<IREE::ExecutableOp>(multiArchExecutableOp.getBlock().front());
- auto entryPointOp =
- executableOp.getInnerModule().lookupSymbol<FuncOp>(entryPoint);
- if (!entryPointOp) {
- return op->emitError() << "Entry point @" << entryPoint.str()
- << " not found in executable @" << executable.str();
- }
- if (!entryPointOp.getAttr("iree.ordinal")) {
- return op->emitError() << "No ordinal assigned to entry point";
- }
- int entryPointOrdinal =
- entryPointOp.getAttr("iree.ordinal").cast<IntegerAttr>().getInt();
-
- RETURN_IF_FAILURE(writer->WriteUint32(executableOrdinal));
- RETURN_IF_FAILURE(writer->WriteUint16(entryPointOrdinal));
-
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::DynamicDispatchOp op,
- BytecodeWriter *writer) {
- RETURN_IF_FAILURE(
- writer->WriteOpcode(iree::SequencerOpcode::kDynamicDispatch));
- RETURN_IF_FAILURE(writeDispatchOpExecutableRef(op, op.getExecutable(),
- op.getEntryPoint(), writer));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getWorkload()));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
- // TODO(benvanik): support output arg group (or change to tags).
- RETURN_IF_FAILURE(writer->WriteCount(/*output_arg_count*/ 0));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::StaticDispatchOp op,
- BytecodeWriter *writer) {
- RETURN_IF_FAILURE(
- writer->WriteOpcode(iree::SequencerOpcode::kStaticDispatch));
- RETURN_IF_FAILURE(writeDispatchOpExecutableRef(op, op.getExecutable(),
- op.getEntryPoint(), writer));
- auto workloadAttr = op.getWorkload();
- RETURN_IF_FAILURE(
- writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({0}).getInt()));
- RETURN_IF_FAILURE(
- writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({1}).getInt()));
- RETURN_IF_FAILURE(
- writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({2}).getInt()));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
- // TODO(benvanik): support output arg group (or change to tags).
- RETURN_IF_FAILURE(writer->WriteCount(/*output_arg_count*/ 0));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::AllocHeapOp op, BytecodeWriter *writer) {
- auto memRefType = op.getType().cast<MemRefType>();
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kAllocHeap));
- RETURN_IF_FAILURE(writer->WriteInt32(0));
- RETURN_IF_FAILURE(writer->WriteTypeIndex(memRefType.getElementType()));
- RETURN_IF_FAILURE(writer->WriteShapePieces(memRefType));
- RETURN_IF_FAILURE(writer->WriteLocals(op.getOperands()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::ComputeRangeOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kComputeRange));
- RETURN_IF_FAILURE(writer->WriteLocal(op.shape()));
- RETURN_IF_FAILURE(writer->WriteUint8(op.elementSize().getZExtValue()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.indices()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.lengths()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.dstOffset()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.dstLength()));
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::StaticSliceOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticSlice));
- RETURN_IF_FAILURE(writer->WriteLocal(op.src()));
- RETURN_IF_FAILURE(writer->WriteInt32(op.offset().getZExtValue()));
- RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue()));
- RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getResult()->getType()));
- RETURN_IF_FAILURE(
- writer->WriteShapePieces(op.getResult()->getType().cast<ShapedType>()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::StaticCopyOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticCopy));
- RETURN_IF_FAILURE(writer->WriteLocal(op.src()));
- RETURN_IF_FAILURE(writer->WriteInt32(op.srcOffset().getZExtValue()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.dst()));
- RETURN_IF_FAILURE(writer->WriteInt32(op.dstOffset().getZExtValue()));
- RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue()));
- return success();
-}
-
-LogicalResult writeOp(IREESeq::LL::StaticFillOp op, BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticFill));
- RETURN_IF_FAILURE(writer->WriteInt32(op.value().getZExtValue()));
- RETURN_IF_FAILURE(writer->WriteLocal(op.dst()));
- RETURN_IF_FAILURE(writer->WriteInt32(op.dstOffset().getZExtValue()));
- RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue()));
- return success();
-}
-
-} // namespace
-
-void registerSequencerCustomWriters(VMFunctionBuilder *builder) {
-#define REGISTER_CUSTOM_WRITER_IMPL(op_type) \
- builder->RegisterCustomWriter( \
- op_type::getOperationName(), \
- +[](Operation *op, BytecodeWriter *writer) { \
- return writeOp(cast<op_type>(op), writer); \
- });
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::ConstantOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallImportOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallIndirectOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::BranchOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CondBranchOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::DynamicDispatchOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticDispatchOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::AllocHeapOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::ComputeRangeOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticSliceOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticCopyOp);
- REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticFillOp);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/OpWriters.h b/iree/compiler/IR/Sequencer/OpWriters.h
deleted file mode 100644
index a0039af..0000000
--- a/iree/compiler/IR/Sequencer/OpWriters.h
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_
-#define IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_
-
-#include "iree/compiler/Serialization/VMFunctionBuilder.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Registers custom op writers with the builder.
-// Ops not registered will use the generic writer.
-void registerSequencerCustomWriters(VMFunctionBuilder *builder);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_
diff --git a/iree/compiler/IR/Sequencer/test/BUILD b/iree/compiler/IR/Sequencer/test/BUILD
deleted file mode 100644
index 4df56ac..0000000
--- a/iree/compiler/IR/Sequencer/test/BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_setup_lit_package(
- data = [
- "//iree/tools:iree-opt",
- "//iree/tools:iree-run-mlir",
- ],
-)
-
-iree_glob_lit_tests()
diff --git a/iree/compiler/IR/StructureOps.cpp b/iree/compiler/IR/StructureOps.cpp
deleted file mode 100644
index 1ee1c72..0000000
--- a/iree/compiler/IR/StructureOps.cpp
+++ /dev/null
@@ -1,250 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/StructureOps.h"
-
-#include "iree/compiler/IR/Types.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallString.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/STLExtras.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-
-//===----------------------------------------------------------------------===//
-// Generic printers and parsers.
-//===----------------------------------------------------------------------===//
-
-// Parses an op that has no inputs and no outputs.
-static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) {
- if (failed(parser.parseOptionalAttributeDict(state.attributes))) {
- return failure();
- }
- return success();
-}
-
-// Prints an op that has no inputs and no outputs.
-static void printNoIOOp(Operation *op, OpAsmPrinter &printer) {
- printer << op->getName();
- printer.printOptionalAttrDict(op->getAttrs());
-}
-
-//===----------------------------------------------------------------------===//
-// iree.module
-//===----------------------------------------------------------------------===//
-
-void ModuleOp::build(Builder *builder, OperationState &state) {
- ensureTerminator(*state.addRegion(), *builder, state.location);
-}
-
-static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
- Region *body = state.addRegion();
- if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) {
- return failure();
- }
- if (parser.parseOptionalAttributeDict(state.attributes)) {
- return failure();
- }
- ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
- return success();
-}
-
-static void printModuleOp(OpAsmPrinter &printer, Operation *op) {
- printer << op->getName();
- printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/false);
- printer.printOptionalAttrDict(op->getAttrs());
-}
-
-//===----------------------------------------------------------------------===//
-// iree.multi_arch_executable
-//===----------------------------------------------------------------------===//
-
-void MultiArchExecutableOp::build(Builder *builder, OperationState &state,
- StringRef name) {
- state.addAttribute(SymbolTable::getSymbolAttrName(),
- builder->getStringAttr(name));
- ensureTerminator(*state.addRegion(), *builder, state.location);
-}
-
-static ParseResult parseMultiArchExecutableOp(OpAsmParser &parser,
- OperationState &state) {
- auto &builder = parser.getBuilder();
-
- // Parse the name as a symbol reference attr and then convert to a string.
- SymbolRefAttr nameAttr;
- if (failed(parser.parseAttribute(nameAttr, SymbolTable::getSymbolAttrName(),
- state.attributes))) {
- return failure();
- }
- state.attributes.back().second = builder.getStringAttr(nameAttr.getValue());
-
- if (succeeded(parser.parseOptionalLSquare())) {
- IntegerAttr ordinalAttr;
- if (failed(parser.parseAttribute(ordinalAttr, builder.getIntegerType(32),
- "iree.ordinal", state.attributes)) ||
- failed(parser.parseRSquare())) {
- return failure();
- }
- }
-
- if (failed(parser.parseLParen()) || failed(parser.parseRParen())) {
- return failure();
- }
-
- Region *body = state.addRegion();
- if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) {
- return failure();
- }
- if (succeeded(parser.parseOptionalKeyword("attributes"))) {
- if (failed(parser.parseOptionalAttributeDict(state.attributes))) {
- return failure();
- }
- }
-
- MultiArchExecutableOp::ensureTerminator(*body, builder, state.location);
-
- return success();
-}
-
-static void printMultiArchExecutableOp(OpAsmPrinter &printer,
- MultiArchExecutableOp op) {
- printer << op.getOperationName() << " @" << op.sym_name();
- if (auto ordinalAttr =
- op.getAttr("iree.ordinal").dyn_cast_or_null<IntegerAttr>()) {
- printer << "[" << ordinalAttr.getInt() << "]";
- }
- printer << "()";
-
- printer.printRegion(op.body(), /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/false);
-
- // Print out executable attributes, if present.
- SmallVector<StringRef, 2> ignoredAttrs = {
- SymbolTable::getSymbolAttrName(),
- "iree.ordinal",
- };
- SmallVector<NamedAttribute, 4> attrs(
- llvm::make_filter_range(op.getAttrs(), [&](const NamedAttribute &attr) {
- return llvm::count(ignoredAttrs, attr.first) == 0;
- }));
- if (!attrs.empty()) {
- printer << "\n attributes ";
- printer.printOptionalAttrDict(attrs);
- }
-}
-
-//===----------------------------------------------------------------------===//
-// iree.executable
-//===----------------------------------------------------------------------===//
-
-void ExecutableOp::build(Builder *builder, OperationState &state,
- IREE::ExecutableFormat format) {
- state.addAttribute("format",
- builder->getI32IntegerAttr(static_cast<uint32_t>(format)));
- ensureTerminator(*state.addRegion(), *builder, state.location);
-}
-
-static ParseResult parseExecutableOp(OpAsmParser &parser,
- OperationState &state) {
- auto &builder = parser.getBuilder();
-
- if (succeeded(parser.parseOptionalLSquare())) {
- IntegerAttr ordinalAttr;
- if (failed(parser.parseAttribute(ordinalAttr, builder.getIntegerType(32),
- "iree.ordinal", state.attributes)) ||
- failed(parser.parseRSquare())) {
- return failure();
- }
- }
-
- IntegerAttr executableOrdinalAttr;
- StringAttr formatAttr;
- llvm::SMLoc formatLoc;
- if (failed(parser.parseLParen()) ||
- failed(parser.getCurrentLocation(&formatLoc)) ||
- failed(parser.parseAttribute(formatAttr, "format", state.attributes))) {
- return failure();
- }
- auto format = symbolizeExecutableFormat(formatAttr.getValue());
- if (!format.hasValue()) {
- return parser.emitError(formatLoc)
- << "Unknown executable format " << formatAttr.getValue();
- }
- state.attributes.back().second =
- builder.getI32IntegerAttr(static_cast<int32_t>(format.getValue()));
-
- Region *body = state.addRegion();
- if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) {
- return failure();
- }
- if (succeeded(parser.parseOptionalKeyword("attributes"))) {
- if (failed(parser.parseOptionalAttributeDict(state.attributes))) {
- return failure();
- }
- }
-
- ExecutableOp::ensureTerminator(*body, parser.getBuilder(), state.location);
-
- return success();
-}
-
-static void printExecutableOp(OpAsmPrinter &printer, ExecutableOp op) {
- printer << op.getOperationName();
- if (auto ordinalAttr =
- op.getAttr("iree.ordinal").dyn_cast_or_null<IntegerAttr>()) {
- printer << "[" << ordinalAttr.getInt() << "]";
- }
- printer << "(";
- auto format = symbolizeExecutableFormat(op.format());
- if (format.hasValue()) {
- printer << stringifyExecutableFormat(format.getValue());
- } else {
- printer << "INVALID FORMAT";
- }
- printer << ")";
-
- printer.printRegion(op.body(), /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/false);
-
- // Print out executable attributes, if present.
- SmallVector<StringRef, 2> ignoredAttrs = {
- "iree.ordinal",
- "format",
- };
- SmallVector<NamedAttribute, 4> attrs(
- llvm::make_filter_range(op.getAttrs(), [&](const NamedAttribute &attr) {
- return llvm::count(ignoredAttrs, attr.first) == 0;
- }));
- if (!attrs.empty()) {
- printer << "\n attributes ";
- printer.printOptionalAttrDict(attrs);
- }
-}
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/StructureOps.cpp.inc"
-
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/StructureOps.h b/iree/compiler/IR/StructureOps.h
deleted file mode 100644
index 003310b..0000000
--- a/iree/compiler/IR/StructureOps.h
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_IR_STRUCTUREOPS_H_
-#define IREE_COMPILER_IR_STRUCTUREOPS_H_
-
-#include <cstdint>
-
-#include "iree/compiler/IR/Types.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/FunctionSupport.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-
-#define GET_OP_CLASSES
-#include "iree/compiler/IR/StructureOps.h.inc"
-
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_IR_STRUCTUREOPS_H_
diff --git a/iree/compiler/IR/StructureOps.td b/iree/compiler/IR/StructureOps.td
deleted file mode 100644
index 189d3a4..0000000
--- a/iree/compiler/IR/StructureOps.td
+++ /dev/null
@@ -1,122 +0,0 @@
-// Structural ops such as 'module' and 'executable'.
-// These are used to organize IREE IR into regions representing ops that act at
-// the sequencer level (coarse control flow/scheduling) and ops that perform
-// actual work (math/etc) on runtime execution backends.
-
-#ifdef IREE_STRUCTURE_OPS
-#else
-#define IREE_STRUCTURE_OPS
-
-#ifdef IREE_OP_BASE
-#else
-include "iree/compiler/IR/OpBase.td"
-#endif // IREE_OP_BASE
-
-class IREE_StructureOp<string mnemonic, list<OpTrait> traits = []> :
- Op<IREE_Dialect, mnemonic, traits> {
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ print$cppClass(p, *this); }];
-}
-
-def IREE_ModuleOp :
- IREE_StructureOp<"module", [
- SingleBlockImplicitTerminator<"ModuleEndOp">,
- NativeOpTrait<"SymbolTable">
- ]> {
- let regions = (region SizedRegion<1>:$body);
- let extraClassDeclaration = [{
- Block& getBlock() {
- return this->getOperation()->getRegion(0).front();
- }
- }];
-
- let skipDefaultBuilders = 1;
- let builders = [OpBuilder<"Builder *, OperationState &state">];
-}
-
-def IREE_ModuleEndOp :
- IREE_StructureOp<"_module_end", [
- IREE_ModuleOnly,
- Terminator
- ]> {
- let parser = [{ return parseNoIOOp(parser, result); }];
- let printer = [{ printNoIOOp(getOperation(), p); }];
-}
-
-def IREE_MultiArchExecutableOp :
- IREE_StructureOp<"multi_arch_executable", [
- // TODO(benvanik): make iree.module work and make this IREE_ModuleOnly.
- SingleBlockImplicitTerminator<"MultiArchExecutableEndOp">
- ]> {
- let arguments = (ins
- StrAttr:$sym_name,
- OptionalAttr<I32Attr>:$ordinal
- );
-
- let regions = (region SizedRegion<1>:$body);
- let extraClassDeclaration = [{
- StringRef getName() {
- return this->getOperation()->template getAttrOfType<StringAttr>(
- ::mlir::SymbolTable::getSymbolAttrName()).getValue();
- }
-
- Region& getBody() {
- return this->getOperation()->getRegion(0);
- }
- Block& getBlock() {
- return this->getOperation()->getRegion(0).front();
- }
- }];
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilder<"Builder *builder, OperationState &state, StringRef name">,
- ];
-}
-
-def IREE_MultiArchExecutableEndOp :
- IREE_StructureOp<"_multi_arch_executable_end", [
- IREE_MultiArchExecutableOnly,
- Terminator
- ]> {
- let parser = [{ return parseNoIOOp(parser, result); }];
- let printer = [{ printNoIOOp(getOperation(), p); }];
-}
-
-def IREE_ExecutableOp :
- IREE_StructureOp<"executable", [
- SingleBlockImplicitTerminator<"ExecutableEndOp">,
- NativeOpTrait<"SymbolTable">
- ]> {
- let arguments = (ins
- IREE_ExecutableFormatAttr:$format,
- OptionalAttr<I32Attr>:$ordinal
- );
-
- let regions = (region SizedRegion<1>:$body);
- let extraClassDeclaration = [{
- Region& getBody() {
- return this->getOperation()->getRegion(0);
- }
- Block& getBlock() {
- return this->getOperation()->getRegion(0).front();
- }
- ::mlir::ModuleOp getInnerModule() {
- return *getBlock().getOps<::mlir::ModuleOp>().begin();
- }
- }];
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilder<[{Builder *builder, OperationState &state,
- ExecutableFormat executable_format}]>,
- ];
-}
-
-def IREE_ExecutableEndOp :
- IREE_StructureOp<"_executable_end", [Terminator, IREE_ExecutableOnly]> {
- let parser = [{ return parseNoIOOp(parser, result); }];
- let printer = [{ printNoIOOp(getOperation(), p); }];
-}
-
-#endif // IREE_STRUCTURE_OPS
diff --git a/iree/compiler/IR/Traits.cpp b/iree/compiler/IR/Traits.cpp
deleted file mode 100644
index 8a7081a..0000000
--- a/iree/compiler/IR/Traits.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Traits.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// TODO(benvanik): traits.
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Types.cpp b/iree/compiler/IR/Types.cpp
deleted file mode 100644
index bc49235..0000000
--- a/iree/compiler/IR/Types.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Types.h"
-
-#include "iree/compiler/IR/Enums.cpp.inc"
-
-namespace mlir {
-namespace iree_compiler {
-
-// static
-DeviceType DeviceType::get(MLIRContext *context) {
- return Base::get(context, TypeKind::Device);
-}
-
-// static
-DeviceGroupType DeviceGroupType::get(MLIRContext *context) {
- return Base::get(context, TypeKind::DeviceGroup);
-}
-
-// static
-CommandBufferType CommandBufferType::get(MLIRContext *context) {
- return Base::get(context, TypeKind::CommandBuffer);
-}
-
-// static
-EventType EventType::get(MLIRContext *context) {
- return Base::get(context, TypeKind::Event);
-}
-
-// static
-SemaphoreType SemaphoreType::get(MLIRContext *context) {
- return Base::get(context, TypeKind::Semaphore);
-}
-
-// static
-FenceType FenceType::get(MLIRContext *context) {
- return Base::get(context, TypeKind::Fence);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/IR/Types.h b/iree/compiler/IR/Types.h
deleted file mode 100644
index da53d53..0000000
--- a/iree/compiler/IR/Types.h
+++ /dev/null
@@ -1,115 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_IR_TYPES_H_
-#define IREE_COMPILER_IR_TYPES_H_
-
-#include <cstdint>
-
-#include "llvm/ADT/DenseMapInfo.h"
-#include "llvm/ADT/StringSwitch.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Support/LLVM.h"
-
-// Order matters.
-#include "iree/compiler/IR/Enums.h.inc"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace TypeKind {
-enum Kind {
- Device = Type::FIRST_IREE_TYPE,
- DeviceGroup,
- CommandBuffer,
- Event,
- Semaphore,
- Fence,
-};
-} // namespace TypeKind
-
-// clang-format off
-#define IREE_TYPE_TABLE(map) \
- map("device", TypeKind::Device, DeviceType) \
- map("device_group", TypeKind::DeviceGroup, DeviceGroupType) \
- map("command_buffer", TypeKind::CommandBuffer, CommandBufferType) \
- map("event", TypeKind::Event, EventType) \
- map("semaphore", TypeKind::Semaphore, SemaphoreType) \
- map("fence", TypeKind::Fence, FenceType)
-// clang-format on
-
-// iree.device mapping to a runtime-resolved device type.
-class DeviceType : public Type::TypeBase<DeviceType, Type> {
- public:
- using Base::Base;
-
- static bool kindof(unsigned kind) { return kind == TypeKind::Device; }
-
- static DeviceType get(MLIRContext *context);
-};
-
-// iree.device_group relating multiple iree.device requirements with each other.
-class DeviceGroupType : public Type::TypeBase<DeviceGroupType, Type> {
- public:
- using Base::Base;
-
- static bool kindof(unsigned kind) { return kind == TypeKind::DeviceGroup; }
-
- static DeviceGroupType get(MLIRContext *context);
-};
-
-// iree.command_buffer mapping to an iree::hal::CommandBuffer.
-class CommandBufferType : public Type::TypeBase<CommandBufferType, Type> {
- public:
- using Base::Base;
-
- static bool kindof(unsigned kind) { return kind == TypeKind::CommandBuffer; }
-
- static CommandBufferType get(MLIRContext *context);
-};
-
-// iree.event mapping to an iree::hal::Event.
-class EventType : public Type::TypeBase<EventType, Type> {
- public:
- using Base::Base;
-
- static bool kindof(unsigned kind) { return kind == TypeKind::Event; }
-
- static EventType get(MLIRContext *context);
-};
-
-// iree.semaphore mapping to an iree::hal::Semaphore.
-class SemaphoreType : public Type::TypeBase<SemaphoreType, Type> {
- public:
- using Base::Base;
-
- static bool kindof(unsigned kind) { return kind == TypeKind::Semaphore; }
-
- static SemaphoreType get(MLIRContext *context);
-};
-
-// iree.fence mapping to an iree::hal::Fence.
-class FenceType : public Type::TypeBase<FenceType, Type> {
- public:
- using Base::Base;
-
- static bool kindof(unsigned kind) { return kind == TypeKind::Fence; }
-
- static FenceType get(MLIRContext *context);
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_IR_TYPES_H_
diff --git a/iree/compiler/IR/test/BUILD b/iree/compiler/IR/test/BUILD
deleted file mode 100644
index 4df56ac..0000000
--- a/iree/compiler/IR/test/BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_setup_lit_package(
- data = [
- "//iree/tools:iree-opt",
- "//iree/tools:iree-run-mlir",
- ],
-)
-
-iree_glob_lit_tests()
diff --git a/iree/compiler/Serialization/BUILD b/iree/compiler/Serialization/BUILD
deleted file mode 100644
index 6bee31b..0000000
--- a/iree/compiler/Serialization/BUILD
+++ /dev/null
@@ -1,43 +0,0 @@
-# Serialization for the VM bytecode.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Serialization",
- srcs = [
- "BytecodeTables.cpp",
- "BytecodeWriter.cpp",
- "VMDeviceTableBuilder.cpp",
- "VMExecutableTableBuilder.cpp",
- "VMFunctionBuilder.cpp",
- "VMFunctionTableBuilder.cpp",
- "VMModuleBuilder.cpp",
- "VMSourceMapBuilder.cpp",
- ],
- hdrs = [
- "BytecodeTables.h",
- "BytecodeWriter.h",
- "VMDeviceTableBuilder.h",
- "VMExecutableTableBuilder.h",
- "VMFunctionBuilder.h",
- "VMFunctionTableBuilder.h",
- "VMModuleBuilder.h",
- "VMSourceMapBuilder.h",
- ],
- deps = [
- "//iree/compiler/IR",
- "//iree/compiler/Utils",
- "//iree/schemas",
- "//iree/schemas/bytecode:bytecode_v0",
- "//iree/schemas/bytecode:interpreter_bytecode_v0",
- "//iree/schemas/bytecode:sequencer_bytecode_v0",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:StandardOps",
- "@local_config_mlir//:Support",
- ],
-)
diff --git a/iree/compiler/Serialization/BytecodeTables.cpp b/iree/compiler/Serialization/BytecodeTables.cpp
deleted file mode 100644
index 7d0cc9d..0000000
--- a/iree/compiler/Serialization/BytecodeTables.cpp
+++ /dev/null
@@ -1,73 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Serialization/BytecodeTables.h"
-
-#include "llvm/ADT/STLExtras.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Info tables mapping 1:1 with bytecode ops.
-//
-// Note that we ensure the table is 256 elements long exactly to make sure
-// that unused opcodes are handled gracefully.
-#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \
- { \
- name, \
- flags, \
- {operand_encodings}, \
- },
-
-static const OpcodeInfo kInterpreterInfoTable[256] = {
- IREE_INTERPRETER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)};
-
-static const OpcodeInfo kSequencerInfoTable[256] = {
- IREE_SEQUENCER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)};
-
-#undef DECLARE_INFO
-
-} // namespace
-
-llvm::Optional<iree::InterpreterOpcode> GetInterpreterOpcodeByName(
- StringRef name) {
- for (int i = 0; i < llvm::array_lengthof(kInterpreterInfoTable); ++i) {
- if (name == kInterpreterInfoTable[i].mnemonic) {
- return static_cast<iree::InterpreterOpcode>(i);
- }
- }
- return llvm::None;
-}
-
-const OpcodeInfo& GetInterpreterOpcodeInfo(iree::InterpreterOpcode opcode) {
- return kInterpreterInfoTable[static_cast<uint8_t>(opcode)];
-}
-
-llvm::Optional<iree::SequencerOpcode> GetSequencerOpcodeByName(StringRef name) {
- for (int i = 0; i < llvm::array_lengthof(kSequencerInfoTable); ++i) {
- if (name == kSequencerInfoTable[i].mnemonic) {
- return static_cast<iree::SequencerOpcode>(i);
- }
- }
- return llvm::None;
-}
-
-const OpcodeInfo& GetSequencerOpcodeInfo(iree::SequencerOpcode opcode) {
- return kSequencerInfoTable[static_cast<uint8_t>(opcode)];
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Serialization/BytecodeTables.h b/iree/compiler/Serialization/BytecodeTables.h
deleted file mode 100644
index bee8763..0000000
--- a/iree/compiler/Serialization/BytecodeTables.h
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_
-#define IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_
-
-#include "iree/schemas/bytecode/interpreter_bytecode_v0.h"
-#include "iree/schemas/bytecode/sequencer_bytecode_v0.h"
-#include "llvm/ADT/Optional.h"
-#include "llvm/ADT/StringRef.h"
-#include "mlir/Support/LLVM.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-struct OpcodeInfo {
- const char* mnemonic = nullptr;
- iree::OpcodeFlagBitfield flags = iree::OpcodeFlagBitfield::kDefault;
- union {
- const char operands_value[8] = {0};
- const iree::OperandEncoding operands[8];
- };
-};
-
-// Returns an opcode - if found - for the given interpreter op.
-llvm::Optional<iree::InterpreterOpcode> GetInterpreterOpcodeByName(
- StringRef name);
-
-// Returns the info for the given interpreter opcode.
-const OpcodeInfo& GetInterpreterOpcodeInfo(iree::InterpreterOpcode opcode);
-
-// Returns an opcode - if found - for the given sequencer op.
-llvm::Optional<iree::SequencerOpcode> GetSequencerOpcodeByName(StringRef name);
-
-// Returns the info for the given sequencer opcode.
-const OpcodeInfo& GetSequencerOpcodeInfo(iree::SequencerOpcode opcode);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_
diff --git a/iree/compiler/Serialization/BytecodeWriter.cpp b/iree/compiler/Serialization/BytecodeWriter.cpp
deleted file mode 100644
index a357c3a..0000000
--- a/iree/compiler/Serialization/BytecodeWriter.cpp
+++ /dev/null
@@ -1,334 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Serialization/BytecodeWriter.h"
-
-#include <algorithm>
-
-#include "iree/compiler/Utils/Macros.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/Support/LLVM.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-LogicalResult BytecodeWriter::WriteCount(int count) {
- if (count > UINT8_MAX) {
- // TODO(benvanik): varints?
- llvm::errs() << "Too many items: " << count
- << "; only 0-UINT8_MAX are supported";
- return failure();
- }
- return WriteUint8(static_cast<uint8_t>(count));
-}
-
-LogicalResult BytecodeWriter::WriteTypeIndex(Type type) {
- iree::BuiltinType type_index;
- if (type.isInteger(8)) {
- type_index = iree::BuiltinType::kI8;
- } else if (type.isInteger(16)) {
- type_index = iree::BuiltinType::kI16;
- } else if (type.isInteger(32)) {
- type_index = iree::BuiltinType::kI32;
- } else if (type.isInteger(64)) {
- type_index = iree::BuiltinType::kI64;
- } else if (type.isF16()) {
- type_index = iree::BuiltinType::kF16;
- } else if (type.isF32()) {
- type_index = iree::BuiltinType::kF32;
- } else if (type.isF64()) {
- type_index = iree::BuiltinType::kF64;
- } else {
- // TODO(benvanik): support unknown types as BuiltinType::kOpaque?
- return emitError(UnknownLoc::get(type.getContext()))
- << "Type " << type << " cannot be represented by a builtin type";
- }
- return WriteUint8(static_cast<uint8_t>(type_index));
-}
-
-LogicalResult BytecodeWriter::WriteFunctionOrdinal(FuncOp function) {
- auto functionOrdinal = function.getAttrOfType<IntegerAttr>("iree.ordinal");
- if (!functionOrdinal) {
- return function.emitError() << "Ordinal not assigned to function";
- }
- RETURN_IF_FAILURE(WriteUint32(functionOrdinal.getInt()));
- return success();
-}
-
-LogicalResult BytecodeWriter::WriteImportOrdinal(FuncOp function) {
- // For now this is the same as internal function ordinals, though we could
- // probably shrink it.
- return WriteFunctionOrdinal(function);
-}
-
-LogicalResult BytecodeWriter::WriteConstant(MemRefType memRefType,
- Attribute baseAttr) {
- // All types are memrefs, so we only need the element type.
- RETURN_IF_FAILURE(WriteTypeIndex(memRefType.getElementType()));
-
- // Write shape (we could optimize this for cases of scalars and such).
- RETURN_IF_FAILURE(WriteCount(memRefType.getRank()));
- for (int i = 0; i < memRefType.getRank(); ++i) {
- RETURN_IF_FAILURE(WriteInt32(memRefType.getDimSize(i)));
- }
-
- if (auto attr = baseAttr.dyn_cast<SplatElementsAttr>()) {
- RETURN_IF_FAILURE(
- WriteUint8(static_cast<uint8_t>(iree::ConstantEncoding::kSplat)));
- return WriteAttributeData(attr.getSplatValue());
- }
- RETURN_IF_FAILURE(
- WriteUint8(static_cast<uint8_t>(iree::ConstantEncoding::kDense)));
- return WriteAttributeData(baseAttr);
-}
-
-LogicalResult BytecodeWriter::WriteAttributeData(Attribute baseAttr) {
- if (auto attr = baseAttr.dyn_cast<BoolAttr>()) {
- return WriteUint8(attr.getValue() ? 1 : 0);
- } else if (auto attr = baseAttr.dyn_cast<IntegerAttr>()) {
- if (attr.getType().isIndex()) {
- int32_t value = static_cast<int32_t>(attr.getInt());
- return WriteBytes(&value, 4);
- } else {
- int bitWidth = attr.getValue().getBitWidth();
- switch (bitWidth) {
- case 8:
- case 16:
- case 32:
- case 64:
- return WriteBytes(attr.getValue().getRawData(), bitWidth / 8);
- default:
- return emitError(UnknownLoc::get(baseAttr.getContext()))
- << "Bit width for integers must be one of 8,16,32,64; others "
- "not implemented: "
- << bitWidth;
- }
- }
- } else if (auto attr = baseAttr.dyn_cast<FloatAttr>()) {
- int bitWidth = attr.getType().getIntOrFloatBitWidth();
- auto bitcastValue = attr.getValue().bitcastToAPInt();
- switch (bitWidth) {
- case 16:
- case 32:
- case 64:
- return WriteBytes(bitcastValue.getRawData(), bitWidth / 8);
- default:
- return emitError(UnknownLoc::get(baseAttr.getContext()))
- << "Bit width for floats must be one of 16,32,64; others "
- "not implemented: "
- << bitWidth;
- }
- } else if (auto attr = baseAttr.dyn_cast<StringAttr>()) {
- // TODO(benvanik): other attribute encodings.
- } else if (auto attr = baseAttr.dyn_cast<ArrayAttr>()) {
- // TODO(benvanik): other attribute encodings.
- } else if (auto attr = baseAttr.dyn_cast<AffineMapAttr>()) {
- // TODO(benvanik): other attribute encodings.
- } else if (auto attr = baseAttr.dyn_cast<IntegerSetAttr>()) {
- // TODO(benvanik): other attribute encodings.
- } else if (auto attr = baseAttr.dyn_cast<TypeAttr>()) {
- // TODO(benvanik): other attribute encodings.
- } else if (auto attr = baseAttr.dyn_cast<SymbolRefAttr>()) {
- // TODO(benvanik): other attribute encodings.
- } else if (auto attr = baseAttr.dyn_cast<SplatElementsAttr>()) {
- return WriteAttributeData(attr.getSplatValue());
- } else if (auto attr = baseAttr.dyn_cast<DenseIntElementsAttr>()) {
- int elementCount = attr.getType().getNumElements();
- if (elementCount == 0) {
- return success();
- }
- int bitWidth = attr.getType().getElementTypeBitWidth();
- int byteWidth = bitWidth / 8;
- auto dst = ReserveBytes(elementCount * byteWidth);
- if (dst.empty()) return failure();
- uint8_t *dstPtr = dst.data();
- for (auto element : attr) {
- assert(element.getBitWidth() == bitWidth);
- std::memcpy(dstPtr, element.getRawData(), byteWidth);
- dstPtr += byteWidth;
- }
- return success();
- } else if (auto attr = baseAttr.dyn_cast<DenseFPElementsAttr>()) {
- int elementCount = attr.getType().getNumElements();
- if (elementCount == 0) {
- return success();
- }
- int bitWidth = attr.getType().getElementTypeBitWidth();
- auto dst = ReserveBytes(elementCount * bitWidth / 8);
- if (dst.empty()) return failure();
- uint8_t *dstPtr = dst.data();
- for (auto element : attr) {
- auto bitcastValue = element.bitcastToAPInt();
- std::memcpy(dstPtr, bitcastValue.getRawData(),
- bitcastValue.getBitWidth() / 8);
- dstPtr += bitWidth / 8;
- }
- return success();
- } else if (auto attr = baseAttr.dyn_cast<DenseElementsAttr>()) {
- // TODO(benvanik): other attribute encodings.
- } else if (auto attr = baseAttr.dyn_cast<OpaqueElementsAttr>()) {
- // TODO(benvanik): other attribute encodings.
- } else if (auto attr = baseAttr.dyn_cast<SparseElementsAttr>()) {
- // TODO(benvanik): other attribute encodings.
- }
- return emitError(UnknownLoc::get(baseAttr.getContext()))
- << "Serializer for attribute kind "
- << static_cast<int>(baseAttr.getKind()) << " not implemented";
-}
-
-Optional<int> BytecodeWriter::LookupLocalOrdinal(Value *value) {
- int ordinal;
- auto it = localMap_.find(value);
- if (it != localMap_.end()) {
- ordinal = it->second;
- } else {
- ordinal = localMap_.size();
- localMap_.insert({value, ordinal});
- }
- if (ordinal > UINT16_MAX) {
- // TODO(benvanik): varints?
- emitError(UnknownLoc::get(value->getContext()))
- << "Too many ordinals: " << ordinal
- << "; only 0-UINT16_MAX are supported";
- return llvm::None;
- }
- return ordinal;
-}
-
-LogicalResult BytecodeWriter::PrepareLocal(Value *value) {
- if (!LookupLocalOrdinal(value).hasValue()) return failure();
- return success();
-}
-
-LogicalResult BytecodeWriter::WriteLocal(Value *value) {
- auto ordinal = LookupLocalOrdinal(value);
- if (!ordinal.hasValue()) {
- return failure();
- }
- if (ordinal.getValue() > UINT16_MAX) {
- // TODO(benvanik): varints?
- return emitError(UnknownLoc::get(value->getContext()))
- << "Too many locals: " << ordinal.getValue()
- << "; only 0-UINT16_MAX are supported";
- }
- return WriteUint16(static_cast<uint16_t>(ordinal.getValue()));
-}
-
-LogicalResult BytecodeWriter::WriteLocals(
- llvm::iterator_range<Operation::operand_iterator> values) {
- int count = std::distance(values.begin(), values.end());
- RETURN_IF_FAILURE(WriteCount(count));
- for (auto *value : values) {
- RETURN_IF_FAILURE(WriteLocal(value));
- }
- return success();
-}
-
-LogicalResult BytecodeWriter::WriteLocals(
- llvm::iterator_range<Operation::result_iterator> values) {
- int count = std::distance(values.begin(), values.end());
- RETURN_IF_FAILURE(WriteCount(count));
- for (auto *value : values) {
- RETURN_IF_FAILURE(WriteLocal(value));
- }
- return success();
-}
-
-MutableArrayRef<uint8_t> BytecodeWriter::ReserveBytes(size_t dataLength) {
- int offset = bytecode_.size();
- bytecode_.resize(offset + dataLength);
- return MutableArrayRef<uint8_t>(
- reinterpret_cast<uint8_t *>(bytecode_.data()) + offset, dataLength);
-}
-
-LogicalResult BytecodeWriter::WriteBytes(const void *data, size_t dataLength) {
- auto dst = ReserveBytes(dataLength);
- if (dataLength != dst.size()) {
- return failure();
- }
- std::memcpy(dst.data(), data, dst.size());
- return success();
-}
-
-LogicalResult BytecodeWriter::WriteUint8(uint8_t value) {
- return WriteBytes(&value, sizeof(value));
-}
-
-LogicalResult BytecodeWriter::WriteUint16(uint16_t value) {
- return WriteBytes(&value, sizeof(value));
-}
-
-LogicalResult BytecodeWriter::WriteInt32(int32_t value) {
- return WriteBytes(&value, sizeof(value));
-}
-
-LogicalResult BytecodeWriter::WriteUint32(uint32_t value) {
- return WriteBytes(&value, sizeof(value));
-}
-
-LogicalResult BytecodeWriter::WriteElementsAttrInt32(ElementsAttr attr) {
- int elementCount = attr.getType().getNumElements();
- RETURN_IF_FAILURE(WriteCount(elementCount));
- for (auto value : attr.getValues<int32_t>()) {
- RETURN_IF_FAILURE(WriteInt32(value));
- }
- return success();
-}
-
-LogicalResult BytecodeWriter::WriteShapePieces(const ShapedType &type) {
- RETURN_IF_FAILURE(WriteCount(type.getRank()));
- for (int64_t dim : type.getShape()) {
- RETURN_IF_FAILURE(WriteInt32(dim));
- }
- return success();
-}
-
-LogicalResult BytecodeWriter::WriteShapePieces(ElementsAttr pieces) {
- return WriteElementsAttrInt32(pieces);
-}
-
-LogicalResult BytecodeWriter::MarkBlockOffset(Block *block) {
- blockOffsets_[block] = bytecode_.size();
- return success();
-}
-
-LogicalResult BytecodeWriter::WriteBlockOffset(Block *targetBlock) {
- // Reserve space for the offset and stash for later fixup.
- blockOffsetFixups_.push_back({targetBlock, bytecode_.size()});
- bytecode_.resize(bytecode_.size() + sizeof(int32_t));
- return success();
-}
-
-LogicalResult BytecodeWriter::FixupOffsets() {
- for (const auto &fixup : blockOffsetFixups_) {
- auto it = blockOffsets_.find(fixup.first);
- if (it == blockOffsets_.end()) {
- llvm::errs() << "Block offset not found: " << fixup.first;
- return failure();
- }
- std::memcpy(bytecode_.data() + fixup.second, &it->second, sizeof(int32_t));
- }
- blockOffsetFixups_.clear();
- return success();
-}
-
-std::vector<uint8_t> BytecodeWriter::Finish() {
- localMap_.clear();
- return std::move(bytecode_);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Serialization/BytecodeWriter.h b/iree/compiler/Serialization/BytecodeWriter.h
deleted file mode 100644
index dbadd1c..0000000
--- a/iree/compiler/Serialization/BytecodeWriter.h
+++ /dev/null
@@ -1,96 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_
-#define IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_
-
-#include <cstddef>
-#include <utility>
-#include <vector>
-
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/schemas/bytecode/bytecode_v0.h"
-#include "llvm/ADT/Optional.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Block.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/Types.h"
-#include "mlir/IR/Value.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-class BytecodeWriter {
- public:
- int offset() const { return bytecode_.size(); }
-
- int local_count() const { return localMap_.size(); }
-
- template <typename T>
- LogicalResult WriteOpcode(T value) {
- static_assert(sizeof(T) == sizeof(uint8_t), "Opcode enum size mismatch");
- return WriteUint8(static_cast<uint8_t>(value));
- }
-
- LogicalResult WriteCount(int count);
-
- LogicalResult WriteTypeIndex(Type type);
-
- LogicalResult WriteFunctionOrdinal(FuncOp function);
- LogicalResult WriteImportOrdinal(FuncOp function);
-
- LogicalResult WriteConstant(MemRefType memRefType, Attribute baseAttr);
- LogicalResult WriteAttributeData(Attribute baseAttr);
-
- llvm::Optional<int> LookupLocalOrdinal(Value *value);
- LogicalResult PrepareLocal(Value *value);
- LogicalResult WriteLocal(Value *value);
- LogicalResult WriteLocals(
- llvm::iterator_range<Operation::operand_iterator> values);
- LogicalResult WriteLocals(
- llvm::iterator_range<Operation::result_iterator> values);
-
- LogicalResult WriteBytes(const void *data, size_t dataLength);
- MutableArrayRef<uint8_t> ReserveBytes(size_t dataLength);
- LogicalResult WriteUint8(uint8_t value);
- LogicalResult WriteUint16(uint16_t value);
- LogicalResult WriteInt32(int32_t value);
- LogicalResult WriteUint32(uint32_t value);
-
- LogicalResult WriteElementsAttrInt32(ElementsAttr attr);
-
- LogicalResult WriteShapePieces(const ShapedType &type);
- LogicalResult WriteShapePieces(ElementsAttr pieces);
-
- LogicalResult MarkBlockOffset(Block *block);
- LogicalResult WriteBlockOffset(Block *targetBlock);
- LogicalResult FixupOffsets();
-
- std::vector<uint8_t> Finish();
-
- private:
- std::vector<uint8_t> bytecode_;
-
- llvm::DenseMap<Value *, int> localMap_;
-
- llvm::DenseMap<Block *, size_t> blockOffsets_;
- std::vector<std::pair<Block *, size_t>> blockOffsetFixups_;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_
diff --git a/iree/compiler/Serialization/VMDeviceTableBuilder.cpp b/iree/compiler/Serialization/VMDeviceTableBuilder.cpp
deleted file mode 100644
index 7703014..0000000
--- a/iree/compiler/Serialization/VMDeviceTableBuilder.cpp
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Serialization/VMDeviceTableBuilder.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-VMDeviceTableBuilder::VMDeviceTableBuilder(
- ::flatbuffers::FlatBufferBuilder *fbb)
- : fbb_(fbb) {}
-
-LogicalResult VMDeviceTableBuilder::AddDevice(
- ::flatbuffers::Offset<iree::DeviceDef> deviceDef) {
- deviceDefs_.push_back(deviceDef);
- return success();
-}
-
-LogicalResult VMDeviceTableBuilder::AddDeviceGroup(
- ::flatbuffers::Offset<iree::DeviceGroupDef> deviceGroupDef) {
- deviceGroupDefs_.push_back(deviceGroupDef);
- return success();
-}
-
-::flatbuffers::Offset<iree::DeviceTableDef> VMDeviceTableBuilder::Finish() {
- auto devicesOffset = fbb_->CreateVector(deviceDefs_);
- auto deviceGroupsOffset = fbb_->CreateVector(deviceGroupDefs_);
- iree::DeviceTableDefBuilder dtdb(*fbb_);
- dtdb.add_devices(devicesOffset);
- dtdb.add_device_groups(deviceGroupsOffset);
- return dtdb.Finish();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Serialization/VMDeviceTableBuilder.h b/iree/compiler/Serialization/VMDeviceTableBuilder.h
deleted file mode 100644
index 9170baf..0000000
--- a/iree/compiler/Serialization/VMDeviceTableBuilder.h
+++ /dev/null
@@ -1,45 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_
-#define IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_
-
-#include "flatbuffers/flatbuffers.h"
-#include "iree/schemas/device_table_def_generated.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-class VMDeviceTableBuilder {
- public:
- explicit VMDeviceTableBuilder(::flatbuffers::FlatBufferBuilder *fbb);
-
- LogicalResult AddDevice(::flatbuffers::Offset<iree::DeviceDef> deviceDef);
-
- LogicalResult AddDeviceGroup(
- ::flatbuffers::Offset<iree::DeviceGroupDef> deviceGroupDef);
-
- ::flatbuffers::Offset<iree::DeviceTableDef> Finish();
-
- private:
- ::flatbuffers::FlatBufferBuilder *fbb_;
- std::vector<::flatbuffers::Offset<iree::DeviceDef>> deviceDefs_;
- std::vector<::flatbuffers::Offset<iree::DeviceGroupDef>> deviceGroupDefs_;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_
diff --git a/iree/compiler/Serialization/VMExecutableTableBuilder.cpp b/iree/compiler/Serialization/VMExecutableTableBuilder.cpp
deleted file mode 100644
index 7a5004a..0000000
--- a/iree/compiler/Serialization/VMExecutableTableBuilder.cpp
+++ /dev/null
@@ -1,41 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Serialization/VMExecutableTableBuilder.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-VMExecutableTableBuilder::VMExecutableTableBuilder(
- ::flatbuffers::FlatBufferBuilder *fbb)
- : fbb_(fbb) {}
-
-LogicalResult VMExecutableTableBuilder::AddMultiArchExecutable(
- ::flatbuffers::Offset<iree::MultiArchExecutableDef>
- multiArchExecutableDef) {
- multiArchExecutableDefs_.push_back(multiArchExecutableDef);
- return success();
-}
-
-::flatbuffers::Offset<iree::ExecutableTableDef>
-VMExecutableTableBuilder::Finish() {
- auto multiArchExecutablesOffset =
- fbb_->CreateVector(multiArchExecutableDefs_);
- iree::ExecutableTableDefBuilder etdb(*fbb_);
- etdb.add_multi_arch_executables(multiArchExecutablesOffset);
- return etdb.Finish();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Serialization/VMExecutableTableBuilder.h b/iree/compiler/Serialization/VMExecutableTableBuilder.h
deleted file mode 100644
index 1125a4d..0000000
--- a/iree/compiler/Serialization/VMExecutableTableBuilder.h
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_
-#define IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_
-
-#include "flatbuffers/flatbuffers.h"
-#include "iree/schemas/executable_table_def_generated.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-class VMExecutableTableBuilder {
- public:
- explicit VMExecutableTableBuilder(::flatbuffers::FlatBufferBuilder *fbb);
-
- LogicalResult AddMultiArchExecutable(
- ::flatbuffers::Offset<iree::MultiArchExecutableDef>
- multiArchExecutableDef);
-
- ::flatbuffers::Offset<iree::ExecutableTableDef> Finish();
-
- private:
- ::flatbuffers::FlatBufferBuilder *fbb_;
- std::vector<::flatbuffers::Offset<iree::MultiArchExecutableDef>>
- multiArchExecutableDefs_;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_
diff --git a/iree/compiler/Serialization/VMFunctionBuilder.cpp b/iree/compiler/Serialization/VMFunctionBuilder.cpp
deleted file mode 100644
index cc2fee2..0000000
--- a/iree/compiler/Serialization/VMFunctionBuilder.cpp
+++ /dev/null
@@ -1,359 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Serialization/VMFunctionBuilder.h"
-
-#include "flatbuffers/flatbuffers.h"
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Types.h"
-#include "iree/compiler/Serialization/BytecodeTables.h"
-#include "iree/compiler/Utils/Macros.h"
-#include "iree/schemas/bytecode/bytecode_v0.h"
-#include "iree/schemas/type_def_generated.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/Module.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-LogicalResult WriteGenericIreeOp(Block *block, Operation *op,
- BytecodeWriter *writer) {
- // Strip the dialect name from the op name and lookup the opcode.
- // TODO(benvanik): adjust for supporting sequencer opcodes.
-
- auto opName = op->getName().getStringRef();
- auto dialect = op->getDialect();
- if (!dialect) {
- return op->emitOpError() << "Op does not belong to a registered dialect";
- }
-
- auto dialectNamespace = dialect->getNamespace();
- std::unique_ptr<OpcodeInfo> operandInfo;
- auto strippedOpName = opName.substr(opName.find('.') + 1).str();
- if (dialectNamespace == "iree_ll_seq") {
- auto opcode = GetSequencerOpcodeByName(strippedOpName);
- if (!opcode.hasValue()) {
- return op->emitOpError()
- << "No sequencer opcode found for op; is it a pseudo op?";
- }
- RETURN_IF_FAILURE(writer->WriteOpcode(opcode.getValue()));
- operandInfo =
- std::make_unique<OpcodeInfo>(GetSequencerOpcodeInfo(opcode.getValue()));
- } else if (dialectNamespace == "iree_ll_interp" ||
- // TODO(gcmn) remove special case for IREE dialect?
- dialectNamespace == IREEDialect::getDialectNamespace()) {
- auto opcode = GetInterpreterOpcodeByName(strippedOpName);
- if (!opcode.hasValue()) {
- return op->emitOpError()
- << "No interpreter opcode found for op; is it a pseudo op?";
- }
- RETURN_IF_FAILURE(writer->WriteOpcode(opcode.getValue()));
- operandInfo = std::make_unique<OpcodeInfo>(
- GetInterpreterOpcodeInfo(opcode.getValue()));
- } else {
- return op->emitOpError()
- << "Op belongs to unknown dialect " << dialectNamespace.str();
- }
- // Write inputs and outputs based on the bytecode encoding.
- int operandIndex = 0;
- int resultIndex = 0;
- for (int i = 0; i < llvm::array_lengthof(operandInfo->operands); ++i) {
- auto op_encoding = operandInfo->operands[i];
- if (op_encoding == iree::OperandEncoding::kNone) break;
- switch (op_encoding) {
- case iree::OperandEncoding::kInputSlot:
- case iree::OperandEncoding::kOutputSlot: {
- auto *value = op->getOperand(operandIndex++);
- RETURN_IF_FAILURE(writer->WriteLocal(value));
- break;
- }
- case iree::OperandEncoding::kVariadicInputSlots:
- case iree::OperandEncoding::kVariadicOutputSlots: {
- int count = op->getNumOperands() - operandIndex;
- RETURN_IF_FAILURE(writer->WriteCount(count));
- for (; count; --count) {
- auto *value = op->getOperand(operandIndex++);
- RETURN_IF_FAILURE(writer->WriteLocal(value));
- }
- break;
- }
- case iree::OperandEncoding::kResultSlot: {
- auto *value = op->getResult(resultIndex++);
- RETURN_IF_FAILURE(writer->WriteLocal(value));
- break;
- }
- case iree::OperandEncoding::kVariadicResultSlots: {
- int count = op->getNumResults() - resultIndex;
- RETURN_IF_FAILURE(writer->WriteCount(count));
- for (; count; --count) {
- auto *value = op->getResult(resultIndex++);
- RETURN_IF_FAILURE(writer->WriteLocal(value));
- }
- break;
- }
- case iree::OperandEncoding::kConstant:
- case iree::OperandEncoding::kFunctionOrdinal:
- case iree::OperandEncoding::kBlockOffset:
- case iree::OperandEncoding::kTypeIndex:
- case iree::OperandEncoding::kIndex:
- case iree::OperandEncoding::kIndexList:
- case iree::OperandEncoding::kCmpIPredicate:
- case iree::OperandEncoding::kCmpFPredicate:
- return op->emitOpError()
- << "Operand encoding " << static_cast<char>(op_encoding)
- << " not supported by generic writer for " << opName.str();
- return failure();
- default:
- return op->emitOpError()
- << "Operand encoding " << static_cast<char>(op_encoding) << " ("
- << static_cast<int>(op_encoding) << ") not recognized (typo?)";
- }
- }
-
- return success();
-}
-
-} // namespace
-
-VMFunctionBuilder::VMFunctionBuilder(FuncOp function,
- VMFunctionTableBuilder *functionTable,
- ::flatbuffers::FlatBufferBuilder *fbb)
- : context_(function.getContext()),
- function_(function),
- functionTable_(functionTable),
- fbb_(fbb) {}
-
-void VMFunctionBuilder::RegisterCustomWriter(StringRef operationName,
- CustomWriterFn writerFn) {
- customWriters_.insert({operationName, writerFn});
-}
-
-LogicalResult VMFunctionBuilder::ConvertBytecode() {
- BytecodeWriter writer;
- sourceMap_ = {};
-
- RETURN_IF_FAILURE(BeginFunction(function_, &writer));
- for (auto &block : function_.getBlocks()) {
- RETURN_IF_FAILURE(BeginBlock(&block, &writer));
- for (auto &op : block.getOperations()) {
- if (failed(WriteOperation(&block, &op, &writer))) {
- op.emitError() << "Unable to serialize operation";
- return failure();
- }
- }
- RETURN_IF_FAILURE(EndBlock(&block, block.getTerminator(), &writer));
- }
- RETURN_IF_FAILURE(EndFunction(function_, &writer));
-
- int localCount = writer.local_count();
- auto bodyBytes = writer.Finish();
- auto bodyOffset = fbb_->CreateVector(
- reinterpret_cast<const int8_t *>(bodyBytes.data()), bodyBytes.size());
- iree::BytecodeDefBuilder bdb(*fbb_);
- bdb.add_local_count(localCount);
- bdb.add_contents(bodyOffset);
- bytecodeDef_ = bdb.Finish();
-
- return success();
-}
-
-::flatbuffers::Offset<iree::FunctionDef> VMFunctionBuilder::Finish() {
- using TypeDefVector =
- ::flatbuffers::Vector<::flatbuffers::Offset<iree::TypeDef>>;
-
- const auto &functionType = function_.getType();
- std::vector<::flatbuffers::Offset<iree::TypeDef>> inputs;
- for (const auto &type : functionType.getInputs()) {
- auto typeOffset = SerializeType(type, fbb_);
- if (typeOffset.IsNull()) return {};
- inputs.push_back(typeOffset);
- }
- ::flatbuffers::Offset<TypeDefVector> inputsOffset;
- if (!inputs.empty()) {
- inputsOffset = fbb_->CreateVector(inputs);
- }
-
- std::vector<::flatbuffers::Offset<iree::TypeDef>> results;
- for (const auto &type : functionType.getResults()) {
- auto typeOffset = SerializeType(type, fbb_);
- if (typeOffset.IsNull()) return {};
- results.push_back(typeOffset);
- }
- ::flatbuffers::Offset<TypeDefVector> resultsOffset;
- if (!results.empty()) {
- resultsOffset = fbb_->CreateVector(results);
- }
- iree::FunctionTypeDefBuilder ftb(*fbb_);
- ftb.add_inputs(inputsOffset);
- ftb.add_results(resultsOffset);
- auto functionTypeOffset = ftb.Finish();
-
- // TODO(benvanik): strip names of internal functions.
- auto nameOffset = fbb_->CreateString(function_.getName().str());
- iree::FunctionDefBuilder fdb(*fbb_);
- fdb.add_name(nameOffset);
- fdb.add_type(functionTypeOffset);
- fdb.add_bytecode(bytecodeDef_);
- return fdb.Finish();
-}
-
-LogicalResult VMFunctionBuilder::BeginFunction(FuncOp function,
- BytecodeWriter *writer) {
- // Assign value slots for all arguments and results.
- // Keeping them at the front will make it easier to find during debugging
- // and makes spans easier to compute at runtime.
- for (auto argument : function.getArguments()) {
- RETURN_IF_FAILURE(writer->PrepareLocal(argument));
- }
- return success();
-}
-
-LogicalResult VMFunctionBuilder::EndFunction(FuncOp function,
- BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->FixupOffsets());
- return success();
-}
-
-LogicalResult VMFunctionBuilder::BeginBlock(Block *block,
- BytecodeWriter *writer) {
- RETURN_IF_FAILURE(writer->MarkBlockOffset(block));
- return success();
-}
-
-LogicalResult VMFunctionBuilder::EndBlock(Block *block, Operation *op,
- BytecodeWriter *writer) {
- return success();
-}
-
-LogicalResult VMFunctionBuilder::WriteOperation(Block *block, Operation *baseOp,
- BytecodeWriter *writer) {
- if (!baseOp->getLoc().isa<UnknownLoc>()) {
- sourceMap_.locations.push_back({writer->offset(), baseOp->getLoc()});
- }
-
- // Check registered writers first to allow overrides.
- auto writerIt = customWriters_.find(baseOp->getName().getStringRef());
- if (writerIt != customWriters_.end()) {
- return writerIt->second(baseOp, writer);
- }
-
- // Fallback to using the generic writer.
- if (baseOp->getAbstractOperation()->dialect.getNamespace().startswith(
- "iree")) {
- RETURN_IF_FAILURE(WriteGenericIreeOp(block, baseOp, writer));
- } else {
- return baseOp->emitError()
- << "Unsupported op " << baseOp->getName().getStringRef().str()
- << "; incorrectly outlined or not yet implemented";
- }
- return success();
-}
-
-::flatbuffers::Offset<iree::TypeDef> VMFunctionBuilder::SerializeType(
- Type type, ::flatbuffers::FlatBufferBuilder *fbb) {
- ::flatbuffers::Offset<void> typeDefUnion;
- iree::TypeDefUnion typeUnionType;
- if (auto memRefType = type.dyn_cast<MemRefType>()) {
- auto memRefTypeOffset = SerializeMemRefType(memRefType, fbb_);
- if (memRefTypeOffset.IsNull()) return {};
- typeDefUnion = memRefTypeOffset.Union();
- typeUnionType = iree::TypeDefUnion::MemRefTypeDef;
- } else if (auto deviceType = type.dyn_cast<DeviceType>()) {
- typeDefUnion = iree::CreateDeviceTypeDef(*fbb).Union();
- typeUnionType = iree::TypeDefUnion::DeviceTypeDef;
- } else if (auto commandBufferType = type.dyn_cast<CommandBufferType>()) {
- typeDefUnion = iree::CreateCommandBufferTypeDef(*fbb).Union();
- typeUnionType = iree::TypeDefUnion::CommandBufferTypeDef;
- } else if (auto eventType = type.dyn_cast<EventType>()) {
- typeDefUnion = iree::CreateEventTypeDef(*fbb).Union();
- typeUnionType = iree::TypeDefUnion::EventTypeDef;
- } else if (auto semaphoreType = type.dyn_cast<SemaphoreType>()) {
- typeDefUnion = iree::CreateSemaphoreTypeDef(*fbb).Union();
- typeUnionType = iree::TypeDefUnion::SemaphoreTypeDef;
- } else if (auto fenceType = type.dyn_cast<FenceType>()) {
- typeDefUnion = iree::CreateFenceTypeDef(*fbb).Union();
- typeUnionType = iree::TypeDefUnion::FenceTypeDef;
- } else {
- function_.emitError() << "Function " << function_.getName().str()
- << " has unsupported I/O with type " << type;
- return {};
- }
-
- iree::TypeDefBuilder tdb(*fbb);
- tdb.add_type_union_type(typeUnionType);
- tdb.add_type_union(typeDefUnion);
- return tdb.Finish();
-}
-
-::flatbuffers::Offset<iree::MemRefTypeDef>
-VMFunctionBuilder::SerializeMemRefType(const MemRefType &type,
- ::flatbuffers::FlatBufferBuilder *fbb) {
- auto elementTypeOffset = SerializeElementType(type.getElementType(), fbb);
- if (elementTypeOffset.IsNull()) return {};
- std::vector<int> shape;
- for (int dim : type.getShape()) {
- shape.push_back(dim);
- }
- auto shapeOffset = fbb->CreateVector(shape);
- iree::MemRefTypeDefBuilder tb(*fbb);
- tb.add_element_type(elementTypeOffset);
- tb.add_shape(shapeOffset);
- tb.add_memory_space(type.getMemorySpace());
- return tb.Finish();
-}
-
-::flatbuffers::Offset<iree::ElementTypeDef>
-VMFunctionBuilder::SerializeElementType(const Type &genericType,
- ::flatbuffers::FlatBufferBuilder *fbb) {
- ::flatbuffers::Offset<void> typeDefUnion;
- iree::ElementTypeDefUnion typeUnionType;
- if (auto type = genericType.dyn_cast<FloatType>()) {
- iree::FloatTypeDefBuilder tb(*fbb);
- tb.add_width(type.getWidth());
- typeDefUnion = tb.Finish().Union();
- typeUnionType = iree::ElementTypeDefUnion::FloatTypeDef;
- } else if (auto type = genericType.dyn_cast<IntegerType>()) {
- iree::IntegerTypeDefBuilder tb(*fbb);
- tb.add_width(type.getWidth());
- typeDefUnion = tb.Finish().Union();
- typeUnionType = iree::ElementTypeDefUnion::IntegerTypeDef;
- } else if (auto type = genericType.dyn_cast<OpaqueType>()) {
- auto dialectOffset = fbb->CreateString(type.getDialectNamespace().c_str());
- auto typeDataOffset = fbb->CreateString(type.getTypeData().data());
- iree::UnknownTypeDefBuilder tb(*fbb);
- tb.add_dialect(dialectOffset);
- tb.add_type_data(typeDataOffset);
- typeDefUnion = tb.Finish().Union();
- typeUnionType = iree::ElementTypeDefUnion::UnknownTypeDef;
- } else {
- function_.emitError()
- << "Unimplemented type encoding: " << genericType
- << "; ensure IREE lowering passes are converting types to the IREE "
- "set";
- return {};
- }
-
- iree::ElementTypeDefBuilder tdb(*fbb);
- tdb.add_type_union_type(typeUnionType);
- tdb.add_type_union(typeDefUnion);
- return tdb.Finish();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Serialization/VMFunctionBuilder.h b/iree/compiler/Serialization/VMFunctionBuilder.h
deleted file mode 100644
index 129afeb..0000000
--- a/iree/compiler/Serialization/VMFunctionBuilder.h
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_
-#define IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_
-
-#include "iree/compiler/Serialization/BytecodeWriter.h"
-#include "iree/compiler/Serialization/VMFunctionTableBuilder.h"
-#include "iree/compiler/Serialization/VMSourceMapBuilder.h"
-#include "iree/schemas/bytecode_def_generated.h"
-#include "iree/schemas/function_def_generated.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-class VMFunctionBuilder {
- public:
- using CustomWriterFn =
- std::function<LogicalResult(Operation *, BytecodeWriter *writer)>;
-
- VMFunctionBuilder(FuncOp function, VMFunctionTableBuilder *functionTable,
- ::flatbuffers::FlatBufferBuilder *fbb);
- ~VMFunctionBuilder() = default;
-
- void RegisterCustomWriter(StringRef operationName, CustomWriterFn writerFn);
-
- const VMFunctionSourceMap &source_map() const { return sourceMap_; }
-
- LogicalResult ConvertBytecode();
-
- ::flatbuffers::Offset<iree::FunctionDef> Finish();
-
- ::flatbuffers::Offset<iree::TypeDef> SerializeType(
- Type type, ::flatbuffers::FlatBufferBuilder *fbb);
- ::flatbuffers::Offset<iree::MemRefTypeDef> SerializeMemRefType(
- const MemRefType &genericType, ::flatbuffers::FlatBufferBuilder *fbb);
- ::flatbuffers::Offset<iree::ElementTypeDef> SerializeElementType(
- const Type &genericType, ::flatbuffers::FlatBufferBuilder *fbb);
-
- private:
- LogicalResult BeginFunction(FuncOp function, BytecodeWriter *writer);
- LogicalResult EndFunction(FuncOp function, BytecodeWriter *writer);
- LogicalResult BeginBlock(Block *block, BytecodeWriter *writer);
- LogicalResult EndBlock(Block *block, Operation *op, BytecodeWriter *writer);
-
- LogicalResult WriteOperation(Block *block, Operation *baseOp,
- BytecodeWriter *writer);
-
- llvm::StringMap<CustomWriterFn> customWriters_;
-
- MLIRContext *context_;
- FuncOp function_;
- VMFunctionTableBuilder *functionTable_;
- ::flatbuffers::FlatBufferBuilder *fbb_;
- ::flatbuffers::Offset<iree::BytecodeDef> bytecodeDef_;
- VMFunctionSourceMap sourceMap_;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_
diff --git a/iree/compiler/Serialization/VMFunctionTableBuilder.cpp b/iree/compiler/Serialization/VMFunctionTableBuilder.cpp
deleted file mode 100644
index 090b520..0000000
--- a/iree/compiler/Serialization/VMFunctionTableBuilder.cpp
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Serialization/VMFunctionTableBuilder.h"
-
-#include "iree/compiler/Serialization/VMSourceMapBuilder.h"
-#include "llvm/Support/raw_ostream.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-VMFunctionTableBuilder::VMFunctionTableBuilder(
- ::flatbuffers::FlatBufferBuilder *fbb)
- : fbb_(fbb) {}
-
-bool VMFunctionTableBuilder::IsFunctionDeclared(FuncOp funcOp) {
- return functionSet_.count(funcOp.getName()) != 0;
-}
-
-LogicalResult VMFunctionTableBuilder::DeclareFunction(FuncOp funcOp,
- LinkageType linkageType) {
- if (functionSet_.count(funcOp.getName())) {
- return funcOp.emitError() << "Function has already been declared/defined";
- }
- auto functionOrdinal = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal");
- if (!functionOrdinal) {
- return funcOp.emitError() << "Ordinal not assigned to function";
- }
- int ordinal = functionOrdinal.getInt();
- functionDefs_.resize(
- std::max(functionDefs_.size(), static_cast<size_t>(ordinal) + 1u));
- functionSourceMaps_.resize(
- std::max(functionDefs_.size(), static_cast<size_t>(ordinal) + 1u));
- functionSet_.insert({funcOp.getName()});
- switch (linkageType) {
- case LinkageType::kInternal:
- break;
- case LinkageType::kImport:
- importIndices_.push_back(ordinal);
- break;
- case LinkageType::kExport:
- exportIndices_.push_back(ordinal);
- break;
- }
- return success();
-}
-
-LogicalResult VMFunctionTableBuilder::DefineFunction(
- FuncOp funcOp, ::flatbuffers::Offset<iree::FunctionDef> functionDef,
- VMFunctionSourceMap functionSourceMap) {
- auto functionOrdinal = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal");
- if (!functionOrdinal) {
- return funcOp.emitError() << "Ordinal not assigned to function";
- }
- int ordinal = functionOrdinal.getInt();
- if (!functionDefs_[ordinal].IsNull()) {
- return funcOp.emitOpError() << "Function has already been defined";
- }
- functionDefs_[ordinal] = functionDef;
- functionSourceMaps_[ordinal] = std::move(functionSourceMap);
- return success();
-}
-
-::flatbuffers::Offset<iree::FunctionTableDef> VMFunctionTableBuilder::Finish() {
- auto functionsOffset = fbb_->CreateVector(functionDefs_);
- auto importsOffset = fbb_->CreateVector(importIndices_);
- auto exportsOffset = fbb_->CreateVector(exportIndices_);
- iree::FunctionTableDefBuilder ftdb(*fbb_);
- ftdb.add_functions(functionsOffset);
- ftdb.add_imports(importsOffset);
- ftdb.add_exports(exportsOffset);
- return ftdb.Finish();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Serialization/VMFunctionTableBuilder.h b/iree/compiler/Serialization/VMFunctionTableBuilder.h
deleted file mode 100644
index fc87abf..0000000
--- a/iree/compiler/Serialization/VMFunctionTableBuilder.h
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_
-#define IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_
-
-#include <string>
-#include <vector>
-
-#include "flatbuffers/flatbuffers.h"
-#include "iree/compiler/Serialization/VMSourceMapBuilder.h"
-#include "iree/schemas/function_def_generated.h"
-#include "iree/schemas/function_table_def_generated.h"
-#include "llvm/ADT/StringSet.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/OperationSupport.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-enum class LinkageType {
- kInternal,
- kImport,
- kExport,
-};
-
-class VMFunctionTableBuilder {
- public:
- explicit VMFunctionTableBuilder(::flatbuffers::FlatBufferBuilder *fbb);
-
- int max_function_ordinal() const { return functionDefs_.size(); }
-
- ArrayRef<VMFunctionSourceMap> function_source_maps() {
- return llvm::makeArrayRef(functionSourceMaps_);
- }
-
- // Returns true if |funcOp| has already been declared in the table.
- bool IsFunctionDeclared(FuncOp funcOp);
-
- // Declares |funcOp| with the given |linkageType|.
- // Fails if the function has already been declared or defined.
- LogicalResult DeclareFunction(FuncOp funcOp, LinkageType linkageType);
-
- // Defines |funcOp| using the given |functionDef|.
- LogicalResult DefineFunction(
- FuncOp funcOp, ::flatbuffers::Offset<iree::FunctionDef> functionDef,
- VMFunctionSourceMap functionSourceMap);
-
- ::flatbuffers::Offset<iree::FunctionTableDef> Finish();
-
- private:
- ::flatbuffers::FlatBufferBuilder *fbb_;
- llvm::StringSet<> functionSet_;
- std::vector<::flatbuffers::Offset<iree::FunctionDef>> functionDefs_;
- std::vector<VMFunctionSourceMap> functionSourceMaps_;
- std::vector<int> importIndices_;
- std::vector<int> exportIndices_;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_
diff --git a/iree/compiler/Serialization/VMModuleBuilder.cpp b/iree/compiler/Serialization/VMModuleBuilder.cpp
deleted file mode 100644
index 50cea2e..0000000
--- a/iree/compiler/Serialization/VMModuleBuilder.cpp
+++ /dev/null
@@ -1,70 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Serialization/VMModuleBuilder.h"
-
-#include "iree/schemas/executable_table_def_generated.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-VMModuleBuilder::VMModuleBuilder(::flatbuffers::FlatBufferBuilder *fbb)
- : fbb_(fbb),
- deviceTable_(fbb),
- functionTable_(fbb),
- executableTable_(fbb),
- sourceMap_(fbb) {}
-
-::flatbuffers::Offset<iree::ModuleDef> VMModuleBuilder::Finish() {
- auto nameOffset = fbb_->CreateString("module");
- auto deviceTableOffset = deviceTable_.Finish();
- if (deviceTableOffset.IsNull()) return {};
- auto functionTableOffset = functionTable_.Finish();
- if (functionTableOffset.IsNull()) return {};
- auto executableTableOffset = executableTable_.Finish();
- if (executableTableOffset.IsNull()) return {};
-
- for (int function_ordinal = 0;
- function_ordinal < functionTable_.function_source_maps().size();
- ++function_ordinal) {
- if (failed(sourceMap_.AddFunction(
- function_ordinal,
- functionTable_.function_source_maps()[function_ordinal]))) {
- return {};
- }
- }
- auto sourceMapOffset =
- sourceMap_.Finish(functionTable_.max_function_ordinal());
- if (sourceMapOffset.IsNull()) return {};
-
- iree::ModuleDefBuilder mdb(*fbb_);
- mdb.add_name(nameOffset);
- mdb.add_device_table(deviceTableOffset);
- mdb.add_function_table(functionTableOffset);
- mdb.add_executable_table(executableTableOffset);
- mdb.add_source_map(sourceMapOffset);
- return mdb.Finish();
-}
-
-std::vector<uint8_t> VMModuleBuilder::Serialize(
- ::flatbuffers::Offset<iree::ModuleDef> module_def) {
- FinishModuleDefBuffer(*fbb_, module_def);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb_->GetSize());
- std::memcpy(bytes.data(), fbb_->GetBufferPointer(), bytes.size());
- return bytes;
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Serialization/VMModuleBuilder.h b/iree/compiler/Serialization/VMModuleBuilder.h
deleted file mode 100644
index 1f991fa..0000000
--- a/iree/compiler/Serialization/VMModuleBuilder.h
+++ /dev/null
@@ -1,57 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_
-#define IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_
-
-#include <vector>
-
-#include "flatbuffers/flatbuffers.h"
-#include "iree/compiler/Serialization/VMDeviceTableBuilder.h"
-#include "iree/compiler/Serialization/VMExecutableTableBuilder.h"
-#include "iree/compiler/Serialization/VMFunctionTableBuilder.h"
-#include "iree/compiler/Serialization/VMSourceMapBuilder.h"
-#include "iree/schemas/module_def_generated.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-class VMModuleBuilder {
- public:
- explicit VMModuleBuilder(::flatbuffers::FlatBufferBuilder *fbb);
-
- ::flatbuffers::FlatBufferBuilder *fbb() const { return fbb_; }
- VMDeviceTableBuilder *device_table() { return &deviceTable_; }
- VMFunctionTableBuilder *function_table() { return &functionTable_; }
- VMExecutableTableBuilder *executable_table() { return &executableTable_; }
- VMSourceMapBuilder *source_map() { return &sourceMap_; }
-
- ::flatbuffers::Offset<iree::ModuleDef> Finish();
-
- std::vector<uint8_t> Serialize(
- ::flatbuffers::Offset<iree::ModuleDef> module_def);
-
- private:
- ::flatbuffers::FlatBufferBuilder *fbb_;
-
- VMDeviceTableBuilder deviceTable_;
- VMFunctionTableBuilder functionTable_;
- VMExecutableTableBuilder executableTable_;
- VMSourceMapBuilder sourceMap_;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_
diff --git a/iree/compiler/Serialization/VMSourceMapBuilder.cpp b/iree/compiler/Serialization/VMSourceMapBuilder.cpp
deleted file mode 100644
index c0b0637..0000000
--- a/iree/compiler/Serialization/VMSourceMapBuilder.cpp
+++ /dev/null
@@ -1,164 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Serialization/VMSourceMapBuilder.h"
-
-#include "flatbuffers/flatbuffers.h"
-#include "iree/schemas/source_map_def_generated.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/Identifier.h"
-#include "mlir/IR/Location.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-VMSourceMapBuilder::VMSourceMapBuilder(::flatbuffers::FlatBufferBuilder *fbb)
- : fbb_(fbb) {}
-
-int VMSourceMapBuilder::GetUniqueString(std::string value) {
- auto it = stringTableMap_.find(value);
- if (it != stringTableMap_.end()) {
- return it->second;
- }
- int stringIndex = stringTable_.size();
- stringTableMap_.insert({value, stringIndex});
- stringTable_.push_back(std::move(value));
- return stringIndex;
-}
-
-LogicalResult VMSourceMapBuilder::AddFunction(
- int functionOrdinal, VMFunctionSourceMap functionSourceMap) {
- if (functionMaps_.size() <= functionOrdinal) {
- functionMaps_.resize(functionOrdinal + 1);
- }
- functionMaps_[functionOrdinal] = std::move(functionSourceMap);
- return success();
-}
-
-::flatbuffers::Offset<iree::SourceMapDef> VMSourceMapBuilder::Finish(
- int maxFunctionOrdinal) {
- // NOTE: we always ensure the source map table is the same size as the
- // function table so that lookups at runtime can be validated once at load
- // time (ensuring the tables match up) instead of on each lookup.
- if (maxFunctionOrdinal < functionMaps_.size()) {
- llvm::errs() << "Max function ordinal defined as " << maxFunctionOrdinal
- << " but there are " << functionMaps_.size()
- << " function source maps present";
- return {};
- }
- functionMaps_.resize(maxFunctionOrdinal);
-
- std::vector<::flatbuffers::Offset<iree::FunctionSourceMapDef>> functionDefs;
- functionDefs.resize(maxFunctionOrdinal);
- for (int i = 0; i < functionMaps_.size(); ++i) {
- const auto &functionMap = functionMaps_[i];
- functionDefs[i] = SerializeVMFunctionSourceMap(functionMap);
- if (functionDefs[i].IsNull()) return {};
- }
-
- auto functionTableOffset = fbb_->CreateVector(functionDefs);
- auto stringTableOffset = fbb_->CreateVectorOfStrings(stringTable_);
- iree::SourceMapDefBuilder smdb(*fbb_);
- smdb.add_function_table(functionTableOffset);
- smdb.add_string_table(stringTableOffset);
- return smdb.Finish();
-}
-
-::flatbuffers::Offset<iree::FunctionSourceMapDef>
-VMSourceMapBuilder::SerializeVMFunctionSourceMap(
- const VMFunctionSourceMap &functionMap) {
- if (functionMap.locations.empty()) {
- // Empty table. This ensures that we still have a non-null value in the
- // function table but doesn't waste much space.
- iree::FunctionSourceMapDefBuilder fsmdb(*fbb_);
- return fsmdb.Finish();
- }
-
- LocationOffsetTable locationOffsetTable;
- std::vector<iree::BytecodeSourceLocation> bytecodeMap;
- for (const auto &offset_location : functionMap.locations) {
- int locationIndex =
- SerializeLocation(offset_location.second, &locationOffsetTable);
- bytecodeMap.push_back({offset_location.first, locationIndex});
- }
- auto locationTableOffset =
- fbb_->CreateVector(locationOffsetTable.locationDefs);
- auto bytecodeMapOffset = fbb_->CreateVectorOfStructs(bytecodeMap);
-
- iree::FunctionSourceMapDefBuilder fsmdb(*fbb_);
- fsmdb.add_location_table(locationTableOffset);
- fsmdb.add_bytecode_map(bytecodeMapOffset);
- return fsmdb.Finish();
-}
-
-int VMSourceMapBuilder::SerializeLocation(
- const Location &location, LocationOffsetTable *locationOffsetTable) {
- auto existingIt = locationOffsetTable->locationMap.find(location);
- if (existingIt != locationOffsetTable->locationMap.end()) {
- return existingIt->getSecond();
- }
-
- iree::LocationDefUnion locationUnionType;
- ::flatbuffers::Offset<void> locationUnionOffset;
- if (auto fileLoc = location.dyn_cast<FileLineColLoc>()) {
- locationUnionType = iree::LocationDefUnion::FileLocationDef;
- int filenameIndex = GetUniqueString(fileLoc.getFilename().str());
- iree::FileLocationDefBuilder lb(*fbb_);
- lb.add_filename(filenameIndex);
- lb.add_line(fileLoc.getLine());
- lb.add_column(fileLoc.getColumn());
- locationUnionOffset = lb.Finish().Union();
- } else if (auto nameLoc = location.dyn_cast<NameLoc>()) {
- locationUnionType = iree::LocationDefUnion::NameLocationDef;
- int nameIndex = GetUniqueString(nameLoc.getName().str());
- iree::NameLocationDefBuilder lb(*fbb_);
- lb.add_name(nameIndex);
- locationUnionOffset = lb.Finish().Union();
- } else if (auto callSiteLoc = location.dyn_cast<CallSiteLoc>()) {
- locationUnionType = iree::LocationDefUnion::CallSiteLocationDef;
- int calleeIndex =
- SerializeLocation(callSiteLoc.getCallee(), locationOffsetTable);
- int callerIndex =
- SerializeLocation(callSiteLoc.getCaller(), locationOffsetTable);
- iree::CallSiteLocationDefBuilder lb(*fbb_);
- lb.add_callee_location(calleeIndex);
- lb.add_caller_location(callerIndex);
- locationUnionOffset = lb.Finish().Union();
- } else if (auto fusedLoc = location.dyn_cast<FusedLoc>()) {
- locationUnionType = iree::LocationDefUnion::FusedLocationDef;
- std::vector<int> locationIndices;
- locationIndices.reserve(fusedLoc.getLocations().size());
- for (const auto &child_loc : fusedLoc.getLocations()) {
- int child_index = SerializeLocation(child_loc, locationOffsetTable);
- locationIndices.push_back(child_index);
- }
- auto locationIndicesOffset = fbb_->CreateVector(locationIndices);
- iree::FusedLocationDefBuilder lb(*fbb_);
- lb.add_locations(locationIndicesOffset);
- locationUnionOffset = lb.Finish().Union();
- } else {
- llvm_unreachable("Unimplemented location kind");
- }
-
- iree::LocationDefBuilder ldb(*fbb_);
- ldb.add_location_union_type(locationUnionType);
- ldb.add_location_union(locationUnionOffset);
- int locationIndex = locationOffsetTable->locationDefs.size();
- locationOffsetTable->locationDefs.push_back(ldb.Finish());
- locationOffsetTable->locationMap.insert({location, locationIndex});
- return locationIndex;
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Serialization/VMSourceMapBuilder.h b/iree/compiler/Serialization/VMSourceMapBuilder.h
deleted file mode 100644
index 6c202ac..0000000
--- a/iree/compiler/Serialization/VMSourceMapBuilder.h
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_
-#define IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_
-
-#include <vector>
-
-#include "flatbuffers/flatbuffers.h"
-#include "iree/schemas/source_map_def_generated.h"
-#include "llvm/ADT/StringMap.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-struct VMFunctionSourceMap {
- std::vector<std::pair<int, Location>> locations;
-};
-
-class VMSourceMapBuilder {
- public:
- explicit VMSourceMapBuilder(::flatbuffers::FlatBufferBuilder *fbb);
-
- LogicalResult AddFunction(int functionOrdinal,
- VMFunctionSourceMap functionSourceMap);
-
- ::flatbuffers::Offset<iree::SourceMapDef> Finish(int maxFunctionOrdinal);
-
- private:
- struct LocationOffsetTable {
- std::vector<::flatbuffers::Offset<iree::LocationDef>> locationDefs;
- llvm::DenseMap<Location, int> locationMap;
- };
-
- int GetUniqueString(std::string value);
-
- ::flatbuffers::Offset<iree::FunctionSourceMapDef>
- SerializeVMFunctionSourceMap(const VMFunctionSourceMap &functionMap);
- int SerializeLocation(const Location &location,
- LocationOffsetTable *locationOffsetTable);
-
- ::flatbuffers::FlatBufferBuilder *fbb_;
- std::vector<std::string> stringTable_;
- llvm::StringMap<int> stringTableMap_;
- std::vector<VMFunctionSourceMap> functionMaps_;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_
diff --git a/iree/compiler/Transforms/AggressiveOpElimination.cpp b/iree/compiler/Transforms/AggressiveOpElimination.cpp
deleted file mode 100644
index 31c706b..0000000
--- a/iree/compiler/Transforms/AggressiveOpElimination.cpp
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <deque>
-#include <memory>
-
-#include "iree/compiler/IR/Interpreter/HLOps.h"
-#include "iree/compiler/IR/Interpreter/LLOps.h"
-#include "iree/compiler/IR/Sequencer/HLOps.h"
-#include "iree/compiler/IR/Sequencer/LLOps.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "mlir/Analysis/Dominance.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Block.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-template <typename T>
-struct EraseUnused : public OpRewritePattern<T> {
- using OpRewritePattern<T>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(T op,
- PatternRewriter &rewriter) const override {
- if (op.use_empty()) {
- rewriter.eraseOp(op);
- return this->matchSuccess();
- }
- return this->matchFailure();
- }
-};
-
-void populateAggressiveOpEliminationPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns.insert<EraseUnused<LoadOp>, EraseUnused<AllocOp>,
- EraseUnused<IREESeq::HL::AllocHeapOp>,
- EraseUnused<IREESeq::LL::AllocHeapOp>,
- EraseUnused<IREEInterp::HL::AllocHeapOp>,
- EraseUnused<IREEInterp::LL::AllocHeapOp>>(ctx);
-}
-
-} // namespace
-
-// TODO(b/142012496) Make these be handled by normal DCE.
-class AggressiveOpEliminationPass
- : public FunctionPass<AggressiveOpEliminationPass> {
- public:
- void runOnFunction() override {
- OwningRewritePatternList patterns;
- populateAggressiveOpEliminationPatterns(patterns, &getContext());
-
- applyPatternsGreedily(getFunction(), patterns);
- }
-};
-
-std::unique_ptr<OpPassBase<FuncOp>> createAggressiveOpEliminationPass() {
- return std::make_unique<AggressiveOpEliminationPass>();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/BUILD b/iree/compiler/Transforms/BUILD
deleted file mode 100644
index 4060276..0000000
--- a/iree/compiler/Transforms/BUILD
+++ /dev/null
@@ -1,41 +0,0 @@
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Transforms",
- srcs = [
- "AggressiveOpElimination.cpp",
- "AssignFunctionOrdinals.cpp",
- "ConvertFromTupleCallingConvention.cpp",
- "ConvertToMemRefCallingConvention.cpp",
- "DropUnreachableFunctions.cpp",
- "DropUnusedExecutables.cpp",
- "LegalizeTypeStorage.cpp",
- "LowerStdToIreeDialect.cpp",
- "LowerXLAToIreeDialect.cpp",
- ],
- hdrs = [
- "ConversionUtils.h",
- "Passes.h",
- "Rewrites.h",
- ],
- deps = [
- "//iree/compiler/IR",
- "//iree/compiler/IR/Interpreter",
- "//iree/compiler/IR/Sequencer",
- "//iree/compiler/Utils",
- "@llvm//:support",
- "@local_config_mlir//:Analysis",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:Pass",
- "@local_config_mlir//:StandardDialectRegistration",
- "@local_config_mlir//:StandardOps",
- "@local_config_mlir//:Support",
- "@local_config_mlir//:TransformUtils",
- "@local_config_mlir//:Transforms",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
- ],
- alwayslink = 1,
-)
diff --git a/iree/compiler/Transforms/ConversionUtils.h b/iree/compiler/Transforms/ConversionUtils.h
deleted file mode 100644
index 18ccd6e..0000000
--- a/iree/compiler/Transforms/ConversionUtils.h
+++ /dev/null
@@ -1,101 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_
-#define IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_
-
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-template <typename SrcOp, typename DstOp>
-struct UnaryOpLowering : public OpConversionPattern<SrcOp> {
- using OpConversionPattern<SrcOp>::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- SrcOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *value = loadAccessValue(op.getLoc(), operands[0], rewriter);
- value = wrapAsMemRef(value, op, rewriter);
-
- auto dstType = convertTypeToMemRef(op.getResult());
- auto dstOp = rewriter.create<DstOp>(op.getLoc(), dstType, value);
- auto result = dstOp.getResult();
- result = wrapAsTensor(result, op, rewriter);
-
- rewriter.replaceOp(
- op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)});
- return this->matchSuccess();
- }
-};
-
-template <typename SrcOp, typename DstOp>
-struct BinaryOpLowering : public OpConversionPattern<SrcOp> {
- using OpConversionPattern<SrcOp>::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- SrcOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *lhsValue = loadAccessValue(op.getLoc(), operands[0], rewriter);
- auto *rhsValue = loadAccessValue(op.getLoc(), operands[1], rewriter);
- auto dstType = convertTypeToMemRef(op.getResult());
-
- lhsValue = wrapAsMemRef(lhsValue, op, rewriter);
- rhsValue = wrapAsMemRef(rhsValue, op, rewriter);
-
- auto midOp =
- rewriter.create<DstOp>(op.getLoc(), dstType, lhsValue, rhsValue);
- auto result = midOp.getResult();
- result = wrapAsTensor(result, op, rewriter);
-
- rewriter.replaceOp(
- op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)});
- return this->matchSuccess();
- }
-};
-
-template <typename SrcOp, typename DstOp>
-struct TernaryOpLowering : public OpConversionPattern<SrcOp> {
- using OpConversionPattern<SrcOp>::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- SrcOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *aValue = loadAccessValue(op.getLoc(), operands[0], rewriter);
- auto *bValue = loadAccessValue(op.getLoc(), operands[1], rewriter);
- auto *cValue = loadAccessValue(op.getLoc(), operands[2], rewriter);
-
- aValue = wrapAsMemRef(aValue, op, rewriter);
- bValue = wrapAsMemRef(bValue, op, rewriter);
- cValue = wrapAsMemRef(cValue, op, rewriter);
-
- auto dstType = convertTypeToMemRef(op.getResult());
- auto dstOp =
- rewriter.create<DstOp>(op.getLoc(), dstType, aValue, bValue, cValue);
- auto result = dstOp.getResult();
- result = wrapAsTensor(result, op, rewriter);
-
- rewriter.replaceOp(
- op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)});
- return this->matchSuccess();
- }
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_
diff --git a/iree/compiler/Transforms/ConvertToMemRefCallingConvention.cpp b/iree/compiler/Transforms/ConvertToMemRefCallingConvention.cpp
deleted file mode 100644
index 6bebd56..0000000
--- a/iree/compiler/Transforms/ConvertToMemRefCallingConvention.cpp
+++ /dev/null
@@ -1,398 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-void copyOperationAttrs(Operation *oldOp, Operation *newOp) {
- for (const auto &oldAttr : oldOp->getAttrs()) {
- newOp->setAttr(oldAttr.first, oldAttr.second);
- }
-}
-
-FunctionType getMemRefFunctionType(FunctionType type) {
- Builder builder(type.getContext());
- llvm::SmallVector<Type, 8> replacementInputs;
- for (auto type : type.getInputs()) {
- auto memRefType = convertTypeToMemRef(type);
- if (!memRefType) {
- return nullptr;
- }
- replacementInputs.push_back(memRefType);
- }
- llvm::SmallVector<Type, 8> replacementResults;
- for (auto type : type.getResults()) {
- auto memRefType = convertTypeToMemRef(type);
- if (!memRefType) {
- return nullptr;
- }
- replacementResults.push_back(memRefType);
- }
- return builder.getFunctionType(replacementInputs, replacementResults);
-}
-
-bool insertLoad(BlockArgument *oldArg, BlockArgument *newArg,
- OpBuilder &builder, BlockAndValueMapping *mapping) {
- auto loc = oldArg->getOwner()->getParent()->getLoc();
-
- // If old arg was a memref we don't need to change anything. We still need
- // to remap so that the use lists match through conversion, though.
- if (oldArg->getType().isa<MemRefType>()) {
- mapping->map(oldArg, newArg);
- return false;
- } else if (oldArg->getType().isa<TensorType>()) {
- auto castOp = builder.create<IREE::MemRefToTensorOp>(loc, newArg);
- mapping->map(oldArg, castOp.getResult());
- return false;
- }
-
- // Insert the load we'll use to unbox the value.
- auto loadedValue = builder.create<LoadOp>(loc, newArg, ArrayRef<Value *>{});
- mapping->map(oldArg, loadedValue);
-
- return false;
-}
-
-bool insertLoad(Operation *oldOp, Value *oldValue, Value *newValue,
- OpBuilder &builder, BlockAndValueMapping *mapping) {
- // If old value was a memref we don't need to change anything.
- if (oldValue->getType().isa<MemRefType>()) {
- mapping->map(oldValue, newValue);
- return false;
- } else if (oldValue->getType().isa<TensorType>()) {
- auto castOp =
- builder.create<IREE::MemRefToTensorOp>(oldOp->getLoc(), newValue);
- mapping->map(oldValue, castOp.getResult());
- return false;
- }
-
- assert(newValue->getType().isa<MemRefType>());
-
- // Insert the load we'll use to unbox the value.
- auto loadedValue =
- builder.create<LoadOp>(oldOp->getLoc(), newValue, ArrayRef<Value *>{});
- mapping->map(oldValue, loadedValue);
-
- return false;
-}
-
-Value *insertStore(Operation *oldOp, Value *oldValue, OpBuilder &builder,
- BlockAndValueMapping *mapping) {
- auto *newValue = mapping->lookupOrNull(oldValue);
- if (!newValue) {
- return nullptr;
- }
-
- // If the previous value was already a memref we don't need to change
- // anything.
- // TODO(benvanik): ensure indices make sense.
- if (oldValue->getType().isa<MemRefType>()) {
- return newValue;
- } else if (oldValue->getType().isa<TensorType>()) {
- auto castOp =
- builder.create<IREE::TensorToMemRefOp>(oldOp->getLoc(), newValue);
- return castOp.getResult();
- }
-
- // Look back up and see if we can find the memref the value was loaded from.
- if (auto *sourceMemRef = resolveValueToSourceMemRef(oldValue, oldOp)) {
- return mapping->lookupOrNull(sourceMemRef);
- }
-
- // Allocate the memref to store the value.
- auto newStorage = builder.create<AllocOp>(
- oldOp->getLoc(), convertTypeToMemRef(oldValue->getType()));
-
- // Insert the store we'll use to box the value.
- builder.create<StoreOp>(oldOp->getLoc(), newValue, newStorage,
- ArrayRef<Value *>{});
-
- return newStorage;
-}
-
-bool convertCallOp(CallOp *oldOp, OpBuilder &builder,
- BlockAndValueMapping *mapping) {
- llvm::SmallVector<Value *, 4> newArgs;
- for (auto *oldArg : oldOp->getOperands()) {
- auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
- if (!newArg) {
- return true;
- }
- newArgs.push_back(newArg);
- }
-
- SmallVector<Type, 4> resultTypes;
- for (auto oldType : oldOp->getOperation()->getResultTypes()) {
- resultTypes.push_back(convertTypeToMemRef(oldType));
- }
- auto newOp = builder.create<CallOp>(oldOp->getLoc(), oldOp->getCallee(),
- resultTypes, newArgs);
- copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
-
- for (int i = 0; i < newOp.getNumResults(); ++i) {
- auto *oldResult = oldOp->getResult(i);
- auto *newResult = newOp.getResult(i);
- if (insertLoad(oldOp->getOperation(), oldResult, newResult, builder,
- mapping)) {
- return true;
- }
- }
-
- return false;
-}
-
-bool convertCallIndirectOp(CallIndirectOp *oldOp, OpBuilder &builder,
- BlockAndValueMapping *mapping) {
- // TODO(benvanik): support wrapping callee values.
- oldOp->emitError("CallIndirectOp not yet supported");
- return true;
-#if 0
- llvm::SmallVector<Value *, 4> newArgs;
- for (auto *oldArg : oldOp->getArgOperands()) {
- auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
- if (!newArg) {
- return true;
- }
- newArgs.push_back(newArg);
- }
-
- auto newOp = builder.create<CallIndirectOp>(oldOp->getLoc(),
- oldOp->getCallee(), newArgs);
- copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
-
- for (int i = 0; i < newOp.getNumResults(); ++i) {
- auto *oldResult = oldOp->getResult(i);
- auto *newResult = newOp.getResult(i);
- if (insertLoad(oldOp->getOperation(), oldResult, newResult, builder,
- mapping)) {
- return true;
- }
- }
-
- return false;
-#endif // 0
-}
-
-bool convertReturnOp(Operation *oldOp, OpBuilder &builder,
- BlockAndValueMapping *mapping) {
- BlockAndValueMapping returnMapping;
- for (auto *oldArg : oldOp->getOperands()) {
- auto *newArg = insertStore(oldOp, oldArg, builder, mapping);
- if (!newArg) {
- return true;
- }
- returnMapping.map(oldArg, newArg);
- }
-
- builder.clone(*oldOp, returnMapping);
- return false;
-}
-
-bool convertBranchOp(BranchOp *oldOp, OpBuilder &builder,
- BlockAndValueMapping *mapping) {
- llvm::SmallVector<Value *, 4> newArgs;
- for (auto *oldArg : oldOp->getOperands()) {
- auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
- if (!newArg) {
- return true;
- }
- newArgs.push_back(newArg);
- }
-
- auto *dest = mapping->lookupOrNull(oldOp->getDest());
- if (!dest) {
- oldOp->emitError("Destination block mapping not found");
- return true;
- }
-
- auto newOp = builder.create<BranchOp>(oldOp->getLoc(), dest, newArgs);
- copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
-
- return false;
-}
-
-bool convertCondBranchOp(CondBranchOp *oldOp, OpBuilder &builder,
- BlockAndValueMapping *mapping) {
- llvm::SmallVector<Value *, 4> trueArgs;
- for (auto *oldArg : oldOp->getTrueOperands()) {
- auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
- if (!newArg) {
- return true;
- }
- trueArgs.push_back(newArg);
- }
- llvm::SmallVector<Value *, 4> falseArgs;
- for (auto *oldArg : oldOp->getFalseOperands()) {
- auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
- if (!newArg) {
- return true;
- }
- falseArgs.push_back(newArg);
- }
-
- auto *trueDest = mapping->lookupOrNull(oldOp->getTrueDest());
- if (!trueDest) {
- oldOp->emitError("True destination block mapping not found");
- return true;
- }
- auto *falseDest = mapping->lookupOrNull(oldOp->getFalseDest());
- if (!falseDest) {
- oldOp->emitError("False destination block mapping not found");
- return true;
- }
-
- // Lowering will take care of the condition store.
- auto *newCondition = mapping->lookupOrNull(oldOp->getCondition());
- if (!newCondition) {
- oldOp->emitError("Condition value mapping not found");
- return false;
- }
-
- auto newOp = builder.create<CondBranchOp>(
- oldOp->getLoc(), newCondition, trueDest, trueArgs, falseDest, falseArgs);
- copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
-
- return false;
-}
-
-bool convertOperation(Operation *oldOp, OpBuilder &builder,
- BlockAndValueMapping *mapping) {
- if (isa<ConstantOp>(oldOp)) {
- builder.clone(*oldOp, *mapping);
- return false;
- } else if (auto callOp = dyn_cast<CallOp>(oldOp)) {
- return convertCallOp(&callOp, builder, mapping);
- } else if (auto callIndirectOp = dyn_cast<CallIndirectOp>(oldOp)) {
- return convertCallIndirectOp(&callIndirectOp, builder, mapping);
- } else if (isa<ReturnOp>(oldOp) || isa<IREE::ReturnOp>(oldOp)) {
- return convertReturnOp(oldOp, builder, mapping);
- } else if (auto branchOp = dyn_cast<BranchOp>(oldOp)) {
- return convertBranchOp(&branchOp, builder, mapping);
- } else if (auto condBranchOp = dyn_cast<CondBranchOp>(oldOp)) {
- return convertCondBranchOp(&condBranchOp, builder, mapping);
- } else {
- builder.clone(*oldOp, *mapping);
- return false;
- }
-}
-
-bool convertFunction(FuncOp oldFunc, FuncOp newFunc) {
- OpBuilder builder(newFunc.getBody());
- BlockAndValueMapping mapping;
-
- // Create new blocks matching the expected arguments of the old ones.
- // This sets up the block mappings to enable us to reference blocks forward
- // during conversion.
- newFunc.getBlocks().clear();
- for (auto &oldBlock : oldFunc.getBlocks()) {
- auto *newBlock = builder.createBlock(&newFunc.getBody());
- for (auto *oldArg : oldBlock.getArguments()) {
- // Replace the block args with memrefs.
- auto memRefType = convertTypeToMemRef(oldArg->getType());
- if (!memRefType) return true;
- auto *newArg = newBlock->addArgument(memRefType);
-
- // Insert loads to preserve type, if needed.
- // This will replace all uses of the oldArg with the loaded value from
- // newArg so that the block contents are still using unwrapped values.
- if (insertLoad(oldArg, newArg, builder, &mapping)) {
- return true;
- }
- }
- mapping.map(&oldBlock, newBlock);
- }
-
- // Convert all ops in the blocks.
- for (auto &oldBlock : oldFunc.getBlocks()) {
- builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock));
- for (auto &oldOp : oldBlock.getOperations()) {
- if (convertOperation(&oldOp, builder, &mapping)) {
- return true;
- }
- }
- }
-
- return false;
-}
-
-} // namespace
-
-class ConvertToMemRefCallingConventionPass
- : public ModulePass<ConvertToMemRefCallingConventionPass> {
- public:
- void runOnModule() override {
- auto module = getModule();
-
- // Build a list of (oldFunc, newFunc) for all functions we need to
- // replace. This will ensure that when we go to convert function bodies we
- // have only new functions defined.
- std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions;
-
- for (auto oldFunc : module.getOps<FuncOp>()) {
- // Create the replacement function, ensuring that we copy attributes.
- auto functionType = getMemRefFunctionType(oldFunc.getType());
- if (!functionType) {
- return signalPassFailure();
- }
-
- auto newFunc = FuncOp::create(oldFunc.getLoc(), oldFunc.getName(),
- functionType, oldFunc.getDialectAttrs());
- convertedFunctions.push_back({oldFunc, newFunc});
-
- // Perform the actual body conversion now.
- if (convertFunction(oldFunc, newFunc)) {
- return signalPassFailure();
- }
- }
-
- // Replace functions in the module.
- for (auto &pair : convertedFunctions) {
- pair.first.erase();
- module.push_back(pair.second);
- }
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>>
-createConvertToMemRefCallingConventionPass() {
- return std::make_unique<ConvertToMemRefCallingConventionPass>();
-}
-
-static PassRegistration<ConvertToMemRefCallingConventionPass> pass(
- "convert-to-memref-calling-convention",
- "Convert functions to use a memref-based calling convention.");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/DropUnreachableFunctions.cpp b/iree/compiler/Transforms/DropUnreachableFunctions.cpp
deleted file mode 100644
index 6a5748c..0000000
--- a/iree/compiler/Transforms/DropUnreachableFunctions.cpp
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Utils/ModuleUtils.h"
-#include "llvm/ADT/SetVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Drops all functions in a module that are not reachable by functions with the
-// "iree.module.export" attribute.
-class DropUnreachableModuleFunctionsPass
- : public ModulePass<DropUnreachableModuleFunctionsPass> {
- public:
- void runOnModule() override {
- dropUnusedFunctions(getModule(), {"iree.module.export"});
- }
-};
-
-// Drops all functions in a module that are not reachable by functions with the
-// "iree.executable.export" attribute.
-class DropUnreachableExecutableFunctionsPass
- : public ModulePass<DropUnreachableExecutableFunctionsPass> {
- public:
- void runOnModule() override {
- dropUnusedFunctions(getModule(), {"iree.executable.export"});
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>>
-createDropUnreachableModuleFunctionsPass() {
- return std::make_unique<DropUnreachableModuleFunctionsPass>();
-}
-
-std::unique_ptr<OpPassBase<ModuleOp>>
-createDropUnreachableExecutableFunctionsPass() {
- return std::make_unique<DropUnreachableExecutableFunctionsPass>();
-}
-
-static PassRegistration<DropUnreachableModuleFunctionsPass> moduleFunctionsPass(
- "iree-drop-unreachable-module-functions",
- "Drop all functions not reachable from an exported function");
-
-static PassRegistration<DropUnreachableExecutableFunctionsPass>
- executableFunctionsPass(
- "iree-drop-unreachable-executable-functions",
- "Drop all functions not reachable from an exported function");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/DropUnusedExecutables.cpp b/iree/compiler/Transforms/DropUnusedExecutables.cpp
deleted file mode 100644
index e10f81c..0000000
--- a/iree/compiler/Transforms/DropUnusedExecutables.cpp
+++ /dev/null
@@ -1,62 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Sequencer/HLOps.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "llvm/ADT/SetVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Drops all executables in a module that are not used by any dispatch
-// sequencer op.
-class DropUnusedExecutablesPass : public ModulePass<DropUnusedExecutablesPass> {
- public:
- void runOnModule() override {
- DenseSet<StringRef> usedExecutableNames;
- for (auto funcOp : getModule().getOps<FuncOp>()) {
- funcOp.walk([&](IREESeq::HL::DispatchOp op) {
- usedExecutableNames.insert(op.getExecutable());
- });
- }
- DenseSet<Operation *> deadExecutables;
- for (auto executableOp :
- getModule().getOps<IREE::MultiArchExecutableOp>()) {
- if (usedExecutableNames.count(executableOp.getName()) == 0) {
- deadExecutables.insert(executableOp);
- }
- }
- for (auto executableOp : deadExecutables) {
- executableOp->erase();
- }
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>> createDropUnusedExecutablesPass() {
- return std::make_unique<DropUnusedExecutablesPass>(); // NOLINT
-}
-
-static PassRegistration<DropUnusedExecutablesPass> executableFunctionsPass(
- "iree-drop-unused-executables",
- "Drop all executables not reachable from a dispatch/reduce op.");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/BUILD b/iree/compiler/Transforms/Interpreter/BUILD
deleted file mode 100644
index e5816fa..0000000
--- a/iree/compiler/Transforms/Interpreter/BUILD
+++ /dev/null
@@ -1,41 +0,0 @@
-# Transforms specific to the IREE interpreter.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Interpreter",
- srcs = [
- "ExpandReductionsToOps.cpp",
- "LowerInterpreterDialect.cpp",
- "LowerStdToInterpreterDialect.cpp",
- "LowerToInterpreterDialect.cpp",
- "LowerXLAToInterpreterDialect.cpp",
- "MakeExecutableABI.cpp",
- ],
- hdrs = [
- "Passes.h",
- "Rewrites.h",
- ],
- deps = [
- "//iree/compiler/IR",
- "//iree/compiler/IR/Interpreter",
- "//iree/compiler/Serialization",
- "//iree/compiler/Transforms",
- "//iree/compiler/Utils",
- "//iree/schemas/bytecode:interpreter_bytecode_v0",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:Pass",
- "@local_config_mlir//:StandardOps",
- "@local_config_mlir//:Support",
- "@local_config_mlir//:TransformUtils",
- "@local_config_mlir//:Transforms",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_lower_general_dot",
- ],
- alwayslink = 1,
-)
diff --git a/iree/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp b/iree/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp
deleted file mode 100644
index 41706fe..0000000
--- a/iree/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp
+++ /dev/null
@@ -1,216 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Interpreter/HLDialect.h"
-#include "iree/compiler/IR/Interpreter/HLOps.h"
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Transforms/ConversionUtils.h"
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "iree/compiler/Utils/OpCreationUtils.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-LogicalResult convertReductionOp(FuncOp entryPoint, FuncOp applyFunc,
- Operation *elementOp, OpBuilder &builder) {
- // Ensure that this op is pass-through and does not interact with any other
- // ops within the function.
- // TODO(b/139313439): support fused reductions.
- for (auto *operand : elementOp->getOperands()) {
- if (operand->getDefiningOp() != nullptr) {
- return elementOp->emitOpError()
- << "Fused reductions are not supported (operand not sourced from "
- "block args)";
- }
- }
- for (auto *result : elementOp->getResults()) {
- for (auto *user : result->getUsers()) {
- if (!user->isKnownTerminator()) {
- return elementOp->emitOpError() << "Fused reductions are not supported "
- "(result used by non-terminator)";
- }
- }
- }
-
- // Determine the index of the args we care about. We'll use these to match up
- // the operands of the entry point with our application.
- // Our arguments are expanded tuples like <lhs0, rhs0>, <lhs1, rhs1>, so this
- // index gets the offset * 2.
- auto &applyEntryBlock = applyFunc.getBlocks().front();
- int setIndex = std::distance(applyEntryBlock.args_begin(),
- llvm::find(applyEntryBlock.getArguments(),
- elementOp->getOperand(0))) /
- 2;
-
- // Map to the args from the entry point.
- auto &entryPointEntryBlock = entryPoint.getBlocks().front();
- Value *srcArg = entryPointEntryBlock.getArgument(setIndex);
- Value *initArg = entryPointEntryBlock.getArgument(
- applyFunc.getNumArguments() / 2 + setIndex);
- Value *dstArg =
- entryPointEntryBlock.getArgument(applyFunc.getNumArguments() + setIndex);
- auto dstType = dstArg->getType().cast<ShapedType>();
- Type elementType = dstType.getElementType();
- auto loc = elementOp->getLoc();
- auto dimensionAttr = entryPoint.getAttrOfType<IntegerAttr>(
- "iree.executable.reduction.dimension");
-
- Operation *expandedOp = nullptr;
- if (isa<IREEInterp::HL::AddFOp>(elementOp) ||
- isa<IREEInterp::HL::AddIOp>(elementOp)) {
- if (elementType.isa<FloatType>()) {
- expandedOp = builder.create<IREEInterp::HL::ReduceSumFOp>(
- loc, dstType, srcArg, initArg, dimensionAttr);
- } else {
- expandedOp = builder.create<IREEInterp::HL::ReduceSumIOp>(
- loc, dstType, srcArg, initArg, dimensionAttr);
- }
- } else if (isa<IREEInterp::HL::MinFOp>(elementOp) ||
- isa<IREEInterp::HL::MinISOp>(elementOp) ||
- isa<IREEInterp::HL::MinIUOp>(elementOp)) {
- if (elementType.isa<FloatType>()) {
- expandedOp = builder.create<IREEInterp::HL::ReduceMinFOp>(
- loc, dstType, srcArg, initArg, dimensionAttr);
- } else {
- expandedOp = builder.create<IREEInterp::HL::ReduceMinIOp>(
- loc, dstType, srcArg, initArg, dimensionAttr);
- }
- } else if (isa<IREEInterp::HL::MaxFOp>(elementOp) ||
- isa<IREEInterp::HL::MaxISOp>(elementOp) ||
- isa<IREEInterp::HL::MaxIUOp>(elementOp)) {
- if (elementType.isa<FloatType>()) {
- expandedOp = builder.create<IREEInterp::HL::ReduceMaxFOp>(
- loc, dstType, srcArg, initArg, dimensionAttr);
- } else {
- expandedOp = builder.create<IREEInterp::HL::ReduceMaxIOp>(
- loc, dstType, srcArg, initArg, dimensionAttr);
- }
- }
- if (!expandedOp) {
- return elementOp->emitOpError()
- << "No matching expanded reduction op for elemental op";
- }
- llvm::SmallVector<int64_t, 4> zeroOffset(dstType.getRank(), 0);
- auto zeroIndices = createArrayConstant(builder, loc, zeroOffset);
- auto lengths = createArrayConstant(builder, loc, dstType.getShape());
- builder.create<IREEInterp::HL::CopyOp>(
- loc, expandedOp->getResult(0), zeroIndices, dstArg, zeroIndices, lengths);
-
- return success();
-}
-
-// Replaces the given elemental |funcOp| with a widened reduction.
-LogicalResult expandReductionFunction(FuncOp entryFunc) {
- if (!entryFunc.empty()) {
- return entryFunc.emitError()
- << "Function has already been expanded or has existing contents";
- } else if (!entryFunc.getAttr("iree.executable.reduction.dimension")) {
- return entryFunc.emitError() << "Windowed reductions are not yet supported";
- }
- auto applySym =
- entryFunc.getAttrOfType<SymbolRefAttr>("iree.executable.reduction.apply");
- if (!applySym) {
- return entryFunc.emitError() << "No reduction application function defined";
- }
- auto applyFunc = entryFunc.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
- applySym.getValue());
- if (!applyFunc) {
- return entryFunc.emitError()
- << "Unable to find apply function " << applySym;
- }
-
- auto *entryBlock = entryFunc.addEntryBlock();
- OpBuilder builder(entryBlock);
-
- if (applyFunc.getBlocks()
- .front()
- .walk([&](Operation *op) {
- if (!op->isKnownTerminator()) {
- if (failed(
- convertReductionOp(entryFunc, applyFunc, op, builder))) {
- return WalkResult::interrupt();
- }
- }
- return WalkResult::advance();
- })
- .wasInterrupted()) {
- return applyFunc.emitError() << "Unable to convert apply func";
- }
-
- builder.create<IREE::ReturnOp>(builder.getUnknownLoc());
-
- // Remove the apply function as we have inlined it.
- applyFunc.erase();
- entryFunc.removeAttr("iree.executable.reduction.apply");
- entryFunc.removeAttr("iree.executable.reduction.dimension");
-
- return success();
-}
-
-// Limited lowering of reductions to fat reduce_* ops.
-//
-// The specific subset this supports is:
-// * 'min', 'max', and 'add' computations, with function names matching the
-// computation
-// * one op per reduction (no fusions yet).
-// Note: computations and shapes are not validated.
-//
-// TODO(b/139410773): Implement more generally, supporting custom computations.
-class ExpandReductionsToOpsPass : public ModulePass<ExpandReductionsToOpsPass> {
- public:
- void runOnModule() override {
- auto module = getModule();
- SmallVector<FuncOp, 4> reductionFuncs;
- for (auto funcOp : module.getOps<FuncOp>()) {
- if (funcOp.getAttr("iree.executable.reduction.apply")) {
- reductionFuncs.push_back(funcOp);
- }
- }
- for (auto funcOp : reductionFuncs) {
- if (failed(expandReductionFunction(funcOp))) {
- return signalPassFailure();
- }
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<OpPassBase<ModuleOp>> createExpandReductionsToOpsPass() {
- return std::make_unique<ExpandReductionsToOpsPass>();
-}
-
-static PassRegistration<ExpandReductionsToOpsPass> pass(
- "iree-expand-reductions-to-ops",
- "Expands IREE reduction functions to their interpreter ops");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp b/iree/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp
deleted file mode 100644
index 22b6a8a..0000000
--- a/iree/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp
+++ /dev/null
@@ -1,253 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Interpreter/HLDialect.h"
-#include "iree/compiler/IR/Interpreter/HLOps.h"
-#include "iree/compiler/IR/Interpreter/LLDialect.h"
-#include "iree/compiler/IR/Interpreter/LLOps.h"
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Serialization/BytecodeTables.h"
-#include "iree/schemas/bytecode/interpreter_bytecode_v0.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Allocator.h"
-#include "llvm/Support/Casting.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-struct LowerBranchOpPattern
- : public OpRewritePattern<IREEInterp::HL::BranchOp> {
- using OpRewritePattern<IREEInterp::HL::BranchOp>::OpRewritePattern;
-
- PatternMatchResult matchAndRewrite(IREEInterp::HL::BranchOp op,
- PatternRewriter &rewriter) const {
- SmallVector<Value *, 8> operands{op.getOperation()->getOperands()};
-
- rewriter.replaceOpWithNewOp<IREEInterp::LL::BranchOp>(op, op.getDest(),
- operands);
- return matchSuccess();
- }
-};
-
-struct LowerCondCondBranchOpPattern
- : public OpRewritePattern<IREEInterp::HL::CondBranchOp> {
- using OpRewritePattern<IREEInterp::HL::CondBranchOp>::OpRewritePattern;
-
- PatternMatchResult matchAndRewrite(IREEInterp::HL::CondBranchOp op,
- PatternRewriter &rewriter) const {
- SmallVector<Value *, 8> trueOperands{op.getTrueOperands()};
- SmallVector<Value *, 8> falseOperands{op.getFalseOperands()};
-
- rewriter.replaceOpWithNewOp<IREEInterp::LL::CondBranchOp>(
- op, op.getCondition(), op.getTrueDest(), trueOperands,
- op.getFalseDest(), falseOperands);
- return matchSuccess();
- }
-};
-
-// Returns true if the op defined by |opName| (like 'iree_ll_interp.reshape')
-// uses output operands for results (like iree_ll_interp.add_i) or returns real
-// results.
-bool opTakesOutputOperands(llvm::StringRef opName) {
- if (!opName.consume_front("iree_ll_interp.")) {
- assert(false && "op not part of IREE LL Interpreter dialect");
- return false;
- }
- auto opcode = GetInterpreterOpcodeByName(opName.str());
- assert(opcode.hasValue() && "op has no corresponding opcode");
- const auto &info = GetInterpreterOpcodeInfo(opcode.getValue());
- for (auto &operand : info.operands) {
- if (operand == iree::OperandEncoding::kOutputSlot ||
- operand == iree::OperandEncoding::kVariadicOutputSlots) {
- return true;
- }
- }
- return false;
-}
-
-template <typename SrcOp, typename DstOp>
-class SimpleOpLowering : public OpRewritePattern<SrcOp> {
- using OpRewritePattern<SrcOp>::OpRewritePattern;
-
- PatternMatchResult matchAndRewrite(SrcOp op,
- PatternRewriter &rewriter) const {
- SmallVector<Value *, 8> operands{op.getOperation()->getOperands()};
-
- // Most ops take results as output operands to populate during execution.
- // Certain ops, like reshape, return references to existing memrefs and
- // should still retain their results.
- if (!opTakesOutputOperands(DstOp::getOperationName())) {
- SmallVector<Type, 8> resultTypes{op.getOperation()->getResultTypes()};
-
- rewriter.replaceOpWithNewOp<DstOp>(op, resultTypes, operands,
- op.getAttrs());
- return this->matchSuccess();
- }
-
- SmallVector<Value *, 4> replacementValues;
- for (Value *result : op.getOperation()->getResults()) {
- auto memRefType = result->getType().cast<MemRefType>();
- if (!memRefType.hasStaticShape()) {
- // TODO(benvanik): real thing here - dynamic shaping required.
- // This should emit a shape calculation based on the operation. Most
- // are likely simple and by running DCE after this we can clean up
- // parts that are static or unused.
- op.emitOpError() << "uses unsupported dynamic shapes";
- return this->matchFailure();
- }
- ArrayRef<Value *> dim_pieces;
- auto allocOp = rewriter.create<IREEInterp::LL::AllocHeapOp>(
- op.getLoc(), memRefType, dim_pieces);
- operands.push_back(allocOp);
- replacementValues.push_back(allocOp);
- }
- ArrayRef<Type> resultTypes;
- rewriter.create<DstOp>(op.getLoc(), resultTypes, operands, op.getAttrs());
- rewriter.replaceOp(op, replacementValues);
- return this->matchSuccess();
- }
-};
-
-} // namespace
-
-void populateInterpreterLoweringPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns.insert<LowerBranchOpPattern, LowerCondCondBranchOpPattern>(ctx);
- patterns.insert<
- SimpleOpLowering<IREE::ConstantOp, IREEInterp::LL::ConstantOp>,
- SimpleOpLowering<IREEInterp::HL::CopyOp, IREEInterp::LL::DynamicCopyOp>,
- SimpleOpLowering<IREEInterp::HL::SliceOp,
- IREEInterp::LL::DynamicSliceOp>>(ctx);
-#define SAME_NAME_SIMPLE_PATTERN(op_name) \
- SimpleOpLowering<IREEInterp::HL::op_name, IREEInterp::LL::op_name>
- // clang-format off
- patterns.insert<
- SAME_NAME_SIMPLE_PATTERN(AssignOp),
- SAME_NAME_SIMPLE_PATTERN(AbsFOp),
- SAME_NAME_SIMPLE_PATTERN(AbsIOp),
- SAME_NAME_SIMPLE_PATTERN(AddFOp),
- SAME_NAME_SIMPLE_PATTERN(AddIOp),
- SAME_NAME_SIMPLE_PATTERN(AllocHeapOp),
- SAME_NAME_SIMPLE_PATTERN(AndOp),
- SAME_NAME_SIMPLE_PATTERN(Atan2FOp),
- SAME_NAME_SIMPLE_PATTERN(BreakOp),
- SAME_NAME_SIMPLE_PATTERN(BroadcastOp),
- SAME_NAME_SIMPLE_PATTERN(CallOp),
- SAME_NAME_SIMPLE_PATTERN(CallIndirectOp),
- SAME_NAME_SIMPLE_PATTERN(CeilFOp),
- SAME_NAME_SIMPLE_PATTERN(ClampFOp),
- SAME_NAME_SIMPLE_PATTERN(CloneOp),
- SAME_NAME_SIMPLE_PATTERN(CmpFOp),
- SAME_NAME_SIMPLE_PATTERN(CmpIOp),
- SAME_NAME_SIMPLE_PATTERN(CondAssignOp),
- SAME_NAME_SIMPLE_PATTERN(ConvertSSOp),
- SAME_NAME_SIMPLE_PATTERN(ConvertUUOp),
- SAME_NAME_SIMPLE_PATTERN(ConvertSUOp),
- SAME_NAME_SIMPLE_PATTERN(ConvertUSOp),
- SAME_NAME_SIMPLE_PATTERN(CondBreakOp),
- SAME_NAME_SIMPLE_PATTERN(CosFOp),
- SAME_NAME_SIMPLE_PATTERN(DimOp),
- SAME_NAME_SIMPLE_PATTERN(DivFOp),
- SAME_NAME_SIMPLE_PATTERN(DivISOp),
- SAME_NAME_SIMPLE_PATTERN(DivIUOp),
- SAME_NAME_SIMPLE_PATTERN(ExpFOp),
- SAME_NAME_SIMPLE_PATTERN(LogFOp),
- SAME_NAME_SIMPLE_PATTERN(RsqrtFOp),
- SAME_NAME_SIMPLE_PATTERN(FloorFOp),
- SAME_NAME_SIMPLE_PATTERN(LengthOp),
- SAME_NAME_SIMPLE_PATTERN(MatMulFOp),
- SAME_NAME_SIMPLE_PATTERN(MatMulIOp),
- SAME_NAME_SIMPLE_PATTERN(MaxFOp),
- SAME_NAME_SIMPLE_PATTERN(MaxISOp),
- SAME_NAME_SIMPLE_PATTERN(MaxIUOp),
- SAME_NAME_SIMPLE_PATTERN(MinFOp),
- SAME_NAME_SIMPLE_PATTERN(MinISOp),
- SAME_NAME_SIMPLE_PATTERN(MinIUOp),
- SAME_NAME_SIMPLE_PATTERN(MulAddFOp),
- SAME_NAME_SIMPLE_PATTERN(MulAddIOp),
- SAME_NAME_SIMPLE_PATTERN(MulFOp),
- SAME_NAME_SIMPLE_PATTERN(MulIOp),
- SAME_NAME_SIMPLE_PATTERN(NotOp),
- SAME_NAME_SIMPLE_PATTERN(OrOp),
- SAME_NAME_SIMPLE_PATTERN(PadOp),
- SAME_NAME_SIMPLE_PATTERN(RankOp),
- SAME_NAME_SIMPLE_PATTERN(ReduceSumIOp),
- SAME_NAME_SIMPLE_PATTERN(ReduceSumFOp),
- SAME_NAME_SIMPLE_PATTERN(ReduceMinIOp),
- SAME_NAME_SIMPLE_PATTERN(ReduceMinFOp),
- SAME_NAME_SIMPLE_PATTERN(ReduceMaxIOp),
- SAME_NAME_SIMPLE_PATTERN(ReduceMaxFOp),
- SAME_NAME_SIMPLE_PATTERN(ReshapeOp),
- SAME_NAME_SIMPLE_PATTERN(ReturnOp),
- SAME_NAME_SIMPLE_PATTERN(SelectOp),
- SAME_NAME_SIMPLE_PATTERN(ShapeOp),
- SAME_NAME_SIMPLE_PATTERN(ShiftLeftOp),
- SAME_NAME_SIMPLE_PATTERN(ShiftRightArithmeticOp),
- SAME_NAME_SIMPLE_PATTERN(ShiftRightLogicalOp),
- SAME_NAME_SIMPLE_PATTERN(SinFOp),
- SAME_NAME_SIMPLE_PATTERN(SplitOp),
- SAME_NAME_SIMPLE_PATTERN(SubFOp),
- SAME_NAME_SIMPLE_PATTERN(SubIOp),
- SAME_NAME_SIMPLE_PATTERN(TanhFOp),
- SAME_NAME_SIMPLE_PATTERN(TileOp),
- SAME_NAME_SIMPLE_PATTERN(TraceOp),
- SAME_NAME_SIMPLE_PATTERN(TransposeOp),
- SAME_NAME_SIMPLE_PATTERN(ReverseOp),
- SAME_NAME_SIMPLE_PATTERN(XorOp)>(ctx);
- // clang-format on
-#undef SAME_NAME_SIMPLE_PATTERN
-}
-
-namespace {
-class LowerInterpreterDialectPass
- : public FunctionPass<LowerInterpreterDialectPass> {
- public:
- void runOnFunction() override {
- OwningRewritePatternList patterns;
- populateInterpreterLoweringPatterns(patterns, &getContext());
-
- ConversionTarget target(getContext());
- target.addLegalDialect<IREELLInterpreterDialect>();
- target.addLegalOp<FuncOp, IREE::ReturnOp>();
- if (failed(applyFullConversion(getFunction(), target, patterns))) {
- return signalPassFailure();
- }
- }
-};
-} // namespace
-
-std::unique_ptr<OpPassBase<FuncOp>> createLowerInterpreterDialectPass() {
- return std::make_unique<LowerInterpreterDialectPass>();
-}
-
-static PassRegistration<LowerInterpreterDialectPass> pass(
- "lower-iree-interpreter-hl-to-ll", "Lowers IREE HL ops to IREE LL ops");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp b/iree/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp
deleted file mode 100644
index 9263b2a..0000000
--- a/iree/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp
+++ /dev/null
@@ -1,303 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Interpreter/HLDialect.h"
-#include "iree/compiler/IR/Interpreter/HLOps.h"
-#include "iree/compiler/IR/Interpreter/LLDialect.h"
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Transforms/ConversionUtils.h"
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "iree/compiler/Utils/OpCreationUtils.h"
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/Support/Allocator.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-struct CallOpLowering : public OpConversionPattern<CallOp> {
- using OpConversionPattern::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- CallOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto callOp = cast<CallOp>(op);
- auto calleeType = callOp.getCalleeType();
- rewriter.replaceOpWithNewOp<IREEInterp::HL::CallOp>(
- op, callOp.getCallee(), calleeType.getResults(), operands);
- return matchSuccess();
- }
-};
-
-struct CallIndirectOpLowering : public OpConversionPattern<CallIndirectOp> {
- using OpConversionPattern::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- CallIndirectOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto callOp = cast<CallIndirectOp>(op);
- rewriter.replaceOpWithNewOp<IREEInterp::HL::CallIndirectOp>(
- op, callOp.getCallee(), operands);
- return matchSuccess();
- }
-};
-
-struct ReturnOpLowering : public OpConversionPattern<ReturnOp> {
- using OpConversionPattern::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- ReturnOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREEInterp::HL::ReturnOp>(op, operands);
- return matchSuccess();
- }
-};
-
-struct BranchOpLowering : public OpConversionPattern<BranchOp> {
- using OpConversionPattern::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- BranchOp op, ArrayRef<Value *> properOperands,
- ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREEInterp::HL::BranchOp>(op, destinations[0],
- operands[0]);
- return this->matchSuccess();
- }
-};
-
-struct CondBranchOpLowering : public OpConversionPattern<CondBranchOp> {
- using OpConversionPattern::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- CondBranchOp op, ArrayRef<Value *> properOperands,
- ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *condValue = loadAccessValue(op.getLoc(), properOperands[0], rewriter);
- rewriter.replaceOpWithNewOp<IREEInterp::HL::CondBranchOp>(
- op, condValue, destinations[IREEInterp::HL::CondBranchOp::trueIndex],
- operands[IREEInterp::HL::CondBranchOp::trueIndex],
- destinations[IREEInterp::HL::CondBranchOp::falseIndex],
- operands[IREEInterp::HL::CondBranchOp::falseIndex]);
- return this->matchSuccess();
- }
-};
-
-template <typename SrcOp, typename DstOp>
-struct CompareOpLowering : public OpConversionPattern<SrcOp> {
- using OpConversionPattern<SrcOp>::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- SrcOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto lhValue = loadAccessValue(op.getLoc(), operands[0], rewriter);
- auto rhValue = loadAccessValue(op.getLoc(), operands[1], rewriter);
-
- lhValue = wrapAsMemRef(lhValue, op, rewriter);
- rhValue = wrapAsMemRef(rhValue, op, rewriter);
-
- // TODO(benvanik): map predicate to stable value.
- auto predicate =
- rewriter.getI32IntegerAttr(static_cast<int32_t>(op.getPredicate()));
-
- auto dstType = convertTypeToMemRef(op.getResult());
- auto midOp = rewriter.create<DstOp>(op.getLoc(), dstType, predicate,
- lhValue, rhValue);
-
- auto result = wrapAsTensor(midOp.getResult(), op, rewriter);
- rewriter.replaceOp(
- op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)});
- return this->matchSuccess();
- }
-};
-
-struct CmpIOpLowering
- : public CompareOpLowering<CmpIOp, IREEInterp::HL::CmpIOp> {
- using CompareOpLowering::CompareOpLowering;
-};
-
-struct CmpFOpLowering
- : public CompareOpLowering<CmpFOp, IREEInterp::HL::CmpFOp> {
- using CompareOpLowering::CompareOpLowering;
-};
-
-struct AllocOpLowering : public OpConversionPattern<AllocOp> {
- using OpConversionPattern::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- AllocOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- // TODO(benvanik): replace with length computation.
- rewriter.replaceOpWithNewOp<IREEInterp::HL::AllocHeapOp>(op, op.getType(),
- operands);
- return matchSuccess();
- }
-};
-
-struct DeallocOpLowering : public OpConversionPattern<DeallocOp> {
- using OpConversionPattern::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- DeallocOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREEInterp::HL::DiscardOp>(op, operands[0]);
- return matchSuccess();
- }
-};
-
-struct LoadOpLowering : public OpRewritePattern<LoadOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(LoadOp loadOp,
- PatternRewriter &rewriter) const override {
- if (loadOp.getMemRefType().getRank() != 0) {
- loadOp.emitError() << "Cannot lower load of non-scalar";
- return matchFailure();
- }
- ArrayRef<Value *> dimPieces;
- auto dst =
- rewriter
- .create<AllocOp>(loadOp.getLoc(), loadOp.getMemRefType(), dimPieces)
- .getResult();
- auto emptyArrayMemref = createArrayConstant(rewriter, loadOp.getLoc(), {});
- rewriter.create<IREEInterp::HL::CopyOp>(
- loadOp.getLoc(), loadOp.getMemRef(),
- /*srcIndices=*/emptyArrayMemref, dst,
- /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref);
-
- rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(loadOp, dst);
-
- return matchSuccess();
- }
-};
-
-struct StoreOpLowering : public OpRewritePattern<StoreOp> {
- using OpRewritePattern::OpRewritePattern;
- PatternMatchResult matchAndRewrite(StoreOp storeOp,
- PatternRewriter &rewriter) const override {
- if (storeOp.getMemRefType().getRank() != 0) {
- storeOp.emitError() << "Cannot lower store of non-scalar";
- return matchFailure();
- }
-
- auto src = rewriter.create<IREE::ScalarToMemRefOp>(
- storeOp.getLoc(), storeOp.getValueToStore());
-
- auto emptyArrayMemref = createArrayConstant(rewriter, storeOp.getLoc(), {});
- rewriter.replaceOpWithNewOp<IREEInterp::HL::CopyOp>(
- storeOp, src, /*srcIndices=*/emptyArrayMemref, storeOp.getMemRef(),
- /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref);
-
- return matchSuccess();
- }
-};
-
-#define UNARY_OP_LOWERING(StdOpType, IREEOpType) \
- struct StdOpType##Lowering : public UnaryOpLowering<StdOpType, IREEOpType> { \
- using UnaryOpLowering::UnaryOpLowering; \
- };
-
-#define BINARY_OP_LOWERING(StdOpType, IREEOpType) \
- struct StdOpType##Lowering \
- : public BinaryOpLowering<StdOpType, IREEOpType> { \
- using BinaryOpLowering::BinaryOpLowering; \
- };
-
-#define TERNARY_OP_LOWERING(StdOpType, IREEOpType) \
- struct StdOpType##Lowering \
- : public TernaryOpLowering<StdOpType, IREEOpType> { \
- using TernaryOpLowering::TernaryOpLowering; \
- };
-
-// UNARY_OP_LOWERING(RankOp, IREEInterp::HL::RankOp);
-UNARY_OP_LOWERING(DimOp, IREEInterp::HL::DimOp);
-// UNARY_OP_LOWERING(ShapeOp, IREEInterp::HL::ShapeOp);
-// UNARY_OP_LOWERING(LengthOp, IREEInterp::HL::LengthOp);
-
-// UNARY_OP_LOWERING(NotOp, IREEInterp::HL::NotOp);
-BINARY_OP_LOWERING(AndOp, IREEInterp::HL::AndOp);
-BINARY_OP_LOWERING(OrOp, IREEInterp::HL::OrOp);
-// BINARY_OP_LOWERING(XorOp, IREEInterp::HL::XorOp);
-// BINARY_OP_LOWERING(ShiftLeftOp, IREEInterp::HL::ShiftLeftOp);
-// BINARY_OP_LOWERING(ShiftRightLogicalOp, IREEInterp::HL::ShiftRightLogicalOp);
-// BINARY_OP_LOWERING(ShiftRightArithmeticOp,
-// IREEInterp::HL::ShiftRightArithmeticOp);
-
-BINARY_OP_LOWERING(AddIOp, IREEInterp::HL::AddIOp);
-BINARY_OP_LOWERING(AddFOp, IREEInterp::HL::AddFOp);
-BINARY_OP_LOWERING(SubIOp, IREEInterp::HL::SubIOp);
-BINARY_OP_LOWERING(SubFOp, IREEInterp::HL::SubFOp);
-// UNARY_OP_LOWERING(AbsIOp, IREEInterp::HL::AbsIOp);
-// UNARY_OP_LOWERING(AbsFOp, IREEInterp::HL::AbsFOp);
-BINARY_OP_LOWERING(MulIOp, IREEInterp::HL::MulIOp);
-BINARY_OP_LOWERING(MulFOp, IREEInterp::HL::MulFOp);
-BINARY_OP_LOWERING(DivISOp, IREEInterp::HL::DivISOp);
-BINARY_OP_LOWERING(DivIUOp, IREEInterp::HL::DivIUOp);
-BINARY_OP_LOWERING(DivFOp, IREEInterp::HL::DivFOp);
-// BINARY_OP_LOWERING(MulAddIOp, IREEInterp::HL::MulAddIOp);
-// BINARY_OP_LOWERING(MulAddFOp, IREEInterp::HL::MulAddFOp);
-// UNARY_OP_LOWERING(ExpFOp, IREEInterp::HL::ExpFOp);
-// UNARY_OP_LOWERING(LogFOp, IREEInterp::HL::LogFOp);
-// UNARY_OP_LOWERING(RsqrtFOp, IREEInterp::HL::RsqrtFOp);
-// UNARY_OP_LOWERING(CosFOp, IREEInterp::HL::CosFOp);
-// UNARY_OP_LOWERING(SinFOp, IREEInterp::HL::SinFOp);
-// UNARY_OP_LOWERING(TanhFOp, IREEInterp::HL::TanhFOp);
-// UNARY_OP_LOWERING(Atan2FOp, IREEInterp::HL::Atan2FOp);
-
-// BINARY_OP_LOWERING(MinISOp, IREEInterp::HL::MinISOp);
-// BINARY_OP_LOWERING(MinIUOp, IREEInterp::HL::MinIUOp);
-// BINARY_OP_LOWERING(MinFOp, IREEInterp::HL::MinFOp);
-// BINARY_OP_LOWERING(MaxISOp, IREEInterp::HL::MaxISOp);
-// BINARY_OP_LOWERING(MaxIUOp, IREEInterp::HL::MaxIUOp);
-// BINARY_OP_LOWERING(MaxFOp, IREEInterp::HL::MaxFOp);
-// TERNARY_OP_LOWERING(ClampFOp, IREEInterp::HL::ClampFOp);
-// UNARY_OP_LOWERING(FloorFOp, IREEInterp::HL::FloorFOp);
-// UNARY_OP_LOWERING(CeilFOp, IREEInterp::HL::CeilFOp);
-
-} // namespace
-
-void populateLowerStdToInterpreterPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns.insert<
- // Control flow.
- CallOpLowering, CallIndirectOpLowering, ReturnOpLowering,
- BranchOpLowering, CondBranchOpLowering, CmpIOpLowering, CmpFOpLowering,
- // Memory management.
- AllocOpLowering, DeallocOpLowering, LoadOpLowering, StoreOpLowering,
- // Shape operations.
- DimOpLowering,
- // Logical ops.
- AndOpLowering, OrOpLowering,
- // Arithmetic ops.
- AddIOpLowering, AddFOpLowering, SubIOpLowering, SubFOpLowering,
- MulIOpLowering, MulFOpLowering, DivISOpLowering, DivIUOpLowering,
- DivFOpLowering>(ctx);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp b/iree/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp
deleted file mode 100644
index 19524b0..0000000
--- a/iree/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Interpreter/HLDialect.h"
-#include "iree/compiler/IR/Interpreter/LLDialect.h"
-#include "iree/compiler/Transforms/Interpreter/Rewrites.h"
-#include "iree/compiler/Transforms/Rewrites.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace {
-
-class LowerToInterpreterDialectPass
- : public FunctionPass<LowerToInterpreterDialectPass> {
- public:
- void runOnFunction() override {
- OwningRewritePatternList patterns;
- auto* ctx = &getContext();
- xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, ctx);
- xla_hlo::PopulateXlaToStdPatterns(&patterns, ctx);
- populateLowerStdToIreePatterns(patterns, ctx);
- populateLowerStdToInterpreterPatterns(patterns, ctx);
- populateLowerXlaToIreePatterns(patterns, ctx);
- populateLowerXlaToInterpreterPatterns(patterns, ctx);
-
- ConversionTarget target(getContext());
- target.addLegalDialect<IREEHLInterpreterDialect, IREEDialect>();
- target.addLegalOp<FuncOp, ReturnOp>();
- if (failed(applyFullConversion(getFunction(), target, patterns))) {
- return signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<OpPassBase<FuncOp>> createLowerToInterpreterDialectPass() {
- return std::make_unique<LowerToInterpreterDialectPass>();
-}
-
-static PassRegistration<LowerToInterpreterDialectPass> pass(
- "lower-to-iree-interpreter",
- "Convert all ops to the IREE interpreter dialect");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp b/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp
deleted file mode 100644
index 7bbb0d6..0000000
--- a/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp
+++ /dev/null
@@ -1,565 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Interpreter/HLDialect.h"
-#include "iree/compiler/IR/Interpreter/HLOps.h"
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Transforms/ConversionUtils.h"
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "iree/compiler/Utils/OpCreationUtils.h"
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// TODO(suderman): tablegen this? or something a bit more flexible.
-
-#define UNARY_OP_LOWERING(XlaOpType, IREEOpType) \
- struct XlaOpType##Lowering \
- : public UnaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \
- using UnaryOpLowering::UnaryOpLowering; \
- };
-
-#define TERNARY_OP_LOWERING(XlaOpType, IREEOpType) \
- struct XlaOpType##Lowering \
- : public TernaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \
- using TernaryOpLowering::TernaryOpLowering; \
- };
-
-UNARY_OP_LOWERING(CopyOp, IREEInterp::HL::CloneOp);
-UNARY_OP_LOWERING(ExpOp, IREEInterp::HL::ExpFOp);
-UNARY_OP_LOWERING(LogOp, IREEInterp::HL::LogFOp);
-UNARY_OP_LOWERING(FloorOp, IREEInterp::HL::FloorFOp);
-UNARY_OP_LOWERING(RsqrtOp, IREEInterp::HL::RsqrtFOp);
-UNARY_OP_LOWERING(TanhOp, IREEInterp::HL::TanhFOp);
-TERNARY_OP_LOWERING(SelectOp, IREEInterp::HL::SelectOp);
-
-#undef UNARY_OP_LOWERING
-#undef TERNARY_OP_LOWERING
-
-template <typename T>
-static Operation *createShapeTargetingOp(ConversionPatternRewriter &rewriter,
- Location loc, Value *input,
- MemRefType targetType) {
- auto shapeOp = createArrayConstant(rewriter, loc, targetType.getShape());
- return rewriter.create<T>(loc, targetType, input, shapeOp);
-}
-
-static Value *inputAsMemref(ConversionPatternRewriter &rewriter, Operation *op,
- Value *tensor) {
- return wrapAsMemRef(loadAccessValue(op->getLoc(), tensor, rewriter), op,
- rewriter);
-}
-
-template <typename SrcOp>
-class XlaOpLowering : public OpConversionPattern<SrcOp> {
- using OpConversionPattern<SrcOp>::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- SrcOp srcOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value *, 4> memrefOperands;
- for (auto operand : operands) {
- memrefOperands.push_back(inputAsMemref(rewriter, srcOp, operand));
- }
-
- if (auto dstOp = rewriteInternal(&srcOp, memrefOperands, rewriter)) {
- rewriter.replaceOp(srcOp,
- wrapAsTensor(dstOp->getResult(0), srcOp, rewriter));
- return this->matchSuccess();
- }
- return this->matchFailure();
- }
-
- protected:
- virtual Operation *rewriteInternal(
- SrcOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?");
- }
-};
-
-struct BroadcastInDimOpLowering
- : public XlaOpLowering<xla_hlo::BroadcastInDimOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::BroadcastInDimOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *inputValue = operands[0];
- auto inputType = inputValue->getType().cast<MemRefType>();
- auto finalType = convertTypeToMemRef(*op);
-
- // Reshape to scalar and broadcast.
- auto createFinal = createShapeTargetingOp<IREEInterp::HL::BroadcastOp>;
- llvm::SmallVector<int64_t, 6> intermediateShape{};
-
- // Or reshape to final rank and tile.
- if (inputType.getNumElements() != 1) {
- createFinal = createShapeTargetingOp<IREEInterp::HL::TileOp>;
-
- intermediateShape = llvm::SmallVector<int64_t, 6>(finalType.getRank(), 1);
- auto inputShape = inputType.getShape();
- auto dimensions = op->broadcast_dimensions();
- for (size_t i = 0; i < inputType.getRank(); ++i) {
- auto index = dimensions->getValue(i).cast<IntegerAttr>().getInt();
- intermediateShape[index] = inputShape[i];
- }
- }
-
- auto intermediateType =
- MemRefType::get(intermediateShape, inputType.getElementType());
- auto reshapeOp = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>(
- rewriter, op->getLoc(), inputValue, intermediateType);
- return createFinal(rewriter, op->getLoc(), reshapeOp->getResult(0),
- finalType);
- }
-};
-
-struct ConcatOpLowering : public XlaOpLowering<xla_hlo::ConcatenateOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::ConcatenateOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto finalType = convertTypeToMemRef(*op);
-
- return rewriter.create<IREEInterp::HL::ConcatOp>(
- op->getLoc(), finalType, operands,
- rewriter.getI32IntegerAttr(op->dimension().getZExtValue()));
- }
-};
-
-struct DotOpLowering : public XlaOpLowering<xla_hlo::DotOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::DotOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *lhsValue = operands[0];
- auto *rhsValue = operands[1];
-
- auto finalType = convertTypeToMemRef(*op);
- auto elementType = finalType.getElementType();
- if (!elementType.isa<FloatType>()) {
- op->emitOpError("xla_hlo.dot only supports floating point values");
- }
-
- Operation *matMulOp = rewriter
- .create<IREEInterp::HL::MatMulFOp>(
- op->getLoc(), finalType, lhsValue, rhsValue)
- .getOperation();
- return matMulOp;
- }
-};
-
-struct DynamicUpdateSliceOpLowering
- : public XlaOpLowering<xla_hlo::DynamicUpdateSliceOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::DynamicUpdateSliceOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto operand = operands[0];
- auto update = operands[1];
-
- auto updateType = update->getType().cast<ShapedType>();
- Value *lengthConstant =
- createArrayConstant(rewriter, op->getLoc(), updateType.getShape());
-
- auto startIndices = makeArrayRef(operands).drop_front(2);
- const int rank = startIndices.size();
- llvm::SmallVector<Value *, 4> valuesToConcat;
- valuesToConcat.reserve(startIndices.size());
- auto type = getElementTypeOrSelf(startIndices.front());
-
- // To generate the offset matrix we need to convert the variadic tensors
- // into a reshaped and concated value.
- for (auto index : startIndices) {
- auto reshapedIndex = rewriter.create<IREEInterp::HL::ReshapeOp>(
- op->getLoc(), MemRefType::get({1}, type), index,
- createArrayConstant(rewriter, op->getLoc(), {1}));
- valuesToConcat.push_back(reshapedIndex);
- }
-
- auto dstOffset = rewriter
- .create<IREEInterp::HL::ConcatOp>(
- op->getLoc(), MemRefType::get({rank}, type),
- valuesToConcat, rewriter.getI32IntegerAttr(0))
- .getResult();
-
- llvm::SmallVector<int64_t, 4> zero_offset;
- zero_offset.resize(updateType.getRank(), 0);
- auto srcOffset = createArrayConstant(rewriter, op->getLoc(), zero_offset);
-
- auto copiedOperand = rewriter.create<IREEInterp::HL::CloneOp>(
- op->getLoc(), operand->getType(), operand);
-
- rewriter
- .create<IREEInterp::HL::CopyOp>(op->getLoc(), update, srcOffset,
- copiedOperand, dstOffset,
- lengthConstant)
- .getOperation();
-
- return copiedOperand;
- }
-};
-
-template <typename XlaOpType, typename IreeFloatOpType, typename IreeIntOpType>
-struct BinaryFloatIntOpLowering : public XlaOpLowering<XlaOpType> {
- using XlaOpLowering<XlaOpType>::XlaOpLowering;
-
- Operation *rewriteInternal(
- XlaOpType *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *lhs = operands[0];
- auto *rhs = operands[1];
- auto inputType = lhs->getType().cast<MemRefType>();
- auto elementType = inputType.getElementType();
-
- if (elementType.isa<FloatType>()) {
- return rewriter.create<IreeFloatOpType>(op->getLoc(), inputType, lhs,
- rhs);
- }
-
- return rewriter.create<IreeIntOpType>(op->getLoc(), inputType, lhs, rhs);
- }
-};
-
-struct MaxOpLowering
- : public BinaryFloatIntOpLowering<xla_hlo::MaxOp, IREEInterp::HL::MaxFOp,
- IREEInterp::HL::MaxISOp> {
- using BinaryFloatIntOpLowering::BinaryFloatIntOpLowering;
-};
-
-struct MinOpLowering
- : public BinaryFloatIntOpLowering<xla_hlo::MinOp, IREEInterp::HL::MinFOp,
- IREEInterp::HL::MinISOp> {
- using BinaryFloatIntOpLowering::BinaryFloatIntOpLowering;
-};
-
-struct ConvertLowering : public XlaOpLowering<xla_hlo::ConvertOp> {
- using XlaOpLowering<xla_hlo::ConvertOp>::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::ConvertOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *operand = operands[0];
- auto *result = op->getResult();
-
- auto operandType = operand->getType().cast<MemRefType>().getElementType();
- auto resultType = result->getType().cast<ShapedType>().getElementType();
-
- auto newResultType = convertTypeToMemRef(result);
-
-#define ConvertCase(InType, OutType, NewOp) \
- { \
- if (operandType.isa<InType>() && resultType.isa<OutType>()) { \
- return rewriter.create<NewOp>(op->getLoc(), newResultType, operand); \
- } \
- }
- ConvertCase(IntegerType, IntegerType, IREEInterp::HL::ConvertSSOp);
- ConvertCase(IntegerType, FloatType, IREEInterp::HL::ConvertSFOp);
- ConvertCase(FloatType, IntegerType, IREEInterp::HL::ConvertFSOp);
- ConvertCase(FloatType, FloatType, IREEInterp::HL::ConvertFFOp);
-#undef ConvertCase
-
- return nullptr;
- }
-};
-
-// Lowers a subset of gathers along axis 0 that are really just a slice and
-// reshape.
-struct GatherOpLowering : public OpConversionPattern<xla_hlo::GatherOp> {
- using OpConversionPattern::OpConversionPattern;
-
- // TODO(gcmn): This only handles a minimal number of cases. When XLA
- // redefines gather to be simpler, lower it properly.
- PatternMatchResult matchAndRewrite(
- xla_hlo::GatherOp gatherOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (gatherOp.index_vector_dim() != 0) {
- gatherOp.emitRemark()
- << "Couldn't lower gather with index_vector_dim != 0";
- return matchFailure();
- }
- if (gatherOp.start_index_map().getType().getRank() != 1 ||
- gatherOp.start_index_map().getValue(0).cast<IntegerAttr>().getValue() !=
- 0) {
- gatherOp.emitRemark()
- << "Couldn't lower gather with start_index_map != [0]";
- return matchFailure();
- }
- if (gatherOp.collapsed_slice_dims().getType().getRank() != 1 ||
- gatherOp.collapsed_slice_dims()
- .getValue(0)
- .cast<IntegerAttr>()
- .getValue() != 0) {
- gatherOp.emitRemark()
- << "Couldn't lower gather with collapsed_dims != [0]";
- return matchFailure();
- }
-
- auto resultType = gatherOp.getResult()->getType().cast<RankedTensorType>();
- if (gatherOp.offset_dims().getType().getNumElements() !=
- resultType.getRank()) {
- gatherOp.emitRemark() << "Couldn't lower gather with offset_dims != "
- "[0,...,rank of output]";
- return matchFailure();
- }
- for (auto it : llvm::enumerate(gatherOp.offset_dims())) {
- if (it.index() != it.value()) {
- gatherOp.emitRemark() << "Couldn't lower gather with offset_dims != "
- "[0,...,rank of output]";
- return matchFailure();
- }
- }
-
- for (auto it : llvm::enumerate(resultType.getShape())) {
- if (gatherOp.slice_sizes()
- .getValue(it.index() + 1)
- .cast<IntegerAttr>()
- .getValue() != it.value()) {
- gatherOp.emitRemark()
- << "Couldn't lower gather with slice_sizes not [1] + final shape";
- return matchFailure();
- }
- }
-
- auto inputType = gatherOp.operand()->getType().cast<RankedTensorType>();
-
- auto startIndices =
- inputAsMemref(rewriter, gatherOp, gatherOp.start_indices());
- auto startIndicesType = startIndices->getType().cast<MemRefType>();
- if (startIndicesType.getNumElements() != inputType.getRank()) {
- auto extraDims = inputType.getRank() - startIndicesType.getNumElements();
- auto elementType = startIndicesType.getElementType();
-
- if (startIndicesType.getRank() != 1) {
- startIndices = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>(
- rewriter, gatherOp.getLoc(), startIndices,
- MemRefType::get({1}, elementType))
- ->getResult(0);
- }
-
- llvm::SmallVector<int64_t, 4> zeroes;
- zeroes.resize(extraDims, 0);
-
- auto elementsAttr = DenseIntElementsAttr::get(
- RankedTensorType::get(zeroes.size(), elementType),
- llvm::makeArrayRef(zeroes));
-
- auto extraStartIndices =
- rewriter.create<IREE::ConstantOp>(gatherOp.getLoc(), elementsAttr);
-
- auto memrefOutputType =
- MemRefType::get({inputType.getRank()}, elementType);
-
- SmallVector<Value *, 2> valuesToConcat = {startIndices,
- extraStartIndices};
- startIndices = rewriter.create<IREEInterp::HL::ConcatOp>(
- gatherOp.getLoc(), memrefOutputType, valuesToConcat,
- rewriter.getI32IntegerAttr(0));
- }
-
- auto sliceSizeValues = gatherOp.slice_sizes().getValues<int64_t>();
- std::vector<int64_t> sliceSizes = {sliceSizeValues.begin(),
- sliceSizeValues.end()};
- auto dstType = MemRefType::get(sliceSizes, inputType.getElementType());
-
- auto src = inputAsMemref(rewriter, gatherOp, gatherOp.operand());
- std::vector<Value *> dim_pieces;
- auto dst = rewriter.create<IREEInterp::HL::AllocHeapOp>(
- gatherOp.getLoc(), dstType, dim_pieces);
- auto lengths = rewriter.create<IREE::ConstantOp>(gatherOp.getLoc(),
- gatherOp.slice_sizes());
- llvm::SmallVector<int64_t, 4> zero_offset;
- zero_offset.resize(dstType.getRank(), 0);
- auto dstIndices =
- createArrayConstant(rewriter, gatherOp.getLoc(), zero_offset);
-
- rewriter.create<IREEInterp::HL::CopyOp>(
- gatherOp.getLoc(), src, startIndices, dst, dstIndices, lengths);
-
- auto reshaped = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>(
- rewriter, gatherOp.getLoc(), dst, convertTypeToMemRef(gatherOp));
- rewriter.replaceOp(
- gatherOp, wrapAsTensor(reshaped->getResult(0), gatherOp, rewriter));
-
- return matchSuccess();
- }
-};
-
-struct SliceOpLowering : public XlaOpLowering<xla_hlo::SliceOp> {
- using XlaOpLowering<xla_hlo::SliceOp>::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::SliceOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- // XLA slice has value semantics, whereas the IREE slice creates a view. We
- // lower it to a copy if all strides are one which may be transformed to a
- // slice by later optimizations.
- auto isNotOne = [](APInt stride) { return stride != 1; };
- if (llvm::any_of(op->strides(), isNotOne)) {
- op->emitRemark() << "Could not lower slice op with non-singular strides";
- return nullptr;
- }
-
- auto finalType = convertTypeToMemRef(*op);
- auto src = operands[0];
- std::vector<Value *> dim_pieces;
- auto dst = rewriter.create<IREEInterp::HL::AllocHeapOp>(
- op->getLoc(), finalType, dim_pieces);
- auto srcIndices =
- rewriter.create<IREE::ConstantOp>(op->getLoc(), op->start_indices());
- auto lengths =
- createArrayConstant(rewriter, op->getLoc(), finalType.getShape());
-
- llvm::SmallVector<int64_t, 4> zero_offset;
- zero_offset.resize(finalType.getRank(), 0);
- auto dstIndices = createArrayConstant(rewriter, op->getLoc(), zero_offset);
-
- rewriter.create<IREEInterp::HL::CopyOp>(op->getLoc(), src, srcIndices, dst,
- dstIndices, lengths);
- return dst;
- }
-};
-
-struct PadOpLowering : public XlaOpLowering<xla_hlo::PadOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::PadOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *src = operands[0];
- auto *paddingValue = operands[1];
-
- // TODO(b/140836672) Support negative padding
- for (int i = 0; i < op->edge_padding_high().getNumElements(); ++i) {
- if (op->edge_padding_high().getValue<IntegerAttr>(i).getInt() < 0 ||
- op->edge_padding_low().getValue<IntegerAttr>(i).getInt() < 0) {
- op->emitRemark() << "Could not lower pad op with negative padding";
- return nullptr;
- }
- }
-
- auto edgePaddingLowOp =
- rewriter.create<IREE::ConstantOp>(op->getLoc(), op->edge_padding_low());
- auto edgePaddingHighOp = rewriter.create<IREE::ConstantOp>(
- op->getLoc(), op->edge_padding_high());
- auto interiorPaddingOp =
- rewriter.create<IREE::ConstantOp>(op->getLoc(), op->interior_padding());
-
- return rewriter.create<IREEInterp::HL::PadOp>(
- op->getLoc(), convertTypeToMemRef(*op), src, paddingValue,
- edgePaddingLowOp, edgePaddingHighOp, interiorPaddingOp);
- }
-};
-
-struct ReshapeOpLowering : public XlaOpLowering<xla_hlo::ReshapeOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::ReshapeOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- return createShapeTargetingOp<IREEInterp::HL::ReshapeOp>(
- rewriter, op->getLoc(), operands[0], convertTypeToMemRef(*op));
- }
-};
-
-struct TransposeOpLowering : public XlaOpLowering<xla_hlo::TransposeOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::TransposeOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto permutationOp =
- rewriter.create<IREE::ConstantOp>(op->getLoc(), op->permutation());
-
- return rewriter.create<IREEInterp::HL::TransposeOp>(
- op->getLoc(), convertTypeToMemRef(*op), operands[0], permutationOp);
- }
-};
-
-struct ReverseOpLowering : public XlaOpLowering<xla_hlo::ReverseOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::ReverseOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto reverseOp =
- rewriter.create<IREE::ConstantOp>(op->getLoc(), op->dimensions());
-
- return rewriter.create<IREEInterp::HL::ReverseOp>(
- op->getLoc(), convertTypeToMemRef(*op), operands[0], reverseOp);
- }
-};
-
-} // namespace
-
-void populateLowerXlaToInterpreterPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns
- .insert<BroadcastInDimOpLowering, ConcatOpLowering, ConvertLowering,
- CopyOpLowering, DotOpLowering, DynamicUpdateSliceOpLowering,
- ExpOpLowering, FloorOpLowering, GatherOpLowering, LogOpLowering,
- MaxOpLowering, MinOpLowering, PadOpLowering, ReshapeOpLowering,
- ReverseOpLowering, RsqrtOpLowering, SelectOpLowering,
- SliceOpLowering, TransposeOpLowering, TanhOpLowering>(ctx);
-}
-
-namespace {
-// Just for testing these passes.
-// TODO(b/141337493) can we get rid of this pass entirely?
-class LowerXLAToInterpreterDialectPass
- : public FunctionPass<LowerXLAToInterpreterDialectPass> {
- public:
- void runOnFunction() override {
- OwningRewritePatternList patterns;
- populateLowerXlaToInterpreterPatterns(patterns, &getContext());
-
- ConversionTarget target(getContext());
- target.addLegalDialect<IREEHLInterpreterDialect, IREEDialect>();
- if (failed(applyPartialConversion(getFunction(), target, patterns))) {
- return signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-static PassRegistration<LowerXLAToInterpreterDialectPass> pass(
- "lower-xla-to-iree-interpreter",
- "Convert all XLA functions to the IREE dialect");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/MakeExecutableABI.cpp b/iree/compiler/Transforms/Interpreter/MakeExecutableABI.cpp
deleted file mode 100644
index e304ba5..0000000
--- a/iree/compiler/Transforms/Interpreter/MakeExecutableABI.cpp
+++ /dev/null
@@ -1,147 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Interpreter/HLOps.h"
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Utils/OpCreationUtils.h"
-#include "iree/compiler/Utils/OpUtils.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Replaces a load_input op with valid IR that loads the input value.
-LogicalResult replaceLoadInputOp(IREE::LoadInputOp bindOp) {
- OpBuilder builder(bindOp);
-
- Value *newValue = nullptr;
- auto dstType = bindOp.getResult()->getType();
- if (dstType.isa<TensorType>()) {
- auto castOp =
- builder.create<IREE::MemRefToTensorOp>(bindOp.getLoc(), bindOp.src());
- newValue = castOp.getResult();
- } else if (dstType.isIntOrIndexOrFloat()) {
- auto loadOp = builder.create<LoadOp>(bindOp.getLoc(), dstType, bindOp.src(),
- ArrayRef<Value *>{});
- newValue = loadOp.getResult();
- } else {
- return bindOp.emitError()
- << "Unsupported input destination type " << dstType;
- }
-
- bindOp.replaceAllUsesWith(newValue);
- bindOp.erase();
-
- return success();
-}
-
-// Replaces a store_output op with valid IR that stores the output value.
-LogicalResult replaceStoreOutputOp(IREE::StoreOutputOp bindOp) {
- OpBuilder builder(bindOp);
-
- auto srcType = bindOp.src()->getType();
- if (srcType.isa<MemRefType>()) {
- // Already stored into the output.
- } else if (srcType.isa<TensorType>()) {
- auto castOp =
- builder.create<IREE::TensorToMemRefOp>(bindOp.getLoc(), bindOp.src());
-
- // Insert a copy to our output parameter.
- auto dst = bindOp.dst()->getType().cast<ShapedType>();
- if (!dst.hasStaticShape()) {
- return bindOp.emitError()
- << "Dynamic output args are not yet implemented";
- }
-
- auto zeroValues = llvm::SmallVector<int64_t, 4>(dst.getRank());
- auto zeros = createArrayConstant(builder, bindOp.getLoc(), zeroValues);
- auto lengths =
- createArrayConstant(builder, bindOp.getLoc(), dst.getShape());
- builder.create<IREEInterp::HL::CopyOp>(bindOp.getLoc(), castOp.getResult(),
- zeros, bindOp.dst(), zeros, lengths);
- } else if (srcType.isIntOrIndexOrFloat()) {
- builder.create<StoreOp>(bindOp.getLoc(), bindOp.src(), bindOp.dst(),
- ArrayRef<Value *>{});
- } else {
- return bindOp.emitError() << "Unsupported output src type " << srcType;
- }
-
- bindOp.erase();
-
- return success();
-}
-
-// Strips iree.bind_* ops from |func|.
-LogicalResult stripBindingOps(FuncOp func) {
- // Find iree.load_input ops to replace with memref_to_tensor if needed.
- SmallVector<IREE::LoadInputOp, 8> bindInputOps;
- func.walk([&](IREE::LoadInputOp bindOp) { bindInputOps.push_back(bindOp); });
- for (auto &bindOp : bindInputOps) {
- if (failed(replaceLoadInputOp(bindOp))) {
- return failure();
- }
- }
-
- // Find iree.store_output ops and replace with tensor_to_memref if needed.
- SmallVector<IREE::StoreOutputOp, 8> bindOutputOps;
- func.walk(
- [&](IREE::StoreOutputOp bindOp) { bindOutputOps.push_back(bindOp); });
- for (auto &bindOp : bindOutputOps) {
- if (failed(replaceStoreOutputOp(bindOp))) {
- return failure();
- }
- }
-
- return success();
-}
-
-} // namespace
-
-// Finds iree.executable.export functions and fixes up bindings.
-// For the interpreter this really just means stripping the bind ops entirely.
-class MakeExecutableABIPass : public ModulePass<MakeExecutableABIPass> {
- public:
- void runOnModule() override {
- auto module = getModule();
- for (auto func : module.getOps<FuncOp>()) {
- if (func.getAttr("iree.executable.export")) {
- if (failed(stripBindingOps(func))) {
- return signalPassFailure();
- }
- }
- }
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>> createMakeExecutableABIPass() {
- return std::make_unique<MakeExecutableABIPass>();
-}
-
-static PassRegistration<MakeExecutableABIPass> pass(
- "iree-make-executable-abi",
- "Makes functions match the IREE dispatch executable ABI.");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/test/BUILD b/iree/compiler/Transforms/Interpreter/test/BUILD
deleted file mode 100644
index 59cfc89..0000000
--- a/iree/compiler/Transforms/Interpreter/test/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-# Tests for lowering MLIR in various dialects to IREE interpreter bytecode.
-
-load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_setup_lit_package(
- data = [
- "//iree/tools:iree-opt",
- ],
-)
-
-iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/Interpreter/test/xla/BUILD b/iree/compiler/Transforms/Interpreter/test/xla/BUILD
deleted file mode 100644
index 8ad177e..0000000
--- a/iree/compiler/Transforms/Interpreter/test/xla/BUILD
+++ /dev/null
@@ -1,17 +0,0 @@
-# Tests specific to lowering XLA to IREE.
-
-load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_setup_lit_package(
- data = [
- "//iree/tools:iree-opt",
- "//iree/tools:iree-run-mlir",
- ],
-)
-
-iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/LegalizeTypeStorage.cpp b/iree/compiler/Transforms/LegalizeTypeStorage.cpp
deleted file mode 100644
index ab80e05..0000000
--- a/iree/compiler/Transforms/LegalizeTypeStorage.cpp
+++ /dev/null
@@ -1,145 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-#include "llvm/ADT/DenseSet.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-bool convertOperation(Operation *oldOp, OpBuilder &builder,
- BlockAndValueMapping *mapping) {
- OperationState state(oldOp->getLoc(), oldOp->getName());
- if (oldOp->getNumSuccessors() == 0) {
- // Non-branching operations can just add all the operands.
- for (auto *oldOperand : oldOp->getOperands()) {
- state.operands.push_back(mapping->lookupOrDefault(oldOperand));
- }
- } else {
- // We add the operands separated by nullptr's for each successor.
- unsigned firstSuccOperand = oldOp->getNumSuccessors()
- ? oldOp->getSuccessorOperandIndex(0)
- : oldOp->getNumOperands();
- auto opOperands = oldOp->getOpOperands();
- unsigned i = 0;
- for (; i != firstSuccOperand; ++i) {
- state.operands.push_back(mapping->lookupOrDefault(opOperands[i].get()));
- }
- for (unsigned succ = 0, e = oldOp->getNumSuccessors(); succ != e; ++succ) {
- state.successors.push_back(
- mapping->lookupOrDefault(oldOp->getSuccessor(succ)));
- // Add sentinel to delineate successor operands.
- state.operands.push_back(nullptr);
- // Remap the successors operands.
- for (auto *operand : oldOp->getSuccessorOperands(succ)) {
- state.operands.push_back(mapping->lookupOrDefault(operand));
- }
- }
- }
- for (const auto &oldType : oldOp->getResultTypes()) {
- state.types.push_back(legalizeType(oldType));
- }
- state.attributes = {oldOp->getAttrs().begin(), oldOp->getAttrs().end()};
- auto newOp = builder.createOperation(state);
- for (int i = 0; i < newOp->getNumResults(); ++i) {
- mapping->map(oldOp->getResult(i), newOp->getResult(i));
- }
- return false;
-}
-
-bool convertFunction(FuncOp oldFunction, FuncOp newFunction) {
- OpBuilder builder(newFunction.getBody());
- BlockAndValueMapping mapping;
-
- // Create new blocks matching the expected arguments of the old ones.
- // This sets up the block mappings to enable us to reference blocks forward
- // during conversion.
- newFunction.getBlocks().clear();
- for (auto &oldBlock : oldFunction.getBlocks()) {
- auto *newBlock = builder.createBlock(&newFunction.getBody());
- mapping.map(&oldBlock, newBlock);
- for (auto *oldArg : oldBlock.getArguments()) {
- auto *newArg = newBlock->addArgument(legalizeType(oldArg->getType()));
- mapping.map(oldArg, newArg);
- }
- }
-
- // Convert all ops in the blocks.
- for (auto &oldBlock : oldFunction.getBlocks()) {
- builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock));
- for (auto &oldOp : oldBlock.getOperations()) {
- if (convertOperation(&oldOp, builder, &mapping)) {
- return true;
- }
- }
- }
-
- return false;
-}
-
-} // namespace
-
-class LegalizeTypeStoragePass : public ModulePass<LegalizeTypeStoragePass> {
- public:
- void runOnModule() override {
- auto module = getModule();
-
- // Build a list of (oldFunction, newFunction) for all functions we need to
- // replace. This will ensure that when we go to convert function bodies we
- // have only new functions defined.
- std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions;
-
- for (auto oldFunction : module.getOps<FuncOp>()) {
- // Create the replacement function, ensuring that we copy attributes.
- auto newFunction = FuncOp::create(
- oldFunction.getLoc(), oldFunction.getName(),
- legalizeType(oldFunction.getType()).cast<FunctionType>(),
- oldFunction.getDialectAttrs());
- convertedFunctions.push_back({oldFunction, newFunction});
-
- // Perform the actual body conversion now that we have proper signatures.
- if (convertFunction(oldFunction, newFunction)) {
- return signalPassFailure();
- }
- }
-
- // Replace functions in the module.
- for (auto &pair : convertedFunctions) {
- pair.first.erase();
- module.push_back(pair.second);
- }
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeTypeStoragePass() {
- return std::make_unique<LegalizeTypeStoragePass>();
-}
-
-static PassRegistration<LegalizeTypeStoragePass> pass(
- "iree-legalize-type-storage",
- "Legalizes types to ones supported by the IREE VM.");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/LowerStdToIreeDialect.cpp b/iree/compiler/Transforms/LowerStdToIreeDialect.cpp
deleted file mode 100644
index b2aef0c..0000000
--- a/iree/compiler/Transforms/LowerStdToIreeDialect.cpp
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-struct ConstantOpLowering : public OpRewritePattern<ConstantOp> {
- using OpRewritePattern::OpRewritePattern;
-
- PatternMatchResult matchAndRewrite(ConstantOp op,
- PatternRewriter &rewriter) const override {
- if (auto elementsValue = op.getValue().dyn_cast<ElementsAttr>()) {
- auto ireeConst =
- rewriter.create<IREE::ConstantOp>(op.getLoc(), elementsValue);
-
- auto result = wrapAsTensor(ireeConst.getResult(), op, rewriter);
- rewriter.replaceOp(op, result);
- return matchSuccess();
- }
-
- auto type = op.getValue().getType();
- if (!type.isIntOrFloat()) {
- return matchFailure();
- }
- auto elementsValue =
- DenseElementsAttr::get(RankedTensorType::get({}, type), op.getValue());
- auto ireeConst =
- rewriter.create<IREE::ConstantOp>(op.getLoc(), elementsValue);
- rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(op, ireeConst);
- return matchSuccess();
- }
-};
-
-struct ExtractElementOpLowering : public OpRewritePattern<ExtractElementOp> {
- using OpRewritePattern::OpRewritePattern;
-
- PatternMatchResult matchAndRewrite(ExtractElementOp op,
- PatternRewriter &rewriter) const override {
- Value *memRefInput =
- wrapAsMemRef(loadAccessValue(op.getLoc(), op.getAggregate(), rewriter),
- op, rewriter);
-
- SmallVector<Value *, 4> indices = {op.indices().begin(),
- op.indices().end()};
- rewriter.replaceOpWithNewOp<LoadOp>(op, memRefInput, indices);
- return matchSuccess();
- }
-};
-
-} // namespace
-
-void populateLowerStdToIreePatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns.insert<ConstantOpLowering, ExtractElementOpLowering>(ctx);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/LowerXLAToIreeDialect.cpp b/iree/compiler/Transforms/LowerXLAToIreeDialect.cpp
deleted file mode 100644
index 7a7f87f..0000000
--- a/iree/compiler/Transforms/LowerXLAToIreeDialect.cpp
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "mlir/IR/PatternMatch.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-struct ConstOpLowering : public OpRewritePattern<xla_hlo::ConstOp> {
- using OpRewritePattern::OpRewritePattern;
-
- PatternMatchResult matchAndRewrite(xla_hlo::ConstOp op,
- PatternRewriter &rewriter) const override {
- auto ireeConst = rewriter.create<IREE::ConstantOp>(op.getLoc(), op.value());
- rewriter.replaceOp(op, wrapAsTensor(ireeConst, op, rewriter));
- return matchSuccess();
- }
-};
-
-} // namespace
-
-void populateLowerXlaToIreePatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns.insert<ConstOpLowering>(ctx);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp b/iree/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp
deleted file mode 100644
index 467903b..0000000
--- a/iree/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Utils/OpUtils.h"
-#include "llvm/ADT/DenseMap.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-class AssignExecutableOrdinalsPass
- : public ModulePass<AssignExecutableOrdinalsPass> {
- public:
- void runOnModule() override {
- Builder builder(getModule());
- int nextExecutableOrdinal = 0;
- for (auto multiArchExecutableOp :
- getModule().getOps<IREE::MultiArchExecutableOp>()) {
- multiArchExecutableOp.setAttr(
- "iree.ordinal", builder.getI32IntegerAttr(nextExecutableOrdinal++));
-
- // We'll scan for all entry points in the first executable. Then on all
- // other executables we can reuse the ordinals (ensuring that iteration
- // order does not matter).
- llvm::DenseMap<StringRef, FuncOp> entryPointMap;
- for (auto executableOp :
- multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) {
- executableOp.setAttr("iree.ordinal",
- multiArchExecutableOp.getAttr("iree.ordinal"));
- int nextEntryPointOrdinal = 0;
- for (auto funcOp : executableOp.getInnerModule().getOps<FuncOp>()) {
- if (!funcOp.getAttr("iree.executable.export")) continue;
- auto it = entryPointMap.find(funcOp.getName());
- if (it == entryPointMap.end()) {
- funcOp.setAttr("iree.ordinal",
- builder.getI32IntegerAttr(nextEntryPointOrdinal++));
- entryPointMap.insert({funcOp.getName(), funcOp});
- } else {
- funcOp.setAttr("iree.ordinal", it->second.getAttr("iree.ordinal"));
- }
- }
- }
- }
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>> createAssignExecutableOrdinalsPass() {
- return std::make_unique<AssignExecutableOrdinalsPass>();
-}
-
-static PassRegistration<AssignExecutableOrdinalsPass> pass(
- "iree-assign-executable-ordinals",
- "Assigns executable and entry point ordinals");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp b/iree/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp
deleted file mode 100644
index 2fcd6e3..0000000
--- a/iree/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Sequencer/LLOps.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Utils/OpUtils.h"
-#include "llvm/ADT/StringMap.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-struct WorkloadInfo {
- SmallVector<ElementsAttr, 4> staticWorkloads;
- SmallVector<Value *, 4> dynamicWorkloads;
-};
-
-// Finds all dispatches and records their workload attributes mapped by
-// (executable ordinal, entry point ordinal).
-llvm::StringMap<llvm::StringMap<WorkloadInfo>> gatherExecutableWorkloadInfos(
- ModuleOp moduleOp) {
- llvm::StringMap<llvm::StringMap<WorkloadInfo>> workloadInfos;
- for (auto funcOp : moduleOp.getOps<FuncOp>()) {
- funcOp.walk([&](IREESeq::LL::DynamicDispatchOp op) {
- auto &workloadInfo =
- workloadInfos[op.getExecutable()][op.getEntryPoint()];
- workloadInfo.dynamicWorkloads.push_back(op.getWorkload());
- });
- funcOp.walk([&](IREESeq::LL::StaticDispatchOp op) {
- auto &workloadInfo =
- workloadInfos[op.getExecutable()][op.getEntryPoint()];
- for (auto existingWorkloadAttr : workloadInfo.staticWorkloads) {
- if (existingWorkloadAttr == op.getWorkload()) {
- return; // Already present, ignore.
- }
- }
- workloadInfo.staticWorkloads.push_back(op.getWorkload());
- });
- }
- return workloadInfos;
-}
-
-// Adds attributes to the given executable entry point describing the workload
-// info to the backends that will be processing them.
-LogicalResult attributeExecutableEntryPointWorkload(
- FuncOp entryPointOp, const WorkloadInfo &workloadInfo) {
- if (!workloadInfo.dynamicWorkloads.empty()) {
- return entryPointOp.emitError() << "Dynamic workloads not yet supported";
- }
- if (workloadInfo.staticWorkloads.size() != 1) {
- return entryPointOp.emitError() << "Static workload sizes differ in shape";
- }
-
- // Easy because we just support static workloads now.
- // When this code is adapted to support dynamic workloads we'll want to put
- // a pair of attrs describing which dimensions may be static and which args
- // have the dynamic values to reference.
- entryPointOp.setAttr("iree.executable.workload",
- workloadInfo.staticWorkloads.front());
-
- return success();
-}
-
-} // namespace
-
-class AssignExecutableWorkloadAttrsPass
- : public ModulePass<AssignExecutableWorkloadAttrsPass> {
- public:
- void runOnModule() override {
- Builder builder(getModule());
-
- // Find all dispatches and capture their workload information.
- // We store this information by executable and then entry point ordinal.
- auto executableWorkloadInfos = gatherExecutableWorkloadInfos(getModule());
-
- // Process each executable with the workload information.
- for (auto &executableIt : executableWorkloadInfos) {
- auto multiArchExecutableOp = cast<IREE::MultiArchExecutableOp>(
- getModule().lookupSymbol(executableIt.first()));
- for (auto executableOp :
- multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) {
- for (auto &entryPointIt : executableIt.second) {
- auto funcOp = cast<FuncOp>(
- executableOp.getInnerModule().lookupSymbol(entryPointIt.first()));
- if (failed(attributeExecutableEntryPointWorkload(
- funcOp, entryPointIt.second))) {
- return signalPassFailure();
- }
- }
- }
- }
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>>
-createAssignExecutableWorkloadAttrsPass() {
- return std::make_unique<AssignExecutableWorkloadAttrsPass>();
-}
-
-static PassRegistration<AssignExecutableWorkloadAttrsPass> pass(
- "iree-assign-executable-workload-attrs",
- "Assigns executable entrypoint workload attributes");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/BUILD b/iree/compiler/Transforms/Sequencer/BUILD
deleted file mode 100644
index 777e28b..0000000
--- a/iree/compiler/Transforms/Sequencer/BUILD
+++ /dev/null
@@ -1,46 +0,0 @@
-# Transforms specific to the IREE sequencer.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Sequencer",
- srcs = [
- "AssignExecutableOrdinals.cpp",
- "AssignExecutableWorkloadAttrs.cpp",
- "FoldCompatibleDispatchRegions.cpp",
- "IdentifyDispatchRegions.cpp",
- "IdentifyReductionRegions.cpp",
- "LegalizeInputs.cpp",
- "LowerSequencerDialect.cpp",
- "LowerStdToSequencerDialect.cpp",
- "LowerToSequencerDialect.cpp",
- "LowerXLAToSequencerDialect.cpp",
- "OutlineDispatchRegions.cpp",
- "OutlineReductionRegions.cpp",
- "RematerializeDispatchConstants.cpp",
- ],
- hdrs = [
- "Passes.h",
- "Rewrites.h",
- ],
- deps = [
- "//iree/compiler/IR",
- "//iree/compiler/IR/Sequencer",
- "//iree/compiler/Transforms",
- "//iree/compiler/Utils",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:Pass",
- "@local_config_mlir//:StandardDialectRegistration",
- "@local_config_mlir//:StandardOps",
- "@local_config_mlir//:Support",
- "@local_config_mlir//:TransformUtils",
- "@local_config_mlir//:Transforms",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_lower_general_dot",
- ],
-)
diff --git a/iree/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp
deleted file mode 100644
index d4ea4a0..0000000
--- a/iree/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Utils/DispatchUtils.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Identifies dispatch regions that have compatible workloads and folds them.
-// This relies on CSE having deduped workloads to simplify the logic to simply
-// looking for dispatch regions using the same values.
-class FoldCompatibleDispatchRegionsPass
- : public FunctionPass<FoldCompatibleDispatchRegionsPass> {
- public:
- void runOnFunction() override {
- auto func = getFunction();
- for (auto &block : func) {
- if (failed(mergeBlockDispatchRegions(func, &block))) {
- return signalPassFailure();
- }
- }
- }
-};
-
-std::unique_ptr<OpPassBase<FuncOp>> createFoldCompatibleDispatchRegionsPass() {
- return std::make_unique<FoldCompatibleDispatchRegionsPass>();
-}
-
-static PassRegistration<FoldCompatibleDispatchRegionsPass> pass(
- "iree-fold-compatible-dispatch-regions",
- "Folds dispatch regions that have compatible workloads.");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp b/iree/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp
deleted file mode 100644
index 93be8a9..0000000
--- a/iree/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp
+++ /dev/null
@@ -1,259 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <algorithm>
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Utils/DispatchUtils.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Returns true if the given |op| can be dispatched in all cases.
-// Other passes may handle special cases of these ops but this initial
-// identification is conservative.
-bool isDispatchableOp(Operation *op) {
- if (op->getDialect() && op->getDialect()->getNamespace().startswith("iree")) {
- // Ignore things we've already produced as they should only relate to
- // sequencer operations.
- return false;
- } else if (op->isKnownTerminator()) {
- // Currently we skip all terminators as we want to leave them in the block
- // to keep it valid. Future folding passes may take care of them if they are
- // worth bringing into the dispatch region.
- return false;
- } else if (isa<CallOp>(op)) {
- // This may be handled by a control-flow folding pass later once we have
- // done our initial analysis and know what functions are compatible.
- return false;
- } else if (isa<CallIndirectOp>(op)) {
- // Indirect calls are not supported in dispatch code.
- return false;
- } else if (isa<AllocOp>(op)) {
- // Allocations are sequencer ops.
- // Note that we could support static allocations (convert to stack/etc).
- return false;
- } else if (isa<ConstantOp>(op)) {
- // Constants are handled in the RematerializeDispatchConstants pass.
- // We do that independently so that we can more easily see the use of
- // constants across all dispatches instead of just on an individual basis
- // as we do here.
- return false;
- } else if (isa<xla_hlo::DynamicUpdateSliceOp>(op)) {
- // TODO(benvanik): lower these to the sequencer dialect prior to ID'ing.
- return false;
- }
- return true;
-}
-
-// Returns true if the given |op| can have other ops fused into it.
-// This is sketchy and it'd be nice to define this as an op property instead.
-//
-// What we are looking for in foldable ops is whether the execution of the op
-// when fused has some possible benefit (or at least, a non-negative cost).
-// Eventually we want to allow backends to vote on this and allow multiple
-// folding strategies within the same executable. For now we just hardcode what
-// we know for the ops we have.
-//
-// Preconditions: isDispatchableOp(op) == true.
-bool isFusionRootOp(Operation *op) {
- if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) {
- // We have hand-written kernels for these right now we want to stand alone.
- // When we do a bit more magic we should allow these ops to fold.
- return false;
- }
- return true;
-}
-
-// Returns true if the given |op| can be fused into other ops.
-//
-// Ops that perform narrowing on shapes (such as reduction ops) should not
-// generally be fused with other downstream ops (probably...). This avoids
-// potential oversampling and indexing issues and allows backends to perform
-// more efficient rooted cascading reduction dispatches.
-//
-// Preconditions: isDispatchableOp(op) == true.
-bool isFusableOp(Operation *op) {
- if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) {
- return false;
- } else if (isa<xla_hlo::ReduceOp>(op)) {
- // Reduction is usually a dedicated root operation - we can shove things in
- // the front of it but not behind.
- return false;
- }
- return true;
-}
-
-// Puts all of the |unsortedOps| into |sortedOps| in an arbitrary topological
-// order.
-// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
-//
-// Preconditions: |unsortedOps| has no cycles within the set of ops.
-std::vector<Operation *> sortOpsTopologically(
- const llvm::SetVector<Operation *> &unsortedOps) {
- llvm::SetVector<Operation *> unmarkedOps;
- unmarkedOps.insert(unsortedOps.begin(), unsortedOps.end());
- llvm::SetVector<Operation *> markedOps;
-
- using VisitFn = std::function<void(Operation * op)>;
- VisitFn visit = [&](Operation *op) {
- if (markedOps.count(op) > 0) return;
- for (auto *result : op->getResults()) {
- for (auto *user : result->getUsers()) {
- // Don't visit ops not in our set.
- if (unsortedOps.count(user) == 0) continue;
- visit(user);
- }
- }
- markedOps.insert(op);
- };
-
- while (!unmarkedOps.empty()) {
- auto *op = unmarkedOps.pop_back_val();
- visit(op);
- }
-
- auto sortedOps = markedOps.takeVector();
- std::reverse(sortedOps.begin(), sortedOps.end());
- return sortedOps;
-}
-
-// Recursively traverses the IR DAG along the operand edges to find ops we are
-// able to fuse and appends them to |subgraph|.
-void gatherFusionOps(Operation *op, llvm::SetVector<Operation *> *subgraph) {
- // Skip ops that are used outside of the subgraph we are building.
- for (auto *result : op->getResults()) {
- if (result->use_empty() || result->hasOneUse()) continue;
- for (auto *user : result->getUsers()) {
- if (subgraph->count(user) == 0) {
- // Op that consumes the result is not (yet) in the subgraph.
- // For now we'll ignore these as it may represent a fork that we don't
- // want to join too early.
- return;
- }
- }
- }
-
- // Walk backward up to ops providing our input operands.
- for (auto *operand : op->getOperands()) {
- auto *sourceOp = operand->getDefiningOp();
- if (!sourceOp) continue;
- if (subgraph->count(sourceOp) == 0) {
- if (isDispatchableOp(sourceOp) && isFusableOp(sourceOp)) {
- gatherFusionOps(sourceOp, subgraph);
- }
- }
- }
-
- subgraph->insert(op);
-}
-
-// Finds all ops that can be fused together with the given |rootOp| by searching
-// backwards in the op order through input edges.
-// Returns a topologically sorted list of all fused ops with |rootOp| at the
-// end.
-std::vector<Operation *> findFusionSubgraphFromRoot(Operation *rootOp) {
- if (!isFusionRootOp(rootOp)) {
- return {rootOp};
- }
- llvm::SetVector<Operation *> subgraph;
- subgraph.insert(rootOp);
- gatherFusionOps(rootOp, &subgraph);
- return sortOpsTopologically(subgraph);
-}
-
-// Identifies ranges of dispatchable ops and moves them into dispatch regions.
-LogicalResult identifyBlockDispatchRegions(FuncOp func, Block *block) {
- // Fixed point iteration until we can no longer fuse anything.
- bool didFindAnyNewRegions;
- do {
- // Iterate in reverse so we root further along in the op list.
- didFindAnyNewRegions = false;
- for (auto &rootOp : llvm::reverse(*block)) {
- if (!isDispatchableOp(&rootOp)) {
- // Op should remain at the sequencer level.
- continue;
- }
-
- // Attempt to find all operations, including rootOp, that can be fused.
- // The ops will be sorted in topological order with rootOp as the last op.
- // Worst case we may end up with a subgraph of only the rootOp.
- auto fusedSubgraph = findFusionSubgraphFromRoot(&rootOp);
-
- // Compute the workload based on the output shape.
- // When variadic all output shapes match so we can just take the first.
- auto *workload = calculateWorkload(&rootOp, rootOp.getResult(0));
-
- // Try to build a dispatch region from this root.
- if (failed(buildDispatchRegion(func, block, workload, fusedSubgraph))) {
- return failure();
- }
-
- // Successfully created a dispatch region from the ops and we must now
- // start over again as we've likely trashed the whole block structure.
- didFindAnyNewRegions = true;
- break;
- }
- } while (didFindAnyNewRegions);
- return success();
-}
-
-} // namespace
-
-// Identifies dispatchable ops and moves them into iree.dispatch_regions.
-// Some ops, such as call, will be deferred until following passes.
-class IdentifyDispatchRegionsPass
- : public FunctionPass<IdentifyDispatchRegionsPass> {
- public:
- void runOnFunction() override {
- auto func = getFunction();
- for (auto &block : func) {
- if (failed(identifyBlockDispatchRegions(func, &block))) {
- return signalPassFailure();
- }
- }
- }
-};
-
-std::unique_ptr<OpPassBase<FuncOp>> createIdentifyDispatchRegionsPass() {
- return std::make_unique<IdentifyDispatchRegionsPass>();
-}
-
-static PassRegistration<IdentifyDispatchRegionsPass> pass(
- "iree-identify-dispatch-regions",
- "Conservatively identifies dispatch regions in functions.");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp b/iree/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp
deleted file mode 100644
index 40d4ac2..0000000
--- a/iree/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp
+++ /dev/null
@@ -1,163 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <algorithm>
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/Types.h"
-#include "iree/compiler/Utils/DispatchUtils.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Builds a new iree.reduction_region with the given |invocationRegion|.
-// The new region will be inserted after |originalOp|.
-//
-// All |invocationRegion| ops must be compatible with the |workload| specified
-// as they will all be dispatched with the same workgroup structure. The
-// |invocationRegion| will not be modified.
-LogicalResult buildReductionRegion(Operation *originalOp,
- ArrayRef<Value *> operands,
- ArrayRef<Value *> initialValues,
- ArrayRef<int64_t> dimensions,
- Region &invocationRegion) {
- OpBuilder parentBuilder(originalOp);
-
- // Compute the workload based on the output shape.
- // When variadic all output shapes match so we can just take the first.
- auto *workload = calculateWorkload(originalOp, originalOp->getResult(0));
-
- // Build the region op and add it to the parent block.
- SmallVector<Type, 4> resultTypes{originalOp->getResultTypes()};
- auto reductionRegionOp = parentBuilder.create<IREE::ReductionRegionOp>(
- originalOp->getLoc(), resultTypes, workload, operands, initialValues,
- dimensions);
-
- // Create the block and setup the arg mapping for captured values.
- BlockAndValueMapping mapping;
- invocationRegion.cloneInto(&reductionRegionOp.getBody(), mapping);
-
- // Replace xla_hlo.return -> iree.return.
- OpBuilder regionBuilder(reductionRegionOp.getBody());
- reductionRegionOp.walk([&](xla_hlo::ReturnOp returnOp) {
- regionBuilder.setInsertionPoint(returnOp);
- SmallVector<Value *, 4> returnValues(returnOp.getOperands());
- regionBuilder.create<IREE::ReturnOp>(returnOp.getLoc(), returnValues);
- returnOp.erase();
- });
-
- // Replace usage of values with the results of the region.
- for (int i = 0; i < originalOp->getNumResults(); ++i) {
- originalOp->getResult(i)->replaceAllUsesWith(
- reductionRegionOp.getResult(i));
- }
-
- return success();
-}
-
-// Converts an xla_hlo::ReduceOp to a reduction region and inlines the target
-// computation into the region body.
-LogicalResult buildReductionRegionFromXLAReduceOp(xla_hlo::ReduceOp reduceOp) {
- SmallVector<Value *, 4> operands(reduceOp.getOperands());
- OperandAdaptor<xla_hlo::ReduceOp> adaptor(operands);
-
- SmallVector<int64_t, 4> dimensions;
- for (auto dim : reduceOp.dimensions().getIntValues()) {
- dimensions.push_back(dim.getSExtValue());
- }
-
- // Create the iree.reduction_region.
- if (failed(buildReductionRegion(reduceOp, adaptor.operands(),
- adaptor.init_values(), dimensions,
- reduceOp.body()))) {
- return failure();
- }
-
- // Remove original XLA reduction op.
- reduceOp.erase();
-
- return success();
-}
-
-// Identifies reduction ops and moves them into reduction regions.
-LogicalResult identifyBlockReductionRegions(FuncOp funcOp, Block *block) {
- // Fixed point iteration until we can no longer fuse anything.
- bool didFindAnyNewRegions;
- do {
- // Iterate in reverse so we root further along in the op list.
- didFindAnyNewRegions = false;
- for (auto &rootOp : llvm::reverse(*block)) {
- if (auto reduceOp = dyn_cast<xla_hlo::ReduceOp>(rootOp)) {
- if (failed(buildReductionRegionFromXLAReduceOp(reduceOp))) {
- return failure();
- }
-
- // Successfully created a dispatch region from the ops and we must now
- // start over again as we've likely trashed the whole block structure.
- didFindAnyNewRegions = true;
- break;
- }
- }
- } while (didFindAnyNewRegions);
- return success();
-}
-
-} // namespace
-
-// Identifies reduction ops and moves their targets into iree.reduction_regions.
-class IdentifyReductionRegionsPass
- : public ModulePass<IdentifyReductionRegionsPass> {
- public:
- void runOnModule() override {
- for (auto funcOp : getModule().getOps<FuncOp>()) {
- for (auto &block : funcOp) {
- if (failed(identifyBlockReductionRegions(funcOp, &block))) {
- return signalPassFailure();
- }
- }
- }
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>> createIdentifyReductionRegionsPass() {
- return std::make_unique<IdentifyReductionRegionsPass>(); // NOLINT
-}
-
-static PassRegistration<IdentifyReductionRegionsPass> pass(
- "iree-identify-reduction-regions",
- "Identifies reduction regions based on input reduction ops.");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/LegalizeInputs.cpp b/iree/compiler/Transforms/Sequencer/LegalizeInputs.cpp
deleted file mode 100644
index 5d79791..0000000
--- a/iree/compiler/Transforms/Sequencer/LegalizeInputs.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/Transforms/Rewrites.h"
-#include "iree/compiler/Transforms/Sequencer/Rewrites.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace {
-
-class LegalizeInputOpsPass
- : public FunctionPass<LegalizeInputOpsPass> {
- public:
- void runOnFunction() override {
- OwningRewritePatternList patterns;
- xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
-
- ConversionTarget target(getContext());
- target.addLegalDialect<xla_hlo::XlaHloDialect, StandardOpsDialect>();
- target.addLegalOp<FuncOp, ReturnOp>();
- target.addIllegalOp<xla_hlo::DotGeneralOp, xla_hlo::WhileOp>();
- if (failed(applyFullConversion(getFunction(), target, patterns))) {
- return signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<OpPassBase<FuncOp>> createLegalizeInputOpsPass() {
- return std::make_unique<LegalizeInputOpsPass>();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp b/iree/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp
deleted file mode 100644
index ea3efa6..0000000
--- a/iree/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp
+++ /dev/null
@@ -1,305 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/Sequencer/HLDialect.h"
-#include "iree/compiler/IR/Sequencer/HLOps.h"
-#include "iree/compiler/IR/Sequencer/LLDialect.h"
-#include "iree/compiler/IR/Sequencer/LLOps.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-template <typename SrcOp>
-class SequencerLoweringPattern : public OpConversionPattern<SrcOp> {
- public:
- SequencerLoweringPattern(MLIRContext *context, TypeConverter &typeConverter)
- : OpConversionPattern<SrcOp>(context), typeConverter_(typeConverter) {}
-
- protected:
- TypeConverter &typeConverter_;
-};
-
-// Returns an integer scalar memref containing the offset specified by |indices|
-// within |type|.
-Value *computeOffset(Location loc, Value *reference, Value *indices,
- OpBuilder &builder) {
- auto referenceType = reference->getType().cast<ShapedType>();
- auto *shapeMemRef = builder
- .create<IREESeq::LL::AllocHeapOp>(
- loc,
- MemRefType::get({referenceType.getRank()},
- builder.getIntegerType(32)),
- ArrayRef<Value *>{})
- .getResult();
- builder.create<IREESeq::LL::ShapeOp>(loc, reference, shapeMemRef);
- auto *resultMemRef =
- builder
- .create<IREESeq::LL::AllocHeapOp>(
- loc, MemRefType::get({}, builder.getIntegerType(32)),
- ArrayRef<Value *>{})
- .getResult();
- auto elementSizeAttr = builder.getIntegerAttr(
- builder.getIntegerType(8), referenceType.getElementTypeBitWidth() / 8);
- builder.create<IREESeq::LL::ComputeOffsetOp>(
- loc, shapeMemRef, elementSizeAttr, indices, resultMemRef);
- return resultMemRef;
-}
-
-// Returns a tuple of (offset, length) integer scalar memrefs with the range
-// specified by |indices| and |lengths| within |type|.
-std::pair<Value *, Value *> computeRange(Location loc, Value *reference,
- Value *indices, Value *lengths,
- OpBuilder &builder) {
- auto referenceType = reference->getType().cast<ShapedType>();
- auto *shapeMemRef = builder
- .create<IREESeq::LL::AllocHeapOp>(
- loc,
- MemRefType::get({referenceType.getRank()},
- builder.getIntegerType(32)),
- ArrayRef<Value *>{})
- .getResult();
- builder.create<IREESeq::LL::ShapeOp>(loc, reference, shapeMemRef);
- auto *offsetMemRef =
- builder
- .create<IREESeq::LL::AllocHeapOp>(
- loc, MemRefType::get({}, builder.getIntegerType(32)),
- ArrayRef<Value *>{})
- .getResult();
- auto *lengthMemRef =
- builder
- .create<IREESeq::LL::AllocHeapOp>(
- loc, MemRefType::get({}, builder.getIntegerType(32)),
- ArrayRef<Value *>{})
- .getResult();
- auto elementSizeAttr = builder.getIntegerAttr(
- builder.getIntegerType(8), referenceType.getElementTypeBitWidth() / 8);
- builder.create<IREESeq::LL::ComputeRangeOp>(loc, shapeMemRef, elementSizeAttr,
- indices, lengths, offsetMemRef,
- lengthMemRef);
- return {offsetMemRef, lengthMemRef};
-}
-
-struct LowerSliceOpPattern
- : public SequencerLoweringPattern<IREESeq::HL::SliceOp> {
- using SequencerLoweringPattern::SequencerLoweringPattern;
-
- PatternMatchResult matchAndRewrite(
- IREESeq::HL::SliceOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- OperandAdaptor<IREESeq::HL::SliceOp> operandAdaptor(operands);
- auto range = computeRange(op.getLoc(), operandAdaptor.src(),
- operandAdaptor.indices(),
- operandAdaptor.lengths(), rewriter);
- rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicSliceOp>(
- op, typeConverter_.convertType(op.getType()),
- ArrayRef<Value *>{operandAdaptor.src(), range.first, range.second},
- op.getAttrs());
- return matchSuccess();
- }
-};
-
-struct LowerShapeOpPattern
- : public SequencerLoweringPattern<IREESeq::HL::ShapeOp> {
- using SequencerLoweringPattern::SequencerLoweringPattern;
-
- PatternMatchResult matchAndRewrite(
- IREESeq::HL::ShapeOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *shapeMemRef =
- rewriter
- .create<IREESeq::LL::AllocHeapOp>(
- op.getLoc(),
- MemRefType::get({op.getType().cast<ShapedType>().getRank()},
- rewriter.getIntegerType(64)),
- ArrayRef<Value *>{})
- .getResult();
- op.replaceAllUsesWith(shapeMemRef);
- rewriter.replaceOpWithNewOp<IREESeq::LL::ShapeOp>(op, operands[0],
- shapeMemRef);
- return matchSuccess();
- }
-};
-
-struct LowerCopyOpPattern
- : public SequencerLoweringPattern<IREESeq::HL::CopyOp> {
- using SequencerLoweringPattern::SequencerLoweringPattern;
-
- PatternMatchResult matchAndRewrite(
- IREESeq::HL::CopyOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- OperandAdaptor<IREESeq::HL::CopyOp> operandAdaptor(operands);
- auto *srcOffsetMemRef =
- computeOffset(op.getLoc(), operandAdaptor.src(),
- operandAdaptor.srcIndices(), rewriter);
- auto dstRange = computeRange(op.getLoc(), operandAdaptor.dst(),
- operandAdaptor.dstIndices(),
- operandAdaptor.lengths(), rewriter);
- rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicCopyOp>(
- op, operandAdaptor.src(), srcOffsetMemRef, operandAdaptor.dst(),
- dstRange.first, dstRange.second);
- return matchSuccess();
- }
-};
-
-struct LowerFillOpPattern
- : public SequencerLoweringPattern<IREESeq::HL::FillOp> {
- using SequencerLoweringPattern::SequencerLoweringPattern;
-
- PatternMatchResult matchAndRewrite(
- IREESeq::HL::FillOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- OperandAdaptor<IREESeq::HL::FillOp> operandAdaptor(operands);
- auto dstRange = computeRange(op.getLoc(), operandAdaptor.dst(),
- operandAdaptor.dstIndices(),
- operandAdaptor.lengths(), rewriter);
- rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicFillOp>(
- op, operandAdaptor.value(), operandAdaptor.dst(), dstRange.first,
- dstRange.second);
- return matchSuccess();
- }
-};
-
-struct LowerBranchOpPattern
- : public SequencerLoweringPattern<IREESeq::HL::BranchOp> {
- using SequencerLoweringPattern<
- IREESeq::HL::BranchOp>::SequencerLoweringPattern;
-
- PatternMatchResult matchAndRewrite(
- IREESeq::HL::BranchOp op, ArrayRef<Value *> properOperands,
- ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREESeq::LL::BranchOp>(op, destinations[0],
- operands[0]);
- return matchSuccess();
- }
-};
-
-struct LowerCondCondBranchOpPattern
- : public SequencerLoweringPattern<IREESeq::HL::CondBranchOp> {
- using SequencerLoweringPattern<
- IREESeq::HL::CondBranchOp>::SequencerLoweringPattern;
-
- PatternMatchResult matchAndRewrite(
- IREESeq::HL::CondBranchOp op, ArrayRef<Value *> properOperands,
- ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREESeq::LL::CondBranchOp>(
- op, properOperands[0],
- destinations[IREESeq::HL::CondBranchOp::trueIndex],
- operands[IREESeq::HL::CondBranchOp::trueIndex],
- destinations[IREESeq::HL::CondBranchOp::falseIndex],
- operands[IREESeq::HL::CondBranchOp::falseIndex]);
- return matchSuccess();
- }
-};
-
-// Rewrites an op into one with all the same operands, results, and attributes.
-// Operands and results in the ops must have the same order and attributes must
-// have the same name. They must also be constructed properly by the default
-// builders.
-template <typename SRC, typename DST>
-struct LowerIdenticalOpPattern : public SequencerLoweringPattern<SRC> {
- using SequencerLoweringPattern<SRC>::SequencerLoweringPattern;
-
- PatternMatchResult matchAndRewrite(
- SRC op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- SmallVector<Type, 8> originalResultTypes{
- op.getOperation()->getResultTypes()};
- SmallVector<Type, 8> resultTypes;
- if (failed(this->typeConverter_.convertTypes(originalResultTypes,
- resultTypes))) {
- op.emitOpError() << "Failed to convert result types";
- return this->matchFailure();
- }
- rewriter.replaceOpWithNewOp<DST>(op, resultTypes, operands, op.getAttrs());
- return this->matchSuccess();
- }
-};
-
-} // namespace
-
-class LowerSequencerDialectPass : public ModulePass<LowerSequencerDialectPass> {
- public:
- void runOnModule() override {
- auto *ctx = &getContext();
- LLTypeConverter typeConverter(ctx);
- OwningRewritePatternList patterns;
- patterns.insert<
- LowerIdenticalOpPattern<IREE::ConstantOp, IREESeq::LL::ConstantOp>,
- LowerIdenticalOpPattern<IREESeq::HL::DispatchOp,
- IREESeq::LL::DynamicDispatchOp>,
- LowerShapeOpPattern, LowerCopyOpPattern, LowerSliceOpPattern,
- LowerBranchOpPattern, LowerCondCondBranchOpPattern>(ctx, typeConverter);
-#define IDENTICAL_OP_LOWERING(op_name) \
- LowerIdenticalOpPattern<IREESeq::HL::op_name, IREESeq::LL::op_name>
- patterns.insert<
- IDENTICAL_OP_LOWERING(AllocHeapOp), IDENTICAL_OP_LOWERING(CloneOp),
- IDENTICAL_OP_LOWERING(ReshapeOp), IDENTICAL_OP_LOWERING(CallOp),
- IDENTICAL_OP_LOWERING(ReturnOp)>(ctx, typeConverter);
-#undef IDENTICAL_OP_LOWERING
-
- mlir::populateFuncOpTypeConversionPattern(patterns, ctx, typeConverter);
- ConversionTarget target(*ctx);
- target.addLegalDialect<IREELLSequencerDialect>();
- target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
- return typeConverter.isSignatureLegal(op.getType());
- });
-
- // TODO(b/142791494): The conversion framework will recurse into the
- // executable if we just call it on the top-level module. This can't be a
- // function pass because type conversion replaces the original functions.
- auto funcsIt = getModule().getOps<FuncOp>();
- SmallVector<Operation *, 4> funcs(funcsIt.begin(), funcsIt.end());
-
- if (failed(applyFullConversion(funcs, target, patterns, &typeConverter))) {
- return signalPassFailure();
- }
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>> createLowerSequencerDialectPass() {
- return std::make_unique<LowerSequencerDialectPass>();
-}
-
-static PassRegistration<LowerSequencerDialectPass> pass(
- "iree-lower-sequencer-dialect",
- "Lowers the IREE HL sequencer dialect to the LL sequencer dialect.");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp b/iree/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp
deleted file mode 100644
index 3f839ab..0000000
--- a/iree/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp
+++ /dev/null
@@ -1,204 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/Sequencer/HLDialect.h"
-#include "iree/compiler/IR/Sequencer/HLOps.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "iree/compiler/Utils/OpCreationUtils.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-template <typename T>
-class SequencerConversionPattern : public OpConversionPattern<T> {
- using OpConversionPattern<T>::OpConversionPattern;
-};
-
-struct CallOpLowering : public SequencerConversionPattern<CallOp> {
- using SequencerConversionPattern::SequencerConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- CallOp callOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- SmallVector<Type, 4> resultTypes(callOp.getResultTypes());
- rewriter.replaceOpWithNewOp<IREESeq::HL::CallOp>(callOp, callOp.getCallee(),
- resultTypes, operands);
-
- return matchSuccess();
- }
-};
-
-struct CallIndirectOpLowering
- : public SequencerConversionPattern<CallIndirectOp> {
- using SequencerConversionPattern::SequencerConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- CallIndirectOp callOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREESeq::HL::CallIndirectOp>(
- callOp, callOp.getCallee(), operands);
- return matchSuccess();
- }
-};
-
-struct ReturnOpLowering : public SequencerConversionPattern<ReturnOp> {
- using SequencerConversionPattern::SequencerConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- ReturnOp returnOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value *, 4> newOperands;
- newOperands.reserve(operands.size());
- for (auto *operand : operands) {
- newOperands.push_back(wrapAsMemRef(operand, returnOp, rewriter));
- }
- rewriter.replaceOpWithNewOp<IREESeq::HL::ReturnOp>(returnOp, newOperands);
- return matchSuccess();
- }
-};
-
-struct BranchOpLowering : public SequencerConversionPattern<BranchOp> {
- using SequencerConversionPattern::SequencerConversionPattern;
- PatternMatchResult matchAndRewrite(
- BranchOp branchOp, ArrayRef<Value *> properOperands,
- ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREESeq::HL::BranchOp>(
- branchOp, destinations[0], operands[0]);
- return this->matchSuccess();
- }
-};
-
-struct CondBranchOpLowering : public SequencerConversionPattern<CondBranchOp> {
- using SequencerConversionPattern::SequencerConversionPattern;
- PatternMatchResult matchAndRewrite(
- CondBranchOp condBranchOp, ArrayRef<Value *> properOperands,
- ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto *condValue =
- loadAccessValue(condBranchOp.getLoc(), properOperands[0], rewriter);
- rewriter.replaceOpWithNewOp<IREESeq::HL::CondBranchOp>(
- condBranchOp, condValue,
- destinations[IREESeq::HL::CondBranchOp::trueIndex],
- operands[IREESeq::HL::CondBranchOp::trueIndex],
- destinations[IREESeq::HL::CondBranchOp::falseIndex],
- operands[IREESeq::HL::CondBranchOp::falseIndex]);
- return this->matchSuccess();
- }
-};
-
-struct AllocOpLowering : public SequencerConversionPattern<AllocOp> {
- using SequencerConversionPattern::SequencerConversionPattern;
- PatternMatchResult matchAndRewrite(
- AllocOp allocOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- // TODO(benvanik): replace with length computation.
- rewriter.replaceOpWithNewOp<IREESeq::HL::AllocHeapOp>(
- allocOp, allocOp.getType(), operands);
- return matchSuccess();
- }
-};
-
-struct DeallocOpLowering : public SequencerConversionPattern<DeallocOp> {
- using SequencerConversionPattern::SequencerConversionPattern;
- PatternMatchResult matchAndRewrite(
- DeallocOp deallocOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREESeq::HL::DiscardOp>(deallocOp, operands[0]);
- return matchSuccess();
- }
-};
-
-struct LoadOpLowering : public SequencerConversionPattern<LoadOp> {
- using SequencerConversionPattern::SequencerConversionPattern;
- PatternMatchResult matchAndRewrite(
- LoadOp loadOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (loadOp.getMemRefType().getRank() != 0) {
- loadOp.emitError() << "Cannot lower load of non-scalar";
- return matchFailure();
- }
- ArrayRef<Value *> dimPieces;
- auto dst = rewriter.create<AllocOp>(loadOp.getLoc(), loadOp.getMemRefType(),
- dimPieces);
- auto emptyArrayMemref = createArrayConstant(rewriter, loadOp.getLoc(), {});
- rewriter.create<IREESeq::HL::CopyOp>(loadOp.getLoc(), loadOp.getMemRef(),
- /*srcIndices=*/emptyArrayMemref, dst,
- /*dstIndices=*/emptyArrayMemref,
- /*lengths=*/emptyArrayMemref);
-
- rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(loadOp, dst);
-
- return matchSuccess();
- }
-};
-
-struct StoreOpLowering : public SequencerConversionPattern<StoreOp> {
- using SequencerConversionPattern::SequencerConversionPattern;
- PatternMatchResult matchAndRewrite(
- StoreOp storeOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (storeOp.getMemRefType().getRank() != 0) {
- storeOp.emitError() << "Cannot lower store of non-scalar";
- return matchFailure();
- }
-
- auto src = rewriter.create<IREE::ScalarToMemRefOp>(
- storeOp.getLoc(), storeOp.getValueToStore());
-
- auto emptyArrayMemref = createArrayConstant(rewriter, storeOp.getLoc(), {});
- rewriter.replaceOpWithNewOp<IREESeq::HL::CopyOp>(
- storeOp, src, /*srcIndices=*/emptyArrayMemref, storeOp.getMemRef(),
- /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref);
-
- return matchSuccess();
- }
-};
-
-} // namespace
-
-void populateLowerStdToSequencerPatterns(OwningRewritePatternList &patterns,
- MLIRContext *context) {
- patterns.insert<
- // Control flow.
- CallOpLowering, CallIndirectOpLowering, ReturnOpLowering,
- BranchOpLowering, CondBranchOpLowering,
- // Memory management.
- AllocOpLowering, DeallocOpLowering, LoadOpLowering, StoreOpLowering>(
- context);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp b/iree/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp
deleted file mode 100644
index 658120b..0000000
--- a/iree/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp
+++ /dev/null
@@ -1,62 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Sequencer/HLDialect.h"
-#include "iree/compiler/IR/Sequencer/LLDialect.h"
-#include "iree/compiler/Transforms/Rewrites.h"
-#include "iree/compiler/Transforms/Sequencer/Rewrites.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace {
-
-class LowerToSequencerDialectPass
- : public FunctionPass<LowerToSequencerDialectPass> {
- public:
- void runOnFunction() override {
- OwningRewritePatternList patterns;
- auto* ctx = &getContext();
- xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, ctx);
- xla_hlo::PopulateXlaToStdPatterns(&patterns, ctx);
- populateLowerStdToIreePatterns(patterns, ctx);
- populateLowerStdToSequencerPatterns(patterns, ctx);
- populateLowerXlaToIreePatterns(patterns, ctx);
- populateLowerXlaToSequencerPatterns(patterns, ctx);
-
- ConversionTarget target(getContext());
- target.addLegalDialect<IREEHLSequencerDialect, IREEDialect>();
- target.addLegalOp<FuncOp>();
- if (failed(applyFullConversion(getFunction(), target, patterns))) {
- return signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<OpPassBase<FuncOp>> createLowerToSequencerDialectPass() {
- return std::make_unique<LowerToSequencerDialectPass>();
-}
-
-static PassRegistration<LowerToSequencerDialectPass> pass(
- "lower-to-iree-sequencer", "Convert all ops to the IREE sequencer dialect");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp b/iree/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp
deleted file mode 100644
index 785015c..0000000
--- a/iree/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp
+++ /dev/null
@@ -1,224 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/IR/Dialect.h"
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/Sequencer/HLDialect.h"
-#include "iree/compiler/IR/Sequencer/HLOps.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Transforms/ConversionUtils.h"
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "iree/compiler/Utils/OpCreationUtils.h"
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// TODO(suderman): tablegen this? or something a bit more flexible.
-
-#define UNARY_OP_LOWERING(XlaOpType, IREEOpType) \
- struct XlaOpType##Lowering \
- : public UnaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \
- using UnaryOpLowering::UnaryOpLowering; \
- };
-
-#define TERNARY_OP_LOWERING(XlaOpType, IREEOpType) \
- struct XlaOpType##Lowering \
- : public TernaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \
- using TernaryOpLowering::TernaryOpLowering; \
- };
-
-UNARY_OP_LOWERING(CopyOp, IREESeq::HL::CloneOp);
-
-#undef UNARY_OP_LOWERING
-#undef TERNARY_OP_LOWERING
-
-template <typename T>
-static Operation *createShapeTargetingOp(ConversionPatternRewriter &rewriter,
- Location loc, Value *input,
- MemRefType targetType) {
- auto shapeOp = createArrayConstant(rewriter, loc, targetType.getShape());
- return rewriter.create<T>(loc, targetType, input, shapeOp);
-}
-
-static Value *inputAsMemref(ConversionPatternRewriter &rewriter, Operation *op,
- Value *tensor) {
- return wrapAsMemRef(loadAccessValue(op->getLoc(), tensor, rewriter), op,
- rewriter);
-}
-
-template <typename SrcOp>
-class XlaOpLowering : public OpConversionPattern<SrcOp> {
- public:
- using OpConversionPattern<SrcOp>::OpConversionPattern;
-
- PatternMatchResult matchAndRewrite(
- SrcOp op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto srcOp = cast<SrcOp>(op);
-
- SmallVector<Value *, 4> memrefOperands;
- for (auto operand : operands) {
- memrefOperands.push_back(inputAsMemref(rewriter, op, operand));
- }
-
- auto dstOp = rewriteInternal(&srcOp, memrefOperands, rewriter);
- rewriter.replaceOp(op, wrapAsTensor(dstOp->getResult(0), srcOp, rewriter));
- return this->matchSuccess();
- }
-
- protected:
- virtual Operation *rewriteInternal(
- SrcOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?");
- }
-};
-
-struct ConcatOpLowering : public XlaOpLowering<xla_hlo::ConcatenateOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::ConcatenateOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto finalType = convertTypeToMemRef(*op);
-
- return rewriter.create<IREESeq::HL::ConcatOp>(
- op->getLoc(), finalType, operands,
- rewriter.getI32IntegerAttr(op->dimension().getZExtValue()));
- }
-};
-
-struct DynamicUpdateSliceLowering
- : public XlaOpLowering<xla_hlo::DynamicUpdateSliceOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::DynamicUpdateSliceOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto operand = operands[0];
- auto update = operands[1];
-
- auto updateType = update->getType().cast<ShapedType>();
- Value *lengthConstant =
- createArrayConstant(rewriter, op->getLoc(), updateType.getShape());
-
- auto startIndices = makeArrayRef(operands).drop_front(2);
- const int rank = startIndices.size();
- llvm::SmallVector<Value *, 4> valuesToConcat;
- valuesToConcat.reserve(startIndices.size());
- auto type = getElementTypeOrSelf(startIndices.front());
-
- // To generate the offset matrix we need to convert the variadic tensors
- // into a reshaped and concated value.
- for (auto index : startIndices) {
- auto reshapedIndex = rewriter.create<IREESeq::HL::ReshapeOp>(
- op->getLoc(), MemRefType::get({1}, type), index,
- createArrayConstant(rewriter, op->getLoc(), {1}));
- valuesToConcat.push_back(reshapedIndex);
- }
-
- auto dstOffset = rewriter
- .create<IREESeq::HL::ConcatOp>(
- op->getLoc(), MemRefType::get({rank}, type),
- valuesToConcat, rewriter.getI32IntegerAttr(0))
- .getResult();
-
- llvm::SmallVector<int64_t, 4> zero_offset;
- zero_offset.resize(updateType.getRank(), 0);
- auto srcOffset = createArrayConstant(rewriter, op->getLoc(), zero_offset);
-
- auto copiedOperand = rewriter.create<IREESeq::HL::CloneOp>(
- op->getLoc(), operand->getType(), operand);
-
- rewriter
- .create<IREESeq::HL::CopyOp>(op->getLoc(), update, srcOffset,
- copiedOperand, dstOffset, lengthConstant)
- .getOperation();
-
- return copiedOperand;
- }
-};
-
-struct SliceLowering : public XlaOpLowering<xla_hlo::SliceOp> {
- using XlaOpLowering::XlaOpLowering;
- Operation *rewriteInternal(
- xla_hlo::SliceOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- // XLA slice has value semantics, whereas the IREE slice creates a view. We
- // lower it to a copy if all strides are one which may be transformed to a
- // slice by later optimizations.
- auto isNotOne = [](APInt stride) { return stride != 1; };
- if (llvm::any_of(op->strides(), isNotOne)) {
- op->emitRemark() << "Could not lower slice op with non-singular strides";
- return nullptr;
- }
-
- auto finalType = convertTypeToMemRef(*op);
-
- auto src = operands[0];
- std::vector<Value *> dim_pieces;
- auto dst = rewriter.create<IREESeq::HL::AllocHeapOp>(op->getLoc(),
- finalType, dim_pieces);
- auto srcIndices =
- rewriter.create<IREE::ConstantOp>(op->getLoc(), op->start_indices());
- auto lengths =
- createArrayConstant(rewriter, op->getLoc(), finalType.getShape());
-
- llvm::SmallVector<int64_t, 4> zero_offset;
- zero_offset.resize(finalType.getRank(), 0);
- auto dstIndices = createArrayConstant(rewriter, op->getLoc(), zero_offset);
-
- rewriter.create<IREESeq::HL::CopyOp>(op->getLoc(), src, srcIndices, dst,
- dstIndices, lengths);
- return dst;
- }
-};
-
-struct ReshapeOpLowering : public XlaOpLowering<xla_hlo::ReshapeOp> {
- using XlaOpLowering::XlaOpLowering;
-
- Operation *rewriteInternal(
- xla_hlo::ReshapeOp *op, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter) const override {
- return createShapeTargetingOp<IREESeq::HL::ReshapeOp>(
- rewriter, op->getLoc(), operands[0], convertTypeToMemRef(*op));
- }
-};
-
-} // namespace
-
-void populateLowerXlaToSequencerPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns.insert<ConcatOpLowering, CopyOpLowering, DynamicUpdateSliceLowering,
- ReshapeOpLowering, SliceLowering>(ctx);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp b/iree/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp
deleted file mode 100644
index fff3e8c..0000000
--- a/iree/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp
+++ /dev/null
@@ -1,242 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <utility>
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/Sequencer/HLOps.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/IR/Types.h"
-#include "iree/compiler/Utils/DispatchUtils.h"
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-#include "llvm/ADT/SetVector.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Inserts a load from a wrapped memref (as inserted via insertDispatcherStore).
-// Returns the value in the original type.
-Value *insertDispatcheeLoad(Operation *op, Type originalType, Value *value,
- OpBuilder &builder) {
- // If old value was a memref we don't need to change anything.
- if (originalType.isa<MemRefType>()) {
- return value;
- }
-
- auto loadInputOp =
- builder.create<IREE::LoadInputOp>(op->getLoc(), originalType, value);
- value->replaceAllUsesWith(loadInputOp.getResult());
- loadInputOp.setOperand(value);
- return loadInputOp.getResult();
-}
-
-// Marshals args and results as buffers for the given region.
-// Beyond inserting the appropriate tensor-to-memref ops we avoid mutating the
-// interior of the dispatch region as much as possible.
-LogicalResult marshalDispatchSite(IREE::DispatchRegionOp regionOp) {
- auto &entryBlock = regionOp.getBody().getBlocks().front();
- OpBuilder dispatcherBuilder(regionOp);
- OpBuilder dispatcheeBuilder(&entryBlock, entryBlock.begin());
-
- // Wrap input operands and unwrap in the entry block.
- SmallVector<Value *, 8> newArgs;
- for (int i = 0; i < regionOp.getNumArgOperands(); ++i) {
- // Wrap the input outside of the region.
- auto *blockArg = entryBlock.getArgument(i);
- Type originalType = blockArg->getType();
- auto *originalArg = regionOp.getArgOperand(i);
- auto *wrappedArg =
- insertDispatcherStore(regionOp, originalArg, dispatcherBuilder);
- newArgs.push_back(wrappedArg);
- blockArg->setType(wrappedArg->getType());
-
- // Unwrap the block arg value and replace all of the uses with the newly
- // unwrapped value.
- insertDispatcheeLoad(regionOp, originalType, blockArg, dispatcheeBuilder);
- }
-
- // Allocate output arguments and replace the return values with those.
- SmallVector<Type, 8> newResults;
- SmallVector<std::pair<int, Value *>, 8> resultIndicesToOutputArgs;
- SmallVector<int, 8> deadResultIndices;
- SmallVector<std::pair<Value *, Value *>, 8> replacedResults;
- for (int i = 0; i < regionOp.getNumResults(); ++i) {
- auto *result = regionOp.getResult(i);
- auto convertedType = convertTypeToMemRef(result->getType());
-
- // Allocate output buffer in the dispatcher to pass in to the region.
- Value *allocatedValue = allocateDispatchOutputBuffer(
- regionOp.getLoc(), convertedType, dispatcherBuilder);
- if (!allocatedValue) {
- regionOp.emitError("unable to allocate result value");
- return failure();
- }
- newArgs.push_back(allocatedValue);
-
- auto *newBlockArg = entryBlock.addArgument(allocatedValue->getType());
- resultIndicesToOutputArgs.push_back({i, newBlockArg});
-
- // NOTE: right now we always replace results. If we want to allow return
- // values we can avoid killing them here.
- deadResultIndices.push_back(i);
- replacedResults.push_back({result, allocatedValue});
- }
-
- // Remove dead results from return statements.
- regionOp.walk([&](IREE::ReturnOp returnOp) {
- // Replace the results we were returning with stores to output arguments.
- OpBuilder builder(returnOp);
- for (auto resultToArg : resultIndicesToOutputArgs) {
- auto *value = returnOp.getOperand(resultToArg.first);
- auto *outputArg = resultToArg.second;
- builder.create<IREE::StoreOutputOp>(returnOp.getLoc(), value, outputArg);
- }
-
- // Filter out the results that are now dead.
- SmallVector<Value *, 8> newOperands(returnOp.getOperands());
- for (int i = deadResultIndices.size() - 1; i >= 0; --i) {
- newOperands.erase(newOperands.begin() + deadResultIndices[i]);
- }
- returnOp.getOperation()->setOperands(newOperands);
- });
-
- // Clone the region op with the new args/results.
- auto newRegionOp = dispatcherBuilder.create<IREE::DispatchRegionOp>(
- regionOp.getLoc(), newResults, regionOp.getWorkload(), newArgs);
- newRegionOp.getBody().takeBody(regionOp.getBody());
-
- // Marshal back the results by replacing uses of the original with loads from
- // the new output arg.
- for (auto &it : replacedResults) {
- insertDispatcherLoad(regionOp, it.first, it.second, dispatcherBuilder);
- }
-
- // Remove original region.
- regionOp.erase();
-
- return success();
-}
-
-// Converts a dispatch_region into a dispatch to the outlined region function.
-LogicalResult convertToDispatchOp(IREE::DispatchRegionOp regionOp,
- IREE::MultiArchExecutableOp executable,
- FuncOp entryPoint) {
- // Insert at the same place as the original region.
- OpBuilder dispatcherBuilder(regionOp);
-
- // Ensure workload is a memref.
- auto *workload =
- wrapAsMemRef(regionOp.getWorkload(), regionOp, dispatcherBuilder);
-
- // Create the dispatch op to the executable function.
- SmallVector<Value *, 8> operandValues(regionOp.getArgOperands());
- auto dispatchOp = dispatcherBuilder.create<IREESeq::HL::DispatchOp>(
- regionOp.getLoc(), executable.getName(), entryPoint.getName(), workload,
- entryPoint.getType().getResults(), operandValues);
-
- // Replace uses of the existing results with the new results.
- for (int i = 0; i < regionOp.getNumResults(); ++i) {
- regionOp.getResult(i)->replaceAllUsesWith(dispatchOp.getResult(i));
- }
-
- // Erase original region.
- regionOp.erase();
-
- return success();
-}
-
-// Outlines a dispatch region into an iree.multi_arch_executable.
-LogicalResult outlineDispatchRegion(IREE::DispatchRegionOp regionOp,
- int outlinedRegionOrdinal) {
- // Build function type matching 1:1 with the region signature.
- SmallVector<Type, 8> operandTypes;
- for (auto *arg : regionOp.getArgOperands()) {
- operandTypes.push_back(arg->getType());
- }
- SmallVector<Type, 8> resultTypes(regionOp.getResultTypes());
- auto functionType =
- FunctionType::get(operandTypes, resultTypes, regionOp.getContext());
-
- // Create the executable with the region cloned into it.
- IREE::MultiArchExecutableOp multiArchExecutable;
- FuncOp outlinedFunc;
- std::tie(multiArchExecutable, outlinedFunc) = createRegionExecutable(
- regionOp, functionType,
- "_dispatch_" + std::to_string(outlinedRegionOrdinal));
- outlinedFunc.setAttr("iree.executable.export",
- UnitAttr::get(regionOp.getContext()));
-
- // Finally convert the dispatch region into a dispatch to the outlined func.
- return convertToDispatchOp(regionOp, multiArchExecutable, outlinedFunc);
-}
-
-} // namespace
-
-class OutlineDispatchRegionsPass
- : public ModulePass<OutlineDispatchRegionsPass> {
- public:
- void runOnModule() override {
- auto module = getModule();
-
- ModuleManager moduleManager(module);
- auto funcs = module.getOps<FuncOp>();
- SmallVector<FuncOp, 4> funcOps(funcs.begin(), funcs.end());
- for (auto func : funcOps) {
- // Perform marshaling of the dispatcher and dispatchee I/O.
- // This inserts the required stores and loads to make everything memrefs
- // and adds the iree.load_input/iree.store_output ops to the dispatchee.
- if (func.walk([&](IREE::DispatchRegionOp op) {
- if (failed(marshalDispatchSite(op))) {
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- })
- .wasInterrupted()) {
- return signalPassFailure();
- }
-
- // Outline all of the iree.dispatch_region ops in this function.
- SmallVector<IREE::DispatchRegionOp, 8> dispatchRegionOps;
- func.walk(
- [&](IREE::DispatchRegionOp op) { dispatchRegionOps.push_back(op); });
- for (int i = 0; i < dispatchRegionOps.size(); ++i) {
- if (failed(outlineDispatchRegion(dispatchRegionOps[i], i))) {
- return signalPassFailure();
- }
- }
- }
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>> createOutlineDispatchRegionsPass() {
- return std::make_unique<OutlineDispatchRegionsPass>();
-}
-
-static PassRegistration<OutlineDispatchRegionsPass> pass(
- "iree-outline-dispatch-regions",
- "Outlines dispatch regions into standalone functions");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp b/iree/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp
deleted file mode 100644
index 9636d1b..0000000
--- a/iree/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp
+++ /dev/null
@@ -1,307 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <utility>
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/Sequencer/HLOps.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/IR/Types.h"
-#include "iree/compiler/Utils/DispatchUtils.h"
-#include "iree/compiler/Utils/MemRefUtils.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SetVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Determines the shapes involved with reducing this dimension.
-SmallVector<int64_t, 4> calculateResultShape(Value *input,
- int windowDimension) {
- SmallVector<int64_t, 4> resultShape;
- for (auto it :
- llvm::enumerate(input->getType().cast<ShapedType>().getShape())) {
- if (it.index() != windowDimension) {
- resultShape.push_back(it.value());
- }
- }
- return resultShape;
-}
-
-// Creates an executable that holds the given elemental reduction region.
-// The executable will have an entry point taking the specified reduction values
-// and writing the results to output arguments.
-std::pair<IREE::MultiArchExecutableOp, FuncOp> createReductionExecutable(
- IREE::ReductionRegionOp regionOp, int outlinedRegionOrdinal,
- int separatedReductionIndex, int reductionDimension,
- SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs) {
- Builder builder(regionOp.getContext());
-
- // Build function type matching 1:1 with the region signature.
- SmallVector<Type, 8> elementalOperandTypes;
- SmallVector<Type, 8> elementalResultTypes;
- for (auto *arg : regionOp.getInitialValueOperands()) {
- // (in0, in1) -> out0
- elementalOperandTypes.push_back(arg->getType());
- elementalOperandTypes.push_back(arg->getType());
- elementalResultTypes.push_back(arg->getType());
- }
- auto elementalFunctionType = FunctionType::get(
- elementalOperandTypes, elementalResultTypes, regionOp.getContext());
-
- // Create the executable with the region cloned into it.
- IREE::MultiArchExecutableOp multiArchExecutable;
- FuncOp elementalFunc;
- std::tie(multiArchExecutable, elementalFunc) = createRegionExecutable(
- regionOp, elementalFunctionType,
- "_reduce_" + std::to_string(outlinedRegionOrdinal) + "_dim_" +
- std::to_string(separatedReductionIndex));
-
- // Create a new entry point that we can use with the signature for this
- // dimension.
- SmallVector<Type, 8> allOperandTypes;
- auto inputTypes =
- llvm::map_range(inputs, [](Value *value) { return value->getType(); });
- allOperandTypes.append(inputTypes.begin(), inputTypes.end());
- auto initialValueTypes = llvm::map_range(
- initialValues, [](Value *value) { return value->getType(); });
- allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end());
- for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) {
- auto shapedType = resultType.value().cast<ShapedType>();
- allOperandTypes.push_back(MemRefType::get(
- calculateResultShape(inputs[resultType.index()], reductionDimension),
- shapedType.getElementType()));
- }
- auto entryFuncType = FunctionType::get(allOperandTypes, ArrayRef<Type>{},
- regionOp.getContext());
- auto entryFunc =
- FuncOp::create(regionOp.getLoc(),
- (elementalFunc.getName() + "_entry").str(), entryFuncType);
- entryFunc.setAttr("iree.executable.export",
- UnitAttr::get(regionOp.getContext()));
- elementalFunc.getOperation()->getBlock()->push_back(entryFunc);
- entryFunc.getOperation()->moveBefore(elementalFunc);
- entryFunc.setAttr("iree.executable.reduction",
- UnitAttr::get(regionOp.getContext()));
- entryFunc.setAttr("iree.executable.reduction.apply",
- builder.getSymbolRefAttr(elementalFunc));
-
- return {multiArchExecutable, entryFunc};
-}
-
-// Converts a reduction_region into a dispatch to the outlined region function
-// for a single reduction dimension.
-// Returns the results of the reduction or empty if the construction fails.
-SmallVector<Value *, 4> convertToDispatchOp(
- IREE::ReductionRegionOp regionOp, IREE::MultiArchExecutableOp executable,
- FuncOp entryFunc, int reductionDimension,
- SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs,
- OpBuilder &dispatcherBuilder) {
- // Allocate output args and replace the return values with those.
- SmallVector<Value *, 4> resultValues;
- for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) {
- // Allocate output buffer in the dispatcher to pass in to the region.
- auto shapedType = resultType.value().cast<ShapedType>();
- Value *allocatedValue = allocateDispatchOutputBuffer(
- regionOp.getLoc(),
- MemRefType::get(calculateResultShape(inputs[resultType.index()],
- reductionDimension),
- shapedType.getElementType()),
- dispatcherBuilder);
- if (!allocatedValue) {
- regionOp.emitError("unable to allocate result value");
- return {};
- }
- resultValues.push_back(allocatedValue);
- }
-
- // Calculate workload from the result shape.
- auto *workload =
- wrapAsMemRef(calculateWorkload(regionOp, resultValues.front()), regionOp,
- dispatcherBuilder);
-
- // Create the reduce op to the executable function.
- std::vector<Value *> allOperands;
- allOperands.insert(allOperands.end(), inputs.begin(), inputs.end());
- allOperands.insert(allOperands.end(), initialValues.begin(),
- initialValues.end());
- allOperands.insert(allOperands.end(), resultValues.begin(),
- resultValues.end());
- dispatcherBuilder.create<IREESeq::HL::DispatchOp>(
- regionOp.getLoc(), executable.getName(), entryFunc.getName(), workload,
- ArrayRef<Type>{}, allOperands);
-
- return resultValues;
-}
-
-// Outlines a reduction region into one or more iree.multi_arch_executables.
-// This separates the reduction into multiple dispatches, one for each reduction
-// dimension (thankfully XLA's operation semantics state this is ok). We then
-// special case the first dispatch such that it takes the constant initial
-// values so that we don't have to materialize a buffer for them.
-LogicalResult outlineReductionRegion(IREE::ReductionRegionOp regionOp,
- int outlinedRegionOrdinal) {
- // Insert at the same place as the original region.
- OpBuilder dispatcherBuilder(regionOp);
-
- // Wrap input operands in memrefs.
- SmallVector<Value *, 4> initialValues{llvm::map_range(
- regionOp.getInitialValueOperands(), [&](Value *originalArg) {
- return insertDispatcherStore(regionOp, originalArg, dispatcherBuilder);
- })};
- SmallVector<Value *, 4> temps{
- llvm::map_range(regionOp.getReductionOperands(), [&](Value *originalArg) {
- return insertDispatcherStore(regionOp, originalArg, dispatcherBuilder);
- })};
-
- // Create one dispatch per dimension being reduced.
- // We'll do this by chaining the original input through with the temporary
- // reduction results. The results we end up with will be the originally
- // requested shape and we can just substitute them.
- if (regionOp.isWindowed()) {
- auto windowDimensions = regionOp.window_dimensions().getValue();
- auto windowStrides = regionOp.window_strides().getValue();
- auto baseDilations = regionOp.base_dilations().getValue();
- auto windowDilations = regionOp.window_dilations().getValue();
- SmallVector<std::tuple<int64_t, int64_t, int64_t, int64_t>, 4>
- sortedWindowAttrs;
- for (uint64_t i = 0; i < windowDimensions.getNumElements(); ++i) {
- int64_t windowDimension =
- windowDimensions.getValue<IntegerAttr>({i}).getInt();
- int64_t windowStride = windowStrides.getValue<IntegerAttr>({i}).getInt();
- int64_t baseDilation = baseDilations.getValue<IntegerAttr>({i}).getInt();
- int64_t windowDilation =
- windowDilations.getValue<IntegerAttr>({i}).getInt();
- sortedWindowAttrs.push_back(
- {windowDimension, windowStride, baseDilation, windowDilation});
- }
- llvm::sort(sortedWindowAttrs,
- [](std::tuple<int64_t, int64_t, int64_t, int64_t> a,
- std::tuple<int64_t, int64_t, int64_t, int64_t> b) {
- return std::get<0>(a) - std::get<0>(b);
- });
- for (auto windowAttrs : llvm::enumerate(sortedWindowAttrs)) {
- int64_t windowDimension = std::get<0>(windowAttrs.value());
- int64_t windowStride = std::get<1>(windowAttrs.value());
- int64_t baseDilation = std::get<2>(windowAttrs.value());
- int64_t windowDilation = std::get<3>(windowAttrs.value());
- IREE::MultiArchExecutableOp multiArchExecutable;
- FuncOp entryFunc;
- std::tie(multiArchExecutable, entryFunc) = createReductionExecutable(
- regionOp, outlinedRegionOrdinal, windowAttrs.index(), windowDimension,
- initialValues, temps);
- entryFunc.setAttr("iree.executable.reduction.padding_mode",
- dispatcherBuilder.getI32IntegerAttr(
- regionOp.padding_mode().getValue()));
- entryFunc.setAttr("iree.executable.reduction.window_dimension",
- dispatcherBuilder.getI32IntegerAttr(windowDimension));
- entryFunc.setAttr("iree.executable.reduction.window_stride",
- dispatcherBuilder.getI32IntegerAttr(windowStride));
- entryFunc.setAttr("iree.executable.reduction.base_dilation",
- dispatcherBuilder.getI32IntegerAttr(baseDilation));
- entryFunc.setAttr("iree.executable.reduction.window_dilation",
- dispatcherBuilder.getI32IntegerAttr(windowDilation));
- temps = convertToDispatchOp(regionOp, multiArchExecutable, entryFunc,
- windowDimension, initialValues,
- std::move(temps), dispatcherBuilder);
- if (temps.empty()) {
- return regionOp.emitOpError()
- << "Failed to construct reduction for windowed dimension "
- << windowDimension;
- }
- }
- } else {
- auto dimensions = regionOp.dimensions().getValue();
- SmallVector<int64_t, 4> sortedDimensions;
- for (uint64_t i = 0; i < dimensions.getNumElements(); ++i) {
- sortedDimensions.push_back(
- dimensions.getValue<IntegerAttr>({i}).getInt());
- }
- llvm::sort(sortedDimensions, [](int64_t a, int64_t b) { return a - b; });
- for (auto dimension : llvm::enumerate(sortedDimensions)) {
- IREE::MultiArchExecutableOp multiArchExecutable;
- FuncOp entryFunc;
- std::tie(multiArchExecutable, entryFunc) = createReductionExecutable(
- regionOp, outlinedRegionOrdinal, dimension.index(), dimension.value(),
- initialValues, temps);
- entryFunc.setAttr("iree.executable.reduction.dimension",
- dispatcherBuilder.getI32IntegerAttr(dimension.value()));
- temps = convertToDispatchOp(regionOp, multiArchExecutable, entryFunc,
- dimension.value(), initialValues,
- std::move(temps), dispatcherBuilder);
- if (temps.empty()) {
- return regionOp.emitOpError()
- << "Failed to construct reduction for dimension "
- << dimension.value();
- }
- }
- }
- for (auto it : llvm::enumerate(regionOp.getResults())) {
- insertDispatcherLoad(regionOp, it.value(), temps[it.index()],
- dispatcherBuilder);
- }
-
- // Erase original region.
- regionOp.erase();
-
- return success();
-}
-
-} // namespace
-
-class OutlineReductionRegionsPass
- : public ModulePass<OutlineReductionRegionsPass> {
- public:
- void runOnModule() override {
- auto module = getModule();
-
- ModuleManager moduleManager(module);
- auto funcs = module.getOps<FuncOp>();
- SmallVector<FuncOp, 4> funcOps(funcs.begin(), funcs.end());
- for (auto func : funcOps) {
- // Outline all of the iree.reduction_region ops in this function.
- std::vector<IREE::ReductionRegionOp> reductionRegionOps;
- func.walk([&](IREE::ReductionRegionOp op) {
- reductionRegionOps.push_back(op);
- });
- for (int i = 0; i < reductionRegionOps.size(); ++i) {
- if (failed(outlineReductionRegion(reductionRegionOps[i], i))) {
- return signalPassFailure();
- }
- }
- }
- }
-};
-
-std::unique_ptr<OpPassBase<ModuleOp>> createOutlineReductionRegionsPass() {
- return std::make_unique<OutlineReductionRegionsPass>(); // NOLINT
-}
-
-static PassRegistration<OutlineReductionRegionsPass> pass(
- "iree-outline-reduction-regions",
- "Outlines reduction regions into standalone functions");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp b/iree/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp
deleted file mode 100644
index c541e4e..0000000
--- a/iree/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp
+++ /dev/null
@@ -1,149 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <algorithm>
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Utils/DispatchUtils.h"
-#include "llvm/Support/Debug.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Chosen randomly for now. We can measure and see what makes sense.
-constexpr int64_t kMaxRematerializedConstantSizeInBytes = 1 * 1024;
-
-// Returns true if the constant value is under a certain threshold.
-// This threshold is fixed for all backends as a value that is assumed small
-// enough to be worth inlining possibly several times (at the cost of binary
-// bloat).
-bool isConstantSmall(ConstantOp constantOp) {
- if (auto shapedType = constantOp.getType().dyn_cast<ShapedType>()) {
- return shapedType.getSizeInBits() / 8 <=
- kMaxRematerializedConstantSizeInBytes;
- }
-
- // Assume anything unshaped is small. This may not always be true in custom
- // dialects but is in std for now.
- return true;
-}
-
-// Returns true if the dispatch region is allowed to have constants inside.
-// Certain regions that may get replaced or turned into kernel imports shouldn't
-// have the constants moved into them as they'll just get lost.
-bool canDispatchRegionContainConstants(
- IREE::DispatchRegionOp dispatchRegionOp) {
- for (auto &block : dispatchRegionOp.getBody()) {
- for (auto &op : block) {
- if (isa<xla_hlo::DotOp>(&op)) {
- return false;
- }
- }
- }
- return true;
-}
-
-// Rematerializes a constant inside of all dispatch regions that use it.
-// Afterward the constant is only removed if there are no other uses within the
-// non-dispatch block (such as by sequencer ops).
-LogicalResult rematerializeConstantInDispatchRegions(ConstantOp constantOp) {
- Value *constantValue = constantOp.getResult();
- SmallVector<IREE::DispatchRegionOp, 4> usingRegionOps;
- for (auto *user : constantValue->getUsers()) {
- if (auto dispatchRegionOp = dyn_cast<IREE::DispatchRegionOp>(user)) {
- // Ensure this isn't just the workload and is used as an arg.
- if (std::find(dispatchRegionOp.arg_operand_begin(),
- dispatchRegionOp.arg_operand_end(),
- constantValue) != dispatchRegionOp.arg_operand_end()) {
- if (canDispatchRegionContainConstants(dispatchRegionOp)) {
- usingRegionOps.push_back(dispatchRegionOp);
- }
- }
- }
- }
- for (auto &dispatchRegionOp : usingRegionOps) {
- if (failed(inlineDispatchRegionOperandsUsingValue(dispatchRegionOp,
- constantValue))) {
- return failure();
- }
- }
-
- // Remove if there are no other uses within the block.
- if (constantOp.use_empty()) {
- constantOp.erase();
- }
-
- return success();
-}
-
-} // namespace
-
-// Finds constant arguments to dispatch regions that are too small to be worth
-// putting into constant pools. This prevents things like a CSE'd scalar
-// constant of 0.0 being passed by reference to a bunch of regions. Later
-// backend-specific passes running on the dispatch regions may also be able to
-// improve their constant propagation chances by having the full constant value
-// available.
-//
-// Note that this currently only operates at the block level. Constants that are
-// pushed across branches are assumed to have been rematerialized within blocks
-// already, but if that isn't the case then this pass can be extended to do
-// that.
-class RematerializeDispatchConstantsPass
- : public FunctionPass<RematerializeDispatchConstantsPass> {
- public:
- void runOnFunction() override {
- for (auto &block : getFunction()) {
- SmallVector<ConstantOp, 8> smallConstantOps;
- for (auto constantOp : block.getOps<ConstantOp>()) {
- if (isConstantSmall(constantOp)) {
- smallConstantOps.push_back(constantOp);
- }
- }
- // Note: we iterate in reverse so that the rematerialized constants appear
- // in the same order they did originally (as insertion is at the top).
- for (auto constantOp : llvm::reverse(smallConstantOps)) {
- if (failed(rematerializeConstantInDispatchRegions(constantOp))) {
- return signalPassFailure();
- }
- }
- }
- }
-};
-
-std::unique_ptr<OpPassBase<FuncOp>> createRematerializeDispatchConstantsPass() {
- return std::make_unique<RematerializeDispatchConstantsPass>();
-}
-
-static PassRegistration<RematerializeDispatchConstantsPass> pass(
- "iree-rematerialize-dispatch-constants",
- "Rematerializes small previously-CSE'd constants into dispatch regions.");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/Rewrites.h b/iree/compiler/Transforms/Sequencer/Rewrites.h
deleted file mode 100644
index 87fa105..0000000
--- a/iree/compiler/Transforms/Sequencer/Rewrites.h
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_
-#define IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_
-
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Adds rewrite patterns for lowering IREE Sequencer HL ops (iree_hl_seq.*)
-// to LL ops (iree_ll_seq.*).
-void populateSequencerLoweringPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
-
-// Adds rewrite patterns for lowering xla_hlo ops to Sequencer HL ops.
-void populateLowerXlaToSequencerPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
-
-// Adds rewrite patterns for lowering standard ops to Sequencer HL ops.
-void populateLowerStdToSequencerPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_
diff --git a/iree/compiler/Transforms/Sequencer/test/BUILD b/iree/compiler/Transforms/Sequencer/test/BUILD
deleted file mode 100644
index 74a48a3..0000000
--- a/iree/compiler/Transforms/Sequencer/test/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-# Tests specific to the sequencer.
-
-load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_setup_lit_package(
- data = [
- "//iree/tools:iree-opt",
- ],
-)
-
-iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/test/BUILD b/iree/compiler/Transforms/test/BUILD
deleted file mode 100644
index 79700a8..0000000
--- a/iree/compiler/Transforms/test/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-# Tests for common transforms.
-
-load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_setup_lit_package(
- data = [
- "//iree/tools:iree-opt",
- ],
-)
-
-iree_glob_lit_tests()
diff --git a/iree/compiler/Translation/Interpreter/BUILD b/iree/compiler/Translation/Interpreter/BUILD
deleted file mode 100644
index 8dff827..0000000
--- a/iree/compiler/Translation/Interpreter/BUILD
+++ /dev/null
@@ -1,31 +0,0 @@
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Interpreter",
- srcs = ["InterpreterExecutableTranslation.cpp"],
- hdrs = ["InterpreterExecutableTranslation.h"],
- deps = [
- "//iree/compiler/IR",
- "//iree/compiler/IR/Interpreter",
- "//iree/compiler/Serialization",
- "//iree/compiler/Transforms",
- "//iree/compiler/Transforms/Interpreter",
- "//iree/compiler/Utils",
- "//iree/schemas",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:Pass",
- "@local_config_mlir//:StandardDialectRegistration",
- "@local_config_mlir//:Support",
- "@local_config_mlir//:Transforms",
- "@local_config_mlir//:Translation",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
- ],
- alwayslink = 1,
-)
diff --git a/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp b/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp
deleted file mode 100644
index d21d6e0..0000000
--- a/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp
+++ /dev/null
@@ -1,289 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h"
-
-#include <cstdint>
-#include <iostream>
-#include <vector>
-
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/minireflect.h"
-#include "iree/compiler/IR/ConfigOps.h"
-#include "iree/compiler/IR/Interpreter/OpWriters.h"
-#include "iree/compiler/IR/Types.h"
-#include "iree/compiler/Serialization/VMFunctionBuilder.h"
-#include "iree/compiler/Serialization/VMFunctionTableBuilder.h"
-#include "iree/compiler/Serialization/VMModuleBuilder.h"
-#include "iree/compiler/Transforms/Interpreter/Passes.h"
-#include "iree/compiler/Transforms/Passes.h"
-#include "iree/compiler/Utils/Macros.h"
-#include "iree/compiler/Utils/OpUtils.h"
-#include "iree/compiler/Utils/TranslationUtils.h"
-#include "iree/schemas/executable_def_generated.h"
-#include "iree/schemas/module_def_generated.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/Debug.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Module.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Passes.h"
-#include "mlir/Translation.h"
-#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Builds a pass pipeline that optimizes and legalizes the module to the form
-// expected by translation.
-void buildLegalizeInputPassPipeline(PassManager *passManager) {
- // Standard passes that shake out a lot of garbage.
- // Some may have been run prior to translation but this ensures we are always
- // in a known state.
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createLoopFusionPass());
- passManager->addPass(createLoopInvariantCodeMotionPass());
- passManager->addPass(createMemRefDataFlowOptPass());
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createSimplifyAffineStructuresPass());
- passManager->addPass(createCSEPass());
- passManager->addPass(createCanonicalizerPass());
-
- // Eliminate ops we don't care about based on a lack of side-effects.
- // IREE does not guarantee exception/error behavior of dead ops.
- passManager->addPass(createAggressiveOpEliminationPass());
-
- // Expand uses of tuples into independent args/results.
- passManager->addPass(createConvertFromTupleCallingConventionPass());
- passManager->addPass(createCanonicalizerPass());
-}
-
-// Builds a pass pipeline that converts functions to the iree_hl_interp dialect.
-void buildInterpreterConversionPassPipeline(PassManager *passManager) {
- // We don't need the IREE binding ops anymore, as we match the calling
- // convention exactly (we're the same VM).
- passManager->addPass(createMakeExecutableABIPass());
-
- // Convert to the memref calling convention and optimize away as many
- // loads and stores as we can prior to progressing.
- passManager->addPass(createConvertToMemRefCallingConventionPass());
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createMemRefDataFlowOptPass());
-
- // Convert various dialects to IREE opcodes and cleanup leftover conversions.
- passManager->addPass(createLowerToInterpreterDialectPass());
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createAggressiveOpEliminationPass());
-
- // Widen reduction functions (that have iree.executable.reduction attrs) to
- // use their primitive IREE ops.
- passManager->addPass(createExpandReductionsToOpsPass());
-
- // Convert any uses of index to int32_t (as we explicitly don't want to
- // support dynamic index width).
- // This also looks for other weird types (i1, etc).
- passManager->addPass(createLegalizeTypeStoragePass());
-
- // Perform any last-minute optimizations to trim down the IR.
- passManager->addPass(createAggressiveOpEliminationPass());
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createLoopFusionPass());
- passManager->addPass(createLoopInvariantCodeMotionPass());
- passManager->addPass(createMemRefDataFlowOptPass());
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createCSEPass());
- passManager->addPass(createCanonicalizerPass());
-
- // Drop all functions that are not reachable.
- passManager->addPass(createDropUnreachableExecutableFunctionsPass());
-}
-
-// Builds a pass pipeline that lowers the iree_hl_interp dialect to the
-// iree_ll_interp dialect and prepares for serialization.
-void buildInterpreterLoweringPassPipeline(PassManager *passManager) {
- // Lower iree_hl_interp -> iree_ll_interp.
- passManager->addPass(createLowerInterpreterDialectPass());
-
- // Assign ordinals used by the bytecode to reference executables and
- // functions.
- passManager->addPass(createAssignFunctionOrdinalsPass());
-}
-
-class InterpreterTranslator {
- public:
- explicit InterpreterTranslator(ExecutableTranslationOptions options)
- : options_(options) {}
-
- const ExecutableTranslationOptions &options() const { return options_; }
-
- std::unique_ptr<iree::ExecutableDefT> translateExecutable(
- IREE::ExecutableOp executableOp);
-
- private:
- LogicalResult translateExecutableModule(IREE::ExecutableOp executableOp,
- ModuleOp moduleOp,
- VMModuleBuilder *moduleBuilder);
- LogicalResult declareFunction(FuncOp function,
- VMModuleBuilder *moduleBuilder);
- LogicalResult defineFunction(FuncOp function, VMModuleBuilder *moduleBuilder);
-
- ExecutableTranslationOptions options_;
-};
-
-std::unique_ptr<iree::ExecutableDefT>
-InterpreterTranslator::translateExecutable(IREE::ExecutableOp executableOp) {
- auto moduleOp = executableOp.getInnerModule();
-
- // Run all passes to go from input to the iree_ll_interp dialect.
- auto executableConversionPasses =
- createPassManager(moduleOp.getContext(), options());
- buildLegalizeInputPassPipeline(executableConversionPasses.get());
- buildInterpreterConversionPassPipeline(executableConversionPasses.get());
- buildInterpreterLoweringPassPipeline(executableConversionPasses.get());
- if (failed(runPassPipeline(options(), executableConversionPasses.get(),
- moduleOp))) {
- executableOp.emitError() << "Failed to run conversion passes";
- return {};
- }
-
- // Build the module bytecode.
- ::flatbuffers::FlatBufferBuilder fbb;
- VMModuleBuilder moduleBuilder(&fbb);
- if (failed(
- translateExecutableModule(executableOp, moduleOp, &moduleBuilder))) {
- executableOp.emitError() << "Failed to translate executable module";
- return {};
- }
- auto moduleDef = moduleBuilder.Finish();
- if (moduleDef.IsNull()) {
- moduleOp.emitError() << "Failed to verify completed module def";
- return {};
- }
- auto bytes = moduleBuilder.Serialize(moduleDef);
- if (bytes.empty()) {
- moduleOp.emitError() << "Failed to serialize final module def";
- return {};
- }
-
- OpBuilder builder(executableOp);
- executableOp.setAttr("format", builder.getI32IntegerAttr(static_cast<int32_t>(
- IREE::ExecutableFormat::IreeBytecode)));
-
- auto executableDef = std::make_unique<iree::ExecutableDefT>();
- executableDef->format =
- static_cast<uint32_t>(IREE::ExecutableFormat::IreeBytecode);
- executableDef->supported_features = iree::ExecutableFeature::kDebugging;
- executableDef->contents = std::move(bytes);
- return executableDef;
-}
-
-LogicalResult InterpreterTranslator::translateExecutableModule(
- IREE::ExecutableOp executableOp, ModuleOp moduleOp,
- VMModuleBuilder *moduleBuilder) {
- // Declare functions first so that we get stable indices during declaration
- // (as call ops need to use the function table).
- for (auto function : moduleOp.getOps<FuncOp>()) {
- RETURN_IF_FAILURE(declareFunction(function, moduleBuilder));
- }
-
- // Define functions now that all functions have been declared.
- for (auto function : moduleOp.getOps<FuncOp>()) {
- RETURN_IF_FAILURE(defineFunction(function, moduleBuilder));
- }
-
- return success();
-}
-
-LogicalResult InterpreterTranslator::declareFunction(
- FuncOp function, VMModuleBuilder *moduleBuilder) {
- auto *functionTable = moduleBuilder->function_table();
- if (functionTable->IsFunctionDeclared(function)) {
- // Already declared.
- return success();
- }
-
- LinkageType linkageType;
- if (function.isExternal()) {
- linkageType = LinkageType::kImport;
- } else if (function.getAttr("iree.executable.export")) {
- linkageType = LinkageType::kExport;
- } else {
- linkageType = LinkageType::kInternal;
- }
- if (failed(functionTable->DeclareFunction(function, linkageType))) {
- return function.emitError() << "Unable to declare function";
- }
-
- // Import functions must have their definition defined here so we get their
- // type. Internal and export functions will be defined during conversion.
- if (linkageType == LinkageType::kImport) {
- VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(),
- moduleBuilder->fbb());
- auto functionOffset = functionBuilder.Finish();
- if (functionOffset.IsNull()) {
- return function.emitError()
- << "Failed to create import function bytecode";
- }
- RETURN_IF_FAILURE(
- functionTable->DefineFunction(function, functionOffset, {}));
- }
-
- return success();
-}
-
-LogicalResult InterpreterTranslator::defineFunction(
- FuncOp function, VMModuleBuilder *moduleBuilder) {
- VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(),
- moduleBuilder->fbb());
- registerInterpreterCustomWriters(&functionBuilder);
- RETURN_IF_FAILURE(functionBuilder.ConvertBytecode());
- auto functionOffset = functionBuilder.Finish();
- if (functionOffset.IsNull()) {
- return function.emitError() << "Failed to serialize function";
- }
- RETURN_IF_FAILURE(moduleBuilder->function_table()->DefineFunction(
- function, functionOffset, functionBuilder.source_map()));
- return success();
-}
-
-} // namespace
-
-llvm::Optional<ExecutableTranslationResult>
-translateExecutableToInterpreterExecutable(
- ArrayRef<IREE::ExecutableOp> executableOps,
- ExecutableTranslationOptions options) {
- InterpreterTranslator translator(options);
- ExecutableTranslationResult translationResult;
- for (auto executableOp : llvm::make_early_inc_range(executableOps)) {
- auto executableDef = translator.translateExecutable(executableOp);
- if (!executableDef) {
- executableOp.emitError() << "Failed to translate one or more executables";
- return llvm::None;
- }
- translationResult.executable_defs.push_back(std::move(executableDef));
- }
- return translationResult;
-}
-
-static ExecutableTranslationRegistration
- InterpreterExecutableTranslationRegistration(
- "interpreter-bytecode", translateExecutableToInterpreterExecutable);
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h b/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h
deleted file mode 100644
index 41949c7..0000000
--- a/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_
-#define IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_
-
-#include <vector>
-
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Utils/TranslationUtils.h"
-#include "mlir/IR/Module.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Translates an MLIR module into a bytecode interpreter executable.
-// These executables are stored as IREE modules as defined in the
-// https://github.com/google/iree/tree/master/iree/schemas/module_def.fbs
-// schema.
-llvm::Optional<ExecutableTranslationResult>
-translateExecutableToInterpreterExecutable(
- ArrayRef<IREE::ExecutableOp> executableOps,
- ExecutableTranslationOptions options = {});
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_
diff --git a/iree/compiler/Translation/SPIRV/AffineExprCodegen.h b/iree/compiler/Translation/SPIRV/AffineExprCodegen.h
deleted file mode 100644
index b21182b..0000000
--- a/iree/compiler/Translation/SPIRV/AffineExprCodegen.h
+++ /dev/null
@@ -1,143 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- AffineExprCodegen.h -------------------------------------*- C++//-*-===//
-//
-// Code-generation for Affine Expression.
-//
-//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H
-#define IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H
-
-#include "iree/compiler/Translation/SPIRV/XLAIndexPropagation.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/IR/AffineExprVisitor.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-/// Codegenerator for affine expressions.
-class AffineExprCodegen : public AffineExprVisitor<AffineExprCodegen, Value *> {
- public:
- explicit AffineExprCodegen(spirv::ModuleOp module,
- IndexComputationCache &tensorIndices)
- : builder(module.getContext()),
- location(module.getLoc()),
- tensorIndices(tensorIndices) {}
-
- Value *visitAddExpr(AffineBinaryOpExpr expr) {
- auto operand1 = getValueInternal(expr.getLHS());
- auto operand2 = getValueInternal(expr.getRHS());
- return builder.create<spirv::IAddOp>(location, operand1, operand2);
- }
- Value *visitMulExpr(AffineBinaryOpExpr expr) {
- auto operand1 = getValueInternal(expr.getLHS());
- auto operand2 = getValueInternal(expr.getRHS());
- return builder.create<spirv::IMulOp>(location, operand1, operand2);
- }
- Value *visitModExpr(AffineBinaryOpExpr expr) {
- auto operand1 = getValueInternal(expr.getLHS());
- auto operand2 = getValueInternal(expr.getRHS());
- return builder.create<spirv::SModOp>(location, operand1, operand2);
- }
- Value *visitFloorDivExpr(AffineBinaryOpExpr expr) {
- auto operand1 = getValueInternal(expr.getLHS());
- auto operand2 = getValueInternal(expr.getRHS());
- return builder.create<spirv::SDivOp>(location, operand1, operand2);
- }
- Value *visitCeilDivExpr(AffineBinaryOpExpr expr) {
- // TODO(ravishankarm): Implement ceil div expr codegen.
- llvm_unreachable("Unimplemented affine AffineCeilDivExpr codegen");
- return nullptr;
- }
- Value *visitConstantExpr(AffineConstantExpr expr) {
- return builder.create<spirv::ConstantOp>(
- location, builder.getIntegerType(32),
- builder.getI32IntegerAttr(expr.getValue()));
- }
- Value *visitDimExpr(AffineDimExpr expr) {
- return threadDimToDstValue.lookup(expr.getPosition());
- }
- Value *visitSymbolExpr(AffineSymbolExpr expr) {
- // TODO(ravishankarm): Implement symbol expr codegen.
- llvm_unreachable("Unimplemented affine AffineSymbolExpr codegen");
- return nullptr;
- }
-
- /// Set the value that contains the workitem ID along a particular
- /// dimension. 0 -> x-dimension, 1 -> y-dimension, etc.
- void setDimDstValue(unsigned dimID, Value *value) {
- threadDimToDstValue[dimID] = value;
- }
-
- /// Generates the scalar value for a affine expression.
- Value *getValue(AffineExpr expr, OpBuilder::InsertPoint ip, Location loc) {
- auto &val = exprToDstValue[expr];
- if (!val) {
- location = loc;
- builder.restoreInsertionPoint(ip);
- val = visit(expr);
- }
- return val;
- }
-
- /// Returns a list of indices of a particular tensor in the source dialect
- /// needed within the dispatch function (obtained from the
- /// IndexComputationCache)
- SmallVector<AffineMap, 4> getIndices(Value *value) {
- SmallVector<AffineMap, 4> indices;
- for (auto &index : tensorIndices[value]) {
- indices.push_back(index.first);
- }
- return indices;
- }
-
- /// For a given tensor in the source dialect and index, return the index of
- /// all operands needed to compute the result.
- ArrayRef<AffineMap> getOperandIndices(Value *value, AffineMap index) {
- return tensorIndices[value][index];
- }
-
- private:
- /// Returns the Value corresponding to the AffineExpr `expr` by either
- /// previously generated value for the same index, or by generating the value.
- /// This version assumes the insertion point/Location has already been set.
- Value *getValueInternal(AffineExpr expr) {
- auto &val = exprToDstValue[expr];
- if (!val) {
- val = visit(expr);
- }
- return val;
- }
-
- OpBuilder builder;
-
- Location location;
-
- /// Map from launch dimension to scalar value.
- DenseMap<unsigned, Value *> threadDimToDstValue;
-
- /// Cache of affine expression to scalar value. TODO(ravishankarm) : Might
- /// need to be changed if we are handling control flow within the dispatch
- /// function.
- DenseMap<AffineExpr, Value *> exprToDstValue;
-
- /// Map from tensor value in source dialect to list of indices of the tensor
- /// needed within a workitem to compute the results of the dispatch function.
- IndexComputationCache &tensorIndices;
-};
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H
diff --git a/iree/compiler/Translation/SPIRV/BUILD b/iree/compiler/Translation/SPIRV/BUILD
deleted file mode 100644
index b0ad00b..0000000
--- a/iree/compiler/Translation/SPIRV/BUILD
+++ /dev/null
@@ -1,50 +0,0 @@
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "SPIRV",
- srcs = [
- "AffineExprCodegen.h",
- "EmbeddedKernels.cpp",
- "IREEIndexComputation.cpp",
- "IREEToSPIRV.cpp",
- "IREEToSPIRVPass.cpp",
- "IndexComputation.cpp",
- "SPIRVExecutableTranslation.cpp",
- "SPIRVLowering.cpp",
- "SPIRVLowering.h",
- "XLAIndexPropagation.cpp",
- ],
- hdrs = [
- "EmbeddedKernels.h",
- "IREEIndexComputation.h",
- "IREEToSPIRV.h",
- "IREEToSPIRVPass.h",
- "IndexComputation.h",
- "SPIRVExecutableTranslation.h",
- "XLAIndexPropagation.h",
- ],
- deps = [
- "//iree/compiler/IR",
- "//iree/compiler/Translation/SPIRV/Kernels",
- "//iree/compiler/Utils",
- "//iree/schemas",
- "//iree/schemas:spirv_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:Pass",
- "@local_config_mlir//:SPIRVDialect",
- "@local_config_mlir//:SPIRVDialectRegistration",
- "@local_config_mlir//:SPIRVSerialization",
- "@local_config_mlir//:StandardDialectRegistration",
- "@local_config_mlir//:StandardOps",
- "@local_config_mlir//:Support",
- "@local_config_mlir//:Transforms",
- "@local_config_mlir//:Translation",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
- ],
- alwayslink = 1,
-)
diff --git a/iree/compiler/Translation/SPIRV/EmbeddedKernels.cpp b/iree/compiler/Translation/SPIRV/EmbeddedKernels.cpp
deleted file mode 100644
index d50f3cf..0000000
--- a/iree/compiler/Translation/SPIRV/EmbeddedKernels.cpp
+++ /dev/null
@@ -1,219 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Translation/SPIRV/EmbeddedKernels.h"
-
-#include "iree/compiler/Translation/SPIRV/Kernels/Kernels.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Module.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Reads the SPIR-V code for the embedded kernel with the given file name.
-// If the kernel under Kernels/ is 'matmul.comp' then |kernelName| would be
-// 'matmul.spv' (because it's been compiled).
-std::vector<uint32_t> readEmbeddedKernelCode(std::string kernelName) {
- auto *fileToc = spirv_kernels::Kernels_create();
- for (int i = 0; i < spirv_kernels::Kernels_size(); ++i) {
- if (std::strcmp(fileToc[i].name, kernelName.c_str()) == 0) {
- std::vector<uint32_t> code;
- code.resize(fileToc[i].size / 4);
- std::memcpy(code.data(), fileToc[i].data, fileToc[i].size);
- return code;
- }
- }
- return {};
-}
-
-// Adds a storage buffer binding to the descriptor set layout.
-void addDescriptorSetLayoutBinding(uint32_t binding,
- iree::VkDescriptorSetLayoutDefT *dsl) {
- auto bindingDef = std::make_unique<iree::VkDescriptorSetLayoutBindingDefT>();
- bindingDef->binding = binding;
- bindingDef->descriptor_count = 1;
- bindingDef->descriptor_type = 7; // VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
- bindingDef->stage_flags = 0x00000020; // VK_SHADER_STAGE_COMPUTE_BIT
- dsl->bindings.push_back(std::move(bindingDef));
-}
-
-// Adds a specialization map entry for |constant_id| set to a 4-byte int value.
-void addSpecializationMapEntry(
- uint32_t constant_id, uint32_t value,
- iree::VkSpecializationInfoDefT *specializationInfoDef) {
- auto specValue = std::make_unique<iree::VkSpecializationMapEntryDefT>();
- specValue->constant_id = constant_id;
- specValue->uint32_value = value;
- specializationInfoDef->map_entries.push_back(std::move(specValue));
-}
-
-LogicalResult buildReductionExecutable(IREE::ExecutableOp executableOp,
- FuncOp entryFuncOp,
- iree::SpirVExecutableDefT *out_def) {
- auto funcType = entryFuncOp.getType();
- auto arg0 = funcType.getInput(0).cast<ShapedType>();
- if (!arg0.getElementType().isF32()) {
- // When we do other types we'll need other shaders.
- return entryFuncOp.emitOpError()
- << "Only floating point reduction is implemented";
- }
-
- auto module = executableOp.getInnerModule();
- auto applyFuncAttr = entryFuncOp.getAttrOfType<SymbolRefAttr>(
- "iree.executable.reduction.apply");
- auto applyFuncOp = module.lookupSymbol(applyFuncAttr.getValue());
-
- // TODO(benvanik): specialize (template on shapes/types/etc).
- std::string kernelName = "reduce_untiled.spv";
- llvm::Optional<uint32_t> operationId;
- applyFuncOp->walk([&](Operation *op) {
- if (isa<xla_hlo::AddOp>(op)) {
- operationId = 0;
- } else if (isa<xla_hlo::MaxOp>(op)) {
- operationId = 1;
- } else if (isa<xla_hlo::MinOp>(op)) {
- operationId = 2;
- }
- });
- if (!operationId.hasValue()) {
- applyFuncOp->dump();
- return applyFuncOp->emitOpError() << "Unsupported reduction operator";
- }
-
- out_def->tag = "__reduce__";
- out_def->entry_points = {"main"};
-
- out_def->code = readEmbeddedKernelCode(kernelName);
-
- // arg0, arg1, ret0
- auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>();
- pipelineLayoutDef->buffer_binding_set = 0;
- auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>();
- addDescriptorSetLayoutBinding(0, dsl.get());
- addDescriptorSetLayoutBinding(1, dsl.get());
- addDescriptorSetLayoutBinding(2, dsl.get());
- pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl));
- out_def->pipeline_layout = std::move(pipelineLayoutDef);
-
- // See the shader source for documentation on the values of A/B/C/R.
- int64_t reductionDimension =
- entryFuncOp
- .getAttrOfType<IntegerAttr>("iree.executable.reduction.dimension")
- .getInt();
- uint32_t r = arg0.getDimSize(reductionDimension);
- uint32_t a = 1;
- for (int i = 0; i < reductionDimension; ++i) {
- a *= arg0.getDimSize(i);
- }
- uint32_t b = 1;
- for (int i = reductionDimension + 1; i < arg0.getRank(); ++i) {
- b *= arg0.getDimSize(i);
- }
- uint32_t c = b;
-
- auto specializationInfoDef =
- std::make_unique<iree::VkSpecializationInfoDefT>();
- addSpecializationMapEntry(/*kOperationId*/ 100, operationId.getValue(),
- specializationInfoDef.get());
- addSpecializationMapEntry(/*kA*/ 101, a, specializationInfoDef.get());
- addSpecializationMapEntry(/*kB*/ 102, b, specializationInfoDef.get());
- addSpecializationMapEntry(/*kC*/ 103, c, specializationInfoDef.get());
- addSpecializationMapEntry(/*kR*/ 104, r, specializationInfoDef.get());
- out_def->specialization_info = std::move(specializationInfoDef);
-
- return success();
-}
-
-// Builds a SPIR-V executable from a well-known matmul executable.
-// |out_def| will be populated with all required information for serialization.
-LogicalResult buildMatMulExecutable(IREE::ExecutableOp executableOp,
- FuncOp entryFuncOp, xla_hlo::DotOp dotOp,
- iree::SpirVExecutableDefT *out_def) {
- auto arg0 = dotOp.getOperand(0)->getType().cast<ShapedType>();
- auto arg1 = dotOp.getOperand(1)->getType().cast<ShapedType>();
-
- out_def->tag = "__matmul__";
- out_def->entry_points = {"main"};
-
- // TODO(benvanik): specialize (template on shapes/types/etc).
- out_def->code = readEmbeddedKernelCode("matmul.spv");
-
- // arg0, arg1, ret0
- auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>();
- pipelineLayoutDef->buffer_binding_set = 0;
- auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>();
- addDescriptorSetLayoutBinding(0, dsl.get());
- addDescriptorSetLayoutBinding(1, dsl.get());
- addDescriptorSetLayoutBinding(2, dsl.get());
- pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl));
- out_def->pipeline_layout = std::move(pipelineLayoutDef);
-
- // Shapes of [arg0, arg1, ret0].
- // arg0 = [b0, m, k]
- // arg1 = [b0, k, n]
- // ret0 = [b0, m, n]
- // Note that we handle both batched (rank 3) and unbatched (rank 2).
- uint32_t m = arg0.getRank() == 3 ? arg0.getDimSize(1) : arg0.getDimSize(0);
- uint32_t k = arg0.getRank() == 3 ? arg0.getDimSize(2) : arg0.getDimSize(1);
- uint32_t n = arg1.getRank() == 3 ? arg1.getDimSize(2) : arg1.getDimSize(1);
- auto specializationInfoDef =
- std::make_unique<iree::VkSpecializationInfoDefT>();
- addSpecializationMapEntry(/*kMatrixM*/ 100, m, specializationInfoDef.get());
- addSpecializationMapEntry(/*kMatrixK*/ 101, k, specializationInfoDef.get());
- addSpecializationMapEntry(/*kMatrixN*/ 102, n, specializationInfoDef.get());
- out_def->specialization_info = std::move(specializationInfoDef);
-
- return success();
-}
-
-} // namespace
-
-bool tryEmbeddedKernelRewrite(IREE::ExecutableOp executableOp,
- iree::SpirVExecutableDefT *out_def) {
- auto module = executableOp.getInnerModule();
- for (auto funcOp : module.getOps<FuncOp>()) {
- if (funcOp.getAttr("iree.executable.reduction")) {
- if (failed(buildReductionExecutable(executableOp, funcOp, out_def))) {
- executableOp.emitOpError() << "Failed to splat in the reduction kernel";
- return false;
- }
- return true;
- }
-
- for (auto &block : funcOp) {
- for (auto &op : block) {
- if (isa<xla_hlo::ConvOp>(&op)) {
- executableOp.emitOpError() << "Conv not yet implemented";
- return false;
- } else if (auto dotOp = dyn_cast_or_null<xla_hlo::DotOp>(&op)) {
- if (failed(buildMatMulExecutable(executableOp, funcOp, dotOp,
- out_def))) {
- executableOp.emitOpError()
- << "Failed to splat in the matmul kernel";
- return false;
- }
- return true;
- }
- }
- }
- }
- return false;
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/EmbeddedKernels.h b/iree/compiler/Translation/SPIRV/EmbeddedKernels.h
deleted file mode 100644
index 951a5ff..0000000
--- a/iree/compiler/Translation/SPIRV/EmbeddedKernels.h
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_
-#define IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_
-
-#include "flatbuffers/flatbuffers.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Tries to match the |executableOp| against an embedded kernel and if matched
-// will populate |out_def| with the kernel.
-// Returns true if the kernel matched and was populated.
-bool tryEmbeddedKernelRewrite(IREE::ExecutableOp executableOp,
- iree::SpirVExecutableDefT* out_def);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_
diff --git a/iree/compiler/Translation/SPIRV/IREEIndexComputation.cpp b/iree/compiler/Translation/SPIRV/IREEIndexComputation.cpp
deleted file mode 100644
index 54ad2f5..0000000
--- a/iree/compiler/Translation/SPIRV/IREEIndexComputation.cpp
+++ /dev/null
@@ -1,107 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- IREEIndexComputation.cpp --------------------------------*- C++//-*-===//
-//
-// Implementaiton of Index Propagation for IREE statements that are used in
-// dispatch functions.
-//
-//===----------------------------------------------------------------------===//
-#include "iree/compiler/Translation/SPIRV/IREEIndexComputation.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-//===----------------------------------------------------------------------===//
-// IREELoadInputOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult IREELoadIndexPropagation::propagateIndexMap(
- Operation *operation, IndexComputationCache &indexMap) const {
- auto loadOp = cast<IREE::LoadInputOp>(operation);
- auto result = operation->getResult(0);
- auto src = loadOp.src();
- auto resultType = result->getType().dyn_cast<RankedTensorType>();
- auto srcType = src->getType().dyn_cast<MemRefType>();
- if (!resultType || !srcType || resultType.getShape() != srcType.getShape()) {
- return loadOp.emitError(
- "mismatch in shape of the result tensor and source memref");
- }
- // Initialize the storage for the src.
- indexMap[src];
- for (auto &resultIndexMap : indexMap[operation->getResult(0)]) {
- indexMap[src][resultIndexMap.first];
- resultIndexMap.second.push_back(resultIndexMap.first);
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// IREEStoreOutputOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult IREEStoreIndexPropagation::propagateIndexMap(
- Operation *operation, IndexComputationCache &indexMap) const {
- auto storeOp = cast<IREE::StoreOutputOp>(operation);
- auto src = storeOp.src();
- auto srcType = src->getType().dyn_cast<ShapedType>();
- if (!srcType || !srcType.hasStaticShape()) {
- return storeOp.emitError(
- "can only handle store with src being tensor of static shape");
- }
-
- SmallVector<int64_t, 3> launchSize;
- if (failed(getLaunchSize(operation, launchSize))) {
- return failure();
- }
-
- // The launch dimensions are [x, y, z] co-ordinates. The reverse of this is
- // used to determine the location of the tensor element computed by a
- // workitem. The choice is failry arbitrary but is done to enable the common
- // case where consecutive workitems compute "logically" adjacent tensor
- // elements.
- Builder builder(storeOp.getContext());
- SmallVector<AffineExpr, 4> affineExprs;
- int64_t numElements = 1;
- for (size_t i = launchSize.size(); i > 0; --i) {
- // If launchSize along any dimension is 1, just use 0 for the index. This is
- // not just an optimization. If you have an output of type memref<f32> which
- // is lowered to !spv.ptr<!spv.struct<f32>, StorageBuffer> with launchSize
- // <1>, then spv.AccessChain requires the indices to be a constant.
- if (launchSize[i - 1] == 1) {
- affineExprs.push_back(builder.getAffineConstantExpr(0));
- } else {
- affineExprs.push_back(builder.getAffineDimExpr(i - 1));
- }
- numElements *= launchSize[i - 1];
- }
- auto launchMap = AffineMap::get(launchSize.size(), 0, affineExprs);
-
- // The stored tensor can be a reshape of the launch dimension. It still
- // retains the requirement that each workitem is computing a single element
- // of the stored tensor.
- AffineMap srcMap;
- SmallVector<int64_t, 3> revLaunchSize(reverse(launchSize));
- if (numElements != srcType.getNumElements() ||
- failed(getReshapeOperandMap(builder, launchMap, revLaunchSize,
- srcType.getShape(), srcMap))) {
- return storeOp.emitError(
- "unable to map from launch id to element to compute within a "
- "workitem");
- }
- indexMap[src][srcMap];
- return success();
-}
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/IREEIndexComputation.h b/iree/compiler/Translation/SPIRV/IREEIndexComputation.h
deleted file mode 100644
index 13d9216..0000000
--- a/iree/compiler/Translation/SPIRV/IREEIndexComputation.h
+++ /dev/null
@@ -1,92 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- IREEIndexComputation.h ----------------------------------*- C++//-*-===//
-//
-// Index Propagation for IREE statements that are used in dispatch functions.
-//
-//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_TRANSLATION_SPIRV_H
-#define IREE_COMPILER_TRANSLATION_SPIRV_H
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Translation/SPIRV/XLAIndexPropagation.h"
-#include "mlir/IR/Function.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-/// Gets the launch size associated with the dispatch function that this op is
-/// part of.
-inline LogicalResult getLaunchSize(Operation *op,
- SmallVectorImpl<int64_t> &launchSize) {
- auto funcOp = op->getParentOfType<FuncOp>();
- if (!funcOp || !funcOp.getAttr("iree.executable.export")) {
- return op->emitError(
- "expected operation to be in dispatch function to get launch size");
- }
- auto workloadAttr =
- funcOp.getAttrOfType<DenseElementsAttr>("iree.executable.workload");
- if (!workloadAttr) {
- op->emitError(
- "unable to find workload size, missing attribute "
- "iree.executable.workload in dispatch function");
- }
- launchSize.clear();
- for (auto value : workloadAttr.getValues<APInt>()) {
- launchSize.push_back(value.getSExtValue());
- }
- // Drop trailing ones.
- auto dropFrom = launchSize.size() - 1;
- while (dropFrom > 0 && launchSize[dropFrom] == 1) {
- --dropFrom;
- }
- if (dropFrom > 0) {
- launchSize.erase(std::next(launchSize.begin(), dropFrom + 1),
- launchSize.end());
- }
- return success();
-}
-
-/// Index propagation for iree.load_input operation. This operation is
-/// essentially a copy from a memref to a tensor. So just copy the index map to
-/// the memref operand from the result tensor.
-class IREELoadIndexPropagation final
- : public IndexPropagationOp<IREE::LoadInputOp> {
- public:
- using IndexPropagationOp<IREE::LoadInputOp>::IndexPropagationOp;
-
- LogicalResult propagateIndexMap(
- Operation *operation, IndexComputationCache &indexMap) const override;
-};
-
-/// Index propagation for iree.store_output operation. The launch size is
-/// assumed to match the shape of the tensor that is being stored. This
-/// operation acts as a seed for the index propogation. Each workitem is assumed
-/// to compute a single element of this tensor. The range of the index map is
-/// the reverse of the launch dimension.
-class IREEStoreIndexPropagation final
- : public IndexPropagationOp<IREE::StoreOutputOp> {
- public:
- using IndexPropagationOp<IREE::StoreOutputOp>::IndexPropagationOp;
-
- LogicalResult propagateIndexMap(
- Operation *operation, IndexComputationCache &indexMap) const override;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSLATION_SPIRV_H
diff --git a/iree/compiler/Translation/SPIRV/IREEToSPIRV.cpp b/iree/compiler/Translation/SPIRV/IREEToSPIRV.cpp
deleted file mode 100644
index 909ce62..0000000
--- a/iree/compiler/Translation/SPIRV/IREEToSPIRV.cpp
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- IREEToSPIRV.cpp -----------------------------------------*- C++//-*-===//
-//
-// Translation of IREE statements in dispatch functions to SPIR-V.
-//
-//===----------------------------------------------------------------------===//
-#include "iree/compiler/Translation/SPIRV/IREEToSPIRV.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-/// IREE::LoadInputOp is essentially a memcpy. Just update the `valueCache` with
-/// the value of the operand.
-LogicalResult IREELoadOpSPIRVLowering::lowerOperation(
- Operation *op, OpBuilder &builder, AffineMap index,
- ArrayRef<Value *> operands, ValueCache &valueCache) const {
- auto loadOp = cast<IREE::LoadInputOp>(op);
- auto result = loadOp.getResult();
- valueCache.setOperandDstValue(result, index, operands[0]);
- return success();
-}
-
-/// IREE::StoreOp needs to write to the spv.globalVariable created for the
-/// memref that holds the result of the dispatch function.
-LogicalResult IREEStoreOpSPIRVLowering::lowerOperation(
- Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
- ValueCache &valueCache,
- DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
- ArrayRef<spirv::GlobalVariableOp> outputBuffers) const {
- auto storeOp = cast<IREE::StoreOutputOp>(op);
- auto src = storeOp.src();
- auto indices = affineExprCodegen.getIndices(src);
- if (indices.size() != 1) {
- return storeOp.emitError(
- "expected to compute a single element of the tensor that is stored "
- "into the output memref");
- }
- auto var = inputBuffers.lookup(storeOp.dst());
- if (!var) {
- return storeOp.emitError(
- "unable to find spv.globalVariable that corresponds to the dst memref");
- }
- auto ptr = genPointerOffset(builder, storeOp.getLoc(), affineExprCodegen,
- indices[0], var);
- auto scalarValue = valueCache.getOperandDstValue(src, indices[0]);
- builder.create<spirv::StoreOp>(storeOp.getLoc(), ptr, scalarValue,
- /*memory_access = */ nullptr,
- /*alignment = */ nullptr);
- return success();
-}
-
-/// IREE::ReturnOp in dispatch functions lowered to SPIR-V should have no
-/// operands.
-LogicalResult IREEReturnOpSPIRVLowering::lowerOperation(
- Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
- ValueCache &valueCache,
- DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
- ArrayRef<spirv::GlobalVariableOp> outputBuffers) const {
- builder.create<spirv::ReturnOp>(op->getLoc());
- return success();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/IREEToSPIRV.h b/iree/compiler/Translation/SPIRV/IREEToSPIRV.h
deleted file mode 100644
index 461e3f3..0000000
--- a/iree/compiler/Translation/SPIRV/IREEToSPIRV.h
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- IREEToSPIRV.h -------------------------------------------*- C++//-*-===//
-//
-// Translation of IREE statements in dispatch functions to SPIR-V.
-//
-//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H
-#define IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Translation/SPIRV/SPIRVLowering.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-/// Translation of iree.load_input operation.
-class IREELoadOpSPIRVLowering final
- : public SPIRVOpLowering<IREE::LoadInputOp> {
- public:
- using SPIRVOpLowering<IREE::LoadInputOp>::SPIRVOpLowering;
-
- LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
- AffineMap index, ArrayRef<Value *> operands,
- ValueCache &valueCache) const override;
-};
-
-/// Translation of iree.return operation.
-class IREEReturnOpSPIRVLowering final : public SPIRVOpLowering<IREE::ReturnOp> {
- public:
- using SPIRVOpLowering<IREE::ReturnOp>::SPIRVOpLowering;
-
- LogicalResult lowerOperation(
- Operation *op, OpBuilder &builder, AffineExprCodegen &codegen,
- ValueCache &valueCache,
- DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
- ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override;
-};
-
-/// Translation of iree.store_output operation.
-class IREEStoreOpSPIRVLowering final
- : public SPIRVOpLowering<IREE::StoreOutputOp> {
- public:
- using SPIRVOpLowering<IREE::StoreOutputOp>::SPIRVOpLowering;
-
- LogicalResult lowerOperation(
- Operation *op, OpBuilder &builder, AffineExprCodegen &codegen,
- ValueCache &valueCache,
- DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
- ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H
diff --git a/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp b/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp
deleted file mode 100644
index 7f1d75c..0000000
--- a/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp
+++ /dev/null
@@ -1,196 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- IREEToSPIRVPass.cpp -------------------------------------*- C++//-*-===//
-//
-// Pass to translate iree executables for vulkan-spirv.
-//
-//===----------------------------------------------------------------------===//
-#include "iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h"
-
-#include "iree/compiler/Translation/SPIRV/IREEIndexComputation.h"
-#include "iree/compiler/Translation/SPIRV/IREEToSPIRV.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-class IREEToSPIRVPass : public ModulePass<IREEToSPIRVPass> {
- void runOnModule() override;
-};
-
-} // namespace
-
-void IREEToSPIRVPass::runOnModule() {
- auto module = getModule();
- OpBuilder builder(module.getBodyRegion());
-
- // Initialize the index computation.
- IndexPropagationList<IndexPropagationOp<ConstantOp>,
- // IREE-specific ops:
- IndexPropagationOp<IREE::ReturnOp>,
- IREELoadIndexPropagation, IREEStoreIndexPropagation,
- // Standard dialect unary elementwise ops:
- NoBroadcastPwOpIndexPropagation<SIToFPOp>,
- NoBroadcastPwOpIndexPropagation<SignExtendIOp>,
- // Standard dialect binary elementwise ops:
- NoBroadcastPwOpIndexPropagation<AddFOp>,
- NoBroadcastPwOpIndexPropagation<AddIOp>,
- NoBroadcastPwOpIndexPropagation<AndOp>,
- NoBroadcastPwOpIndexPropagation<CmpFOp>,
- NoBroadcastPwOpIndexPropagation<CmpIOp>,
- NoBroadcastPwOpIndexPropagation<DivFOp>,
- NoBroadcastPwOpIndexPropagation<DivISOp>,
- NoBroadcastPwOpIndexPropagation<DivIUOp>,
- NoBroadcastPwOpIndexPropagation<MulFOp>,
- NoBroadcastPwOpIndexPropagation<MulIOp>,
- NoBroadcastPwOpIndexPropagation<OrOp>,
- NoBroadcastPwOpIndexPropagation<RemFOp>,
- NoBroadcastPwOpIndexPropagation<RemISOp>,
- NoBroadcastPwOpIndexPropagation<RemIUOp>,
- NoBroadcastPwOpIndexPropagation<SubFOp>,
- NoBroadcastPwOpIndexPropagation<SubFOp>,
- NoBroadcastPwOpIndexPropagation<SubIOp>,
- NoBroadcastPwOpIndexPropagation<TruncateIOp>,
- NoBroadcastPwOpIndexPropagation<XOrOp>,
- NoBroadcastPwOpIndexPropagation<ZeroExtendIOp>,
- // XLA unary elementwise ops:
- NoBroadcastPwOpIndexPropagation<xla_hlo::AbsOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::CeilOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::ConvertOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::CosOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::ExpOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::FloorOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::LogOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::NegOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::RsqrtOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::SignOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::TanhOp>,
- // XLA binary elementwise ops:
- NoBroadcastPwOpIndexPropagation<xla_hlo::AddOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::AndOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::DivOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::MaxOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::MinOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::MulOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::SubOp>,
- // XLA other ops:
- // TODO(ravishankarm): conv, dot.
- // TODO(ravishankarm): gather.
- // TODO(ravishankarm): pad.
- // TODO(ravishankarm): slice.
- NoBroadcastPwOpIndexPropagation<xla_hlo::CopyOp>,
- ReshapeOpIndexPropagation<xla_hlo::ReshapeOp>,
- NoBroadcastPwOpIndexPropagation<xla_hlo::SelectOp>,
- XLABroadcastOpIndexPropagation,
- XLABroadcastInDimOpIndexPropagation,
- XLAReverseOpIndexPropagation,
- XLATransposeOpIndexPropagation>
- indexPropagation;
-
- // Initialize the spir-v codegenerator.
- SPIRVCodegen<
- ConstantOpSPIRVLowering,
- // IREE-specific ops:
- IREELoadOpSPIRVLowering, IREEReturnOpSPIRVLowering,
- IREEStoreOpSPIRVLowering,
- // Standard dialect unary elementwise ops:
- // Standard dialect binary elementwise ops:
- SPIRVPwOpLowering<AddFOp, spirv::FAddOp>,
- SPIRVPwOpLowering<DivFOp, spirv::FDivOp>,
- SPIRVPwOpLowering<MulFOp, spirv::FMulOp>,
- SPIRVPwOpLowering<SubFOp, spirv::FSubOp>,
- SPIRVPwOpLowering<AddIOp, spirv::IAddOp>,
- SPIRVPwOpLowering<DivISOp, spirv::SDivOp>,
- SPIRVPwOpLowering<MulIOp, spirv::IMulOp>,
- SPIRVPwOpLowering<SubIOp, spirv::ISubOp>,
- // XLA unary elementwise ops:
- SPIRVPwOpLowering<xla_hlo::AbsOp, spirv::GLSLSAbsOp, spirv::GLSLFAbsOp>,
- SPIRVPwOpLowering<xla_hlo::CeilOp, spirv::GLSLCeilOp>,
- // TODO(ravishankarm): xla_hlo::ConvertOp
- SPIRVPwOpLowering<xla_hlo::CosOp, spirv::GLSLCosOp>,
- SPIRVPwOpLowering<xla_hlo::ExpOp, spirv::GLSLExpOp>,
- SPIRVPwOpLowering<xla_hlo::FloorOp, spirv::GLSLFloorOp>,
- SPIRVPwOpLowering<xla_hlo::LogOp, spirv::GLSLLogOp>,
- SPIRVPwOpLowering<xla_hlo::NegOp, spirv::FNegateOp>,
- SPIRVPwOpLowering<xla_hlo::RsqrtOp, spirv::GLSLInverseSqrtOp>,
- SPIRVPwOpLowering<xla_hlo::SignOp, spirv::GLSLSSignOp,
- spirv::GLSLFSignOp>,
- SPIRVPwOpLowering<xla_hlo::TanhOp, spirv::GLSLTanhOp>,
- // XLA binary elementwise ops:
- SPIRVPwOpLowering<xla_hlo::AddOp, spirv::IAddOp, spirv::FAddOp>,
- SPIRVPwOpLowering<xla_hlo::AndOp, spirv::LogicalAndOp>,
- SPIRVPwOpLowering<xla_hlo::DivOp, spirv::FDivOp>,
- SPIRVPwOpLowering<xla_hlo::MaxOp, spirv::GLSLSMaxOp, spirv::GLSLFMaxOp>,
- SPIRVPwOpLowering<xla_hlo::MinOp, spirv::GLSLSMinOp, spirv::GLSLFMinOp>,
- SPIRVPwOpLowering<xla_hlo::MulOp, spirv::IMulOp, spirv::FMulOp>,
- SPIRVPwOpLowering<xla_hlo::SubOp, spirv::ISubOp, spirv::FSubOp>,
- // XLA other ops:
- CmpFOpSPIRVLowering,
- SPIRVPwOpLowering<xla_hlo::SelectOp, spirv::SelectOp>,
- SPIRVIndexOpLowering<xla_hlo::BroadcastOp>,
- SPIRVIndexOpLowering<xla_hlo::BroadcastInDimOp>,
- SPIRVIndexOpLowering<xla_hlo::CopyOp>,
- SPIRVIndexOpLowering<xla_hlo::ReshapeOp>,
- SPIRVIndexOpLowering<xla_hlo::ReverseOp>,
- SPIRVIndexOpLowering<xla_hlo::TransposeOp>>
- spirvCodegen;
-
- // Create a spirv.module Op.
- auto spvModule = builder.create<spirv::ModuleOp>(
- module.getLoc(),
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::AddressingModel::Logical)),
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::MemoryModel::GLSL450)));
- SmallVector<StringRef, 2> caps;
- caps.push_back(spirv::stringifyCapability(spirv::Capability::Shader));
- spvModule.setAttr("capabilities", builder.getStrArrayAttr(caps));
- SmallVector<StringRef, 2> exts;
- exts.push_back("SPV_KHR_storage_buffer_storage_class");
- spvModule.setAttr("extensions", builder.getStrArrayAttr(exts));
-
- for (auto funcOp : module.getOps<FuncOp>()) {
- // TODO(ravishankarm): FuncOps in executable that are not dispatch functions
- // are not lowered to SPIR-V. Fix this limitation.
- if (!funcOp.getAttr("iree.executable.export")) continue;
-
- IndexComputationCache indexMap;
- if (failed(indexPropagation.propagate(funcOp.getBody(), indexMap))) {
- return signalPassFailure();
- }
- // dumpIndexCache(indexMap);
-
- ValueCache valueCache;
- AffineExprCodegen affineExprCodegen(spvModule, indexMap);
- if (failed(spirvCodegen.codegen(spvModule, funcOp, affineExprCodegen,
- valueCache))) {
- return signalPassFailure();
- }
- }
-}
-
-std::unique_ptr<OpPassBase<ModuleOp>> createIREEToSPIRVPass() {
- return std::make_unique<IREEToSPIRVPass>();
-}
-static PassRegistration<IREEToSPIRVPass> pass(
- "convert-iree-to-spirv",
- "Convert IREE dispatch functions to SPIR-V dialect");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation.cpp b/iree/compiler/Translation/SPIRV/IndexComputation.cpp
deleted file mode 100644
index fdb8464..0000000
--- a/iree/compiler/Translation/SPIRV/IndexComputation.cpp
+++ /dev/null
@@ -1,269 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- IndexComputation.cpp ------------------------------------*- C++//-*-===//
-//
-// For an IREE dispatch function, compute the map from workitem ID to index of
-// tensor computed within that workitem.
-//
-//===----------------------------------------------------------------------===//
-#include "iree/compiler/Translation/SPIRV/IndexComputation.h"
-
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/raw_ostream.h"
-
-static llvm::cl::opt<bool> doAffineExprSimplify(
- "simplify-spirv-affine-exprs",
- llvm::cl::desc("Simplify affine expressions during code-generation."),
- llvm::cl::init(true));
-
-namespace mlir {
-namespace iree_compiler {
-
-//===----------------------------------------------------------------------===//
-// Reshape Utility Functions
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Handles shapes for scalars. Shape of scalars are represented as empty vetor,
-/// i.e. {}. Its easier to do index propogation to handle the scalar as vector
-/// of size 1.
-inline SmallVector<int64_t, 4> handleIfScalar(ArrayRef<int64_t> shape) {
- SmallVector<int64_t, 4> resultShape;
- if (shape.empty()) {
- return {1};
- }
- return SmallVector<int64_t, 4>(shape.begin(), shape.end());
-}
-
-/// Reshapes are often used to either add a dimension of size 1 or remove a
-/// dimension of size 1. Recognizing such cases can make the code-generation
-/// easier. The AffineMap needs to either add a constant 0 in the range for such
-/// added dimensions or drop those dimensions.
-inline LogicalResult getAffineExprForAddOrRemoveDimension(
- Builder &builder, ArrayRef<AffineExpr> resultExprs,
- ArrayRef<int64_t> resultShape, ArrayRef<int64_t> operandShape,
- SmallVectorImpl<AffineExpr> &operandExprs) {
- auto resultIndex = resultShape.size();
- auto operandIndex = operandShape.size();
- operandExprs.resize(operandShape.size());
- // Try to match up the dimensions of the operand and result by ignoring any
- // dimensions of size of 1 that are introduced.
- while (resultIndex > 0 && operandIndex > 0) {
- if (resultShape[resultIndex - 1] == -1 ||
- operandShape[operandIndex - 1] == -1) {
- return failure();
- }
- if (resultShape[resultIndex - 1] == operandShape[operandIndex - 1]) {
- operandExprs[operandIndex - 1] = resultExprs[resultIndex - 1];
- resultIndex--;
- operandIndex--;
- continue;
- }
- if (resultShape[resultIndex - 1] == 1) {
- // This is a dimension that is added on the operand. This affine
- // expression corresponding to this dimension is dropped.
- resultIndex--;
- continue;
- }
- if (operandShape[operandIndex - 1] == 1) {
- // This is a dimension of size 1 of the operand that is dropped. Add a
- // constant expr 0.
- operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0);
- operandIndex--;
- continue;
- }
- return failure();
- }
- // Any remaining dimensions should be 1.
- while (resultIndex > 0) {
- if (resultShape[resultIndex - 1] != 1) {
- return failure();
- }
- resultIndex--;
- }
- while (operandIndex > 0) {
- if (operandShape[operandIndex - 1] != 1) {
- return failure();
- }
- // This is a dimension of size 1 that is dropped. Add a constant expression
- // 0.
- operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0);
- operandIndex--;
- }
- return success();
-}
-
-/// Constructs the strides of an array assuming a row-major packed layout.
-// TODO(ravishankarm): This assumes the shape are static. When using dynamic
-// shapes, parameters of each dimension can be used to construct AffineExpr for
-// strides along each dimension. Note that multiplying two symbolic constants is
-// technically not affine, but you could use another symbol to represent the
-// product, so it should be still representable as affine exprs.
-inline LogicalResult getRowMajorPackedStrides(
- Builder &builder, ArrayRef<int64_t> shape,
- SmallVectorImpl<AffineExpr> &strides) {
- strides.resize(shape.size());
- int64_t stride = 1;
- for (auto dim : enumerate(reverse(shape))) {
- if (dim.value() < 0) {
- // TODO(ravishankarm) : Better error message.
- return failure();
- }
- strides[shape.size() - 1 - dim.index()] =
- builder.getAffineConstantExpr(stride);
- stride *= dim.value();
- }
- return success();
-}
-
-/// Linearizes the index of the result position accessed using the shape of the
-/// result tensor and delinearizes it to get the position of the operand.
-inline LogicalResult getAffineExprForReshape(
- Builder &builder, unsigned numDims, unsigned numSymbols,
- ArrayRef<AffineExpr> resultExprs, ArrayRef<int64_t> resultShape,
- ArrayRef<int64_t> operandShape, SmallVectorImpl<AffineExpr> &operandExprs) {
- // To linearize the index, assume that the memory is laid out in
- // packed-row-major layout based on the shape.
- // TODO(ravishankarm) : When there is stride information, use that to map from
- // index to memory location.
- SmallVector<AffineExpr, 4> resultStrides;
- if (failed(getRowMajorPackedStrides(builder, resultShape, resultStrides))) {
- return failure();
- }
- AffineExpr linearizedExpr;
- for (auto index : enumerate(resultExprs)) {
- auto val = getAffineBinaryOpExpr(AffineExprKind::Mul, index.value(),
- resultStrides[index.index()]);
- if (doAffineExprSimplify) {
- val = simplifyAffineExpr(val, numDims, numSymbols);
- }
- linearizedExpr = (index.index() ? getAffineBinaryOpExpr(AffineExprKind::Add,
- linearizedExpr, val)
- : val);
- if (doAffineExprSimplify) {
- linearizedExpr = simplifyAffineExpr(val, numDims, numSymbols);
- }
- }
-
- // Unlinearize the index, assuming row-major-packed layout.
- // TODO(ravishankarm) : When there is stride information, use that to map from
- // memory location to index.
- SmallVector<AffineExpr, 4> operandStrides;
- if (failed(getRowMajorPackedStrides(builder, operandShape, operandStrides))) {
- return failure();
- }
- operandExprs.resize(operandStrides.size());
- for (auto stride : enumerate(operandStrides)) {
- if (stride.index() == operandStrides.size() - 1) {
- operandExprs[stride.index()] = linearizedExpr;
- break;
- }
- auto expr = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, linearizedExpr,
- stride.value());
- operandExprs[stride.index()] =
- (doAffineExprSimplify ? simplifyAffineExpr(expr, numDims, numSymbols)
- : expr);
-
- linearizedExpr = getAffineBinaryOpExpr(AffineExprKind::Mod, linearizedExpr,
- stride.value());
- if (doAffineExprSimplify) {
- linearizedExpr = simplifyAffineExpr(linearizedExpr, numDims, numSymbols);
- }
- }
- return success();
-}
-} // namespace
-
-LogicalResult getReshapeOperandMap(Builder &builder, AffineMap resultIndexMap,
- ArrayRef<int64_t> resultShapeRef,
- ArrayRef<int64_t> operandShapeRef,
- AffineMap &operandIndexMap) {
- auto resultShape = handleIfScalar(resultShapeRef);
- auto operandShape = handleIfScalar(operandShapeRef);
- auto resultExprs = resultIndexMap.getResults();
- assert(resultShape.size() == resultExprs.size() &&
- "Ranks of the Domain of index map and result must be the same");
- SmallVector<AffineExpr, 4> operandExprs;
- if (failed(getAffineExprForAddOrRemoveDimension(
- builder, resultExprs, resultShape, operandShape, operandExprs)) &&
- failed(getAffineExprForReshape(
- builder, resultIndexMap.getNumDims(), resultIndexMap.getNumSymbols(),
- resultExprs, resultShape, operandShape, operandExprs))) {
- return failure();
- }
- assert(operandExprs.size() == operandShape.size() &&
- "expected as many exprs for the operand as the rank of the operand");
- operandIndexMap =
- AffineMap::get(resultIndexMap.getNumDims(),
- resultIndexMap.getNumSymbols(), operandExprs);
-
- return success();
-}
-
-LogicalResult IndexPropagation::propagateIndexMap(
- Operation *op, IndexComputationCache &indexMap) const {
- if (op->getNumResults() == 0) {
- // Nothing to do for this op.
- return success();
- }
- if (op->getNumResults() != 1) {
- return op->emitError(
- "default index propagation handles case with a single-return value");
- }
- // Initialize the storage for all the operands.
- for (auto arg : op->getOperands()) {
- indexMap[arg];
- }
- for (auto &resultIndexMap : indexMap[op->getResult(0)]) {
- SmallVector<AffineMap, 4> operandIndices;
- if (failed(this->propagateIndexMap(op, resultIndexMap.first,
- operandIndices))) {
- return failure();
- }
- assert(operandIndices.size() == op->getNumOperands() &&
- "Expected as many indices as operands");
- for (auto arg : enumerate(op->getOperands())) {
- indexMap[arg.value()][operandIndices[arg.index()]];
- resultIndexMap.second.push_back(operandIndices[arg.index()]);
- }
- }
- return success();
-}
-
-void dumpIndexCache(IndexComputationCache &indexMap) {
- for (auto &el : indexMap) {
- // llvm::errs() << "Value : " << *(el.first);
- // llvm::errs().flush();
- if (isa<OpResult>(el.first)) {
- llvm::errs() << "Operation : " << el.first->getDefiningOp()->getName();
- } else if (isa<BlockArgument>(el.first)) {
- llvm::errs() << "BlockArgument";
- }
- for (auto &used : el.second) {
- llvm::errs() << "\n\t" << used.first << " : [";
- std::string sep = "";
- for (auto &operand : used.second) {
- llvm::errs() << sep << operand;
- sep = ", ";
- }
- llvm::errs() << "]";
- }
- llvm::errs() << "\n";
- }
- llvm::errs() << "\n";
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl b/iree/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl
deleted file mode 100644
index 49f03aa..0000000
--- a/iree/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl
+++ /dev/null
@@ -1,32 +0,0 @@
-"""Utilities for handling hand-written SPIR-V files."""
-
-load("//iree:build_defs.bzl", "iree_glsl_vulkan")
-load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
-
-def spirv_kernel_cc_library(name, srcs):
- """Compiles GLSL files into SPIR-V binaries and embeds them in a cc_library.
-
- Args:
- name: cc_library name to depend on.
- srcs: a list of GLSL source files.
- """
- spv_files = []
- for src in srcs:
- spv_name = src.split(".")[-2]
- iree_glsl_vulkan(
- name = spv_name,
- srcs = [src],
- )
- spv_files.append(spv_name + ".spv")
- native.filegroup(
- name = name + "_files",
- srcs = spv_files,
- )
- cc_embed_data(
- name = name,
- srcs = spv_files,
- cc_file_output = name + ".cc",
- h_file_output = name + ".h",
- cpp_namespace = "mlir::iree_compiler::spirv_kernels",
- flatten = True,
- )
diff --git a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp b/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
deleted file mode 100644
index a014d26..0000000
--- a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
+++ /dev/null
@@ -1,314 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h"
-
-#include <cstdint>
-#include <iostream>
-#include <map>
-#include <vector>
-
-#include "flatbuffers/flatbuffers.h"
-#include "iree/compiler/Translation/SPIRV/EmbeddedKernels.h"
-#include "iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h"
-#include "iree/compiler/Utils/OpUtils.h"
-#include "iree/compiler/Utils/TranslationUtils.h"
-#include "iree/schemas/executable_def_generated.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/Dialect/SPIRV/Serialization.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Module.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Passes.h"
-#include "mlir/Translation.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-class SPIRVTranslator {
- public:
- explicit SPIRVTranslator(ExecutableTranslationOptions options)
- : options_(options) {}
-
- const ExecutableTranslationOptions &options() const { return options_; }
-
- // Returns a populated ExecutableDef or nullptr if translation is
- // unsuccessful.
- std::unique_ptr<iree::ExecutableDefT> translateExecutable(
- IREE::ExecutableOp executableOp);
-
- private:
- // Returns a list of entry point names matching the expected export ordinals.
- std::vector<std::string> populateEntryPointNames(
- IREE::ExecutableOp executableOp);
-
- // Translates the input module into the SPIR-V dialect and returns the
- // serialized code words or empty if translation failed.
- std::vector<uint32_t> translateAndSerializeShaderModule(
- IREE::ExecutableOp executableOp);
-
- // Returns a pipeline layout definition based on the bindings required.
- std::unique_ptr<iree::VkPipelineLayoutDefT> populatePipelineLayout(
- spirv::ModuleOp spirvModuleOp);
-
- ExecutableTranslationOptions options_;
-};
-
-std::unique_ptr<iree::ExecutableDefT> SPIRVTranslator::translateExecutable(
- IREE::ExecutableOp executableOp) {
- // Try first to match against an embedded kernel (such as matmul) and
- // otherwise fall back to generating the kernel.
- iree::SpirVExecutableDefT spirvExecutableDef;
- if (!tryEmbeddedKernelRewrite(executableOp, &spirvExecutableDef)) {
- // The sequencer and runtime use ordinals instead of names. We provide the
- // list of entry point names here that are then passed in
- // VkShaderModuleCreateInfo.
- spirvExecutableDef.entry_points = populateEntryPointNames(executableOp);
-
- // Translate the module and generate the SPIR-V code.
- // The module is expected to be modified and must contain the metadata
- // required to enable the following information needed for the
- // SpirVExecutableDef to be extracted.
- spirvExecutableDef.code = translateAndSerializeShaderModule(executableOp);
- if (spirvExecutableDef.code.empty()) {
- executableOp.emitError()
- << "Failed to translate and serialize SPIR-V executable";
- return {};
- }
-
- // Reflect against the entry thunk to identify the required pipeline
- // layout based on binding information. This is used by the runtime to
- // create the VkPipelineLayout.
- for (auto spirvModuleOp :
- executableOp.getBlock().getOps<spirv::ModuleOp>()) {
- spirvExecutableDef.pipeline_layout =
- populatePipelineLayout(spirvModuleOp);
- if (!spirvExecutableDef.pipeline_layout) {
- spirvModuleOp.emitError()
- << "Failed to generate pipeline for SPIR-V module";
- return {};
- }
- break;
- }
- }
-
- // Pack the executable definition and get the bytes with the proper header.
- // The header is used to verify the contents at runtime.
- ::flatbuffers::FlatBufferBuilder fbb;
- auto executableOffset =
- iree::SpirVExecutableDef::Pack(fbb, &spirvExecutableDef);
- iree::FinishSpirVExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
-
- OpBuilder builder(executableOp);
- executableOp.setAttr("format", builder.getI32IntegerAttr(static_cast<int32_t>(
- IREE::ExecutableFormat::SpirV)));
-
- auto executableDef = std::make_unique<iree::ExecutableDefT>();
- executableDef->format = static_cast<uint32_t>(IREE::ExecutableFormat::SpirV);
- executableDef->contents = std::move(bytes);
- return executableDef;
-}
-
-std::vector<std::string> SPIRVTranslator::populateEntryPointNames(
- IREE::ExecutableOp executableOp) {
- auto module = executableOp.getInnerModule();
- DenseMap<unsigned, StringRef> entryPoints;
- for (auto funcOp : module.getOps<FuncOp>()) {
- if (!funcOp.getAttr("iree.executable.export")) continue;
- auto ordinalAttr = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal");
- entryPoints[ordinalAttr.getInt()] = funcOp.getName();
- }
- std::vector<std::string> entryPointNames(entryPoints.size());
- for (auto &entry : entryPoints) {
- entryPointNames[entry.first] = entry.second.str();
- }
- return entryPointNames;
-}
-
-std::vector<uint32_t> SPIRVTranslator::translateAndSerializeShaderModule(
- IREE::ExecutableOp executableOp) {
- auto module = executableOp.getInnerModule();
-
- // We can use the workload hint to know what the expected dispatch workload
- // is. If we want to remap this to make more sense for the operations we are
- // performing we can do that here.
- //
- // Note that workloads are computed per entry point. There may be some
- // dimensions of the workload that are static (in which case workloadAttr will
- // have non-dynamic dims) and others that need to be taken from an argument
- // shape (in which case workloadRef is the argument ordinal to take dynamic
- // dimensions from).
- // TODO(benvanik): make it just an arg instead? iree.workload special op?
- // TODO(benvanik): instead of FuncOp have an iree.entry_point op with these.
- for (auto funcOp : module.getOps<FuncOp>()) {
- // TODO(ravishankarm): FuncOps in executable that are not dispatch functions
- // are not lowered to SPIR-V. Fix this limitation.
- if (!funcOp.getAttr("iree.executable.export")) continue;
- auto workloadAttr =
- funcOp.getAttrOfType<ElementsAttr>("iree.executable.workload");
- auto workloadRefAttr =
- funcOp.getAttrOfType<IntegerAttr>("iree.executable.workload_ref");
- std::array<int32_t, 3> staticWorkloadDims = {-1, -1, -1};
- if (workloadAttr) {
- for (unsigned i = 0; i < 3; ++i) {
- if (auto dimAttr =
- workloadAttr.getValue({i}).dyn_cast_or_null<IntegerAttr>()) {
- staticWorkloadDims[i] = dimAttr.getInt();
- }
- }
- }
- std::array<BlockArgument *, 3> dynamicWorkloadDimRefs;
- if (workloadRefAttr) {
- for (unsigned i = 0; i < 3; ++i) {
- if (staticWorkloadDims[i] == -1) {
- dynamicWorkloadDimRefs[i] =
- funcOp.getArgument(workloadRefAttr.getInt());
- }
- }
- }
-
- // Now staticWorkloadDims will have non-negative values for known dimensions
- // and any dim with -1 will need to be pulled from the corresponding shape
- // dimension of dynamicWorkloadDimRefs.
-
- // TODO(b/137868263): use this information to map from workgroup to
- // invocation and perform indexing.
- }
-
- // Lower module to spirv::ModuleOp.
- auto spirvGenPasses = createPassManager(module.getContext(), options());
- spirvGenPasses->addPass(xla_hlo::createLegalizeToStdPass());
- spirvGenPasses->addPass(createIREEToSPIRVPass());
- if (failed(runPassPipeline(options(), spirvGenPasses.get(), module))) {
- executableOp.emitError() << "Failed to generate spv.module";
- return {};
- }
-
- auto spvModules = module.getOps<spirv::ModuleOp>();
- if (std::distance(spvModules.begin(), spvModules.end()) != 1) {
- executableOp.emitError()
- << "Expected a single spv.module for an IREE executable op";
- return {};
- }
-
- // Serialize the spirv::ModuleOp into the binary that we will embed in the
- // final flatbuffer.
- std::vector<uint32_t> spvBinaries;
- for (auto spvModule : spvModules) {
- SmallVector<uint32_t, 256> spvBinary;
- if (failed(spirv::serialize(spvModule, spvBinary))) {
- executableOp.emitError() << "Failed to serialize spv.module";
- return {};
- }
- spvBinaries.insert(spvBinaries.end(), spvBinary.begin(), spvBinary.end());
-
- // Clone the module into executableOp directly.
- auto clonedModule = spvModule.clone();
- executableOp.getBlock().getOperations().insert(
- std::prev(executableOp.getBlock().getOperations().end()), clonedModule);
- }
- // Remove the original code.
- module.erase();
-
- return spvBinaries;
-}
-
-std::unique_ptr<iree::VkPipelineLayoutDefT>
-SPIRVTranslator::populatePipelineLayout(spirv::ModuleOp spirvModuleOp) {
- // NOTE: we currently make some assumptions about this based on the expected
- // ABI of the runtime. If we wanted to support more general shaders with more
- // complex I/O we'd need to find a better way to communicate this through the
- // VkPipelineLayoutDef.
- auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>();
- pipelineLayoutDef->buffer_binding_set = 0;
-
- // Build a set of descriptor_set -> binding -> variable.
- // This makes it easier to write out the descriptor in a logical order, even
- // though this is not strictly required.
- int64_t maxDescriptorSetOrdinal = -1;
- std::map<int32_t, std::map<int32_t, spirv::GlobalVariableOp>> descriptorSets;
- for (auto globalVar :
- spirvModuleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
- auto descriptorSetAttr =
- globalVar.getAttrOfType<IntegerAttr>("descriptor_set");
- auto bindingAttr = globalVar.getAttrOfType<IntegerAttr>("binding");
- if (!descriptorSetAttr || !bindingAttr) {
- // Not something the runtime cares about.
- continue;
- }
- maxDescriptorSetOrdinal =
- std::max(descriptorSetAttr.getInt(), maxDescriptorSetOrdinal);
- auto &descriptorSet = descriptorSets[descriptorSetAttr.getInt()];
- descriptorSet[bindingAttr.getInt()] = globalVar;
- }
-
- // Create the individual layout and binding defs.
- pipelineLayoutDef->descriptor_set_layouts.resize(maxDescriptorSetOrdinal + 1);
- for (auto &descriptorSetBindings : descriptorSets) {
- int32_t descriptorSet = descriptorSetBindings.first;
- auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>();
-
- for (auto &globalVarBinding : descriptorSetBindings.second) {
- auto binding = std::make_unique<iree::VkDescriptorSetLayoutBindingDefT>();
- binding->binding = globalVarBinding.first;
- binding->descriptor_count = 1;
- // TODO(benvanik): pull from type info.
- binding->descriptor_type = 7; // VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
- binding->stage_flags = 0x00000020; // VK_SHADER_STAGE_COMPUTE_BIT
- dsl->bindings.push_back(std::move(binding));
- }
-
- pipelineLayoutDef->descriptor_set_layouts[descriptorSet] = std::move(dsl);
- }
-
- return pipelineLayoutDef;
-}
-
-} // namespace
-
-llvm::Optional<ExecutableTranslationResult>
-translateExecutableToSPIRVExecutable(ArrayRef<IREE::ExecutableOp> executableOps,
- ExecutableTranslationOptions options) {
- SPIRVTranslator translator(options);
- ExecutableTranslationResult translationResult;
- for (auto executableOp : llvm::make_early_inc_range(executableOps)) {
- auto executableDef = translator.translateExecutable(executableOp);
- if (!executableDef) {
- executableOp.emitError() << "Failed to translate one or more executables";
- return llvm::None;
- }
- translationResult.executable_defs.push_back(std::move(executableDef));
- }
- return translationResult;
-}
-
-static ExecutableTranslationRegistration SPIRVExecutableTranslationRegistration(
- "vulkan-spirv", translateExecutableToSPIRVExecutable);
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h b/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h
deleted file mode 100644
index 68e3d5d..0000000
--- a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_
-#define IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_
-
-#include <vector>
-
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/Utils/TranslationUtils.h"
-#include "mlir/IR/Module.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Translates an MLIR module into a SPIR-V executable.
-// These executables are stored as FlatBuffers in the
-// https://github.com/google/iree/tree/master/iree/schemas/spirv_executable_def.fbs
-// schema.
-llvm::Optional<ExecutableTranslationResult>
-translateExecutableToSPIRVExecutable(ArrayRef<IREE::ExecutableOp> executableOps,
- ExecutableTranslationOptions options = {});
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_
diff --git a/iree/compiler/Translation/SPIRV/SPIRVLowering.cpp b/iree/compiler/Translation/SPIRV/SPIRVLowering.cpp
deleted file mode 100644
index a1f8e35..0000000
--- a/iree/compiler/Translation/SPIRV/SPIRVLowering.cpp
+++ /dev/null
@@ -1,131 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- SPIRVLowering.cpp ---------------------------------------*- C++//-*-===//
-//
-// SPIR-V Code-generation for XLA-HLO Ops within IREE Dispatch functions
-//
-//===----------------------------------------------------------------------===//
-#include "iree/compiler/Translation/SPIRV/SPIRVLowering.h"
-
-namespace mlir {
-namespace iree_compiler {
-//===----------------------------------------------------------------------===//
-// ConstantOp
-//===----------------------------------------------------------------------===//
-LogicalResult ConstantOpSPIRVLowering::lowerOperation(
- Operation *op, OpBuilder &builder, AffineMap index, ArrayRef<Value *>,
- ValueCache &valueCache) const {
- auto constOp = cast<ConstantOp>(op);
- auto attr = constOp.value().dyn_cast<DenseElementsAttr>();
- if (!attr || !attr.isSplat()) {
- return op->emitError(
- "unhandled constant lowering unless value is a splat dense element "
- "attribute");
- }
- auto resultType = constOp.getResult()->getType();
- Type resultElemType;
- if (resultType.isIntOrFloat()) {
- resultElemType = resultType;
- } else if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
- resultElemType = shapedType.getElementType();
- } else {
- return op->emitError("unhandled result type of constant : ") << resultType;
- }
- Attribute constVal = attr.getSplatValue();
- auto spirvConstOp =
- builder.create<spirv::ConstantOp>(op->getLoc(), resultElemType, constVal);
- valueCache.setOperandDstValue(constOp.getResult(), index,
- spirvConstOp.getResult());
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// CmpFOp
-//===----------------------------------------------------------------------===//
-LogicalResult CmpFOpSPIRVLowering::lowerOperation(
- Operation *op, OpBuilder &builder, AffineMap index,
- ArrayRef<Value *> operands, ValueCache &valueCache) const {
- if (operands.size() != 2) {
- return op->emitError("expected two operands in spir-v lowering of CmpFOp");
- }
- Operation *spirvOp = nullptr;
- auto opInfo = op->getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName());
- if (!opInfo) {
- return op->emitError("expected CmpFOp to contain ")
- << CmpFOp::getPredicateAttrName() << " attribute";
- }
- auto boolType = builder.getI1Type();
- auto predicateVal = static_cast<CmpFPredicate>(opInfo.getInt());
- switch (predicateVal) {
-#define DISPATCH(caseLabel, opName) \
- case caseLabel: \
- spirvOp = builder.create<opName>(op->getLoc(), boolType, operands[0], \
- operands[1]); \
- break;
-
- DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
- DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
- DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
- DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
- DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
- DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
- DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
- DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
- DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
- DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
- DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
- DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
-
-#undef DISPATCH
-
- default:
- return op->emitError("unhandled predicate attribute for SPIR-V lowering");
- }
- valueCache.setOperandDstValue(op->getResult(0), index, spirvOp->getResult(0));
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// ReturnOp
-//===----------------------------------------------------------------------===//
-LogicalResult ReturnOpSPIRVLowering::lowerOperation(
- Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
- ValueCache &valueCache,
- DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
- ArrayRef<spirv::GlobalVariableOp> outputBuffers) const {
- auto returnOp = cast<ReturnOp>(op);
- if (returnOp.getNumOperands() != 1) {
- return returnOp.emitError(
- "unhandled lowering of return statement with multiple returns");
- }
- auto returnTensor = returnOp.getOperand(0);
- auto indices = affineExprCodegen.getIndices(returnTensor);
- if (indices.size() != 1) {
- return returnOp.emitError(
- "expected to compute a single element of the return tensor");
- }
- assert(outputBuffers.size() == 1 && "Expected a single output buffer");
- auto var = outputBuffers[0];
- auto ptr = genPointerOffset(builder, returnOp.getLoc(), affineExprCodegen,
- indices[0], var);
- auto scalarVal = valueCache.getOperandDstValue(returnTensor, indices[0]);
- builder.create<spirv::StoreOp>(returnOp.getLoc(), ptr, scalarVal,
- /*memory_access = */ nullptr,
- /*alignment = */ nullptr);
- builder.create<spirv::ReturnOp>(returnOp.getLoc());
- return success();
-}
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/SPIRVLowering.h b/iree/compiler/Translation/SPIRV/SPIRVLowering.h
deleted file mode 100644
index f6de5d0..0000000
--- a/iree/compiler/Translation/SPIRV/SPIRVLowering.h
+++ /dev/null
@@ -1,591 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- SPIRVLowering.h -----------------------------------------*- C++//-*-===//
-//
-// SPIR-V Code-generation for tensor operations within IREE Dispatch functions
-//
-//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H
-#define IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H
-
-#include "iree/compiler/Translation/SPIRV/AffineExprCodegen.h"
-#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/Support/StringExtras.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-class ValueCache {
- public:
- Value *getOperandDstValue(Value *value, AffineMap index) {
- return convertedValueMap.lookup(value).lookup(index);
- }
-
- void setOperandDstValue(Value *value, AffineMap index, Value *scalar) {
- convertedValueMap[value][index] = scalar;
- }
-
- private:
- DenseMap<Value *, DenseMap<AffineMap, Value *>> convertedValueMap;
-};
-
-/// Base class for lowering tensor operations in the dispatch function to SPIR-V
-/// op.
-class SPIRVLowering {
- public:
- virtual ~SPIRVLowering() = default;
- virtual StringRef getOpName() = 0;
- /// This method (in the derived class) should generate the scalar operation
- /// corresponding the the tensor operation `op` to generate the value of the
- /// result tensor at a particular `index`. The scalar value of the operands
- /// needed to compute this value is passed in within `operands`. The methods
- /// have to insert the scalar result value of the generated operation into the
- /// `valueCache`.
- virtual LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
- AffineMap index,
- ArrayRef<Value *> operands,
- ValueCache &valueCache) const {
- return failure();
- }
-
- /// This method (in the derived class) should generate the scalar operations
- /// corresponding to the tensor operation `op`. This should be implemented
- /// when the `op` has no result value, typically store operations and return
- /// operations.
- virtual LogicalResult lowerOperation(
- Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
- ValueCache &valueCache,
- DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
- ArrayRef<spirv::GlobalVariableOp> outputBuffers) const {
- return failure();
- }
-};
-
-/// Base class that gets the opName for the operation.
-template <typename OpTy>
-class SPIRVOpLowering : public SPIRVLowering {
- public:
- using SPIRVLowering::SPIRVLowering;
- virtual ~SPIRVOpLowering<OpTy>() {}
- StringRef getOpName() override { return OpTy::getOperationName(); }
-};
-
-/// SPIR-V lowering for ConstantOp.
-class ConstantOpSPIRVLowering final : public SPIRVOpLowering<ConstantOp> {
- public:
- using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
- LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
- AffineMap index, ArrayRef<Value *> operands,
- ValueCache &valueCache) const override;
-};
-
-/// SPIR-V lowering for CmpFOp.
-class CmpFOpSPIRVLowering final : public SPIRVOpLowering<CmpFOp> {
- public:
- using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
-
- LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
- AffineMap index, ArrayRef<Value *> operands,
- ValueCache &valueCache) const override;
-};
-
-/// SPIR-V lowering for Min/Max operations.
-template <typename OpTy, typename CmpOpTy, typename CmpFOpTy>
-class CmpSelectOpSPIRVLowering final : public SPIRVOpLowering<OpTy> {
- public:
- using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
- LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
- AffineMap index, ArrayRef<Value *> operands,
- ValueCache &valueCache) const override {
- if (op->getNumOperands() != 2) {
- return op->emitError(
- "unhandled SPIR-V lowering for more than 2 operands");
- }
- assert(operands.size() == op->getNumOperands() &&
- "expected as many operands for the replacement as the original "
- "instruction");
- auto cmpSelectOp = cast<OpTy>(op);
- auto result = cmpSelectOp.getResult();
- auto resultTy = result->getType().template dyn_cast<ShapedType>();
- if (!resultTy) {
- return op->emitError(
- "unhandled lowering of operations that don't return a "
- "ShapedType");
- }
- auto elementTy = resultTy.getElementType();
- auto boolTy = builder.getI1Type();
- Operation *cmpOp = nullptr;
- if (elementTy.template isa<FloatType>()) {
- cmpOp = builder.create<CmpFOpTy>(op->getLoc(), boolTy, operands,
- ArrayRef<NamedAttribute>());
- } else {
- cmpOp = builder.create<CmpOpTy>(op->getLoc(), boolTy, operands,
- ArrayRef<NamedAttribute>());
- }
- auto selectOp = builder.create<spirv::SelectOp>(
- op->getLoc(), operands[0]->getType(), cmpOp->getResult(0), operands[0],
- operands[1]);
- valueCache.setOperandDstValue(op->getResult(0), index,
- selectOp.getResult());
- return success();
- }
-};
-
-/// This class is the general template used to emit scalar instruction
-/// corresponding for point-wise operations. Assumes that the original
-/// instruction has a single result value of type ShapedType.
-/// TODO(ravishankarm) : In XLA-HLO, the same operations is used for
-/// integer/float tensor operations. So allow this op to take an additional op
-/// type as a template parameter to handle such cases. Find a better way to do
-/// this.
-template <typename OpTy, typename ReplacementOpTy,
- typename FloatOpTy = ReplacementOpTy>
-class SPIRVPwOpLowering final : public SPIRVOpLowering<OpTy> {
- public:
- using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
-
- LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
- AffineMap index,
- ArrayRef<Value *> scalarOperands,
- ValueCache &valueCache) const override {
- // TODO(ravishankarm) : This check should really be a static_assert. See if
- // that can be changed.
- if (op->getNumOperands() == 0) {
- return op->emitError("expected op to have at least one operand");
- }
- auto pwOp = cast<OpTy>(op);
- auto result = pwOp.getResult();
- auto resultType = result->getType().template dyn_cast<ShapedType>();
- if (!resultType) {
- return op->emitError(
- "unhandled lowering of operations that don't return a "
- "ShapedType");
- }
- auto elementType = resultType.getElementType();
- Operation *scalarOp = nullptr;
- if (elementType.template isa<IntegerType>()) {
- scalarOp = builder
- .create<ReplacementOpTy>(op->getLoc(), elementType,
- scalarOperands,
- ArrayRef<NamedAttribute>())
- .getOperation();
- } else {
- scalarOp =
- builder
- .create<FloatOpTy>(op->getLoc(), elementType, scalarOperands,
- ArrayRef<NamedAttribute>())
- .getOperation();
- }
- if (!scalarOp) {
- return op->emitError("unable to lower operation");
- }
- valueCache.setOperandDstValue(pwOp.getResult(), index,
- scalarOp->getResult(0));
- return success();
- }
-};
-
-/// This class is the general template used to emit scalar instruction for index
-/// transformation instructions like transpose. Assumes a single result value
-/// and a single operand
-template <typename OpTy>
-class SPIRVIndexOpLowering final : public SPIRVOpLowering<OpTy> {
- public:
- using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
-
- LogicalResult lowerOperation(Operation *op, OpBuilder &builder,
- AffineMap index,
- ArrayRef<Value *> scalarOperands,
- ValueCache &valueCache) const override {
- if (op->getNumOperands() != 1) {
- return op->emitError(
- "unhandled lowering of index transformation operation with multiple "
- "operands");
- }
- auto indexOp = cast<OpTy>(op);
- valueCache.setOperandDstValue(indexOp.getResult(), index,
- scalarOperands[0]);
- return success();
- }
-};
-
-/// Ggenerates spv.AccessChain instruction to get the pointer value at a given
-/// location of a spv.globalVariable.
-inline Value *genPointerOffset(OpBuilder &builder, Location loc,
- AffineExprCodegen &affineExprCodegen,
- AffineMap indexMap,
- spirv::GlobalVariableOp &var) {
- auto basePtr = builder.create<spirv::AddressOfOp>(
- loc, var.type(), builder.getSymbolRefAttr(var.sym_name()));
- auto varPtrType = var.type().cast<spirv::PointerType>().getPointeeType();
- // The variable has to be a struct type with a single element.
- assert(varPtrType.isa<spirv::StructType>() &&
- "expected variable type to be a spv.ptr<spv.struct<...>>");
- auto varStructType = varPtrType.cast<spirv::StructType>();
- assert(varStructType.getNumElements() == 1 &&
- "expected variable type to be a spv.ptr of spv.struct with a single "
- "element");
- auto varType = varStructType.getElementType(0);
-
- SmallVector<Value *, 2> accessIndex;
- /// For scalar values, the index-map computed with already map to the 0-th
- /// element. For arrays, they map to the position accessed. So just for arrays
- /// we need to add an extra 0 to index into the struct.
- if (varType.isa<spirv::ArrayType>() ||
- varType.isa<spirv::RuntimeArrayType>()) {
- auto i32Type = builder.getIntegerType(32);
- auto zero = builder.create<spirv::ConstantOp>(loc, i32Type,
- builder.getI32IntegerAttr(0));
- accessIndex.push_back(zero);
- }
- for (auto indexExpr : indexMap.getResults()) {
- accessIndex.push_back(affineExprCodegen.getValue(
- indexExpr, builder.saveInsertionPoint(), loc));
- }
- return builder.create<spirv::AccessChainOp>(loc, basePtr, accessIndex);
-}
-
-/// Lower return statements during SPIR-V codegeneration.
-class ReturnOpSPIRVLowering : public SPIRVOpLowering<ReturnOp> {
- public:
- using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
-
- LogicalResult lowerOperation(
- Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
- ValueCache &valueCache,
- DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
- ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override;
-};
-
-/// Class to drive the SPIRV code-generation.
-template <typename... Ts>
-class SPIRVCodegen {
- using OpCodegenListT = llvm::StringMap<std::unique_ptr<SPIRVLowering>>;
-
- public:
- explicit SPIRVCodegen() { insert(); }
-
- LogicalResult codegen(spirv::ModuleOp &spirvModule, FuncOp &fn,
- AffineExprCodegen &affineExprCodegen,
- ValueCache &valueCache) {
- if (fn.getBlocks().size() != 1) {
- return emitError(
- fn.getLoc(),
- "unimplemeneted handling multiple blocks within a function");
- }
-
- OpBuilder builder(spirvModule.body());
- // Create the entry function and generate global invocation ID. Creates a
- // global variable for all inputs and output tensors.
- return createEntryFn(builder, fn, affineExprCodegen, valueCache);
- }
-
- private:
- /// Helper method to create the entry function. Creates global variables for
- /// all inputs and outputs. Inserts the spv.EntryPoint operations as well.
- LogicalResult createEntryFn(OpBuilder &builder, FuncOp &fn,
- AffineExprCodegen &affineExprCodegen,
- ValueCache &valueCache) {
- auto loc = fn.getLoc();
- // TODO(ravishankarm) : This should actually be part of the SPIR-V
- // conversion framework in MLIR core. Move it there.
- auto convertType = [&loc](Type t,
- spirv::PointerType &varType) -> LogicalResult {
- auto shapedType = t.dyn_cast<ShapedType>();
- if (!shapedType) {
- return emitError(loc, "expected ShapedType argument");
- }
- auto elementType = shapedType.getElementType();
- if (!elementType.isIntOrFloat()) {
- return emitError(loc, "unhandled element type ")
- << elementType << " while lowering to SPIR-V";
- }
- int64_t stride = elementType.getIntOrFloatBitWidth() / 8;
- for (auto dim : reverse(shapedType.getShape())) {
- if (dim <= 0) {
- return emitError(loc, "expected tensor dimensions to be non-zero");
- }
- elementType = spirv::ArrayType::get(
- elementType, dim,
- static_cast<spirv::ArrayType::LayoutInfo>(stride));
- stride *= dim;
- }
- // TODO(ravishankarm): Verify that the type of the variable passes
- // spirv-val.
- varType = spirv::PointerType::get(
- spirv::StructType::get(elementType,
- static_cast<spirv::StructType::LayoutInfo>(0)),
- spirv::StorageClass::StorageBuffer);
- return success();
- };
-
- // Convert functions arguments and return values to
- // spirv::GlobalVariables. All global variables are given a descriptor set
- // of 0 and binding is the argument number.
- auto fnType = fn.getType();
- auto descriptorSetAttrName = convertToSnakeCase(
- stringifyDecoration(spirv::Decoration::DescriptorSet));
- auto bindingAttrName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
- for (auto argType : enumerate(fnType.getInputs())) {
- spirv::PointerType varType;
- if (failed(convertType(argType.value(), varType))) {
- return failure();
- }
- auto varName =
- fn.getName().str() + "_arg_" + std::to_string(argType.index());
- auto var = builder.create<spirv::GlobalVariableOp>(
- loc, TypeAttr::get(varType), builder.getStringAttr(varName), nullptr);
- // Set descriptor_set to 0.
- var.setAttr(descriptorSetAttrName, builder.getI32IntegerAttr(0));
- // Set binding to argument number.
- var.setAttr(bindingAttrName, builder.getI32IntegerAttr(argType.index()));
-
- inputArgToVariable[fn.getArgument(argType.index())] = var;
- }
- for (auto resType : enumerate(fnType.getResults())) {
- spirv::PointerType varType;
- if (failed(convertType(resType.value(), varType))) {
- return failure();
- }
- auto varName =
- fn.getName().str() + "_res_" + std::to_string(resType.index());
- auto var = builder.create<spirv::GlobalVariableOp>(
- loc, TypeAttr::get(varType), builder.getStringAttr(varName), nullptr);
- // Set descriptor_set to 0.
- var.setAttr(descriptorSetAttrName, builder.getI32IntegerAttr(0));
- // Set binding to (result number + num arguments)
- var.setAttr(
- bindingAttrName,
- builder.getI32IntegerAttr(fnType.getNumInputs() + resType.index()));
-
- resultIndexToVariable.push_back(var);
- }
-
- auto entryFnType =
- builder.getFunctionType(ArrayRef<Type>(), ArrayRef<Type>());
- auto entryFn = builder.create<FuncOp>(loc, fn.getName(), entryFnType,
- ArrayRef<NamedAttribute>());
-
- // Start a scope to create an insertion guard to reset the builder once the
- // function is lowered.
- {
- OpBuilder::InsertionGuard funcInsertGuard(builder);
- builder.setInsertionPointToStart(entryFn.addEntryBlock());
-
- // Create the Global invocation ID.
- if (failed(createGlobalInvocationID(builder, fn.getLoc(),
- affineExprCodegen))) {
- return failure();
- }
-
- if (failed(lowerFunction(builder, fn, entryFn, affineExprCodegen,
- valueCache))) {
- return failure();
- }
- }
-
- // Create the entry point instructions for the entry function.
- if (failed(createEntryPoint(builder, loc, entryFn))) {
- return failure();
- }
- return success();
- }
-
- /// Creates the global variable for GlobalInvocationID, and gets the ID at x,
- /// y and z dimensions.
- LogicalResult createGlobalInvocationID(OpBuilder &builder, Location loc,
- AffineExprCodegen &affineExprCodegen) {
- auto moduleOp = builder.getInsertionBlock()
- ->getParentOp()
- ->getParentOfType<spirv::ModuleOp>();
- OpBuilder moduleBuilder(moduleOp.body());
- auto i32Type = builder.getIntegerType(32);
- auto idType = VectorType::get(3, i32Type);
- auto ptrIdType =
- spirv::PointerType::get(idType, spirv::StorageClass::Input);
- auto globalInvocationID = moduleBuilder.create<spirv::GlobalVariableOp>(
- loc, TypeAttr::get(ptrIdType),
- builder.getStringAttr("globalInvocationID"), nullptr);
- globalInvocationID.setAttr(
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)),
- builder.getStringAttr(
- spirv::stringifyBuiltIn(spirv::BuiltIn::GlobalInvocationId)));
- interface.push_back(
- builder.getSymbolRefAttr(globalInvocationID.sym_name()));
-
- auto globalInvocationIDPtr = builder.create<spirv::AddressOfOp>(
- loc, ptrIdType,
- builder.getSymbolRefAttr(globalInvocationID.getOperation()));
- auto id = builder.create<spirv::LoadOp>(loc, idType, globalInvocationIDPtr,
- nullptr, nullptr);
- auto id_x = builder.create<spirv::CompositeExtractOp>(
- loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(0)));
- auto id_y = builder.create<spirv::CompositeExtractOp>(
- loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(1)));
- auto id_z = builder.create<spirv::CompositeExtractOp>(
- loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(2)));
- affineExprCodegen.setDimDstValue(0, id_x);
- affineExprCodegen.setDimDstValue(1, id_y);
- affineExprCodegen.setDimDstValue(2, id_z);
- return success();
- }
-
- /// Method to load the values of globalVariables corresponding to the
- /// arguments of the dispatch function at all indices needed within the
- /// dispatch function.
- LogicalResult initArgValues(OpBuilder &builder, Location loc,
- AffineExprCodegen &affineExprCodegen,
- ValueCache &valueCache, Value *origArg) {
- for (auto indexMap : affineExprCodegen.getIndices(origArg)) {
- auto var = inputArgToVariable.lookup(origArg);
- if (!var) {
- return emitError(
- loc, "undefined SPIR-V global variable for tensor argument");
- }
- auto ptr =
- genPointerOffset(builder, loc, affineExprCodegen, indexMap, var);
- auto elementType =
- ptr->getType().template cast<spirv::PointerType>().getPointeeType();
- auto val = builder.create<spirv::LoadOp>(loc, elementType, ptr,
- /*memory_access =*/nullptr,
- /*alignment = */ nullptr);
- valueCache.setOperandDstValue(origArg, indexMap, val);
- }
- return success();
- }
-
- /// Adds the spv.EntryPointOp and records all the interface variables used in
- /// the entryFn.
- LogicalResult createEntryPoint(OpBuilder &builder, Location loc,
- FuncOp entryFn) {
- builder.create<spirv::EntryPointOp>(
- loc,
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
- builder.getSymbolRefAttr(entryFn), builder.getArrayAttr(interface));
- builder.create<spirv::ExecutionModeOp>(
- loc, builder.getSymbolRefAttr(entryFn),
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::ExecutionMode::LocalSize)),
- builder.getI32ArrayAttr({1, 1, 1}));
- interface.clear();
- return success();
- }
-
- /// Lowers the body of the function in the original dialect to SPIR-V dialect.
- LogicalResult lowerFunction(OpBuilder &builder, FuncOp fn, FuncOp entryFn,
- AffineExprCodegen &affineExprCodegen,
- ValueCache &valueCache) {
- for (auto arg : fn.getArguments()) {
- // Load values of the argument at all indices needed for computation
- // within the dispatch function.
- if (failed(initArgValues(builder, fn.getLoc(), affineExprCodegen,
- valueCache, arg))) {
- return failure();
- }
- }
-
- for (auto &block : fn) {
- for (auto &op : block) {
- // Lower individual operations.
- if (failed(
- lowerOperation(builder, affineExprCodegen, valueCache, &op))) {
- return failure();
- }
- }
- }
- return success();
- }
-
- /// Dispatches the lowering of tensor operation to SPIR-V scalar
- /// operation.
- LogicalResult lowerOperation(OpBuilder &builder,
- AffineExprCodegen &affineExprCodegen,
- ValueCache &valueCache, Operation *op) {
- auto opName = op->getName().getStringRef();
- if (!opCodegenList.count(opName)) {
- return op->emitError("unhandled codegen");
- }
- if (op->getNumResults() > 1) {
- return op->emitError("unhandled codegen for multiple result values");
- }
-
- // Zero return case.
- if (!op->getNumResults()) {
- return opCodegenList[opName]->lowerOperation(
- op, builder, affineExprCodegen, valueCache, inputArgToVariable,
- resultIndexToVariable);
- }
-
- // Single return case.
- auto resultTensor = op->getResult(0);
- auto indices = affineExprCodegen.getIndices(resultTensor);
- for (auto &index : indices) {
- auto operandIndices =
- affineExprCodegen.getOperandIndices(resultTensor, index);
- SmallVector<Value *, 2> scalarOperands;
- for (auto arg : llvm::enumerate(op->getOperands())) {
- auto scalarArg = valueCache.getOperandDstValue(
- arg.value(), operandIndices[arg.index()]);
- if (!scalarArg) {
- return op->emitError("argument ")
- << arg.index() << " has no scalar value";
- }
- scalarOperands.push_back(scalarArg);
- }
- if (failed(opCodegenList[opName]->lowerOperation(
- op, builder, index, scalarOperands, valueCache))) {
- return failure();
- }
- }
- return success();
- }
-
- void insert() {
- std::vector<std::unique_ptr<SPIRVLowering>> objs;
- using dummy = int[];
- (void)dummy{0, (objs.emplace_back(std::make_unique<Ts>()), 0)...};
- for (auto &elem : objs) {
- StringRef opName = elem->getOpName();
- opCodegenList.try_emplace(opName, std::move(elem));
- }
- }
-
- /// List of classes that implement the operation lowering from tensor
- /// operations to SPIR-V.
- OpCodegenListT opCodegenList;
-
- /// I/O interface for the entry function containing global variables that are
- /// used by the entire function call tree.
- SmallVector<Attribute, 4> interface;
-
- /// Mapping from argument of the dispatch function in tensor dialect to the
- /// corresponding spv.globalVariable.
- DenseMap<Value *, spirv::GlobalVariableOp> inputArgToVariable;
-
- /// List of spv.globalVariables created for tensors returned by the dispatch
- /// function in tensor dialects.
- SmallVector<spirv::GlobalVariableOp, 1> resultIndexToVariable;
-
- /// GlobalInvocationID variable.
- spirv::GlobalVariableOp globalInvocationID;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H
diff --git a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.cpp b/iree/compiler/Translation/SPIRV/XLAIndexPropagation.cpp
deleted file mode 100644
index 50cb8ec..0000000
--- a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.cpp
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- XLAIndexPropagation.cpp ---------------------------------*- C++//-*-===//
-//
-// For an IREE dispatch function in XLA-HLO dialect, compute the indices of all
-// tensors needed to produce the value of the result tensors at a particlar
-// index.
-//
-//===----------------------------------------------------------------------===//
-
-#include "iree/compiler/Translation/SPIRV/XLAIndexPropagation.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-//===----------------------------------------------------------------------===//
-// BroadcastInDimOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult XLABroadcastInDimOpIndexPropagation::propagateIndexMap(
- Operation *operation, AffineMap resultIndex,
- SmallVectorImpl<AffineMap> &indexMap) const {
- auto broadcastOp = cast<xla_hlo::BroadcastInDimOp>(operation);
- auto broadcastDim = broadcastOp.broadcast_dimensions();
-
- Builder builder(operation->getContext());
- if (!broadcastDim) {
- // This is a scalar. So all indices map to the same element.
- AffineMap scalarMap =
- AffineMap::get(resultIndex.getNumDims(), resultIndex.getNumSymbols(),
- builder.getAffineConstantExpr(0));
- indexMap.push_back(scalarMap);
- return success();
- }
-
- // Handle non-scalar cases.
- auto dimensions = broadcastDim->getValues<int64_t>();
- SmallVector<AffineExpr, 4> exprs;
- for (auto resultExpr : enumerate(resultIndex.getResults())) {
- if (llvm::any_of(dimensions, [&resultExpr](int64_t dim) {
- return dim == resultExpr.index();
- })) {
- exprs.push_back(resultExpr.value());
- }
- }
- auto operandMap = AffineMap::get(resultIndex.getNumDims(),
- resultIndex.getNumSymbols(), exprs);
- indexMap.push_back(operandMap);
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// BroadcastOp
-//===----------------------------------------------------------------------===//
-
-// For broadcast op, just drop the first N expressions of the resultIndex, where
-// N is the number of elements in broadcast_sizes attribute.
-LogicalResult XLABroadcastOpIndexPropagation::propagateIndexMap(
- Operation *operation, AffineMap resultIndex,
- SmallVectorImpl<AffineMap> &indexMap) const {
- auto broadcastOp = cast<xla_hlo::BroadcastOp>(operation);
- auto broadcastDim = broadcastOp.broadcast_sizes();
-
- SmallVector<AffineExpr, 4> exprs;
- for (auto i : llvm::seq<size_t>(
- broadcastDim.getType().getShape()[0],
- operation->getResult(0)->getType().cast<ShapedType>().getRank())) {
- exprs.push_back(resultIndex.getResult(i));
- }
-
- Builder builder(operation->getContext());
- if (exprs.empty()) {
- // The result is a scalar. Just add a constant expr 0.
- exprs.push_back(builder.getAffineConstantExpr(0));
- }
- auto operandMap = AffineMap::get(resultIndex.getNumDims(),
- resultIndex.getNumSymbols(), exprs);
- indexMap.push_back(operandMap);
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// ReverseOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult XLAReverseOpIndexPropagation::propagateIndexMap(
- Operation *op, AffineMap resultIndex,
- SmallVectorImpl<AffineMap> &indexMap) const {
- auto reverseOp = cast<xla_hlo::ReverseOp>(op);
- DenseSet<unsigned> dimensions;
- for (auto index : reverseOp.dimensions()) {
- dimensions.insert(index.getZExtValue());
- }
- return propagateIndexMapImpl(op, dimensions, resultIndex, indexMap);
-}
-
-//===----------------------------------------------------------------------===//
-// TransposeOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult XLATransposeOpIndexPropagation::propagateIndexMap(
- Operation *op, AffineMap resultIndex,
- SmallVectorImpl<AffineMap> &indexMap) const {
- auto transposeOp = cast<xla_hlo::TransposeOp>(op);
- // Compute the affine map that represents the permutation.
- SmallVector<unsigned, 4> permutation;
- for (auto index : transposeOp.permutation()) {
- permutation.push_back(index.getZExtValue());
- }
- return propagateIndexMapImpl(op, permutation, resultIndex, indexMap);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.h b/iree/compiler/Translation/SPIRV/XLAIndexPropagation.h
deleted file mode 100644
index a20b0f1..0000000
--- a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.h
+++ /dev/null
@@ -1,112 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- XLAIndexPropagation.h -----------------------------------*- C++//-*-===//
-//
-// For an IREE dispatch function in XLA-HLO dialect, compute the indices of all
-// tensors needed to produce the value of the result tensors at a particlar
-// index.
-//
-//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H
-#define IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H
-
-#include "iree/compiler/Translation/SPIRV/IndexComputation.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Function.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-class XLABroadcastInDimOpIndexPropagation final
- : public IndexPropagationOp<xla_hlo::BroadcastInDimOp> {
- public:
- using IndexPropagationOp<xla_hlo::BroadcastInDimOp>::IndexPropagationOp;
-
- LogicalResult propagateIndexMap(
- Operation *operation, AffineMap resultIndex,
- SmallVectorImpl<AffineMap> &operandIndices) const override;
-};
-
-// For broadcast op, just drop the first N expressions of the resultIndex, where
-// N is the number of elements in broadcast_sizes attribute.
-class XLABroadcastOpIndexPropagation final
- : public IndexPropagationOp<xla_hlo::BroadcastOp> {
- public:
- using IndexPropagationOp<xla_hlo::BroadcastOp>::IndexPropagationOp;
-
- LogicalResult propagateIndexMap(
- Operation *operation, AffineMap resultIndex,
- SmallVectorImpl<AffineMap> &operandIndices) const override;
-};
-
-/// For return ops, it is assumed that each thread is computing the value of one
-/// element of the returned tensor.
-template <typename OpTy>
-class ReturnOpIndexPropagation : public IndexPropagationOp<OpTy> {
- public:
- using IndexPropagationOp<OpTy>::IndexPropagationOp;
-
- LogicalResult propagateIndexMap(
- Operation *operation, IndexComputationCache &indexMap) const override {
- if (operation->getNumOperands() != 1) {
- return operation->emitError("unhandled multiple return values");
- }
- auto returnValue = operation->getOperand(0);
- auto returnType = returnValue->getType().cast<RankedTensorType>();
- auto returnRank = returnType.getRank();
- if (returnRank > 3) {
- return operation->emitError("unhandled return tensor of dimension ")
- << returnType.getShape().size();
- }
- // Have as many symbols as the rank of the input tensor. These symbols map
- // to GlobalInvocationID along the three dimensions.
- Builder builder(operation->getContext());
- SmallVector<AffineExpr, 4> affineExprs;
- for (size_t i = returnRank; i > 0; --i) {
- affineExprs.push_back(builder.getAffineDimExpr(i - 1));
- }
- indexMap[operation->getOperand(0)]
- [AffineMap::get(returnRank, 0, affineExprs)];
- return success();
- }
-};
-
-/// Index propogation for XLA Reverse.
-class XLAReverseOpIndexPropagation final
- : public ReverseOpIndexPropagation<xla_hlo::ReverseOp> {
- public:
- using ReverseOpIndexPropagation<
- xla_hlo::ReverseOp>::ReverseOpIndexPropagation;
- LogicalResult propagateIndexMap(
- Operation *op, AffineMap resultIndex,
- SmallVectorImpl<AffineMap> &indexMap) const override;
-};
-
-/// Index propogation for XLA Transpose.
-class XLATransposeOpIndexPropagation final
- : public TransposeOpIndexPropagation<xla_hlo::TransposeOp> {
- public:
- using TransposeOpIndexPropagation<
- xla_hlo::TransposeOp>::TransposeOpIndexPropagation;
- LogicalResult propagateIndexMap(
- Operation *op, AffineMap resultIndex,
- SmallVectorImpl<AffineMap> &indexMap) const override;
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H
diff --git a/iree/compiler/Translation/SPIRV/test/BUILD b/iree/compiler/Translation/SPIRV/test/BUILD
deleted file mode 100644
index 79700a8..0000000
--- a/iree/compiler/Translation/SPIRV/test/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-# Tests for common transforms.
-
-load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_setup_lit_package(
- data = [
- "//iree/tools:iree-opt",
- ],
-)
-
-iree_glob_lit_tests()
diff --git a/iree/compiler/Translation/Sequencer/BUILD b/iree/compiler/Translation/Sequencer/BUILD
deleted file mode 100644
index 4245b19..0000000
--- a/iree/compiler/Translation/Sequencer/BUILD
+++ /dev/null
@@ -1,33 +0,0 @@
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Sequencer",
- srcs = ["SequencerModuleTranslation.cpp"],
- hdrs = ["SequencerModuleTranslation.h"],
- deps = [
- "//iree/base:status",
- "//iree/compiler/IR",
- "//iree/compiler/IR/Sequencer",
- "//iree/compiler/Serialization",
- "//iree/compiler/Transforms",
- "//iree/compiler/Transforms/Sequencer",
- "//iree/compiler/Utils",
- "//iree/hal:executable_format",
- "//iree/schemas",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:Pass",
- "@local_config_mlir//:StandardDialectRegistration",
- "@local_config_mlir//:Support",
- "@local_config_mlir//:Transforms",
- "@local_config_mlir//:Translation",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
- ],
- alwayslink = 1,
-)
diff --git a/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp b/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp
deleted file mode 100644
index 7184f6b..0000000
--- a/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp
+++ /dev/null
@@ -1,505 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h"
-
-#include <cstdint>
-#include <iostream>
-#include <memory>
-#include <unordered_map>
-#include <vector>
-
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/minireflect.h"
-#include "iree/base/status.h"
-#include "iree/compiler/IR/ConfigOps.h"
-#include "iree/compiler/IR/Sequencer/OpWriters.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/compiler/IR/Types.h"
-#include "iree/compiler/Serialization/VMFunctionBuilder.h"
-#include "iree/compiler/Serialization/VMFunctionTableBuilder.h"
-#include "iree/compiler/Serialization/VMModuleBuilder.h"
-#include "iree/compiler/Transforms/Passes.h"
-#include "iree/compiler/Transforms/Sequencer/Passes.h"
-#include "iree/compiler/Utils/Macros.h"
-#include "iree/compiler/Utils/OpUtils.h"
-#include "iree/compiler/Utils/TranslationUtils.h"
-#include "iree/hal/executable_format.h"
-#include "iree/schemas/executable_def_generated.h"
-#include "iree/schemas/executable_table_def_generated.h"
-#include "iree/schemas/module_def_generated.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/StringExtras.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/StringSet.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/ToolOutputFile.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Module.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Transforms/Passes.h"
-#include "mlir/Translation.h"
-#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Builds a pass pipeline that optimizes and legalizes the module to the form
-// expected by partitioning.
-void buildLegalizeInputPassPipeline(PassManager *passManager) {
- // Convert to the subset of XLA HLO and Standard dialects supported as IREE
- // input. In particular, move from XLA HLO to standard control flow.
- passManager->addPass(xla_hlo::createLegalizeControlFlowPass());
- passManager->addPass(createLegalizeInputOpsPass());
-
- // Standard passes that shake out a lot of garbage.
- // Some may have been run prior to translation but this ensures we are always
- // in a known state.
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createLoopFusionPass());
- passManager->addPass(createLoopInvariantCodeMotionPass());
- passManager->addPass(createMemRefDataFlowOptPass());
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createSimplifyAffineStructuresPass());
- passManager->addPass(createCSEPass());
- passManager->addPass(createCanonicalizerPass());
-
- // Expand uses of tuples into independent args/results.
- passManager->addPass(createConvertFromTupleCallingConventionPass());
- passManager->addPass(createCanonicalizerPass());
-}
-
-// Builds a pass pipeline that partitions the module into sequencer functions
-// and executables ready to be translated.
-void buildPartitioningPassPipeline(PassManager *passManager) {
- // Find reduction ops and create iree.reduction_regions. We do this prior to
- // performing dispatch region identification so that we can build as big of
- // fused reduction regions as possible. The remaining ops will be put into
- // dispatch regions.
- passManager->addPass(createIdentifyReductionRegionsPass());
- passManager->addPass(createCSEPass());
-
- // Create all of the dispatch regions, CSE their workloads, and fold.
- passManager->addPass(createIdentifyDispatchRegionsPass());
- passManager->addPass(createCSEPass());
- passManager->addPass(createFoldCompatibleDispatchRegionsPass());
-
- // Note that as we are rematerializing things here it's critical we do not run
- // the canonicalizer/CSE between now and when we outline - otherwise it'll
- // undo all of our work!
- passManager->addPass(createRematerializeDispatchConstantsPass());
-
- // Outline the dispatch regions into their own functions. This separates the
- // sequencer functions performing dispatches from the dispatchees.
- passManager->addPass(createOutlineDispatchRegionsPass());
- passManager->addPass(createOutlineReductionRegionsPass());
-
- // Cleanup identity sequencer tensor-to-memref ops that clutter up the IR.
- passManager->addPass(createCanonicalizerPass());
-
- // Drop all functions that are no longer reachable.
- // This is important as many of the functions remaining are probably
- // dispatchable and unused now that we've outlined them executables.
- passManager->addPass(createDropUnreachableModuleFunctionsPass());
-
- // Drop all unused executables.
- // Note that we need to have dropped unreachable functions first otherwise
- // references could keep executables that are unreachable from exported
- // functions alive.
- passManager->addPass(createDropUnusedExecutablesPass());
-}
-
-// Builds a pass pipeline that converts sequencer functions to the iree_seq.hl
-// dialect.
-void buildSequencerConversionPassPipeline(PassManager *passManager) {
- passManager->addPass(createConvertToMemRefCallingConventionPass());
-
- // Convert ops that are supported by the sequencer directly to the sequencer
- // dialect. The ops that remain should be only those that can be moved into
- // dispatch regions.
- passManager->addPass(createLowerToSequencerDialectPass());
-
- // Cleanup identity sequencer tensor-to-memref ops and other memory accesses
- // that clutter up the IR.
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createMemRefDataFlowOptPass());
-
- // Eliminate ops we don't care about based on a lack of side-effects.
- // IREE does not guarantee exception/error behavior of dead ops.
- passManager->addPass(createAggressiveOpEliminationPass());
-
- // Perform any last-minute optimizations to trim down the IR.
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createMemRefDataFlowOptPass());
- passManager->addPass(createCSEPass());
-}
-
-// Builds a pass pipeline that lowers the iree_seq.hl dialect to the iree_seq.ll
-// dialect and prepares for serialization.
-void buildSequencerLoweringPassPipeline(PassManager *passManager) {
- // Lower iree_hl_seq -> iree_ll_seq.
- passManager->addPass(createLowerSequencerDialectPass());
- passManager->addPass(createCanonicalizerPass());
- passManager->addPass(createMemRefDataFlowOptPass());
- passManager->addPass(createAggressiveOpEliminationPass());
-
- // Assign ordinals used by the bytecode to reference executables and
- // functions.
- passManager->addPass(createAssignFunctionOrdinalsPass());
- passManager->addPass(createAssignExecutableOrdinalsPass());
-
- // Plumb workload information down into executable entry points. This allows
- // the backends to calculate their workgroup sizes, indexing, etc.
- passManager->addPass(createAssignExecutableWorkloadAttrsPass());
-}
-
-// Inserts one or more iree.executable_target_config ops based on the
-// translation options.
-void insertTargetConfigOps(const ModuleTranslationOptions &options,
- OpBuilder &builder) {
- llvm::StringSet<> targetBackends;
- if (options.target_backends.empty()) {
- // Add all backends when none are explicitly provided.
- targetBackends.insert(getExecutableTranslationRegistry().keys().begin(),
- getExecutableTranslationRegistry().keys().end());
- } else {
- for (auto &targetBackend : options.target_backends) {
- for (auto &matchedBackend :
- matchExecutableTranslationBackendNames(targetBackend)) {
- targetBackends.insert(matchedBackend);
- }
- }
- }
- for (auto &targetBackend : targetBackends) {
- builder.create<IREE::ExecutableTargetConfigOp>(builder.getUnknownLoc(),
- targetBackend.getKey());
- }
-}
-
-class SequencerTranslator {
- public:
- explicit SequencerTranslator(ModuleTranslationOptions options)
- : options_(options) {}
-
- const ModuleTranslationOptions &options() const { return options_; }
-
- std::vector<uint8_t> translateModule(ModuleOp module);
-
- private:
- LogicalResult translateMultiArchExecutable(
- IREE::MultiArchExecutableOp executableOp, VMModuleBuilder *moduleBuilder);
-
- LogicalResult translateSequencerModule(ModuleOp module,
- VMModuleBuilder *moduleBuilder);
- LogicalResult declareFunction(FuncOp function,
- VMModuleBuilder *moduleBuilder);
- LogicalResult defineFunction(FuncOp function, VMModuleBuilder *moduleBuilder);
-
- ModuleTranslationOptions options_;
-};
-
-std::vector<uint8_t> SequencerTranslator::translateModule(ModuleOp module) {
- // Run one large set of passes to get to a partitioned module.
- auto partitioningPasses = createPassManager(module.getContext(), options());
- buildLegalizeInputPassPipeline(partitioningPasses.get());
- buildPartitioningPassPipeline(partitioningPasses.get());
- if (failed(runPassPipeline(options(), partitioningPasses.get(), module))) {
- module.emitError() << "Failed to run partitioning passes";
- return {};
- }
-
- // Run the sequencer-specific conversion passes on the module.
- auto sequencerConversionPasses =
- createPassManager(module.getContext(), options());
- buildSequencerConversionPassPipeline(sequencerConversionPasses.get());
- if (failed(runPassPipeline(options(), sequencerConversionPasses.get(),
- module))) {
- module.emitError() << "Failed to run sequencer conversion passes";
- return {};
- }
-
- // Lower sequencer functions to their final form.
- auto sequencerLoweringPasses =
- createPassManager(module.getContext(), options());
- buildSequencerLoweringPassPipeline(sequencerLoweringPasses.get());
- if (failed(
- runPassPipeline(options(), sequencerLoweringPasses.get(), module))) {
- module.emitError() << "Failed to run sequencer lowering passes";
- return {};
- }
-
- // Perform translation on all executables.
- // We then know exactly what executable formats we have and can query them to
- // see if we need to do any additional processing (such as to support better
- // types/etc).
- ::flatbuffers::FlatBufferBuilder fbb;
- VMModuleBuilder moduleBuilder(&fbb);
- for (auto multiArchExecutableOp :
- module.getOps<IREE::MultiArchExecutableOp>()) {
- if (failed(translateMultiArchExecutable(multiArchExecutableOp,
- &moduleBuilder))) {
- module.emitError() << "Failed to translate multi-arch-executable";
- return {};
- }
- }
-
- // Build the module bytecode.
- if (failed(translateSequencerModule(module, &moduleBuilder))) {
- module.emitError() << "Unable to translate sequencer module";
- return {};
- }
- auto moduleDef = moduleBuilder.Finish();
- if (moduleDef.IsNull()) {
- module.emitError() << "Failed to verify completed module def";
- return {};
- }
- auto bytes = moduleBuilder.Serialize(moduleDef);
- if (bytes.empty()) {
- module.emitError() << "Failed to serialize final module def";
- return {};
- }
- return bytes;
-}
-
-LogicalResult SequencerTranslator::translateMultiArchExecutable(
- IREE::MultiArchExecutableOp multiArchExecutableOp,
- VMModuleBuilder *moduleBuilder) {
- auto &fbb = *moduleBuilder->fbb();
-
- // Find the unspecified executable. This is the template from which we will
- // translate to other targets.
- IREE::ExecutableOp templateExecutableOp;
- for (auto executableOp :
- multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) {
- if (executableOp.format() ==
- static_cast<uint32_t>(IREE::ExecutableFormat::Unspecified)) {
- templateExecutableOp = executableOp;
- break;
- }
- }
- if (!templateExecutableOp) {
- // Fine for there to be no unspecified executable - just ignore.
- return success();
- }
- int entryPointCount = 0;
- for (auto func : templateExecutableOp.getInnerModule().getOps<FuncOp>()) {
- if (func.getAttr("iree.executable.export")) {
- ++entryPointCount;
- }
- }
-
- // For now we just add target config ops based on options. In the future we
- // could do this earlier via an analysis pass determining which targets should
- // be used for each executable.
- OpBuilder configBuilder(templateExecutableOp);
- configBuilder.setInsertionPointToStart(&templateExecutableOp.getBlock());
- insertTargetConfigOps(options(), configBuilder);
-
- // Find all target configs and bucket them into the backends that will
- // translate them. This way we can batch the translations and possibly enable
- // backends to dedupe some things.
- DenseMap<StringRef, std::vector<IREE::ExecutableTargetConfigOp>>
- backendTargetConfigOps;
- for (auto targetConfigOp : templateExecutableOp.getBlock()
- .getOps<IREE::ExecutableTargetConfigOp>()) {
- auto &targetConfigOps = backendTargetConfigOps[targetConfigOp.backend()];
- targetConfigOps.push_back(targetConfigOp);
- }
- if (backendTargetConfigOps.empty()) {
- // There are no target configs - which likely means we've already translated
- // this in a previous pass.
- return success();
- }
-
- ExecutableTranslationOptions translationOptions;
- translationOptions.CopyFrom(options());
-
- // Invoke each backend translator on the template executables to produce new
- // executables. The backends may produce any number of executables that we
- // then merge back in to the iree.multi_arch_executable and the module
- // flatbuffer.
- std::vector<std::unique_ptr<iree::ExecutableDefT>> translatedExecutableDefs;
- for (auto it : backendTargetConfigOps) {
- const auto &backendKey = it.first;
- const auto &targetConfigOps = it.second;
-
- // Find the translator to use in the registry. It must have been linked in
- // and the name must match what is used in the registration macro.
- auto translateExecutableFn =
- getExecutableTranslationRegistry().lookup(backendKey);
- if (!translateExecutableFn) {
- return multiArchExecutableOp.emitError()
- << "No registered backend found for target '" << backendKey.str()
- << "'; ensure it is linked in to your binary (have: "
- << llvm::join(getExecutableTranslationRegistry().keys(), ", ")
- << ")";
- }
-
- // Clone the executable for each config so that the translator is allowed to
- // modify it in-place.
- // We also need to strip all of the other configs so that the translator
- // backend only sees the one for each of its configs.
- OpBuilder builder(&multiArchExecutableOp.getBlock());
- builder.setInsertionPoint(multiArchExecutableOp.getBlock().getTerminator());
- SmallVector<IREE::ExecutableOp, 4> clonedExecutableOps;
- for (auto targetConfigOp : targetConfigOps) {
- auto executableCloneOp = cast<IREE::ExecutableOp>(
- builder.clone(*templateExecutableOp.getOperation()));
- for (auto existingTargetConfigOp : llvm::make_early_inc_range(
- executableCloneOp.getBlock()
- .getOps<IREE::ExecutableTargetConfigOp>())) {
- existingTargetConfigOp.erase();
- }
- OpBuilder configBuilder(executableCloneOp);
- configBuilder.setInsertionPointToStart(&executableCloneOp.getBlock());
- configBuilder.clone(*targetConfigOp.getOperation());
- clonedExecutableOps.push_back(executableCloneOp);
- }
-
- // Perform translation on all of the backend-specific targets.
- // Note that the results here may not have the same number of executables we
- // started with if the backend either couldn't satisfy some of the requests
- // or decided to dedupe or expand certain ones.
- auto translationResults =
- translateExecutableFn(clonedExecutableOps, translationOptions);
- if (!translationResults.hasValue()) {
- return multiArchExecutableOp.emitError()
- << "Failed to translate executable with backend " << backendKey;
- }
- for (auto &executableDef : translationResults.getValue().executable_defs) {
- translatedExecutableDefs.push_back(std::move(executableDef));
- }
- }
-
- // Remove configs from the template executable so that if we are called again
- // we don't re-translate.
- for (auto targetConfigOp : llvm::make_early_inc_range(
- templateExecutableOp.getBlock()
- .getOps<IREE::ExecutableTargetConfigOp>())) {
- targetConfigOp.erase();
- }
-
- // Create multi-arch executable with all of the target-specific executables.
- iree::MultiArchExecutableDefT maedf;
- maedf.name = multiArchExecutableOp.getName();
- maedf.entry_point_count = entryPointCount;
- maedf.executables = std::move(translatedExecutableDefs);
- auto maedfOffset = iree::MultiArchExecutableDef::Pack(fbb, &maedf);
- RETURN_IF_FAILURE(
- moduleBuilder->executable_table()->AddMultiArchExecutable(maedfOffset));
-
- return success();
-}
-
-LogicalResult SequencerTranslator::translateSequencerModule(
- ModuleOp module, VMModuleBuilder *moduleBuilder) {
- // Declare functions. This must happen first so that we get stable indices
- // during declaration (as call ops need to use the function table).
- for (auto function : module.getOps<FuncOp>()) {
- RETURN_IF_FAILURE(declareFunction(function, moduleBuilder));
- }
-
- // Define functions and convert their bodies to bytecode.
- for (auto function : module.getOps<FuncOp>()) {
- RETURN_IF_FAILURE(defineFunction(function, moduleBuilder));
- }
-
- return success();
-}
-
-LogicalResult SequencerTranslator::declareFunction(
- FuncOp function, VMModuleBuilder *moduleBuilder) {
- auto *functionTable = moduleBuilder->function_table();
- if (functionTable->IsFunctionDeclared(function)) {
- // Already declared.
- return success();
- }
-
- LinkageType linkageType;
- if (function.isExternal()) {
- linkageType = LinkageType::kImport;
- } else if (function.getAttr("iree.module.export")) {
- linkageType = LinkageType::kExport;
- } else {
- linkageType = LinkageType::kInternal;
- }
- if (failed(functionTable->DeclareFunction(function, linkageType))) {
- return function.emitError()
- << "Unable to declare function " << function.getName();
- }
-
- // Import functions must have their definition defined here so we get their
- // type. Internal and export functions will be defined during conversion.
- if (linkageType == LinkageType::kImport) {
- VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(),
- moduleBuilder->fbb());
- auto functionOffset = functionBuilder.Finish();
- if (functionOffset.IsNull()) {
- return function.emitError()
- << "Failed to create import function bytecode";
- }
- RETURN_IF_FAILURE(
- functionTable->DefineFunction(function, functionOffset, {}));
- }
-
- return success();
-}
-
-LogicalResult SequencerTranslator::defineFunction(
- FuncOp function, VMModuleBuilder *moduleBuilder) {
- VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(),
- moduleBuilder->fbb());
- registerSequencerCustomWriters(&functionBuilder);
- RETURN_IF_FAILURE(functionBuilder.ConvertBytecode());
- auto functionOffset = functionBuilder.Finish();
- if (functionOffset.IsNull()) {
- return function.emitError() << "Failed to convert function to bytecode";
- }
- RETURN_IF_FAILURE(moduleBuilder->function_table()->DefineFunction(
- function, functionOffset, functionBuilder.source_map()));
- return success();
-}
-
-} // namespace
-
-std::vector<uint8_t> translateMlirToIreeSequencerModule(
- ModuleOp module, ModuleTranslationOptions options) {
- SequencerTranslator translator(options);
- return translator.translateModule(module);
-}
-
-LogicalResult translateMlirToIreeSequencerModuleFile(
- ModuleOp module, llvm::raw_ostream &output) {
- ModuleTranslationOptions options;
- SequencerTranslator translator(options);
- auto bytecodeModule = translator.translateModule(module);
- if (bytecodeModule.empty()) {
- return emitError(UnknownLoc::get(module.getContext()),
- "failed to translate module");
- }
-
- output.write(reinterpret_cast<const char *>(bytecodeModule.data()),
- bytecodeModule.size());
- return success();
-}
-
-static TranslateFromMLIRRegistration MlirToIreeSequencerModuleTranslate(
- "mlir-to-iree-module", translateMlirToIreeSequencerModuleFile);
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h b/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h
deleted file mode 100644
index 0b861a3..0000000
--- a/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_
-#define IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_
-
-#include <vector>
-
-#include "iree/compiler/Utils/TranslationUtils.h"
-#include "mlir/IR/Module.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Translates an MLIR module in a compatible IREE input dialect (such as XLA HLO
-// and/or Std) into an IREE Module. Executables will be lowered based on the
-// provided configuration.
-// Returns an empty vector on translation failure.
-std::vector<uint8_t> translateMlirToIreeSequencerModule(
- ModuleOp module, ModuleTranslationOptions options = {});
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_
diff --git a/iree/compiler/Utils/BUILD b/iree/compiler/Utils/BUILD
deleted file mode 100644
index 8ff56d5..0000000
--- a/iree/compiler/Utils/BUILD
+++ /dev/null
@@ -1,41 +0,0 @@
-# Utilities for working with IREE MLIR types.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Utils",
- srcs = [
- "DispatchUtils.cpp",
- "MemRefUtils.cpp",
- "ModuleUtils.cpp",
- "OpCreationUtils.cpp",
- "OpUtils.cpp",
- "TranslationUtils.cpp",
- "TypeConversionUtils.cpp",
- ],
- hdrs = [
- "DispatchUtils.h",
- "Macros.h",
- "MemRefUtils.h",
- "ModuleUtils.h",
- "OpCreationUtils.h",
- "OpUtils.h",
- "TranslationUtils.h",
- "TypeConversionUtils.h",
- ],
- deps = [
- "//iree/compiler/IR",
- "//iree/schemas",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:Pass",
- "@local_config_mlir//:StandardOps",
- "@local_config_mlir//:Support",
- "@local_config_mlir//:TransformUtils",
- "@local_config_mlir//:Transforms",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
- ],
-)
diff --git a/iree/compiler/Utils/DispatchUtils.cpp b/iree/compiler/Utils/DispatchUtils.cpp
deleted file mode 100644
index 9e3fb12..0000000
--- a/iree/compiler/Utils/DispatchUtils.cpp
+++ /dev/null
@@ -1,726 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Utils/DispatchUtils.h"
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallPtrSet.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-Value *calculateWorkload(Operation *op, Value *baseOperand) {
- OpBuilder builder(op);
-
- std::array<int32_t, 3> workload = {1, 1, 1};
-
- // TODO(b/139353314): lookup/calculate based on type/etc.
- auto resultType = baseOperand->getType();
- if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
- if (!shapedType.hasStaticShape()) {
- op->emitOpError() << "Dynamic shapes not yet supported";
- return nullptr;
- }
- auto shape = shapedType.getShape();
- // Drop the trailing ones from the shape.
- while (shape.size() > 1 && shape.back() == 1) {
- shape = shape.drop_back();
- }
- if (shape.size() <= 3) {
- // Maps to XYZ (possibly with 1's for unused dimensions).
- for (auto dim : enumerate(shape)) {
- workload[shape.size() - 1 - dim.index()] = dim.value();
- }
- } else {
- // Need to flatten the shape to fit XYZ. For now we just squash from LHS.
- workload[2] = 1;
- for (int i = 0; i < shape.size(); ++i) {
- workload[2] *= shape[i];
- }
- workload[1] = shape[shape.size() - 2];
- workload[0] = shape.back();
- }
- }
-
- // TODO(b/139353314): optimize workload layout.
-
- auto constantType = RankedTensorType::get({3}, builder.getIntegerType(32));
- return builder.create<ConstantOp>(
- op->getLoc(), constantType,
- DenseIntElementsAttr::get<int32_t>(constantType, workload));
-}
-
-bool isTriviallyDispatchable(FuncOp func) {
- if (func.empty()) return false;
- auto &block = func.front();
- if (block.getOperations().size() != 2) return false;
- auto &op0 = block.front();
- auto &op1 = block.back();
- auto regionOp = dyn_cast<IREE::DispatchRegionOp>(op0);
- auto returnOp = dyn_cast<ReturnOp>(op1);
- if (!regionOp || !returnOp ||
- regionOp.getNumResults() != returnOp.getNumOperands()) {
- return false;
- }
- for (int i = 0; i < regionOp.getNumResults(); ++i) {
- if (regionOp.getResult(i) != returnOp.getOperand(i)) return false;
- }
- return true;
-}
-
-namespace {
-
-// Returns the set of values that must be captured for use by |ops| and the
-// set of values defined by |ops| that are used outside of the set.
-LogicalResult analyzeOpRangeValues(
- const llvm::SmallDenseSet<Operation *> &opSet,
- llvm::SetVector<Value *> *capturedValues,
- llvm::SetVector<Value *> *escapingValues) {
- for (auto *op : opSet) {
- for (auto *value : op->getOperands()) {
- if (!llvm::is_contained(opSet, value->getDefiningOp())) {
- // Op is using a value not in the ops set, ensure we capture it.
- capturedValues->insert(value);
- }
- }
- for (auto *value : op->getResults()) {
- for (auto &use : value->getUses()) {
- if (!llvm::is_contained(opSet, use.getOwner())) {
- // An op outside of the ops set is using the value, needs to escape.
- escapingValues->insert(value);
- }
- }
- }
- }
- return success();
-}
-
-} // namespace
-
-LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock,
- Value *workload, ArrayRef<Operation *> ops) {
- // Fused location with all ops.
- SmallVector<Location, 16> opLocs;
- for (auto *op : ops) {
- opLocs.push_back(op->getLoc());
- }
- auto regionLoc = FusedLoc::get(opLocs, func.getContext());
-
- // Get a list of values that we need to capture and values that escape the
- // region and need to be returned.
- llvm::SmallDenseSet<Operation *> opSet;
- opSet.reserve(ops.size());
- opSet.insert(ops.begin(), ops.end());
- llvm::SetVector<Value *> capturedValues;
- llvm::SetVector<Value *> escapingValues;
- if (failed(analyzeOpRangeValues(opSet, &capturedValues, &escapingValues))) {
- return failure();
- }
- SmallVector<Type, 8> escapingTypes;
- for (auto *value : escapingValues) escapingTypes.push_back(value->getType());
-
- // Build the region op and add it to the parent block.
- OpBuilder parentBuilder(parentBlock);
- parentBuilder.setInsertionPoint(ops.back());
- auto dispatchRegionOp = parentBuilder.create<IREE::DispatchRegionOp>(
- regionLoc, escapingTypes, workload, capturedValues.getArrayRef());
-
- // Create the block and setup the arg mapping for captured values.
- auto *regionBlock = new Block();
- dispatchRegionOp.getBody().push_back(regionBlock);
- OpBuilder regionBuilder(regionBlock);
- BlockAndValueMapping mapping;
- for (auto *capturedValue : capturedValues) {
- auto *blockArg = regionBlock->addArgument(capturedValue->getType());
- mapping.map(capturedValue, blockArg);
- }
-
- // Clone ops into the new region block.
- for (auto *op : ops) {
- // Note that this updates the mapping with the new values (so at the end
- // we have those new values).
- regionBuilder.clone(*op, mapping);
- }
-
- // Return results (as we need a terminator in our block).
- // These are all of the values that escape our region.
- SmallVector<Value *, 8> resultValues;
- for (auto *oldValue : escapingValues) {
- resultValues.push_back(mapping.lookupOrDefault(oldValue));
- }
- regionBuilder.create<IREE::ReturnOp>(opLocs.back(), resultValues);
-
- // Replace usage of values with the results of the region.
- for (int i = 0; i < escapingValues.size(); ++i) {
- escapingValues[i]->replaceAllUsesWith(dispatchRegionOp.getResult(i));
- }
-
- // Remove original ops from the parent region.
- for (auto it = ops.rbegin(); it != ops.rend(); ++it) {
- (*it)->erase();
- }
-
- return success();
-}
-
-namespace {
-
-// Replaces |returnOp| with a clone including |newOperands| appended.
-LogicalResult appendReturnOperands(IREE::ReturnOp returnOp,
- ArrayRef<Value *> newOperands) {
- // Insert prior to the original return.
- OpBuilder builder(returnOp);
-
- // Clone with new args.
- SmallVector<Value *, 8> operands;
- operands.reserve(returnOp.getNumOperands() + newOperands.size());
- operands.append(returnOp.operand_begin(), returnOp.operand_end());
- operands.append(newOperands.begin(), newOperands.end());
- builder.create<IREE::ReturnOp>(returnOp.getLoc(), operands);
-
- // Remove original.
- returnOp.erase();
-
- return success();
-}
-
-// Replaces |regionOp| with a clone including |newArgs| and |newResults|.
-IREE::DispatchRegionOp appendRegionArgsAndResults(
- IREE::DispatchRegionOp ®ionOp, ArrayRef<Value *> newArgs,
- ArrayRef<Value *> newResults, Location otherLoc) {
- // Insert prior to the original region.
- OpBuilder builder(regionOp);
-
- // Location is original region + new region location (both probably fused).
- SmallVector<Location, 2> fusedLocs = {regionOp.getLoc(), otherLoc};
- auto fusedLoc = FusedLoc::get(fusedLocs, regionOp.getContext());
-
- // Clone with new results.
- SmallVector<Value *, 8> operands;
- operands.append(regionOp.getArgOperands().begin(),
- regionOp.getArgOperands().end());
- operands.append(newArgs.begin(), newArgs.end());
- SmallVector<Type, 8> resultTypes;
- resultTypes.append(regionOp.result_type_begin(), regionOp.result_type_end());
- for (auto *newResult : newResults) {
- resultTypes.push_back(newResult->getType());
- }
- auto newRegionOp = builder.create<IREE::DispatchRegionOp>(
- fusedLoc, resultTypes, regionOp.getWorkload(), operands,
- regionOp.getAttrs());
- newRegionOp.getBody().takeBody(regionOp.getBody());
-
- // Replace uses of original values with the new values.
- for (int i = 0; i < regionOp.getNumResults(); ++i) {
- regionOp.getResult(i)->replaceAllUsesWith(newRegionOp.getResult(i));
- }
-
- // Erase the original region.
- regionOp.erase();
-
- return newRegionOp;
-}
-
-// Removes results that are not used from the dispatch region.
-// Returns the new operation. There may be unused ops in the region but DCE
-// should take care of that later.
-IREE::DispatchRegionOp removeUnusedResults(IREE::DispatchRegionOp regionOp) {
- // Find return value within the region.
- auto ®ionBlock = regionOp.getBody().getBlocks().front();
- auto returnOp = dyn_cast<IREE::ReturnOp>(regionBlock.getTerminator());
- if (!returnOp) {
- regionBlock.getParent()->getParentOfType<FuncOp>().emitError()
- << "Block does not contain an iree.return op";
- }
-
- // Calculate new return values.
- SmallVector<Type, 8> newReturnTypes;
- SmallVector<Value *, 8> newReturnValues;
- SmallVector<Value *, 8> newRegionResults;
- for (int i = 0; i < returnOp.getNumOperands(); ++i) {
- auto *resultValue = regionOp.getResult(i);
- if (!resultValue->use_empty()) {
- // Still has uses so we will preserve it.
- newReturnTypes.push_back(resultValue->getType());
- newReturnValues.push_back(returnOp.getOperand(i));
- newRegionResults.push_back(resultValue);
- }
- }
-
- // Update return op operands. We can do this in-place as we are only shrinking
- // the list.
- returnOp.getOperation()->setOperands(newReturnValues);
-
- // Insert prior to the original region.
- OpBuilder builder(regionOp);
-
- // Clone with new results.
- SmallVector<Value *, 8> operands(regionOp.getArgOperands());
- auto newRegionOp = builder.create<IREE::DispatchRegionOp>(
- regionOp.getLoc(), newReturnTypes, regionOp.getWorkload(), operands,
- regionOp.getAttrs());
- newRegionOp.getBody().takeBody(regionOp.getBody());
-
- // Replace uses of original values with the new values.
- for (int i = 0; i < newRegionResults.size(); ++i) {
- newRegionResults[i]->replaceAllUsesWith(newRegionOp.getResult(i));
- }
-
- // Erase the original region.
- regionOp.erase();
-
- return newRegionOp;
-}
-
-// Returns true if |lhs| and |rhs| have either an identical workload or one that
-// is compatible.
-bool areDispatchRegionWorkloadsCompatible(IREE::DispatchRegionOp &lhs,
- IREE::DispatchRegionOp &rhs) {
- // TODO(benvanik): more sophisticated checking; right now it's just identical.
- return lhs.getWorkload() == rhs.getWorkload();
-}
-
-// Returns true if |value| depends in any way on |op| through any path.
-// Only works if the operations are within the same block.
-bool doesValueDependOnOperation(Value *value, Operation *op) {
- if (!value->getDefiningOp()) {
- return false;
- } else if (value->getDefiningOp() == op) {
- return true;
- } else if (value->getDefiningOp()->isBeforeInBlock(op)) {
- // Can't depend on |op| as it is defined prior to it.
- return false;
- }
- for (auto *operand : value->getDefiningOp()->getOperands()) {
- if (doesValueDependOnOperation(operand, op)) {
- return true;
- }
- }
- return true;
-}
-
-// Returns true if |rhs| transitively depends on any out of |lhs|.
-// |rhs| may depend directly on the results of |lhs| but no other ops in the
-// parent block will use the results prior to |rhs|.
-bool areDispatchRegionsTransitivelyDependent(IREE::DispatchRegionOp &lhs,
- IREE::DispatchRegionOp &rhs) {
- for (auto *arg : rhs.getArgOperands()) {
- if (arg->getDefiningOp() != lhs && doesValueDependOnOperation(arg, lhs)) {
- // Transitively dependent - boo - can't merge yet.
- return true;
- }
- }
- return false;
-}
-
-// Returns true if the dispatch region contains only a single block.
-// This is because our merge isn't very smart and will not preserve the CFG
-// right now. We can fix this when needed.
-bool isDispatchRegionMergable(IREE::DispatchRegionOp ®ionOp) {
- // Disallow merging of dispatch regions containing matmuls and other big ops.
- // We do this to allow backends to lower the big op as entirely isolated such
- // that substituting library calls is easier.
- for (auto &block : regionOp.getBody().getBlocks()) {
- for (auto &op : block) {
- if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) {
- return false;
- }
- }
- }
- return regionOp.getBody().getBlocks().size() == 1;
-}
-
-// Merges |rhs| into |lhs| and returns the new |lhs| op.
-// Precondition: !areDispatchRegionsTransitivelyDependent
-IREE::DispatchRegionOp mergeDispatchRegions(IREE::DispatchRegionOp &lhs,
- IREE::DispatchRegionOp &rhs) {
- auto &lhsBlock = lhs.getBody().front();
- auto &rhsBlock = rhs.getBody().front();
-
- // Find the values used as return values in the lhs.
- // We'll need to replace the uses in rhs with these.
- auto lhsReturnOp = cast<IREE::ReturnOp>(lhsBlock.getTerminator());
- SmallVector<Value *, 8> lhsReturnValues;
- lhsReturnValues.reserve(lhsReturnOp.getNumOperands());
- lhsReturnValues.append(lhsReturnOp.operand_begin(),
- lhsReturnOp.operand_end());
-
- // Find the values used as return values in the rhs.
- // We'll add these to the results of the lhs region.
- auto rhsReturnOp = cast<IREE::ReturnOp>(rhsBlock.getTerminator());
- SmallVector<Value *, 8> rhsReturnValues;
- rhsReturnValues.reserve(rhsReturnOp.getNumOperands());
- rhsReturnValues.append(rhsReturnOp.operand_begin(),
- rhsReturnOp.operand_end());
-
- // Compute new args.
- BlockAndValueMapping mapping;
- SmallVector<Value *, 8> newArgs;
- for (int rhsOpIdx = 0; rhsOpIdx < rhs.getNumArgOperands(); ++rhsOpIdx) {
- bool didElide = false;
- // Find if the rhs arg already exists on the lhs and dedupe.
- for (int lhsOpIdx = 0; lhsOpIdx < lhs.getNumArgOperands(); ++lhsOpIdx) {
- if (rhs.getArgOperand(rhsOpIdx) == lhs.getArgOperand(lhsOpIdx)) {
- mapping.map(rhsBlock.getArgument(rhsOpIdx),
- lhsBlock.getArgument(lhsOpIdx));
- didElide = true;
- break;
- }
- }
- // Find if the arg has a direct dependency on the results of the lhs.
- for (int lhsResultIdx = 0; lhsResultIdx < lhs.getNumResults();
- ++lhsResultIdx) {
- if (rhs.getArgOperand(rhsOpIdx) == lhs.getResult(lhsResultIdx)) {
- // Direct dependency; can elide. We'll skip adding it to the new region
- // args and instead just remap it later.
- mapping.map(rhsBlock.getArgument(rhsOpIdx),
- lhsReturnValues[lhsResultIdx]);
- didElide = true;
- break;
- }
- }
- if (!didElide) {
- // Add to the lhs block.
- auto *oldArg = rhs.getOperand(rhsOpIdx + 1);
- auto *newArg = lhsBlock.addArgument(oldArg->getType());
- mapping.map(rhsBlock.getArgument(rhsOpIdx), newArg);
- newArgs.push_back(oldArg);
- }
- }
-
- OpBuilder regionBuilder(&lhsBlock);
-
- // Copy ops (replacing any args as needed).
- // Note that we need to insert prior to the terminator.
- regionBuilder.setInsertionPoint(lhsReturnOp);
- for (auto &op : rhsBlock) {
- // Note that this updates the mapping with the new values (so at the end
- // we have those new values).
- //
- // We avoid the return op here as we have already merged it above.
- if (!op.isKnownTerminator()) {
- regionBuilder.clone(op, mapping);
- }
- }
-
- // Compute new results and add to both region and return op.
- SmallVector<Value *, 8> newResults;
- for (auto *rhsResult : rhsReturnValues) {
- newResults.push_back(mapping.lookupOrDefault(rhsResult));
- }
- if (failed(appendReturnOperands(lhsReturnOp, newResults))) {
- return nullptr;
- }
- auto newRegionOp =
- appendRegionArgsAndResults(lhs, newArgs, newResults, rhs.getLoc());
-
- // Replace uses of original values with the new values.
- for (int i = 0; i < rhs.getNumResults(); ++i) {
- rhs.getResult(i)->replaceAllUsesWith(
- newRegionOp.getResult(lhsReturnValues.size() + i));
- }
-
- // Remove rhs region.
- rhs.erase();
-
- // Remove results from the lhs that aren't used anymore as they may have been
- // elided when we merged as only the rhs was using them.
- newRegionOp = removeUnusedResults(newRegionOp);
-
- return newRegionOp;
-}
-
-} // namespace
-
-LogicalResult mergeBlockDispatchRegions(FuncOp func, Block *parentBlock) {
- SmallVector<IREE::DispatchRegionOp, 8> mergableRegions;
- for (auto &op : *parentBlock) {
- if (auto regionOp = dyn_cast<IREE::DispatchRegionOp>(op)) {
- if (isDispatchRegionMergable(regionOp)) {
- mergableRegions.push_back(regionOp);
- } else {
- regionOp.emitRemark(
- "Unable to merge into following iree.dispatch_regions; "
- "contains non-trivial control flow");
- }
- }
- }
- for (int i = 0; i < mergableRegions.size(); ++i) {
- if (!mergableRegions[i]) continue;
- auto &lhs = mergableRegions[i];
- for (int j = i + 1; j < mergableRegions.size(); ++j) {
- if (!mergableRegions[j]) continue;
- auto &rhs = mergableRegions[j];
- if (!areDispatchRegionWorkloadsCompatible(lhs, rhs) ||
- areDispatchRegionsTransitivelyDependent(lhs, rhs)) {
- continue;
- }
- if (!isDispatchRegionMergable(rhs)) {
- // TODO(b/134675461): support non-trivial control flow.
- rhs.emitRemark(
- "Unable to merge into previous iree.dispatch_region; "
- "contains non-trivial control flow");
- }
- mergableRegions[i] = mergeDispatchRegions(lhs, rhs);
- if (!mergableRegions[i]) {
- return failure();
- }
- mergableRegions[j] = nullptr;
- --i; // Try again to see if there are subsequent regions to merge.
- break;
- }
- }
-
- return success();
-}
-
-namespace {
-
-// Recursively clones the given |sourceOp| and returns the newly cloned op.
-Operation *recursivelyCloneOp(Operation *sourceOp, OpBuilder &builder,
- BlockAndValueMapping *mapping) {
- // Note that we dedupe required operands in the case of multiple arguments
- // coming from the same source operation.
- SmallPtrSet<Operation *, 4> operandOps;
- for (auto *operand : sourceOp->getOperands()) {
- operandOps.insert(operand->getDefiningOp());
- }
- for (auto *operandOp : operandOps) {
- recursivelyCloneOp(operandOp, builder, mapping);
- }
- return builder.clone(*sourceOp, *mapping);
-}
-
-// Clones the |sourceValue| op tree into |targetBlock|.
-// |mapping| is used to lookup existing values that may be present in the block
-// such as block arguments or already cloned ancestor ops. |mapping| will be
-// updated as the tree is cloned.
-Value *cloneOpTreeIntoBlock(Value *sourceValue, Block *targetBlock,
- BlockAndValueMapping *mapping) {
- // If the op has already been cloned we can just reuse that.
- // This happens if multiple arguments reference the same trees.
- if (auto *existingValue = mapping->lookupOrNull(sourceValue)) {
- return existingValue;
- }
-
- OpBuilder builder(targetBlock);
- builder.setInsertionPointToStart(targetBlock);
- auto *sourceOp = sourceValue->getDefiningOp();
- auto *clonedOp = recursivelyCloneOp(sourceOp, builder, mapping);
-
- // Return only the result matching our source value (in the case of multiple
- // results).
- int resultIndex = std::distance(
- sourceOp->result_begin(),
- std::find(sourceOp->result_begin(), sourceOp->result_end(), sourceValue));
- return clonedOp->getResult(resultIndex);
-}
-
-} // namespace
-
-LogicalResult inlineDispatchRegionOperandsUsingValue(
- IREE::DispatchRegionOp dispatchRegionOp, Value *value) {
- // Find all args that are using this value.
- SmallVector<unsigned, 4> argIndices;
- for (auto arg : llvm::enumerate(dispatchRegionOp.getArgOperands())) {
- if (arg.value() == value) {
- argIndices.push_back(arg.index());
- }
- }
- if (argIndices.empty()) {
- // Not used? Wasteful call!
- return success();
- }
-
- // Clone the value (and the ops required to create it) into the entry block.
- auto &entryBlock = dispatchRegionOp.getBody().getBlocks().front();
- BlockAndValueMapping mapping;
- auto *clonedValue = cloneOpTreeIntoBlock(value, &entryBlock, &mapping);
-
- // Replace all uses of the inner operand with the new value.
- for (unsigned argIndex : argIndices) {
- entryBlock.getArgument(argIndex)->replaceAllUsesWith(clonedValue);
- }
-
- // Remove the dispatch region args and the block args that have been
- // replaced.
- for (unsigned argIndex : llvm::reverse(argIndices)) {
- dispatchRegionOp.getOperation()->eraseOperand(
- dispatchRegionOp.mapArgOperandToOpOperand(argIndex));
- entryBlock.eraseArgument(argIndex);
- }
-
- return success();
-}
-
-namespace {
-
-// Recursively finds all reachable functions from the given |rootFunc| and adds
-// them to the |reachableFuncs| set.
-//
-// Note that indirect calls are not supported, however we don't allow those in
-// dispatch regions anyway so they should not be present here.
-LogicalResult findReachableFunctions(Operation *rootFunc,
- llvm::SetVector<FuncOp> &reachableFuncs) {
- bool allCallsValid = true;
- rootFunc->walk([&](CallOp op) {
- auto callee = rootFunc->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
- op.getCallee());
- if (!callee.getAttr("iree.dispatchable")) {
- allCallsValid = false;
- rootFunc->emitError() << callee.getName().str() << " is not dispatchable";
- return;
- }
- if (reachableFuncs.insert(callee)) {
- findReachableFunctions(callee, reachableFuncs);
- }
- });
- return success(allCallsValid);
-}
-
-} // namespace
-
-std::pair<IREE::MultiArchExecutableOp, FuncOp> createRegionExecutable(
- Operation *op, FunctionType functionType, StringRef symbolSuffix) {
- // Create the function and take the region body directly.
- // NOTE: this will get uniquified if we have multiple in the same block.
- auto parentFunc = op->getParentOfType<FuncOp>();
- std::string functionName =
- (parentFunc.getName().str() + "_rgn" + symbolSuffix).str();
- auto outlinedFunc = FuncOp::create(op->getLoc(), functionName, functionType);
- BlockAndValueMapping mapping;
- op->getRegion(0).cloneInto(&outlinedFunc.getBody(), mapping);
-
- // Gather all reachable functions.
- llvm::SetVector<FuncOp> reachableFuncs;
- findReachableFunctions(outlinedFunc, reachableFuncs);
-
- // Create the multi-arch executable that will contain the outlined region.
- // NOTE: this will get uniquified if we have multiple in the same block.
- auto parentModule = parentFunc.getParentOfType<ModuleOp>();
- OpBuilder parentModuleBuilder(parentModule);
- parentModuleBuilder.setInsertionPoint(parentFunc);
- std::string executableName =
- (parentFunc.getName().str() + "_ex" + symbolSuffix).str();
- auto multiArchExecutable =
- parentModuleBuilder.create<IREE::MultiArchExecutableOp>(
- outlinedFunc.getLoc(), executableName);
-
- // Create the executable op initially unspecified so that later
- // transformations can compile it to various formats.
- OpBuilder multiArchExecutableBuilder(multiArchExecutable);
- multiArchExecutableBuilder.setInsertionPointToStart(
- &multiArchExecutable.getBlock());
- auto executable = multiArchExecutableBuilder.create<IREE::ExecutableOp>(
- outlinedFunc.getLoc(), IREE::ExecutableFormat::Unspecified);
-
- // Create the inner ModuleOp that contains the original functions. We need
- // to provide this shim as some ops (like std.call) look for the
- // containing module to provide symbol resolution.
- OpBuilder executableBuilder(executable);
- executableBuilder.setInsertionPointToStart(&executable.getBlock());
- auto innerModule = executableBuilder.create<ModuleOp>(outlinedFunc.getLoc());
-
- // TODO(b/137674142): make an ExecutableEntryPointOp and convert the
- // entry thunk into that format.
- innerModule.push_back(outlinedFunc);
-
- // Copy all reachable functions into the executable.
- // Linker passes may dedupe these later on.
- for (auto reachableFunc : reachableFuncs) {
- auto clonedFunc = reachableFunc.clone();
- clonedFunc.removeAttr("iree.dispatchable");
- innerModule.push_back(clonedFunc);
- }
-
- return std::make_pair(multiArchExecutable, outlinedFunc);
-}
-
-Value *insertDispatcherStore(Operation *op, Value *value, OpBuilder &builder) {
- if (!value) {
- return nullptr;
- }
-
- // If the previous value was already a memref we don't need to change
- // anything.
- // TODO(benvanik): ensure indices make sense.
- if (value->getType().isa<MemRefType>()) {
- return value;
- } else if (value->getType().isa<TensorType>()) {
- auto castOp = builder.create<IREE::TensorToMemRefOp>(op->getLoc(), value);
- return castOp.getResult();
- }
-
- // Allocate the memref to store the value.
- auto newStorage = builder.create<AllocOp>(
- op->getLoc(), convertTypeToMemRef(value->getType()));
-
- // Insert the store we'll use to box the value.
- builder.create<StoreOp>(op->getLoc(), value, newStorage, ArrayRef<Value *>{});
-
- return newStorage;
-}
-
-Value *insertDispatcherLoad(Operation *op, Value *originalValue,
- Value *allocatedValue, OpBuilder &builder) {
- // If old value was a memref we don't need to change anything.
- if (originalValue->getType().isa<MemRefType>()) {
- return allocatedValue;
- } else if (originalValue->getType().isa<TensorType>()) {
- auto castOp =
- builder.create<IREE::MemRefToTensorOp>(op->getLoc(), allocatedValue);
- originalValue->replaceAllUsesWith(castOp.getResult());
- return castOp.getResult();
- }
-
- // Insert the load we'll use to unbox the value.
- auto loadOp =
- builder.create<LoadOp>(op->getLoc(), allocatedValue, ArrayRef<Value *>{});
- originalValue->replaceAllUsesWith(loadOp);
- return loadOp;
-}
-
-// TODO(benvanik): enough information to walk into dispatch region and compute
-// shape when not static.
-Value *allocateDispatchOutputBuffer(Location loc, MemRefType type,
- OpBuilder &builder) {
- // TODO(benvanik): allocation algorithm:
- // - synthesize shape logic (magic) [[ for now assume fixed shapes ]]
- // - insert shape logic above region
- // - rely on folding to merge multiple calculations together
- // - unranked = death, need to be able to alloc shape outputs
- // - insert alloc
- SmallVector<Value *, 4> dimPieces;
- return builder.create<AllocOp>(loc, type, dimPieces);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Utils/DispatchUtils.h b/iree/compiler/Utils/DispatchUtils.h
deleted file mode 100644
index 80f9be1..0000000
--- a/iree/compiler/Utils/DispatchUtils.h
+++ /dev/null
@@ -1,92 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Utilities for dispatch region and function manipulation.
-// These are shared between all dispatchable types such as the standard
-// iree.dispatch_region as well as dispatch-related types like
-// iree.reduction_region.
-
-#ifndef IREE_COMPILER_UTILS_DISPATCHUTILS_H_
-#define IREE_COMPILER_UTILS_DISPATCHUTILS_H_
-
-#include <utility>
-
-#include "iree/compiler/IR/Ops.h"
-#include "iree/compiler/IR/StructureOps.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/Value.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Calculates the workload for |op| based on the op type.
-Value *calculateWorkload(Operation *op, Value *baseOperand);
-
-// Returns true if the func is trivially dispatchable, meaning that:
-// - it contains a single block
-// - it contains a single dispatch region
-// - it contains a return op directly returning the dispatch region results
-bool isTriviallyDispatchable(FuncOp func);
-
-// Builds a new iree.dispatch_region with the given |ops|.
-// The region will capture all required values and return all values used
-// outside of the |ops| provided. The region will be inserted at the location of
-// the last operation in the set.
-//
-// All |ops| must be compatible with the |workload| specified as they will all
-// be dispatched with the same workgroup structure.
-// TODO(benvanik): ensure we want to insert at end. Maybe front?
-LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock,
- Value *workload, ArrayRef<Operation *> ops);
-
-// Merges multiple dispatch regions within a block into the same region,
-// if possible. Operations may be reordered if it's possible to merge more while
-// still obeying data dependencies.
-LogicalResult mergeBlockDispatchRegions(FuncOp func, Block *parentBlock);
-
-// Inlines use of the given |value| from outside of a dispatch region to inside
-// of it and removes the argument. Supports multiple arguments that reference
-// |value| and will clone the entire value tree.
-LogicalResult inlineDispatchRegionOperandsUsingValue(
- IREE::DispatchRegionOp dispatchRegionOp, Value *value);
-
-// Creates an iree.multi_arch_executable containing an iree.executable with an
-// exported function containing the body region of |op|. Created executables
-// will be named for their original function concatenated with |symbolSuffix|.
-std::pair<IREE::MultiArchExecutableOp, FuncOp> createRegionExecutable(
- Operation *op, FunctionType functionType, StringRef symbolSuffix);
-
-// Inserts a conversion of an arbitrary |value| to a memref, possibly by way of
-// wrapping in an allocation.
-// Returns a new memref containing the value or an alias to |value|.
-Value *insertDispatcherStore(Operation *op, Value *value, OpBuilder &builder);
-
-// Inserts a load from a wrapped memref.
-// Returns the value in the original type or an alias to the |value| memref.
-Value *insertDispatcherLoad(Operation *op, Value *originalValue,
- Value *allocatedValue, OpBuilder &builder);
-
-// TODO(benvanik): enough information to walk into dispatch region and compute
-// shape when not static.
-Value *allocateDispatchOutputBuffer(Location loc, MemRefType type,
- OpBuilder &builder);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_UTILS_DISPATCHUTILS_H_
diff --git a/iree/compiler/Utils/MemRefUtils.cpp b/iree/compiler/Utils/MemRefUtils.cpp
deleted file mode 100644
index acb8822..0000000
--- a/iree/compiler/Utils/MemRefUtils.cpp
+++ /dev/null
@@ -1,94 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Utils/MemRefUtils.h"
-
-#include <cassert>
-
-#include "iree/compiler/IR/Ops.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace mlir {
-namespace iree_compiler {
-Value *resolveValueToSourceMemRef(Value *value, Operation *useOp) {
- // TODO(benvanik): implement this for real; this is naive but enough for our
- // simple load patterns.
- auto *defInstr = value->getDefiningOp();
- if (auto loadOp = dyn_cast_or_null<LoadOp>(defInstr)) {
- // TODO(benvanik): support views.
- return loadOp.getMemRef();
- }
- return nullptr;
-}
-
-Value *wrapAsTensor(Value *value, Operation *srcOp, OpBuilder &builder) {
- if (srcOp->getResult(0)->getType().isa<TensorType>()) {
- if (isa_and_nonnull<IREE::TensorToMemRefOp>(value->getDefiningOp())) {
- return value->getDefiningOp()->getOperand(0);
- }
- auto newOp = builder.create<IREE::MemRefToTensorOp>(srcOp->getLoc(), value);
- value = newOp.getResult();
- }
- return value;
-}
-
-Value *wrapAsMemRef(Value *value, Operation *srcOp, OpBuilder &builder) {
- if (value->getType().isa<TensorType>()) {
- if (isa_and_nonnull<IREE::MemRefToTensorOp>(value->getDefiningOp())) {
- return value->getDefiningOp()->getOperand(0);
- }
- auto newOp = builder.create<IREE::TensorToMemRefOp>(srcOp->getLoc(), value);
- value = newOp.getResult();
- }
- return value;
-}
-
-Value *loadAccessValue(Location location, Value *operand, OpBuilder &builder) {
- if (operand->getType().isa<MemRefType>() ||
- operand->getType().isa<TensorType>()) {
- return operand;
- }
-
- auto memRefType = MemRefType::get({}, operand->getType());
- if (auto loadOp = dyn_cast_or_null<LoadOp>(operand->getDefiningOp())) {
- // TODO(benvanik): handle creating views.
- if (loadOp.getMemRefType() == memRefType) {
- return loadOp.getMemRef();
- }
- }
-
- auto allocOp = builder.create<AllocOp>(location, memRefType);
- builder.create<StoreOp>(location, operand, allocOp.getResult(),
- ArrayRef<Value *>{});
- return allocOp.getResult();
-}
-
-Value *loadResultValue(Location location, const Type &originalType,
- Value *result, OpBuilder &builder) {
- if (originalType.isa<MemRefType>()) {
- return result;
- } else if (auto tensorType = originalType.dyn_cast<TensorType>()) {
- return result;
- }
-
- auto loadOp = builder.create<LoadOp>(location, result, ArrayRef<Value *>{});
- return loadOp.getResult();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Utils/ModuleUtils.cpp b/iree/compiler/Utils/ModuleUtils.cpp
deleted file mode 100644
index 230fdb1..0000000
--- a/iree/compiler/Utils/ModuleUtils.cpp
+++ /dev/null
@@ -1,98 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Utils/ModuleUtils.h"
-
-#include "llvm/ADT/SetVector.h"
-#include "mlir/IR/Function.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Finds a list of functions with the given |attrName| and adds them to |funcs|.
-void findFunctionsWithAttr(ModuleOp module, const char *attrName,
- llvm::SetVector<FuncOp> &funcs) {
- for (auto func : module.getOps<FuncOp>()) {
- if (func.getAttr(attrName)) {
- funcs.insert(func);
- }
- }
-}
-
-// Inserts functions reachable directly from |func| to |usedFuncs|.
-void insertUsedFunctions(ModuleOp module, FuncOp func,
- DenseSet<FuncOp> *usedFuncs,
- std::vector<FuncOp> *toSearch) {
- auto onCalledFunction = [&](StringRef calleeName) {
- auto calleeFunc = module.lookupSymbol<FuncOp>(calleeName);
- if (usedFuncs->insert(calleeFunc).second) {
- // New function found! Add to queue for searching.
- toSearch->push_back(calleeFunc);
- }
- };
- for (auto &block : func) {
- for (auto &op : block) {
- // TODO(benvanik): replace with iree_hl.call check.
- if (auto calleeAttr = op.getAttr("callee")) {
- onCalledFunction(calleeAttr.cast<SymbolRefAttr>().getValue());
- }
- }
- }
-}
-
-// Returns a set containing the names of all functions used by the given
-// |rootFuncs| list.
-DenseSet<FuncOp> findUsedFunctions(ModuleOp module,
- ArrayRef<FuncOp> rootFuncs) {
- // Breadth-first search.
- DenseSet<FuncOp> usedFuncs;
- usedFuncs.insert(rootFuncs.begin(), rootFuncs.end());
- std::vector<FuncOp> toSearch = {rootFuncs.begin(), rootFuncs.end()};
- while (!toSearch.empty()) {
- auto func = toSearch.back();
- toSearch.pop_back();
- insertUsedFunctions(module, func, &usedFuncs, &toSearch);
- }
- return usedFuncs;
-}
-
-} // namespace
-
-void dropUnusedFunctions(ModuleOp module, ArrayRef<const char *> keepAttrs) {
- // Find all of the exported functions we'll treat as roots.
- llvm::SetVector<FuncOp> rootFuncs;
- for (auto keepAttr : keepAttrs) {
- findFunctionsWithAttr(module, keepAttr, rootFuncs);
- }
-
- // Find the full set of all used functions reachable from the given rootFuncs.
- // This set will contain the rootFuncs.
- auto usedFuncs = findUsedFunctions(module, rootFuncs.getArrayRef());
-
- // Drop all unused functions.
- std::vector<FuncOp> deadFuncs;
- for (auto func : module.getOps<FuncOp>()) {
- if (!llvm::is_contained(usedFuncs, func)) {
- deadFuncs.push_back(func);
- }
- }
- for (auto func : deadFuncs) {
- func.erase();
- }
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Utils/OpCreationUtils.cpp b/iree/compiler/Utils/OpCreationUtils.cpp
deleted file mode 100644
index d624310..0000000
--- a/iree/compiler/Utils/OpCreationUtils.cpp
+++ /dev/null
@@ -1,45 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Utils/OpCreationUtils.h"
-
-#include <cstdint>
-
-#include "iree/compiler/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-ElementsAttr elementsAttrFromArray(OpBuilder &builder,
- ArrayRef<int64_t> elements) {
- return DenseIntElementsAttr::get(
- RankedTensorType::get(elements.size(), builder.getIntegerType(64)),
- elements);
-}
-
-} // namespace
-
-IREE::ConstantOp createArrayConstant(OpBuilder &builder, Location loc,
- llvm::ArrayRef<int64_t> elements) {
- auto elementsAttr = elementsAttrFromArray(builder, elements);
- return builder.create<IREE::ConstantOp>(loc, elementsAttr);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Utils/OpCreationUtils.h b/iree/compiler/Utils/OpCreationUtils.h
deleted file mode 100644
index 7cdd701..0000000
--- a/iree/compiler/Utils/OpCreationUtils.h
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Utility functions related to the creation of new operations. Where possible,
-// use custom builders. These helpers are for situations where a custom builder
-// is not appropriate.
-
-#ifndef IREE_COMPILER_UTILS_OPCREATIONUTILS_H_
-#define IREE_COMPILER_UTILS_OPCREATIONUTILS_H_
-
-#include <cstdint>
-
-#include "iree/compiler/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/Support/LLVM.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-IREE::ConstantOp createArrayConstant(OpBuilder &builder, Location loc,
- llvm::ArrayRef<int64_t> elements);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_UTILS_OPCREATIONUTILS_H_
diff --git a/iree/compiler/Utils/OpUtils.cpp b/iree/compiler/Utils/OpUtils.cpp
deleted file mode 100644
index 5ff88b5..0000000
--- a/iree/compiler/Utils/OpUtils.cpp
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Utils/OpUtils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-void removeDeadOperations(llvm::SetVector<Operation *> &deadOperations) {
- while (!deadOperations.empty()) {
- auto *op = deadOperations.front();
- deadOperations.erase(deadOperations.begin());
- for (auto *operand : op->getOperands()) {
- // TODO(benvanik): add check for op side effects.
- if (operand->hasOneUse()) {
- deadOperations.insert(operand->getDefiningOp());
- }
- }
- op->erase();
- }
-}
-
-void replaceSubsequentUses(Operation *userOp, Value *oldValue,
- Value *newValue) {
- for (auto &use : oldValue->getUses()) {
- if (userOp->isBeforeInBlock(use.getOwner())) {
- use.set(newValue);
- }
- }
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Utils/TranslationUtils.cpp b/iree/compiler/Utils/TranslationUtils.cpp
deleted file mode 100644
index 033485e..0000000
--- a/iree/compiler/Utils/TranslationUtils.cpp
+++ /dev/null
@@ -1,139 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Utils/TranslationUtils.h"
-
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Returns the static registry of translator names to translation functions.
-llvm::StringMap<TranslateExecutableFn>
- &getMutableExecutableTranslationRegistry() {
- static llvm::StringMap<TranslateExecutableFn> registry;
- return registry;
-}
-
-// Returns true if the given |value| matches |pattern| (normal * and ? rules).
-bool matchPattern(StringRef value, StringRef pattern) {
- size_t nextCharIndex = pattern.find_first_of("*?");
- if (nextCharIndex == std::string::npos) {
- return value == pattern;
- } else if (nextCharIndex > 0) {
- if (value.substr(0, nextCharIndex) != pattern.substr(0, nextCharIndex)) {
- return false;
- }
- value = value.substr(nextCharIndex);
- pattern = pattern.substr(nextCharIndex);
- }
- char patternChar = pattern[0];
- if (value.empty() && pattern.empty()) {
- return true;
- } else if (patternChar == '*' && pattern.size() > 1 && value.empty()) {
- return false;
- } else if (patternChar == '*' && pattern.size() == 1) {
- return true;
- } else if (patternChar == '?' || value[0] == patternChar) {
- return matchPattern(value.substr(1), pattern.substr(1));
- } else if (patternChar == '*') {
- return matchPattern(value, pattern.substr(1)) ||
- matchPattern(value.substr(1), pattern);
- }
- return false;
-}
-
-// Force enables IR printing on the |passManager|.
-void enableIRPrinting(PassManager *passManager) {
- auto notVerifier = [](Pass *pass) {
- return pass->getName() != "FunctionVerifier" &&
- pass->getName() != "ModuleVerifier";
- };
- bool printModuleScope = false;
- passManager->enableIRPrinting(/*shouldPrintBeforePass=*/{},
- /*shouldPrintAfterPass=*/notVerifier,
- printModuleScope, llvm::dbgs());
- passManager->disableMultithreading();
-}
-
-} // namespace
-
-ExecutableTranslationRegistration::ExecutableTranslationRegistration(
- llvm::StringRef name, const TranslateExecutableFn &fn) {
- auto ®istry = getMutableExecutableTranslationRegistry();
- if (registry.find(name) != registry.end()) {
- llvm::report_fatal_error(
- "Attempting to overwrite an existing translation function");
- }
- assert(fn && "Attempting to register an empty translation function");
- registry[name] = fn;
-}
-
-const llvm::StringMap<TranslateExecutableFn>
- &getExecutableTranslationRegistry() {
- return getMutableExecutableTranslationRegistry();
-}
-
-std::vector<std::string> matchExecutableTranslationBackendNames(
- llvm::StringRef pattern) {
- std::vector<std::string> matches;
- for (auto &entry : getExecutableTranslationRegistry()) {
- if (matchPattern(entry.getKey(), pattern)) {
- matches.push_back(entry.getKey().str());
- }
- }
- return matches;
-}
-
-std::unique_ptr<PassManager> createPassManager(
- MLIRContext *ctx, const TranslationOptions &translationOptions) {
- std::unique_ptr<PassManager> passManager(new PassManager(ctx));
-
- // Enable IR printing/timing/etc from command line options.
- registerPassManagerCLOptions();
- applyPassManagerCLOptions(*passManager);
-
- // Override with programmatic options.
- if (translationOptions.print_mlir) {
- enableIRPrinting(passManager.get());
- }
-
- return passManager;
-}
-
-LogicalResult runPassPipeline(const TranslationOptions &translationOptions,
- PassManager *passManager, ModuleOp module) {
- if (translationOptions.print_mlir) {
- module.dump();
- }
-
- // Run on the module.
- if (failed(passManager->run(module))) {
- return failure();
- }
-
- if (translationOptions.print_mlir) {
- module.dump();
- }
-
- return success();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Utils/TranslationUtils.h b/iree/compiler/Utils/TranslationUtils.h
deleted file mode 100644
index a6aaefa..0000000
--- a/iree/compiler/Utils/TranslationUtils.h
+++ /dev/null
@@ -1,109 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_
-#define IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_
-
-#include <functional>
-#include <memory>
-
-#include "iree/compiler/IR/StructureOps.h"
-#include "iree/schemas/executable_def_generated.h"
-#include "llvm/ADT/StringMap.h"
-#include "llvm/ADT/StringRef.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/PassManager.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Common translation options for diagnostics and debugging.
-struct TranslationOptions {
- // Enables MLIR IR printing during translation.
- // This can be specified via the -print-ir-before-all and -print-ir-after-all
- // command line flags or overridden programmatically via this flag.
- bool print_mlir = false;
-
- void CopyFrom(const TranslationOptions &other) {
- print_mlir = other.print_mlir;
- }
-};
-
-// Options for iree.module translation for diagnostics and debugging.
-struct ModuleTranslationOptions : public TranslationOptions {
- // Defines which backend translators will be used to translate executables.
- // If empty then all linked in translators will be used.
- // TODO(benvanik): extend to allow specifying entire config blobs via mlir.
- std::vector<std::string> target_backends;
-};
-
-// Options for iree.executable translation for diagnostics and debugging.
-// Target configuration is sourced from the iree.target_config op within the
-// iree.executable.
-struct ExecutableTranslationOptions : public TranslationOptions {};
-
-// Results of a translation operation.
-// May contain zero or more executable defs depending on translation options,
-// defined target configs, and support.
-struct ExecutableTranslationResult {
- std::vector<std::unique_ptr<iree::ExecutableDefT>> executable_defs;
-};
-
-// Registered function that given a set of |executableOps| containing one
-// or more iree.executables will produce zero or more serialized executables.
-//
-// Each iree.executable provided contains one iree.executable_target_config with
-// backend-specific translation information. The translator can decide whether
-// to translate each independently, group them together, etc.
-//
-// The provided |executableOps| can be mutated by the callee and will be
-// preserved for debugging after translation. If any executable in
-// |executableOps| is not used by the translator then it should be erased.
-using TranslateExecutableFn =
- std::function<llvm::Optional<ExecutableTranslationResult>(
- ArrayRef<IREE::ExecutableOp> executableOps,
- ExecutableTranslationOptions options)>;
-
-// Registers an executable translation function.
-struct ExecutableTranslationRegistration {
- ExecutableTranslationRegistration(llvm::StringRef name,
- const TranslateExecutableFn &fn);
-};
-
-// Returns a read-only reference to the translator registry.
-const llvm::StringMap<TranslateExecutableFn>
- &getExecutableTranslationRegistry();
-
-// Returns executable translation backend names matching the given pattern.
-// This accepts wildcards for any delimited value. For example, 'foo-*-bar' will
-// match 'foo-123-bar' and 'foo-456-bar' and 'foo-10?' will match 'foo-101' and
-// 'foo-102'.
-std::vector<std::string> matchExecutableTranslationBackendNames(
- llvm::StringRef pattern);
-
-// Creates a new pass manager initialized with the given options.
-std::unique_ptr<PassManager> createPassManager(
- MLIRContext *ctx, const TranslationOptions &translationOptions);
-
-// Runs an initialized set of passes on the given module.
-LogicalResult runPassPipeline(const TranslationOptions &translationOptions,
- PassManager *passManager, ModuleOp module);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_
diff --git a/iree/compiler/Utils/TypeConversionUtils.cpp b/iree/compiler/Utils/TypeConversionUtils.cpp
deleted file mode 100644
index ae4e11a..0000000
--- a/iree/compiler/Utils/TypeConversionUtils.cpp
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Utils/TypeConversionUtils.h"
-
-#include <cassert>
-
-#include "iree/compiler/IR/Ops.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "mlir/Dialect/StandardOps/Ops.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-Type legalizeType(Type type) {
- if (type.isIndex()) {
- return IntegerType::get(kIndexBitWidth, type.getContext());
- } else if (type.isInteger(1)) {
- return IntegerType::get(kBoolBitWidth, type.getContext());
- } else if (auto memRefType = type.dyn_cast<MemRefType>()) {
- return MemRefType::get(memRefType.getShape(),
- legalizeType(memRefType.getElementType()));
- } else if (auto functionType = type.dyn_cast<FunctionType>()) {
- llvm::SmallVector<Type, 4> inputs;
- for (const auto &oldType : functionType.getInputs()) {
- inputs.push_back(legalizeType(oldType));
- }
- llvm::SmallVector<Type, 4> results;
- for (const auto &oldType : functionType.getResults()) {
- results.push_back(legalizeType(oldType));
- }
- return FunctionType::get(inputs, results, type.getContext());
- }
- return type;
-}
-
-Type LLTypeConverter::convertType(Type type) { return legalizeType(type); }
-
-MemRefType convertTypeToMemRef(Type type) {
- if (type.isIntOrIndexOrFloat()) {
- return MemRefType::get({}, type, {}, 0);
- } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
- return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- } else if (auto memRefType = type.dyn_cast<MemRefType>()) {
- return memRefType;
- } else {
- llvm_unreachable("Unconvertable type");
- }
-}
-
-MemRefType convertTypeToMemRef(Value *value) {
- return convertTypeToMemRef(value->getType());
-}
-
-Type MemRefTypeConverter::convertType(Type type) {
- return convertTypeToMemRef(type);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/hal/BUILD b/iree/hal/BUILD
deleted file mode 100644
index 880f528..0000000
--- a/iree/hal/BUILD
+++ /dev/null
@@ -1,377 +0,0 @@
-# HAL (Hardware Abstraction Layer).
-# Subdirectories contain implementations for different hardware and
-# software backends.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "allocator",
- srcs = ["allocator.cc"],
- hdrs = ["allocator.h"],
- deps = [
- ":buffer",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "api",
- srcs = ["api.cc"],
- hdrs = [
- "api.h",
- "api_detail.h",
- ],
- visibility = ["//visibility:public"],
- deps = [
- ":api_hdrs",
- ":buffer",
- ":buffer_view",
- ":fence",
- ":heap_buffer",
- ":semaphore",
- "//iree/base:api",
- "//iree/base:api_util",
- "//iree/base:shape",
- "//iree/base:tracing",
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "api_hdrs",
- hdrs = ["api.h"],
- deps = [
- "//iree/base:api_hdrs",
- ],
-)
-
-cc_library(
- name = "buffer",
- srcs = ["buffer.cc"],
- hdrs = ["buffer.h"],
- deps = [
- ":resource",
- "//iree/base:bitfield",
- "//iree/base:logging",
- "//iree/base:source_location",
- "//iree/base:status",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@com_google_absl//absl/types:variant",
- ],
-)
-
-cc_test(
- name = "buffer_test",
- srcs = [
- "buffer_mapping_test.cc",
- "buffer_test.cc",
- ],
- deps = [
- ":buffer",
- ":heap_buffer",
- "//iree/base:status",
- "//iree/base:status_matchers",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/types:span",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "buffer_view",
- srcs = ["buffer_view.cc"],
- hdrs = ["buffer_view.h"],
- deps = [
- ":buffer",
- "//iree/base:shape",
- "//iree/base:source_location",
- "//iree/base:status",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_test(
- name = "buffer_view_test",
- srcs = [
- "buffer_view_test.cc",
- ],
- deps = [
- ":buffer",
- ":buffer_view",
- ":heap_buffer",
- "//iree/base:status",
- "//iree/base:status_matchers",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "buffer_view_string_util",
- srcs = ["buffer_view_string_util.cc"],
- hdrs = ["buffer_view_string_util.h"],
- deps = [
- ":allocator",
- ":buffer_view",
- ":heap_buffer",
- "//iree/base:source_location",
- "//iree/base:status",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- ],
-)
-
-cc_test(
- name = "buffer_view_string_util_test",
- srcs = ["buffer_view_string_util_test.cc"],
- deps = [
- ":buffer_view_string_util",
- "//iree/base:status",
- "//iree/base:status_matchers",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "command_buffer",
- srcs = ["command_buffer.cc"],
- hdrs = ["command_buffer.h"],
- deps = [
- ":allocator",
- ":buffer",
- ":buffer_view",
- ":event",
- ":executable",
- ":resource",
- "//iree/base:bitfield",
- "//iree/base:shape",
- "//iree/base:status",
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "command_buffer_validation",
- srcs = ["command_buffer_validation.cc"],
- hdrs = ["command_buffer_validation.h"],
- deps = [
- ":command_buffer",
- "//iree/base:logging",
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "command_queue",
- hdrs = ["command_queue.h"],
- deps = [
- ":command_buffer",
- ":fence",
- ":semaphore",
- "//iree/base:bitfield",
- "//iree/base:status",
- "//iree/base:time",
- "@com_google_absl//absl/time",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "deferred_buffer",
- srcs = ["deferred_buffer.cc"],
- hdrs = ["deferred_buffer.h"],
- deps = [
- ":allocator",
- ":buffer",
- "//iree/base:status",
- ],
-)
-
-cc_test(
- name = "deferred_buffer_test",
- srcs = ["deferred_buffer_test.cc"],
- deps = [
- ":deferred_buffer",
- ":heap_buffer",
- "//iree/base:status_matchers",
- "//iree/hal/testing:mock_allocator",
- "@com_google_absl//absl/memory",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "device",
- hdrs = ["device.h"],
- deps = [
- ":allocator",
- ":buffer",
- ":command_queue",
- ":device_info",
- ":event",
- ":executable_cache",
- ":semaphore",
- "//iree/base:status",
- "//iree/base:time",
- "@com_google_absl//absl/time",
- ],
-)
-
-cc_library(
- name = "device_info",
- hdrs = ["device_info.h"],
- deps = [
- "//iree/base:bitfield",
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "device_manager",
- srcs = ["device_manager.cc"],
- hdrs = ["device_manager.h"],
- deps = [
- ":allocator",
- ":buffer",
- ":command_queue",
- ":device",
- ":device_placement",
- ":executable_format",
- ":fence",
- ":heap_buffer",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:time",
- "//iree/base:tracing",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "device_placement",
- hdrs = ["device_placement.h"],
-)
-
-cc_library(
- name = "driver",
- hdrs = ["driver.h"],
- deps = [
- ":device",
- ":device_info",
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "driver_registry",
- srcs = ["driver_registry.cc"],
- hdrs = ["driver_registry.h"],
- deps = [
- ":driver",
- "//iree/base:init",
- "//iree/base:status",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/synchronization",
- ],
-)
-
-cc_library(
- name = "event",
- hdrs = ["event.h"],
- deps = [
- ":resource",
- ],
-)
-
-cc_library(
- name = "executable",
- hdrs = ["executable.h"],
- deps = [":resource"],
-)
-
-cc_library(
- name = "executable_cache",
- srcs = ["executable_cache.cc"],
- hdrs = ["executable_cache.h"],
- deps = [
- ":executable",
- ":executable_format",
- ":executable_spec",
- "//iree/base:bitfield",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "executable_format",
- hdrs = ["executable_format.h"],
- deps = [
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "executable_spec",
- hdrs = ["executable_spec.h"],
- deps = [
- ":executable_format",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "fence",
- hdrs = ["fence.h"],
- deps = [
- ":resource",
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "heap_buffer",
- srcs = ["heap_buffer.cc"],
- hdrs = ["heap_buffer.h"],
- deps = [
- ":allocator",
- ":buffer",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal/host:host_buffer",
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "resource",
- hdrs = ["resource.h"],
- deps = [
- "//iree/base:ref_ptr",
- ],
-)
-
-cc_library(
- name = "semaphore",
- hdrs = ["semaphore.h"],
- deps = [
- ":resource",
- "@com_google_absl//absl/types:variant",
- ],
-)
-
-cc_library(
- name = "stack_trace",
- hdrs = ["stack_trace.h"],
-)
diff --git a/iree/hal/allocator.cc b/iree/hal/allocator.cc
deleted file mode 100644
index 7232804..0000000
--- a/iree/hal/allocator.cc
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/allocator.h"
-
-#include <cstdint>
-#include <cstdlib>
-#include <string>
-#include <utility>
-
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-
-namespace iree {
-namespace hal {
-
-bool Allocator::CanUseBuffer(Buffer* buffer,
- BufferUsageBitfield intended_usage) const {
- return CanUseBufferLike(buffer->allocator(), buffer->memory_type(),
- buffer->usage(), intended_usage);
-}
-
-StatusOr<ref_ptr<Buffer>> Allocator::AllocateConstant(
- BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) {
- if (AnyBitSet(source_buffer->usage() & BufferUsage::kConstant) &&
- CanUseBuffer(source_buffer.get(), buffer_usage)) {
- // Buffer can be used directly by the device.
- return source_buffer;
- }
-
- IREE_TRACE_SCOPE0("Allocator::AllocateConstant");
-
- // We need to map so we can copy into it.
- buffer_usage |= BufferUsage::kMapping;
- // It will be constant after we write it.
- buffer_usage |= BufferUsage::kConstant;
-
- MemoryTypeBitfield memory_type =
- MemoryType::kDeviceLocal | MemoryType::kHostVisible;
- ASSIGN_OR_RETURN(auto device_buffer, Allocate(memory_type, buffer_usage,
- source_buffer->byte_length()));
- ASSIGN_OR_RETURN(auto source_mapping,
- source_buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
- RETURN_IF_ERROR(device_buffer->WriteData(0, source_mapping.data(),
- source_mapping.byte_length()));
- return device_buffer;
-}
-
-StatusOr<ref_ptr<Buffer>> Allocator::Wrap(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- const void* data,
- size_t data_length) {
- return WrapMutable(memory_type, MemoryAccess::kRead, buffer_usage,
- const_cast<void*>(data), data_length);
-}
-
-StatusOr<ref_ptr<Buffer>> Allocator::WrapMutable(
- MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access,
- BufferUsageBitfield buffer_usage, void* data, size_t data_length) {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Allocator does not support wrapping host memory";
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/allocator.h b/iree/hal/allocator.h
deleted file mode 100644
index ef56240..0000000
--- a/iree/hal/allocator.h
+++ /dev/null
@@ -1,138 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_ALLOCATOR_H_
-#define IREE_HAL_ALLOCATOR_H_
-
-#include <cstddef>
-#include <memory>
-
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer.h"
-
-namespace iree {
-namespace hal {
-
-// Allocates buffers for a particular device memory space.
-//
-// Buffers allocated are only guaranteed to work with the driver that the
-// allocator services. Any attempt to use buffers on drivers they were not
-// allocated from must first be checked with CanUseBuffer.
-//
-// Thread-safe.
-class Allocator {
- public:
- virtual ~Allocator() = default;
-
- // Returns true if the device can use the given buffer for the provided usage.
- // For buffers allocated from this allocator it's expected that the result
- // will always be true. For buffers that originate from another allocator
- // there may be limited support for cross-device usage.
- //
- // Returning false indicates that the buffer must be transferred externally
- // into a buffer compatible with the device this allocator services.
- bool CanUseBuffer(Buffer* buffer, BufferUsageBitfield intended_usage) const;
- virtual bool CanUseBufferLike(Allocator* source_allocator,
- MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- BufferUsageBitfield intended_usage) const = 0;
-
- // Returns true if the allocator can allocate a buffer with the given
- // attributes.
- virtual bool CanAllocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) const = 0;
-
- // Adjusts allocation parameters to be compatible with the allocator.
- // Certain allocators may require particular memory types to function. By
- // adjusting the parameters prior to allocation callers can be sure they are
- // able to successfully Allocate a buffer later on with the same parameters.
- virtual Status MakeCompatible(MemoryTypeBitfield* memory_type,
- BufferUsageBitfield* buffer_usage) const {
- return OkStatus();
- }
-
- // Allocates a buffer from the allocator.
- // Fails if the memory type requested for the given usage cannot be serviced.
- // Callers can use CanAllocate to decide their memory use strategy.
- //
- // The memory type of the buffer returned may differ from the requested value
- // if the device can provide more functionality; for example, if requesting
- // MemoryType::kHostVisible but the memory is really host cached you may get
- // a buffer back with MemoryType::kHostVisible | MemoryType::kHostCached. The
- // only requirement is that the buffer satisfy the required bits.
- virtual StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) = 0;
-
- // Allocates a buffer from the allocator for use as a constant value.
- // The provided |source_buffer| may be returned if the device can use it
- // directly and otherwise will be copied.
- virtual StatusOr<ref_ptr<Buffer>> AllocateConstant(
- BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer);
-
- // Wraps an existing host heap allocation in a buffer.
- // Ownership of the host allocation remains with the caller and the memory
- // must remain valid for so long as the Buffer may be in use.
- // Will have MemoryType::kHostLocal in most cases and may not be usable
- // by the device.
- //
- // The inference optimizer makes assumptions about buffer aliasing based on
- // Buffer instances and because of this wrapping the same host buffer in
- // multiple Buffers will create potential memory aliasing issues that can be
- // difficult to track down. There's no checking as to whether a host buffer
- // has already been wrapped so it's best for callers to ensure this is never
- // possible (the simplest way being to never use Wrap and always just allocate
- // new Buffers).
- //
- // Fails if the allocator cannot access host memory in this way.
- StatusOr<ref_ptr<Buffer>> Wrap(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- const void* data, size_t data_length);
- virtual StatusOr<ref_ptr<Buffer>> WrapMutable(
- MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access,
- BufferUsageBitfield buffer_usage, void* data, size_t data_length);
- template <typename T>
- StatusOr<ref_ptr<Buffer>> Wrap(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- absl::Span<const T> data);
- template <typename T>
- StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield buffer_usage,
- absl::Span<T> data);
-};
-
-// Inline functions and template definitions follow:
-
-template <typename T>
-StatusOr<ref_ptr<Buffer>> Allocator::Wrap(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- absl::Span<const T> data) {
- return Wrap(memory_type, buffer_usage, data.data(), data.size() * sizeof(T));
-}
-
-template <typename T>
-StatusOr<ref_ptr<Buffer>> Allocator::WrapMutable(
- MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access,
- BufferUsageBitfield buffer_usage, absl::Span<T> data) {
- return WrapMutable(memory_type, allowed_access, buffer_usage, data.data(),
- data.size() * sizeof(T));
-}
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_ALLOCATOR_H_
diff --git a/iree/hal/api.cc b/iree/hal/api.cc
deleted file mode 100644
index 79c1a27..0000000
--- a/iree/hal/api.cc
+++ /dev/null
@@ -1,439 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/api.h"
-
-#include "iree/base/api.h"
-#include "iree/base/api_util.h"
-#include "iree/base/shape.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/api_detail.h"
-#include "iree/hal/buffer.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/fence.h"
-#include "iree/hal/heap_buffer.h"
-#include "iree/hal/semaphore.h"
-
-namespace iree {
-namespace hal {
-
-//===----------------------------------------------------------------------===//
-// iree::hal::Buffer
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t iree_hal_buffer_subspan(
- iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
- iree_device_size_t byte_length, iree_allocator_t allocator,
- iree_hal_buffer_t** out_buffer) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_subspan");
-
- if (!out_buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_buffer = nullptr;
-
- if (!buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- auto handle = add_ref(reinterpret_cast<Buffer*>(buffer));
-
- IREE_API_ASSIGN_OR_RETURN(auto new_handle,
- Buffer::Subspan(handle, byte_offset, byte_length));
-
- *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(new_handle.release());
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t
-iree_hal_buffer_retain(iree_hal_buffer_t* buffer) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_retain");
- auto* handle = reinterpret_cast<Buffer*>(buffer);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->AddReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t
-iree_hal_buffer_release(iree_hal_buffer_t* buffer) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_release");
- auto* handle = reinterpret_cast<Buffer*>(buffer);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->ReleaseReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_device_size_t
-iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_byte_length");
- const auto* handle = reinterpret_cast<const Buffer*>(buffer);
- CHECK(handle) << "NULL buffer handle";
- return handle->byte_length();
-}
-
-IREE_API_EXPORT iree_status_t
-iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
- iree_device_size_t byte_length) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_zero");
- auto* handle = reinterpret_cast<Buffer*>(buffer);
- if (!buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- IREE_API_RETURN_IF_ERROR(handle->Fill8(byte_offset, byte_length, 0));
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t
-iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
- iree_device_size_t byte_length, const void* pattern,
- iree_host_size_t pattern_length) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_fill");
- auto* handle = reinterpret_cast<Buffer*>(buffer);
- if (!buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- IREE_API_RETURN_IF_ERROR(
- handle->Fill(byte_offset, byte_length, pattern, pattern_length));
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t iree_hal_buffer_read_data(
- iree_hal_buffer_t* buffer, iree_device_size_t source_offset,
- void* target_buffer, iree_device_size_t data_length) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_read_data");
- auto* handle = reinterpret_cast<Buffer*>(buffer);
- if (!buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- IREE_API_RETURN_IF_ERROR(
- handle->ReadData(source_offset, target_buffer, data_length));
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t iree_hal_buffer_write_data(
- iree_hal_buffer_t* buffer, iree_device_size_t target_offset,
- const void* source_buffer, iree_device_size_t data_length) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_write_data");
- auto* handle = reinterpret_cast<Buffer*>(buffer);
- if (!buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- IREE_API_RETURN_IF_ERROR(
- handle->WriteData(target_offset, source_buffer, data_length));
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t iree_hal_buffer_map(
- iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access,
- iree_device_size_t element_offset, iree_device_size_t element_length,
- iree_hal_mapped_memory_t* out_mapped_memory) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_map");
-
- if (!out_mapped_memory) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- std::memset(out_mapped_memory, 0, sizeof(*out_mapped_memory));
-
- auto* buffer_handle = reinterpret_cast<Buffer*>(buffer);
- if (!buffer_handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- IREE_API_ASSIGN_OR_RETURN(
- auto mapping, buffer_handle->MapMemory<uint8_t>(
- static_cast<MemoryAccessBitfield>(memory_access),
- element_offset, element_length));
-
- static_assert(sizeof(iree_hal_mapped_memory_t::reserved) >=
- sizeof(MappedMemory<uint8_t>),
- "C mapped memory struct must have large enough storage for the "
- "matching C++ struct");
- auto* mapping_storage =
- reinterpret_cast<MappedMemory<uint8_t>*>(out_mapped_memory->reserved);
- *mapping_storage = std::move(mapping);
-
- out_mapped_memory->contents = {const_cast<uint8_t*>(mapping_storage->data()),
- mapping_storage->size()};
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t iree_hal_buffer_unmap(
- iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_map");
- auto* buffer_handle = reinterpret_cast<Buffer*>(buffer);
- if (!buffer_handle || !mapped_memory) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- auto* mapping =
- reinterpret_cast<MappedMemory<uint8_t>*>(mapped_memory->reserved);
- mapping->reset();
-
- std::memset(mapped_memory, 0, sizeof(*mapped_memory));
- return IREE_STATUS_OK;
-}
-
-//===----------------------------------------------------------------------===//
-// iree::hal::HeapBuffer
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate(
- iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage,
- iree_host_size_t allocation_size, iree_allocator_t contents_allocator,
- iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) {
- IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate");
-
- if (!out_buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_buffer = nullptr;
-
- if (!allocation_size) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- auto handle = HeapBuffer::Allocate(
- static_cast<MemoryTypeBitfield>(memory_type),
- static_cast<BufferUsageBitfield>(usage), allocation_size);
-
- *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(
- static_cast<Buffer*>(handle.release()));
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy(
- iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage,
- iree_hal_memory_access_t allowed_access, iree_byte_span_t contents,
- iree_allocator_t contents_allocator, iree_allocator_t allocator,
- iree_hal_buffer_t** out_buffer) {
- IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate_copy");
-
- if (!out_buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_buffer = nullptr;
-
- if (!contents.data || !contents.data_length) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- auto handle = HeapBuffer::AllocateCopy(
- static_cast<BufferUsageBitfield>(usage),
- static_cast<MemoryAccessBitfield>(allowed_access), contents.data,
- contents.data_length);
-
- *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(handle.release());
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap(
- iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access,
- iree_hal_buffer_usage_t usage, iree_byte_span_t contents,
- iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) {
- IREE_TRACE_SCOPE0("iree_hal_heap_buffer_wrap");
-
- if (!out_buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_buffer = nullptr;
-
- if (!contents.data || !contents.data_length) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- auto handle =
- HeapBuffer::WrapMutable(static_cast<MemoryTypeBitfield>(memory_type),
- static_cast<MemoryAccessBitfield>(allowed_access),
- static_cast<BufferUsageBitfield>(usage),
- contents.data, contents.data_length);
-
- *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(handle.release());
- return IREE_STATUS_OK;
-}
-
-//===----------------------------------------------------------------------===//
-// iree::hal::BufferView
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t iree_hal_buffer_view_create(
- iree_hal_buffer_t* buffer, iree_shape_t shape, int8_t element_size,
- iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_view_create");
-
- if (!out_buffer_view) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_buffer_view = nullptr;
-
- if (!buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- } else if (shape.rank > kMaxRank || element_size <= 0) {
- return IREE_STATUS_OUT_OF_RANGE;
- }
-
- // Allocate and initialize the iree_hal_buffer_view struct.
- iree_hal_buffer_view* handle = nullptr;
- IREE_API_RETURN_IF_API_ERROR(allocator.alloc(
- allocator.self, sizeof(*handle), reinterpret_cast<void**>(&handle)));
- new (handle) iree_hal_buffer_view();
- handle->allocator = allocator;
-
- handle->impl.buffer = add_ref(reinterpret_cast<Buffer*>(buffer));
- handle->impl.shape = {shape.dims, shape.rank};
- handle->impl.element_size = element_size;
-
- *out_buffer_view = reinterpret_cast<iree_hal_buffer_view_t*>(handle);
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t
-iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_view_retain");
- auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->AddReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t
-iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_view_release");
- auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->ReleaseReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t iree_hal_buffer_view_assign(
- iree_hal_buffer_view_t* buffer_view, iree_hal_buffer_t* buffer,
- iree_shape_t shape, int8_t element_size) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_view_assign");
- auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->impl.buffer.reset();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t
-iree_hal_buffer_view_reset(iree_hal_buffer_view_t* buffer_view) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_view_reset");
- auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->impl.buffer.reset();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_hal_buffer_t* iree_hal_buffer_view_buffer(
- const iree_hal_buffer_view_t* buffer_view) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_view_buffer");
- const auto* handle =
- reinterpret_cast<const iree_hal_buffer_view*>(buffer_view);
- CHECK(handle) << "NULL buffer_view handle";
- return reinterpret_cast<iree_hal_buffer_t*>(handle->impl.buffer.get());
-}
-
-IREE_API_EXPORT iree_status_t iree_hal_buffer_view_shape(
- const iree_hal_buffer_view_t* buffer_view, iree_shape_t* out_shape) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_view_shape");
-
- if (!out_shape) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- out_shape->rank = 0;
-
- const auto* handle =
- reinterpret_cast<const iree_hal_buffer_view*>(buffer_view);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- const auto& shape = handle->impl.shape;
- return ToApiShape(shape, out_shape);
-}
-
-IREE_API_EXPORT int8_t
-iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view) {
- IREE_TRACE_SCOPE0("iree_hal_buffer_view_element_size");
- const auto* handle =
- reinterpret_cast<const iree_hal_buffer_view*>(buffer_view);
- CHECK(handle) << "NULL buffer_view handle";
- return handle->impl.element_size;
-}
-
-//===----------------------------------------------------------------------===//
-// iree::hal::Semaphore
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t
-iree_hal_semaphore_retain(iree_hal_semaphore_t* semaphore) {
- IREE_TRACE_SCOPE0("iree_hal_semaphore_retain");
- auto* handle = reinterpret_cast<Semaphore*>(semaphore);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->AddReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t
-iree_hal_semaphore_release(iree_hal_semaphore_t* semaphore) {
- IREE_TRACE_SCOPE0("iree_hal_semaphore_release");
- auto* handle = reinterpret_cast<Semaphore*>(semaphore);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->ReleaseReference();
- return IREE_STATUS_OK;
-}
-
-//===----------------------------------------------------------------------===//
-// iree::hal::Fence
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t iree_hal_fence_retain(iree_hal_fence_t* fence) {
- IREE_TRACE_SCOPE0("iree_hal_fence_retain");
- auto* handle = reinterpret_cast<Fence*>(fence);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->AddReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t iree_hal_fence_release(iree_hal_fence_t* fence) {
- IREE_TRACE_SCOPE0("iree_hal_fence_release");
- auto* handle = reinterpret_cast<Fence*>(fence);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->ReleaseReference();
- return IREE_STATUS_OK;
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/api.h b/iree/hal/api.h
deleted file mode 100644
index 6347519..0000000
--- a/iree/hal/api.h
+++ /dev/null
@@ -1,366 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// See iree/base/api.h for documentation on the API conventions used.
-
-#ifndef IREE_HAL_API_H_
-#define IREE_HAL_API_H_
-
-#include <stdint.h>
-
-#include "iree/base/api.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-//===----------------------------------------------------------------------===//
-// Types and Enums
-//===----------------------------------------------------------------------===//
-
-typedef struct iree_hal_buffer iree_hal_buffer_t;
-typedef struct iree_hal_buffer_view iree_hal_buffer_view_t;
-typedef struct iree_hal_semaphore iree_hal_semaphore_t;
-typedef struct iree_hal_fence iree_hal_fence_t;
-
-// Reference to a buffer's mapped memory.
-typedef struct {
- // Contents of the buffer. Behavior is undefined if an access is performed
- // whose type was not specified during mapping.
- iree_byte_span_t contents;
-
- // Used internally - do not modify.
- uint64_t reserved[8];
-} iree_hal_mapped_memory_t;
-
-// A bitfield specifying properties for a memory type.
-typedef enum {
- IREE_HAL_MEMORY_TYPE_NONE = 0,
-
- // Memory is lazily allocated by the device and only exists transiently.
- // This is the optimal mode for memory used only within a single command
- // buffer. Transient buffers, even if they have
- // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE set, should be treated as device-local
- // and opaque as they may have no memory attached to them outside of the time
- // they are being evaluated on devices.
- //
- // This flag can be treated as a hint in most cases; allocating a buffer with
- // it set _may_ return the same as if it had not be set. Certain allocation
- // routines may use the hint to more tightly control reuse or defer wiring the
- // memory.
- IREE_HAL_MEMORY_TYPE_TRANSIENT = 1 << 0,
-
- // Memory allocated with this type can be mapped for host access using
- // iree_hal_buffer_map.
- IREE_HAL_MEMORY_TYPE_HOST_VISIBLE = 1 << 1,
-
- // The host cache management commands MappedMemory::Flush and
- // MappedMemory::Invalidate are not needed to flush host writes
- // to the device or make device writes visible to the host, respectively.
- IREE_HAL_MEMORY_TYPE_HOST_COHERENT = 1 << 2,
-
- // Memory allocated with this type is cached on the host. Host memory
- // accesses to uncached memory are slower than to cached memory, however
- // uncached memory is always host coherent. MappedMemory::Flush must be used
- // to ensure the device has visibility into any changes made on the host and
- // Invalidate must be used to ensure the host has visibility into any changes
- // made on the device.
- IREE_HAL_MEMORY_TYPE_HOST_CACHED = 1 << 3,
-
- // Memory is accessible as normal host allocated memory.
- IREE_HAL_MEMORY_TYPE_HOST_LOCAL =
- IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_COHERENT,
-
- // Memory allocated with this type is visible to the device for execution.
- // Being device visible does not mean the same thing as
- // IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL. Though an allocation may be visible to
- // the device and therefore useable for execution it may require expensive
- // mapping or implicit transfers.
- IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE = 1 << 4,
-
- // Memory allocated with this type is the most efficient for device access.
- // Devices may support using memory that is not device local via
- // IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE but doing so can incur non-trivial
- // performance penalties. Device local memory, on the other hand, is
- // guaranteed to be fast for all operations.
- IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL =
- IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | (1 << 5),
-} iree_hal_memory_type_t;
-
-// A bitfield specifying how memory will be accessed in a mapped memory region.
-typedef enum {
- // Memory is not mapped.
- IREE_HAL_MEMORY_ACCESS_NONE = 0,
- // Memory will be read.
- // If a buffer is only mapped for reading it may still be possible to write to
- // it but the results will be undefined (as it may present coherency issues).
- IREE_HAL_MEMORY_ACCESS_READ = 1 << 0,
- // Memory will be written.
- // If a buffer is only mapped for writing it may still be possible to read
- // from it but the results will be undefined or incredibly slow (as it may
- // be mapped by the driver as uncached).
- IREE_HAL_MEMORY_ACCESS_WRITE = 1 << 1,
- // Memory will be discarded prior to mapping.
- // The existing contents will be undefined after mapping and must be written
- // to ensure validity.
- IREE_HAL_MEMORY_ACCESS_DISCARD = 1 << 2,
- // Memory will be discarded and completely overwritten in a single operation.
- IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE =
- IREE_HAL_MEMORY_ACCESS_WRITE | IREE_HAL_MEMORY_ACCESS_DISCARD,
- // Memory may have any operation performed on it.
- IREE_HAL_MEMORY_ACCESS_ALL = IREE_HAL_MEMORY_ACCESS_READ |
- IREE_HAL_MEMORY_ACCESS_WRITE |
- IREE_HAL_MEMORY_ACCESS_DISCARD,
-} iree_hal_memory_access_t;
-
-// Bitfield that defines how a buffer is intended to be used.
-// Usage allows the driver to appropriately place the buffer for more
-// efficient operations of the specified types.
-typedef enum {
- IREE_HAL_BUFFER_USAGE_NONE = 0,
-
- // The buffer, once defined, will not be mapped or updated again.
- // This should be used for uniform parameter values such as runtime
- // constants for executables. Doing so may allow drivers to inline values or
- // represent them in command buffers more efficiently (avoiding memory reads
- // or swapping, etc).
- IREE_HAL_BUFFER_USAGE_CONSTANT = 1 << 0,
-
- // The buffer can be used as the source or target of a transfer command
- // (CopyBuffer, UpdateBuffer, etc).
- //
- // If |IREE_HAL_BUFFER_USAGE_MAPPING| is not specified drivers may safely
- // assume that the host may never need visibility of this buffer as all
- // accesses will happen via command buffers.
- IREE_HAL_BUFFER_USAGE_TRANSFER = 1 << 1,
-
- // The buffer can be mapped by the host application for reading and writing.
- //
- // As mapping may require placement in special address ranges or system
- // calls to enable visibility the driver can use the presence (or lack of)
- // this flag to perform allocation-type setup and avoid initial mapping
- // overhead.
- IREE_HAL_BUFFER_USAGE_MAPPING = 1 << 2,
-
- // The buffer can be provided as an input or output to an executable.
- // Buffers of this type may be directly used by drivers during dispatch.
- IREE_HAL_BUFFER_USAGE_DISPATCH = 1 << 3,
-
- // Buffer may be used for any operation.
- IREE_HAL_BUFFER_USAGE_ALL = IREE_HAL_BUFFER_USAGE_TRANSFER |
- IREE_HAL_BUFFER_USAGE_MAPPING |
- IREE_HAL_BUFFER_USAGE_DISPATCH,
-} iree_hal_buffer_usage_t;
-
-//===----------------------------------------------------------------------===//
-// iree::hal::Buffer
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Returns a reference to a subspan of the |buffer|.
-// If |byte_length| is IREE_WHOLE_BUFFER the remaining bytes in the buffer after
-// |byte_offset| (possibly 0) will be selected.
-//
-// The parent buffer will remain alive for the lifetime of the subspan
-// returned. If the subspan is a small portion this may cause additional
-// memory to remain allocated longer than required.
-//
-// Returns the given |buffer| if the requested span covers the entire range.
-// |out_buffer| must be released by the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_subspan(
- iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
- iree_device_size_t byte_length, iree_allocator_t allocator,
- iree_hal_buffer_t** out_buffer);
-
-// Retains the given |buffer| for the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_buffer_retain(iree_hal_buffer_t* buffer);
-
-// Releases the given |buffer| from the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_buffer_release(iree_hal_buffer_t* buffer);
-
-// Returns the size in bytes of the buffer.
-IREE_API_EXPORT iree_device_size_t IREE_API_CALL
-iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer);
-
-// Sets a range of the buffer to binary zero.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
- iree_device_size_t byte_length);
-
-// Sets a range of the buffer to the given value.
-// Only |pattern_length| values with 1, 2, or 4 bytes are supported.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
- iree_device_size_t byte_length, const void* pattern,
- iree_host_size_t pattern_length);
-
-// Reads a block of data from the buffer at the given offset.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_read_data(
- iree_hal_buffer_t* buffer, iree_device_size_t source_offset,
- void* target_buffer, iree_device_size_t data_length);
-
-// Writes a block of byte data into the buffer at the given offset.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_write_data(
- iree_hal_buffer_t* buffer, iree_device_size_t target_offset,
- const void* source_buffer, iree_device_size_t data_length);
-
-// Maps the buffer to be accessed as a host pointer into |out_mapped_memory|.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_map(
- iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access,
- iree_device_size_t element_offset, iree_device_size_t element_length,
- iree_hal_mapped_memory_t* out_mapped_memory);
-
-// Unmaps the buffer as was previously mapped to |mapped_memory|.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_unmap(
- iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-//===----------------------------------------------------------------------===//
-// iree::hal::HeapBuffer
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Allocates a zeroed host heap buffer of the given size.
-// The buffer contents will be allocated with |contents_allocator| while
-// |allocator| is used for the iree_hal_buffer_t.
-//
-// Returns a buffer allocated with malloc that may not be usable by devices
-// without copies. |memory_type| should be set to
-// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases.
-// |out_buffer| must be released by the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate(
- iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage,
- iree_host_size_t allocation_size, iree_allocator_t contents_allocator,
- iree_allocator_t allocator, iree_hal_buffer_t** out_buffer);
-
-// Allocates a host heap buffer with a copy of the given data.
-// The buffer contents will be allocated with |contents_allocator| while
-// |allocator| is used for the iree_hal_buffer_t.
-//
-// Returns a buffer allocated with malloc that may not be usable by devices
-// without copies. |memory_type| should be set to
-// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases.
-// |out_buffer| must be released by the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy(
- iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage,
- iree_hal_memory_access_t allowed_access, iree_byte_span_t contents,
- iree_allocator_t contents_allocator, iree_allocator_t allocator,
- iree_hal_buffer_t** out_buffer);
-
-// Wraps an existing host heap allocation in a buffer.
-// Ownership of the host allocation remains with the caller and the memory
-// must remain valid for so long as the iree_hal_buffer_t may be in use.
-//
-// Returns a buffer allocated with malloc that may not be usable by devices
-// without copies. |memory_type| should be set to
-// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases.
-// |out_buffer| must be released by the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap(
- iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access,
- iree_hal_buffer_usage_t usage, iree_byte_span_t contents,
- iree_allocator_t allocator, iree_hal_buffer_t** out_buffer);
-
-// TODO(benvanik): add a wrap that takes an allocator just for the buffer.
-
-#endif // IREE_API_NO_PROTOTYPES
-
-//===----------------------------------------------------------------------===//
-// iree::hal::BufferView
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Creates a buffer view with the given |buffer|, which may be nullptr.
-// |out_buffer_view| must be released by the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_create(
- iree_hal_buffer_t* buffer, iree_shape_t shape, int8_t element_size,
- iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view);
-
-// Retains the given |buffer_view| for the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view);
-
-// Releases the given |buffer_view| from the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view);
-
-// Sets the buffer view to point at the new |buffer| with the given metadata.
-// To clear a buffer_view to empty use iree_hal_buffer_view_reset.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_assign(
- iree_hal_buffer_view_t* buffer_view, iree_hal_buffer_t* buffer,
- iree_shape_t shape, int8_t element_size);
-
-// Resets the buffer view to have an empty buffer and shape.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_buffer_view_reset(iree_hal_buffer_view_t* buffer_view);
-
-// Returns the buffer underlying the buffer view.
-// The caller must retain the returned buffer if they want to continue using it.
-IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL
-iree_hal_buffer_view_buffer(const iree_hal_buffer_view_t* buffer_view);
-
-// Returns the shape of the buffer view in |out_shape|.
-// If there is not enough space in |out_shape| to store all dimensions then
-// IREE_STATUS_OUT_OF_RANGE is returned and |out_shape|.rank is set to the rank.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_shape(
- const iree_hal_buffer_view_t* buffer_view, iree_shape_t* out_shape);
-
-// Returns the size of each element in the buffer view in bytes.
-IREE_API_EXPORT int8_t IREE_API_CALL
-iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-//===----------------------------------------------------------------------===//
-// iree::hal::Semaphore
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Retains the given |semaphore| for the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_semaphore_retain(iree_hal_semaphore_t* semaphore);
-
-// Releases the given |semaphore| from the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_semaphore_release(iree_hal_semaphore_t* semaphore);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-//===----------------------------------------------------------------------===//
-// iree::hal::Fence
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Retains the given |fence| for the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_fence_retain(iree_hal_fence_t* fence);
-
-// Releases the given |fence| from the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_hal_fence_release(iree_hal_fence_t* fence);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-
-#endif // IREE_HAL_API_H_
diff --git a/iree/hal/api_detail.h b/iree/hal/api_detail.h
deleted file mode 100644
index 9bc3047..0000000
--- a/iree/hal/api_detail.h
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-// Additional definitions for internal users of the api. This should only
-// be included from internal implementation files.
-
-#ifndef IREE_HAL_API_DETAIL_H_
-#define IREE_HAL_API_DETAIL_H_
-
-#include "iree/hal/api.h"
-#include "iree/hal/buffer_view.h"
-
-namespace iree {
-namespace hal {
-
-// In the API, buffer views are ref objects, and this allows parts of the
-// API outside of the HAL to work with them.
-struct iree_hal_buffer_view : public RefObject<iree_hal_buffer_view> {
- BufferView impl;
- iree_allocator_t allocator;
-
- static void Delete(iree_hal_buffer_view* ptr) {
- ptr->impl.buffer.reset();
- ptr->allocator.free(ptr->allocator.self, ptr);
- }
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif
diff --git a/iree/hal/buffer.cc b/iree/hal/buffer.cc
deleted file mode 100644
index 49a0602..0000000
--- a/iree/hal/buffer.cc
+++ /dev/null
@@ -1,549 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/buffer.h"
-
-#include <algorithm>
-#include <atomic>
-#include <cstdint>
-#include <cstring>
-#include <sstream>
-
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/types/variant.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace hal {
-
-#if HAS_IREE_BUFFER_DEBUG_NAME
-namespace {
-// Used for diagnostic purposes only as a default buffer name.
-std::atomic<int> next_buffer_id_{0};
-} // namespace
-#endif // HAS_IREE_BUFFER_DEBUG_NAME
-
-std::string MemoryTypeString(MemoryTypeBitfield memory_type) {
- return FormatBitfieldValue(memory_type,
- {
- // Combined:
- {MemoryType::kHostLocal, "kHostLocal"},
- {MemoryType::kDeviceLocal, "kDeviceLocal"},
- // Separate:
- {MemoryType::kTransient, "kTransient"},
- {MemoryType::kHostVisible, "kHostVisible"},
- {MemoryType::kHostCoherent, "kHostCoherent"},
- {MemoryType::kHostCached, "kHostCached"},
- {MemoryType::kDeviceVisible, "kDeviceVisible"},
- });
-}
-
-std::string MemoryAccessString(MemoryAccessBitfield memory_access) {
- return FormatBitfieldValue(memory_access,
- {
- // Combined:
- {MemoryAccess::kAll, "kAll"},
- {MemoryAccess::kDiscardWrite, "kDiscardWrite"},
- // Separate:
- {MemoryAccess::kRead, "kRead"},
- {MemoryAccess::kWrite, "kWrite"},
- {MemoryAccess::kDiscard, "kDiscard"},
- });
-}
-
-std::string BufferUsageString(BufferUsageBitfield buffer_usage) {
- return FormatBitfieldValue(buffer_usage,
- {
- // Combined:
- {BufferUsage::kAll, "kAll"},
- // Separate:
- {BufferUsage::kConstant, "kConstant"},
- {BufferUsage::kTransfer, "kTransfer"},
- {BufferUsage::kMapping, "kMapping"},
- {BufferUsage::kDispatch, "kDispatch"},
- });
-}
-
-// Special router for buffers that just reference other buffers.
-// We keep this out of the base Buffer so that it's a bit easier to track
-// delegation.
-class SubspanBuffer : public Buffer {
- public:
- SubspanBuffer(ref_ptr<Buffer> parent_buffer, device_size_t byte_offset,
- device_size_t byte_length)
- : Buffer(parent_buffer->allocator(), parent_buffer->memory_type(),
- parent_buffer->allowed_access(), parent_buffer->usage(),
- parent_buffer->allocation_size(), byte_offset, byte_length) {
- allocated_buffer_ = parent_buffer.get();
- parent_buffer_ = std::move(parent_buffer);
- }
-
- protected:
- Status FillImpl(device_size_t byte_offset, device_size_t byte_length,
- const void* pattern, device_size_t pattern_length) override {
- return parent_buffer_->FillImpl(byte_offset, byte_length, pattern,
- pattern_length);
- }
-
- Status ReadDataImpl(device_size_t source_offset, void* data,
- device_size_t data_length) override {
- return parent_buffer_->ReadDataImpl(source_offset, data, data_length);
- }
-
- Status WriteDataImpl(device_size_t target_offset, const void* data,
- device_size_t data_length) override {
- return parent_buffer_->WriteDataImpl(target_offset, data, data_length);
- }
-
- Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer,
- device_size_t source_offset,
- device_size_t data_length) override {
- return parent_buffer_->CopyDataImpl(target_offset, source_buffer,
- source_offset, data_length);
- }
-
- Status MapMemoryImpl(MappingMode mapping_mode,
- MemoryAccessBitfield memory_access,
- device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void** out_data) override {
- return parent_buffer_->MapMemoryImpl(mapping_mode, memory_access,
- local_byte_offset, local_byte_length,
- out_data);
- }
-
- Status UnmapMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length, void* data) override {
- return parent_buffer_->UnmapMemoryImpl(local_byte_offset, local_byte_length,
- data);
- }
-
- Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) override {
- return parent_buffer_->InvalidateMappedMemoryImpl(local_byte_offset,
- local_byte_length);
- }
-
- Status FlushMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) override {
- return parent_buffer_->FlushMappedMemoryImpl(local_byte_offset,
- local_byte_length);
- }
-};
-
-// static
-StatusOr<ref_ptr<Buffer>> Buffer::Subspan(const ref_ptr<Buffer>& buffer,
- device_size_t byte_offset,
- device_size_t byte_length) {
- RETURN_IF_ERROR(buffer->CalculateRange(byte_offset, byte_length, &byte_offset,
- &byte_length));
- if (byte_offset == 0 && byte_length == buffer->byte_length()) {
- // Asking for the same buffer.
- return add_ref(buffer);
- }
-
- // To avoid heavy nesting of subspans that just add indirection we go to the
- // parent buffer directly. If we wanted better accounting (to track where
- // buffers came from) we'd want to avoid this but I'm not sure that's worth
- // the super deep indirection that could arise.
- if (buffer->allocated_buffer() != buffer.get()) {
- CHECK(buffer->parent_buffer_);
- return Buffer::Subspan(buffer->parent_buffer_, byte_offset, byte_length);
- } else {
- return {make_ref<SubspanBuffer>(add_ref(buffer), byte_offset, byte_length)};
- }
-}
-
-// static
-Buffer::Overlap Buffer::TestOverlap(
- Buffer* lhs_buffer, device_size_t lhs_offset, device_size_t lhs_length,
- Buffer* rhs_buffer, device_size_t rhs_offset, device_size_t rhs_length) {
- if (lhs_buffer->allocated_buffer() != rhs_buffer->allocated_buffer()) {
- // Not even the same buffers.
- return Overlap::kDisjoint;
- }
- // Resolve offsets into the underlying allocation.
- device_size_t lhs_alloc_offset = lhs_buffer->byte_offset() + lhs_offset;
- device_size_t rhs_alloc_offset = rhs_buffer->byte_offset() + rhs_offset;
- device_size_t lhs_alloc_length = lhs_length == kWholeBuffer
- ? lhs_buffer->byte_length() - lhs_offset
- : lhs_length;
- device_size_t rhs_alloc_length = rhs_length == kWholeBuffer
- ? rhs_buffer->byte_length() - rhs_offset
- : rhs_length;
- if (!lhs_alloc_length || !rhs_alloc_length) {
- return Overlap::kDisjoint;
- }
- if (lhs_alloc_offset == rhs_alloc_offset &&
- lhs_alloc_length == rhs_alloc_length) {
- return Overlap::kComplete;
- }
- return lhs_alloc_offset + lhs_alloc_length > rhs_alloc_offset &&
- rhs_alloc_offset + rhs_alloc_length > lhs_alloc_offset
- ? Overlap::kPartial
- : Overlap::kDisjoint;
-}
-
-// static
-bool Buffer::DoesOverlap(Buffer* lhs_buffer, device_size_t lhs_offset,
- device_size_t lhs_length, Buffer* rhs_buffer,
- device_size_t rhs_offset, device_size_t rhs_length) {
- return TestOverlap(lhs_buffer, lhs_offset, lhs_length, rhs_buffer, rhs_offset,
- rhs_length) != Overlap::kDisjoint;
-}
-
-Buffer::Buffer(Allocator* allocator, MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
- device_size_t allocation_size, device_size_t byte_offset,
- device_size_t byte_length)
- : allocated_buffer_(const_cast<Buffer*>(this)),
- allocator_(allocator),
- memory_type_(memory_type),
- allowed_access_(allowed_access),
- usage_(usage),
- allocation_size_(allocation_size),
- byte_offset_(byte_offset),
- byte_length_(byte_length) {
-#if HAS_IREE_BUFFER_DEBUG_NAME
- // Default name for logging.
- // It'd be nice to defer this until it's required but that would require
- // synchronization or something.
- const char* debug_name_prefix = "";
- if ((memory_type_ & MemoryType::kHostLocal) == MemoryType::kHostLocal) {
- debug_name_prefix = "host_buffer_";
- } else if ((memory_type_ & MemoryType::kDeviceLocal) ==
- MemoryType::kDeviceLocal) {
- // TODO(benvanik): include allocator ID to differentiate devices.
- debug_name_prefix = "device_buffer_";
- }
- debug_name_ = absl::StrCat(debug_name_prefix, next_buffer_id_++);
-#endif // HAS_IREE_BUFFER_DEBUG_NAME
-}
-
-Buffer* Buffer::allocated_buffer() const noexcept {
- Buffer* allocated_buffer = allocated_buffer_;
- while (allocated_buffer != this &&
- allocated_buffer != allocated_buffer->allocated_buffer()) {
- allocated_buffer = allocated_buffer->allocated_buffer();
- }
- return allocated_buffer;
-}
-
-std::string Buffer::DebugString() const {
- std::ostringstream stream;
- stream << allocated_buffer()->debug_name() << "["
- << (allocation_size() == kWholeBuffer
- ? "?"
- : std::to_string(allocation_size()))
- << "].";
- if (AnyBitSet(memory_type() & MemoryType::kTransient)) stream << "Z";
- if ((memory_type() & MemoryType::kHostLocal) == MemoryType::kHostLocal) {
- stream << "h";
- } else {
- if (AnyBitSet(memory_type() & MemoryType::kHostVisible)) stream << "v";
- if (AnyBitSet(memory_type() & MemoryType::kHostCoherent)) stream << "x";
- if (AnyBitSet(memory_type() & MemoryType::kHostCached)) stream << "c";
- }
- if ((memory_type() & MemoryType::kDeviceLocal) == MemoryType::kDeviceLocal) {
- stream << "D";
- } else {
- if (AnyBitSet(memory_type() & MemoryType::kDeviceVisible)) stream << "V";
- }
- stream << ".";
- if (AnyBitSet(usage() & BufferUsage::kConstant)) stream << "c";
- if (AnyBitSet(usage() & BufferUsage::kTransfer)) stream << "t";
- if (AnyBitSet(usage() & BufferUsage::kMapping)) stream << "m";
- if (AnyBitSet(usage() & BufferUsage::kDispatch)) stream << "d";
- if (byte_offset_ || byte_length_ != allocation_size_) {
- stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1)
- << ")";
- }
- return stream.str();
-}
-
-std::string Buffer::DebugStringShort() const {
- // TODO(benvanik): figure out what's most useful here. Maybe a long variant?
- std::ostringstream stream;
- stream << allocated_buffer()->debug_name() << "["
- << (allocation_size() == kWholeBuffer
- ? "?"
- : std::to_string(allocation_size()))
- << "]";
- if (byte_offset_ || byte_length_ != allocation_size_) {
- stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1)
- << ")";
- }
- return stream.str();
-}
-
-Status Buffer::ValidateCompatibleMemoryType(
- MemoryTypeBitfield memory_type) const {
- if ((memory_type_ & memory_type) != memory_type) {
- // Missing one or more bits.
- return PermissionDeniedErrorBuilder(IREE_LOC)
- << "Buffer memory type is not compatible with the requested "
- "operation; buffer has "
- << MemoryTypeString(memory_type_) << ", operation requires "
- << MemoryTypeString(memory_type);
- }
- return OkStatus();
-}
-
-Status Buffer::ValidateAccess(MemoryAccessBitfield memory_access) const {
- if (!AnyBitSet(memory_access &
- (MemoryAccess::kRead | MemoryAccess::kWrite))) {
- // No actual access bits defined.
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Memory access must specify one or more of kRead or kWrite";
- } else if ((allowed_access_ & memory_access) != memory_access) {
- // Bits must match exactly.
- return PermissionDeniedErrorBuilder(IREE_LOC)
- << "The buffer does not support the requested access type; buffer "
- "allows "
- << MemoryAccessString(allowed_access_) << ", operation requires "
- << MemoryAccessString(memory_access);
- }
- return OkStatus();
-}
-
-Status Buffer::ValidateUsage(BufferUsageBitfield usage) const {
- if ((usage_ & usage) != usage) {
- // Missing one or more bits.
- return PermissionDeniedErrorBuilder(IREE_LOC)
- << "Requested usage was not specified when the buffer was "
- "allocated; buffer allows "
- << BufferUsageString(usage_) << ", operation requires "
- << BufferUsageString(usage);
- }
- return OkStatus();
-}
-
-Status Buffer::CalculateRange(device_size_t base_offset,
- device_size_t max_length, device_size_t offset,
- device_size_t length,
- device_size_t* out_adjusted_offset,
- device_size_t* out_adjusted_length) {
- // Check if the start of the range runs off the end of the buffer.
- if (offset > max_length) {
- *out_adjusted_offset = 0;
- if (out_adjusted_length) *out_adjusted_length = 0;
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Attempted to access an address off the end of the valid buffer "
- "range (offset="
- << offset << ", length=" << length
- << ", buffer byte_length=" << max_length << ")";
- }
-
- // Handle length as kWholeBuffer by adjusting it (if allowed).
- if (length == kWholeBuffer && !out_adjusted_length) {
- *out_adjusted_offset = 0;
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "kWholeBuffer may only be used with buffer ranges, not external "
- "pointer ranges";
- }
-
- // Calculate the real ranges adjusted for our region within the allocation.
- device_size_t adjusted_offset = base_offset + offset;
- device_size_t adjusted_length =
- length == kWholeBuffer ? max_length - offset : length;
- if (adjusted_length == 0) {
- // Fine to have a zero length.
- *out_adjusted_offset = adjusted_offset;
- if (out_adjusted_length) *out_adjusted_length = adjusted_length;
- return OkStatus();
- }
-
- // Check if the end runs over the allocation.
- device_size_t end = offset + adjusted_length - 1;
- if (end >= max_length) {
- *out_adjusted_offset = 0;
- if (out_adjusted_length) *out_adjusted_length = 0;
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Attempted to access an address outside of the valid buffer "
- "range (offset="
- << offset << ", adjusted_length=" << adjusted_length
- << ", end=" << end << ", buffer byte_length=" << max_length << ")";
- }
-
- *out_adjusted_offset = adjusted_offset;
- if (out_adjusted_length) *out_adjusted_length = adjusted_length;
- return OkStatus();
-}
-
-Status Buffer::CalculateRange(device_size_t offset, device_size_t length,
- device_size_t* out_adjusted_offset,
- device_size_t* out_adjusted_length) const {
- return CalculateRange(byte_offset_, byte_length_, offset, length,
- out_adjusted_offset, out_adjusted_length);
-}
-
-Status Buffer::CalculateLocalRange(device_size_t max_length,
- device_size_t offset, device_size_t length,
- device_size_t* out_adjusted_offset,
- device_size_t* out_adjusted_length) {
- return CalculateRange(0, max_length, offset, length, out_adjusted_offset,
- out_adjusted_length);
-}
-
-Status Buffer::Fill(device_size_t byte_offset, device_size_t byte_length,
- const void* pattern, device_size_t pattern_length) {
- // If not host visible we'll need to issue command buffers.
- RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
- RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite));
- RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
- RETURN_IF_ERROR(
- CalculateRange(byte_offset, byte_length, &byte_offset, &byte_length));
- if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Fill patterns must be 1, 2, or 4 bytes";
- }
- if ((byte_offset % pattern_length) != 0 ||
- (byte_length % pattern_length) != 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Attempting to fill a range with " << pattern_length
- << " byte values that is not "
- "aligned (offset="
- << byte_offset << ", length=" << byte_length << ")";
- }
- if (byte_length == 0) {
- return OkStatus(); // No-op.
- }
- const uint32_t kZero = 0;
- if (std::memcmp(pattern, &kZero, pattern_length) == 0) {
- // We can turn all-zero values into single-byte fills as that can be much
- // faster on devices (doing a fill8 vs fill32).
- pattern_length = 1;
- }
- return FillImpl(byte_offset, byte_length, pattern, pattern_length);
-}
-
-Status Buffer::ReadData(device_size_t source_offset, void* data,
- device_size_t data_length) {
- // If not host visible we'll need to issue command buffers.
- RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
- RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead));
- RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
- RETURN_IF_ERROR(CalculateRange(source_offset, data_length, &source_offset));
- if (data_length == 0) {
- return OkStatus(); // No-op.
- }
- return ReadDataImpl(source_offset, data, data_length);
-}
-
-Status Buffer::WriteData(device_size_t target_offset, const void* data,
- device_size_t data_length) {
- // If not host visible we'll need to issue command buffers.
- RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
- RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite));
- RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
- RETURN_IF_ERROR(CalculateRange(target_offset, data_length, &target_offset));
- if (data_length == 0) {
- return OkStatus(); // No-op.
- }
- return WriteDataImpl(target_offset, data, data_length);
-}
-
-Status Buffer::CopyData(device_size_t target_offset, Buffer* source_buffer,
- device_size_t source_offset,
- device_size_t data_length) {
- RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
- RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite));
- RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
- RETURN_IF_ERROR(
- source_buffer->ValidateCompatibleMemoryType(MemoryType::kHostVisible));
- RETURN_IF_ERROR(source_buffer->ValidateAccess(MemoryAccess::kRead));
- RETURN_IF_ERROR(source_buffer->ValidateUsage(BufferUsage::kMapping));
-
- // We need to validate both buffers.
- device_size_t source_data_length = data_length;
- device_size_t target_data_length = data_length;
- device_size_t adjusted_source_offset;
- RETURN_IF_ERROR(source_buffer->CalculateRange(
- source_offset, source_data_length, &adjusted_source_offset,
- &source_data_length));
- RETURN_IF_ERROR(CalculateRange(target_offset, target_data_length,
- &target_offset, &target_data_length));
- device_size_t adjusted_data_length;
- if (data_length == kWholeBuffer) {
- // Whole buffer copy requested - that could mean either, so take the min.
- adjusted_data_length = std::min(source_data_length, target_data_length);
- } else {
- // Specific length requested - validate that we have matching lengths.
- CHECK_EQ(source_data_length, target_data_length);
- adjusted_data_length = source_data_length;
- }
-
- // Elide zero length copies.
- if (adjusted_data_length == 0) {
- return OkStatus();
- }
-
- // Check for overlap.
- if (this == source_buffer &&
- adjusted_source_offset <= target_offset + adjusted_data_length &&
- target_offset <= adjusted_source_offset + adjusted_data_length) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Source and target ranges overlap within the same buffer";
- }
-
- return CopyDataImpl(target_offset, source_buffer, source_offset,
- adjusted_data_length);
-}
-
-Status Buffer::MapMemory(MappingMode mapping_mode,
- MemoryAccessBitfield memory_access,
- device_size_t* byte_offset, device_size_t* byte_length,
- void** out_data) {
- RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
- RETURN_IF_ERROR(ValidateAccess(memory_access));
- RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
- RETURN_IF_ERROR(
- CalculateRange(*byte_offset, *byte_length, byte_offset, byte_length));
- *out_data = nullptr;
- return MapMemoryImpl(mapping_mode, memory_access, *byte_offset, *byte_length,
- out_data);
-}
-
-Status Buffer::UnmapMemory(device_size_t local_byte_offset,
- device_size_t local_byte_length, void* data) {
- RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
- RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
- // NOTE: local_byte_offset/local_byte_length are already adjusted.
- return UnmapMemoryImpl(local_byte_offset, local_byte_length, data);
-}
-
-Status Buffer::InvalidateMappedMemory(device_size_t local_byte_offset,
- device_size_t local_byte_length) {
- RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible));
- if (AnyBitSet(memory_type_ & MemoryType::kHostCoherent)) {
- return PermissionDeniedErrorBuilder(IREE_LOC)
- << "Buffer memory type is coherent and invalidation is not required";
- }
- RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
- // NOTE: local_byte_offset/local_byte_length are already adjusted.
- return InvalidateMappedMemoryImpl(local_byte_offset, local_byte_length);
-}
-
-Status Buffer::FlushMappedMemory(device_size_t local_byte_offset,
- device_size_t local_byte_length) {
- RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible |
- MemoryType::kHostCached));
- RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping));
- // NOTE: local_byte_offset/local_byte_length are already adjusted.
- return FlushMappedMemoryImpl(local_byte_offset, local_byte_length);
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/buffer.h b/iree/hal/buffer.h
deleted file mode 100644
index 20065e6..0000000
--- a/iree/hal/buffer.h
+++ /dev/null
@@ -1,903 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Allocated memory buffer wrapper type and utilities.
-//
-// Buffers are the basic unit of memory used by the inference system. They may
-// be allocated such that they are accessible from the host (normal C++ code
-// running on the main CPU), a particular device (such as an accelerator) or
-// family of devices, or from some mix of all of those.
-//
-// The type of memory a buffer is allocated within has implications on it's
-// performance and lifetime. For example if an application attempts to use a
-// host-allocated buffer (MemoryType::kHostLocal) on an accelerator with
-// discrete memory the accelerator may either be unable to access the memory or
-// take a non-trivial performance hit when attempting to do so (involving
-// setting up kernel mappings, doing DMA transfers, etc). Likewise, trying to
-// access a device-allocated buffer (MemoryType::kDeviceLocal) may incur similar
-// overhead or not be possible at all. This may be due to restrictions in the
-// memory visibility, address spaces, mixed endianness or pointer widths,
-// and other weirdness.
-//
-// The memory types (defined by a bitfield of MemoryType values) that a
-// particular context (host or device) may use vary from device to device and
-// must be queried by the application when allocating buffers. It's strongly
-// recommended that the most specific memory type be set as possible. For
-// example allocating a buffer with MemoryType::kHostCoherent even when it will
-// never be used in a way that requires coherency may occupy address space
-// reservations or memory mapping that would otherwise not be needed.
-//
-// As buffers may sometimes not be accessible from the host the base Buffer type
-// does not allow for direct void* access and instead buffers must be either
-// manipulated using utility functions (such as ReadData or WriteData) or by
-// mapping them into a host-accessible address space via MapMemory. Buffer must
-// be unmapped before any command may use it.
-//
-// Buffers may map (roughly) 1:1 with an allocation either from the host heap or
-// a device. Buffer::Subspan can be used to reference subspans of buffers like
-// absl::Span - though unlike absl::Span the returned Buffer holds a reference
-// to the parent buffer.
-
-#ifndef IREE_HAL_BUFFER_H_
-#define IREE_HAL_BUFFER_H_
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <utility>
-
-#include "absl/types/span.h"
-#include "absl/types/variant.h"
-#include "iree/base/bitfield.h"
-#include "iree/base/logging.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/hal/resource.h"
-
-// Only enable debug names in non-opt modes (unless the user forces it on).
-#if !defined(NDEBUG) && !defined(HAS_IREE_BUFFER_DEBUG_NAME)
-#define HAS_IREE_BUFFER_DEBUG_NAME 1
-#endif // !NDEBUG
-
-namespace iree {
-
-// std::size_t equivalent that is the size as used on device.
-// As the device may have a larger memory address space than the host we treat
-// all byte offsets as this type instead of the host-specified size_t.
-using device_size_t = uint64_t;
-
-// When used as a length value in functions causes the length to be the entire
-// remaining buffer from the specified offset.
-constexpr device_size_t kWholeBuffer = ~0ull;
-
-} // namespace iree
-
-namespace iree {
-namespace hal {
-
-class Allocator;
-template <typename T>
-class MappedMemory;
-
-// A bitfield specifying properties for a memory type.
-enum class MemoryType : uint32_t {
- kNone = 0,
-
- // Memory is lazily allocated by the device and only exists transiently.
- // This is the optimal mode for memory used only within a single command
- // buffer. Transient buffers, even if they have kHostVisible set, should be
- // treated as device-local and opaque as they may have no memory attached to
- // them outside of the time they are being evaluated on devices.
- //
- // This flag can be treated as a hint in most cases; allocating a buffer with
- // it set _may_ return the same as if it had not be set. Certain allocation
- // routines may use the hint to more tightly control reuse or defer wiring the
- // memory.
- kTransient = 1 << 0,
-
- // Memory allocated with this type can be mapped for host access using
- // Buffer::MapMemory.
- kHostVisible = 1 << 1,
-
- // The host cache management commands MappedMemory::Flush and
- // MappedMemory::Invalidate are not needed to flush host writes
- // to the device or make device writes visible to the host, respectively.
- kHostCoherent = 1 << 2,
-
- // Memory allocated with this type is cached on the host. Host memory
- // accesses to uncached memory are slower than to cached memory, however
- // uncached memory is always host coherent. MappedMemory::Flush must be used
- // to ensure the device has visibility into any changes made on the host and
- // Invalidate must be used to ensure the host has visibility into any changes
- // made on the device.
- kHostCached = 1 << 3,
-
- // Memory is accessible as normal host allocated memory.
- kHostLocal = kHostVisible | kHostCoherent,
-
- // Memory allocated with this type is visible to the device for execution.
- // Being device visible does not mean the same thing as kDeviceLocal. Though
- // an allocation may be visible to the device and therefore useable for
- // execution it may require expensive mapping or implicit transfers.
- kDeviceVisible = 1 << 4,
-
- // Memory allocated with this type is the most efficient for device access.
- // Devices may support using memory that is not device local via
- // kDeviceVisible but doing so can incur non-trivial performance penalties.
- // Device local memory, on the other hand, is guaranteed to be fast for all
- // operations.
- kDeviceLocal = kDeviceVisible | (1 << 5),
-};
-IREE_BITFIELD(MemoryType);
-using MemoryTypeBitfield = MemoryType;
-std::string MemoryTypeString(MemoryTypeBitfield memory_type);
-
-// A bitfield specifying how memory will be accessed in a mapped memory region.
-enum class MemoryAccess : uint32_t {
- // Memory is not mapped.
- kNone = 0,
-
- // Memory will be read.
- // If a buffer is only mapped for reading it may still be possible to write to
- // it but the results will be undefined (as it may present coherency issues).
- kRead = 1 << 0,
-
- // Memory will be written.
- // If a buffer is only mapped for writing it may still be possible to read
- // from it but the results will be undefined or incredibly slow (as it may
- // be mapped by the driver as uncached).
- kWrite = 1 << 1,
-
- // Memory will be discarded prior to mapping.
- // The existing contents will be undefined after mapping and must be written
- // to ensure validity.
- kDiscard = 1 << 2,
-
- // Memory will be discarded and completely overwritten in a single operation.
- kDiscardWrite = kWrite | kDiscard,
-
- // Memory may have any operation performed on it.
- kAll = kRead | kWrite | kDiscard,
-};
-IREE_BITFIELD(MemoryAccess);
-using MemoryAccessBitfield = MemoryAccess;
-std::string MemoryAccessString(MemoryAccessBitfield memory_access);
-
-// Bitfield that defines how a buffer is intended to be used.
-// Usage allows the driver to appropriately place the buffer for more
-// efficient operations of the specified types.
-enum class BufferUsage {
- kNone = 0,
-
- // The buffer, once defined, will not be mapped or updated again.
- // This should be used for uniform parameter values such as runtime
- // constants for executables. Doing so may allow drivers to inline values or
- // represent them in command buffers more efficiently (avoiding memory reads
- // or swapping, etc).
- kConstant = 1 << 0,
-
- // The buffer can be used as the source or target of a transfer command
- // (CopyBuffer, UpdateBuffer, etc).
- //
- // If |kMapping| is not specified drivers may safely assume that the host
- // may never need visibility of this buffer as all accesses will happen via
- // command buffers.
- kTransfer = 1 << 1,
-
- // The buffer can be mapped by the host application for reading and writing.
- //
- // As mapping may require placement in special address ranges or system
- // calls to enable visibility the driver can use the presence (or lack of)
- // this flag to perform allocation-type setup and avoid initial mapping
- // overhead.
- kMapping = 1 << 2,
-
- // The buffer can be provided as an input or output to an executable.
- // Buffers of this type may be directly used by drivers during dispatch.
- kDispatch = 1 << 3,
-
- // Buffer may be used for any operation.
- kAll = kTransfer | kMapping | kDispatch,
-};
-IREE_BITFIELD(BufferUsage);
-using BufferUsageBitfield = BufferUsage;
-std::string BufferUsageString(BufferUsageBitfield buffer_usage);
-
-// A memory buffer.
-// Buffers have a specific memory_type that is used to describe the capabilities
-// and behavior of the backing memory of the buffer. Buffers may be any mix of
-// host-accessible, host-coherent, or device-accessible for various usages.
-// Depending on these memory types the buffers may be mapped for access on the
-// host as memory though certain restrictions may be imposed.
-//
-// See MemoryType for more information about the types and what operations they
-// support.
-class Buffer : public Resource {
- public:
- // Returns a reference to a subspan of the buffer.
- // If |byte_length| is kWholeBuffer the remaining bytes in the buffer after
- // |byte_offset| (possibly 0) will be selected.
- //
- // The parent buffer will remain alive for the lifetime of the subspan
- // returned. If the subspan is a small portion this may cause additional
- // memory to remain allocated longer than required.
- //
- // Returns the given |buffer| if the requested span covers the entire range.
- static StatusOr<ref_ptr<Buffer>> Subspan(const ref_ptr<Buffer>& buffer,
- device_size_t byte_offset,
- device_size_t byte_length);
-
- // Overlap test results.
- enum class Overlap {
- // No overlap between the two buffers.
- kDisjoint,
- // Partial overlap between the two buffers.
- kPartial,
- // Complete overlap between the two buffers (they are the same).
- kComplete,
- };
-
- // Tests whether the given buffers overlap, including support for subspans.
- // kWholeBuffer may be used for |lhs_length| and/or |rhs_length| to use the
- // lengths of those buffers, respectively.
- static Overlap TestOverlap(Buffer* lhs_buffer, device_size_t lhs_offset,
- device_size_t lhs_length, Buffer* rhs_buffer,
- device_size_t rhs_offset,
- device_size_t rhs_length);
-
- // Returns true if the two buffer ranges overlap at all.
- static bool DoesOverlap(Buffer* lhs_buffer, device_size_t lhs_offset,
- device_size_t lhs_length, Buffer* rhs_buffer,
- device_size_t rhs_offset, device_size_t rhs_length);
-
- // Disallow copies (as copying requires real work).
- Buffer(const Buffer&) = delete;
- Buffer& operator=(const Buffer&) = delete;
-
- ~Buffer() override = default;
-
-#if HAS_IREE_BUFFER_DEBUG_NAME
- // Optionally populated name useful for logging a persistent name for the
- // buffer.
- absl::string_view debug_name() const { return debug_name_; }
- void set_debug_name(std::string debug_name) {
- debug_name_ = std::move(debug_name);
- }
-#else
- absl::string_view debug_name() const { return ""; }
- void set_debug_name(std::string debug_name) {}
-#endif // HAS_IREE_BUFFER_DEBUG_NAME
-
- // Memory allocator this buffer was allocated from.
- // May be nullptr if the buffer has no particular allocator and should be
- // assumed to be allocated from the host heap.
- constexpr Allocator* allocator() const {
- return allocated_buffer_ == this ? allocator_
- : allocated_buffer_->allocator();
- }
-
- // Memory type this buffer is allocated from.
- MemoryTypeBitfield memory_type() const { return memory_type_; }
-
- // Memory access operations allowed on the buffer.
- MemoryAccessBitfield allowed_access() const { return allowed_access_; }
-
- // Bitfield describing how the buffer is to be used.
- BufferUsageBitfield usage() const { return usage_; }
-
- // Returns the underlying buffer that represents the allocated memory for the
- // Buffer. In most cases this is the buffer itself but for buffer subspan
- // references it will point to the parent buffer.
- Buffer* allocated_buffer() const noexcept;
-
- // Size of the resource memory allocation in bytes.
- // This may be rounded up from the originally requested size or the ideal
- // size for the resource based on device restrictions.
- constexpr device_size_t allocation_size() const {
- return allocated_buffer_ == this ? allocation_size_
- : allocated_buffer_->allocation_size();
- }
-
- // Range within the underlying allocation this buffer occupies.
- // For buffers that map 1:1 with an allocation this should be
- // [0, allocation_size()), however may still differ if the allocation needed
- // to be aligned.
- //
- // The offset is most often manipulated by Subspan, however it's important to
- // note that the offset may not be what was passed to Subspan as it refers to
- // the offset in the original ancestor buffer, not the buffer from which the
- // subspan was taken.
- constexpr device_size_t byte_offset() const noexcept { return byte_offset_; }
- constexpr device_size_t byte_length() const noexcept { return byte_length_; }
-
- // TODO(benvanik): add debug_name.
-
- // Returns a longer debug string describing the buffer and its attributes.
- std::string DebugString() const;
- // Returns a short debug string describing the buffer.
- std::string DebugStringShort() const;
-
- // Sets a range of the buffer to the given value.
- // This requires that the resource was allocated with
- // MemoryType::kHostVisible and BufferUsage::kMapping.
- // If |byte_length| is kWholeBuffer the remaining bytes in the buffer after
- // |byte_offset| (possibly 0) will be filled.
- //
- // The |byte_offset| and |byte_length| must be aligned to the size of the fill
- // value. Multi-byte values will be written in host order for host buffers and
- // device order for device buffers.
- //
- // Only |pattern_length| values with 1, 2, or 4 bytes are supported.
- //
- // Fails if the write could not be performed; either the bounds are out of
- // range or the memory type does not support writing in this way.
- Status Fill(device_size_t byte_offset, device_size_t byte_length,
- const void* pattern, device_size_t pattern_length);
- template <typename T>
- Status Fill8(device_size_t byte_offset, device_size_t byte_length, T value);
- template <typename T>
- Status Fill16(device_size_t byte_offset, device_size_t byte_length, T value);
- template <typename T>
- Status Fill32(device_size_t byte_offset, device_size_t byte_length, T value);
- template <typename T>
- Status Fill8(T value);
- template <typename T>
- Status Fill16(T value);
- template <typename T>
- Status Fill32(T value);
-
- // Reads a block of byte data from the resource at the given offset.
- // This requires that the resource was allocated with
- // MemoryType::kHostVisible and BufferUsage::kMapping.
- //
- // Fails if the read could not be performed; either the bounds are out of
- // range or the memory type does not support reading in this way.
- Status ReadData(device_size_t source_offset, void* data,
- device_size_t data_length);
-
- // Writes a block of byte data into the resource at the given offset.
- // This requires that the resource was allocated with
- // MemoryType::kHostVisible and BufferUsage::kMapping.
- //
- // Fails if the write could not be performed; either the bounds are out of
- // range or the memory type does not support writing in this way.
- Status WriteData(device_size_t target_offset, const void* data,
- device_size_t data_length);
-
- // Copies data from the provided source_buffer into the buffer.
- // This requires that the resource was allocated with
- // MemoryType::kHostVisible and BufferUsage::kMapping.
- // The source and destination may be the same buffer but the ranges must not
- // overlap (a la memcpy).
- //
- // Fails if the write could not be performed; either the bounds are out of
- // range or the memory type does not support writing in this way.
- Status CopyData(device_size_t target_offset, Buffer* source_buffer,
- device_size_t source_offset, device_size_t data_length);
- Status CopyData(device_size_t target_offset, Buffer* source_buffer) {
- return CopyData(target_offset, source_buffer, 0, kWholeBuffer);
- }
-
- // Maps the resource memory for direct access from the host.
- // This requires that the resource was allocated with
- // MemoryType::kHostVisible and BufferUsage::kMapping.
- //
- // If MemoryType::kHostCoherent was not specified then explicit
- // Invalidate and Flush calls must be used to control visibility of the data
- // on the device. If MemoryType::kHostCached is not set callers must not
- // attempt to read from the mapped memory as doing so may produce undefined
- // results and/or ultra slow reads.
- //
- // If the MemoryAccess::kDiscard bit is set when mapping for writes the caller
- // guarantees that they will be overwriting all data in the mapped range. This
- // is used as a hint to the device that the prior contents are no longer
- // required and can enable optimizations that save on synchronization and
- // readback. Note however that it is strictly a hint and the contents are not
- // guaranteed to be zeroed during mapping.
- //
- // This allows mapping the memory as a C++ type. Care must be taken to ensure
- // the data layout in C++ matches the expected data layout in the executables
- // that consume this data. For simple primitives like uint8_t or float this is
- // usually not a problem however struct packing may have many restrictions.
- //
- // The returned mapping should be unmapped when it is no longer required.
- // Unmapping does not implicitly flush.
- //
- // Fails if the memory could not be mapped due to mapping exhaustion, invalid
- // arguments, or unsupported memory types.
- //
- // Example:
- // ASSIGN_OR_RETURN(auto mapping, buffer->MapForRead<MyStruct>());
- // mapping[5].foo = 3;
- // std::memcpy(mapping.data(), source_data, mapping.size());
- // mapping.reset();
- template <typename T>
- StatusOr<MappedMemory<T>> MapMemory(
- MemoryAccessBitfield memory_access, device_size_t element_offset = 0,
- device_size_t element_length = kWholeBuffer);
-
- protected:
- template <typename T>
- friend class MappedMemory;
-
- // Defines the mode of a MapMemory operation.
- enum class MappingMode {
- // The call to MapMemory will always be matched with UnmapMemory.
- kScoped,
- };
-
- Buffer(Allocator* allocator, MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
- device_size_t allocation_size, device_size_t byte_offset,
- device_size_t byte_length);
-
- // Allows subclasses to override the allowed access bits.
- // This should only be done when known safe by the allocation scheme.
- void set_allowed_access(MemoryAccessBitfield allowed_access) {
- allowed_access_ = allowed_access;
- }
-
- // Sets a range of the buffer to the given value.
- // State and parameters have already been validated. For the >8bit variants
- // the offset and length have already been validated to be aligned to the
- // natural alignment of the type.
- virtual Status FillImpl(device_size_t byte_offset, device_size_t byte_length,
- const void* pattern,
- device_size_t pattern_length) = 0;
-
- // Reads a block of byte data from the resource at the given offset.
- // State and parameters have already been validated.
- virtual Status ReadDataImpl(device_size_t source_offset, void* data,
- device_size_t data_length) = 0;
-
- // Writes a block of byte data into the resource at the given offset.
- // State and parameters have already been validated.
- virtual Status WriteDataImpl(device_size_t target_offset, const void* data,
- device_size_t data_length) = 0;
-
- // Copies a block of byte data into the resource at the given offset.
- // State and parameters have already been validated.
- virtual Status CopyDataImpl(device_size_t target_offset,
- Buffer* source_buffer,
- device_size_t source_offset,
- device_size_t data_length) = 0;
-
- // Maps memory directly.
- // The output data pointer will be properly aligned to the start of the data.
- // |local_byte_offset| and |local_byte_length| are the adjusted values that
- // should map into the local space of the buffer.
- //
- // Fails if the memory could not be mapped (invalid access type, invalid
- // range, or unsupported memory type).
- // State and parameters have already been validated.
- virtual Status MapMemoryImpl(MappingMode mapping_mode,
- MemoryAccessBitfield memory_access,
- device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void** out_data) = 0;
-
- // Unmaps previously mapped memory.
- // No-op if the memory is not mapped. As this is often used in destructors
- // we can't rely on failures here propagating with anything but CHECK/DCHECK.
- // State and parameters have already been validated.
- virtual Status UnmapMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void* data) = 0;
-
- // Invalidates ranges of non-coherent memory from the host caches.
- // Use this before reading from non-coherent memory.
- // This guarantees that device writes to the memory ranges provided are
- // visible on the host.
- // This is only required for memory types without kHostCoherent set.
- // State and parameters have already been validated.
- virtual Status InvalidateMappedMemoryImpl(
- device_size_t local_byte_offset, device_size_t local_byte_length) = 0;
-
- // Flushes ranges of non-coherent memory from the host caches.
- // Use this after writing to non-coherent memory.
- // This guarantees that host writes to the memory ranges provided are made
- // available for device access.
- // This is only required for memory types without kHostCoherent set.
- // State and parameters have already been validated.
- virtual Status FlushMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) = 0;
-
- // Validates the given buffer range and adjusts the offset and length if the
- // provided length is kWholeBuffer or the buffer is offset within its
- // allocation. This calculates the range in the given domain without adjusting
- // to any particular buffer base offsets.
- static Status CalculateLocalRange(device_size_t max_length,
- device_size_t offset, device_size_t length,
- device_size_t* out_adjusted_offset,
- device_size_t* out_adjusted_length);
-
- private:
- friend class Allocator;
-
- // This is not great and deserves cleanup.
- friend class DeferredBuffer;
- friend class SubspanBuffer;
- friend class HeapBuffer;
-
- // Maps memory directly.
- // The byte offset and byte length may be adjusted for device alignment.
- // The output data pointer will be properly aligned to the start of the data.
- // Fails if the memory could not be mapped (invalid access type, invalid
- // range, or unsupported memory type).
- Status MapMemory(MappingMode mapping_mode, MemoryAccessBitfield memory_access,
- device_size_t* byte_offset, device_size_t* byte_length,
- void** out_data);
-
- // Unmaps previously mapped memory.
- // No-op if the memory is not mapped. As this is often used in destructors
- // we can't rely on failures here propagating with anything but CHECK/DCHECK.
- Status UnmapMemory(device_size_t local_byte_offset,
- device_size_t local_byte_length, void* data);
-
- // Invalidates ranges of non-coherent memory from the host caches.
- // Use this before reading from non-coherent memory.
- // This guarantees that device writes to the memory ranges provided are
- // visible on the host.
- // This is only required for memory types without kHostCoherent set.
- Status InvalidateMappedMemory(device_size_t local_byte_offset,
- device_size_t local_byte_length);
-
- // Flushes ranges of non-coherent memory from the host caches.
- // Use this after writing to non-coherent memory.
- // This guarantees that host writes to the memory ranges provided are made
- // available for device access.
- // This is only required for memory types without kHostCoherent set.
- Status FlushMappedMemory(device_size_t local_byte_offset,
- device_size_t local_byte_length);
-
- // Returns a failure if the memory type the buffer was allocated from is not
- // compatible with the given type.
- Status ValidateCompatibleMemoryType(MemoryTypeBitfield memory_type) const;
- // Returns a failure if the buffer memory type or usage disallows the given
- // access type.
- Status ValidateAccess(MemoryAccessBitfield memory_access) const;
- // Returns a failure if the buffer was not allocated for the given usage.
- Status ValidateUsage(BufferUsageBitfield usage) const;
- // Validates the given buffer range and optionally adjusts the offset and
- // length if the provided length is kWholeBuffer or the buffer is offset
- // within its allocation.
- static Status CalculateRange(device_size_t base_offset,
- device_size_t max_length, device_size_t offset,
- device_size_t length,
- device_size_t* out_adjusted_offset,
- device_size_t* out_adjusted_length = nullptr);
- Status CalculateRange(device_size_t offset, device_size_t length,
- device_size_t* out_adjusted_offset,
- device_size_t* out_adjusted_length = nullptr) const;
-
- // Points to either this or parent_buffer_.get().
- Buffer* allocated_buffer_ = nullptr;
-
- Allocator* allocator_ = nullptr;
- MemoryTypeBitfield memory_type_ = MemoryType::kNone;
- MemoryAccessBitfield allowed_access_ = MemoryAccess::kNone;
- BufferUsageBitfield usage_ = BufferUsage::kNone;
-
- device_size_t allocation_size_ = 0;
- device_size_t byte_offset_ = 0;
- device_size_t byte_length_ = 0;
-
-#if HAS_IREE_BUFFER_DEBUG_NAME
- // Friendly name for the buffer used in DebugString. May be set by the app or
- // auto generated.
- std::string debug_name_;
-#endif // HAS_IREE_BUFFER_DEBUG_NAME
-
- // Defined when this buffer is a subspan of another buffer.
- ref_ptr<Buffer> parent_buffer_;
-};
-
-// A memory mapping RAII object.
-// The mapping will stay active until it is reset and will retain the buffer.
-template <typename T>
-class MappedMemory {
- public:
- using unspecified_bool_type = const T* MappedMemory<T>::*;
-
- MappedMemory() = default;
- MappedMemory(MemoryAccessBitfield access, ref_ptr<Buffer> buffer,
- device_size_t byte_offset, device_size_t byte_length,
- device_size_t element_size, T* data);
-
- // Allow moving but disallow copying as the mapping is stateful.
- MappedMemory(MappedMemory&& rhs) noexcept;
- MappedMemory& operator=(MappedMemory&& rhs) noexcept;
- MappedMemory(const MappedMemory&) = delete;
- MappedMemory& operator=(const MappedMemory&) = delete;
-
- ~MappedMemory();
-
- // The buffer resource that this mapping references.
- const ref_ptr<Buffer>& buffer() const noexcept { return buffer_; }
- // Offset, in bytes, into the resource allocation.
- // This value is *informative only*, as it may vary from device to device.
- device_size_t byte_offset() const noexcept { return byte_offset_; }
- // Length, in bytes, of the resource mapping.
- // This may be larger than the originally requested length due to alignment.
- // This value is *informative only*, as it may vary from device to device.
- device_size_t byte_length() const noexcept { return byte_length_; }
-
- // True if the mapping is empty.
- bool empty() const noexcept { return element_size_ == 0; }
- // The size of the mapping as requested in elements.
- size_t size() const noexcept { return static_cast<size_t>(element_size_); }
-
- // Returns a read-only pointer to the mapped memory.
- // This will be nullptr if the mapping failed or the mapping is not readable.
- const T* data() const noexcept;
- absl::Span<const T> contents() const noexcept { return {data(), size()}; }
-
- // Returns a mutable pointer to the mapped memory.
- // This will be nullptr if the mapping failed or the mapping is not writable.
- // If the mapping was not made with read access it may still be possible to
- // read from this memory but behavior is undefined.
- T* mutable_data() noexcept;
- absl::Span<T> mutable_contents() noexcept { return {mutable_data(), size()}; }
-
- // Equivalent to absl::Span::subspan().
- // May return a 0-length span.
- // Fails if the buffer is not mapped or not mapped for the requested access.
- StatusOr<absl::Span<const T>> Subspan(
- device_size_t element_offset = 0,
- device_size_t element_length = kWholeBuffer) const noexcept;
- StatusOr<absl::Span<T>> MutableSubspan(
- device_size_t element_offset = 0,
- device_size_t element_length = kWholeBuffer) noexcept;
-
- // Accesses an element in the mapped memory.
- // Must be called with a valid index in [0, size()).
- const T& operator[](device_size_t i) const noexcept { return data_[i]; }
-
- // Invalidates a range of non-coherent elements from the host caches.
- Status Invalidate(device_size_t element_offset = 0,
- device_size_t element_length = kWholeBuffer) const;
-
- // Flushes a range of non-coherent elements from the host caches.
- Status Flush(device_size_t element_offset = 0,
- device_size_t element_length = kWholeBuffer);
-
- // Unmaps the mapped memory.
- // The memory will not be implicitly flushed when unmapping.
- void reset();
-
- private:
- Status ValidateAccess(MemoryAccessBitfield memory_access) const;
- Status CalculateDataRange(device_size_t element_offset,
- device_size_t element_length,
- device_size_t* out_adjusted_element_offset,
- device_size_t* out_adjusted_element_length) const;
-
- MemoryAccessBitfield access_ = MemoryAccess::kNone;
- ref_ptr<Buffer> buffer_;
- device_size_t byte_offset_ = 0;
- device_size_t byte_length_ = 0;
- device_size_t element_size_ = 0;
- T* data_ = nullptr;
-};
-
-// Inline functions and template definitions follow:
-
-template <typename T>
-Status Buffer::Fill8(device_size_t byte_offset, device_size_t byte_length,
- T value) {
- auto sized_value = reinterpret_cast<uint8_t*>(&value);
- return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value));
-}
-
-template <typename T>
-Status Buffer::Fill16(device_size_t byte_offset, device_size_t byte_length,
- T value) {
- auto sized_value = reinterpret_cast<uint16_t*>(&value);
- return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value));
-}
-
-template <typename T>
-Status Buffer::Fill32(device_size_t byte_offset, device_size_t byte_length,
- T value) {
- auto sized_value = reinterpret_cast<uint32_t*>(&value);
- return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value));
-}
-
-template <typename T>
-Status Buffer::Fill8(T value) {
- return Fill8(0, kWholeBuffer, value);
-}
-
-template <typename T>
-Status Buffer::Fill16(T value) {
- return Fill16(0, kWholeBuffer, value);
-}
-
-template <typename T>
-Status Buffer::Fill32(T value) {
- return Fill32(0, kWholeBuffer, value);
-}
-
-template <typename T>
-StatusOr<MappedMemory<T>> Buffer::MapMemory(MemoryAccessBitfield memory_access,
- device_size_t element_offset,
- device_size_t element_length) {
- device_size_t byte_offset = element_offset * sizeof(T);
- device_size_t byte_length = element_length == kWholeBuffer
- ? kWholeBuffer
- : element_length * sizeof(T);
- void* data = nullptr;
- RETURN_IF_ERROR(MapMemory(MappingMode::kScoped, memory_access, &byte_offset,
- &byte_length, &data));
- return MappedMemory<T>{
- memory_access, add_ref(this), byte_offset,
- byte_length, byte_length / sizeof(T), static_cast<T*>(data)};
-}
-
-template <typename T>
-MappedMemory<T>::MappedMemory(MemoryAccessBitfield access,
- ref_ptr<Buffer> buffer, device_size_t byte_offset,
- device_size_t byte_length,
- device_size_t element_size, T* data)
- : access_(access),
- buffer_(std::move(buffer)),
- byte_offset_(byte_offset),
- byte_length_(byte_length),
- element_size_(element_size),
- data_(data) {}
-
-template <typename T>
-MappedMemory<T>::MappedMemory(MappedMemory<T>&& rhs) noexcept
- : access_(rhs.access_),
- buffer_(std::move(rhs.buffer_)),
- byte_offset_(rhs.byte_offset_),
- byte_length_(rhs.byte_length_),
- element_size_(rhs.element_size_),
- data_(rhs.data_) {
- rhs.access_ = MemoryAccess::kNone;
- rhs.buffer_.reset();
- rhs.byte_offset_ = 0;
- rhs.byte_length_ = 0;
- rhs.element_size_ = 0;
- rhs.data_ = nullptr;
-}
-
-template <typename T>
-MappedMemory<T>& MappedMemory<T>::operator=(MappedMemory<T>&& rhs) noexcept {
- if (this != &rhs) {
- reset();
- access_ = rhs.access_;
- buffer_ = std::move(rhs.buffer_);
- byte_offset_ = rhs.byte_offset_;
- byte_length_ = rhs.byte_length_;
- element_size_ = rhs.element_size_;
- data_ = rhs.data_;
-
- rhs.access_ = MemoryAccess::kNone;
- rhs.buffer_.reset();
- rhs.byte_offset_ = 0;
- rhs.byte_length_ = 0;
- rhs.element_size_ = 0;
- rhs.data_ = nullptr;
- }
- return *this;
-}
-
-template <typename T>
-MappedMemory<T>::~MappedMemory() {
- // Unmap (if needed) - note that we can't fail gracefully here :(
- reset();
-}
-
-template <typename T>
-const T* MappedMemory<T>::data() const noexcept {
- if (!data_ || !AnyBitSet(access_ & MemoryAccess::kRead)) {
- return nullptr;
- }
- return data_;
-}
-
-template <typename T>
-T* MappedMemory<T>::mutable_data() noexcept {
- if (!data_ || !AnyBitSet(access_ & MemoryAccess::kWrite)) {
- return nullptr;
- }
- return data_;
-}
-
-template <typename T>
-Status MappedMemory<T>::ValidateAccess(
- MemoryAccessBitfield memory_access) const {
- if (!data_) {
- return FailedPreconditionErrorBuilder(IREE_LOC) << "Buffer is not mapped";
- } else if (!AnyBitSet(access_ & memory_access)) {
- return PermissionDeniedErrorBuilder(IREE_LOC)
- << "Buffer is not mapped for the desired access";
- }
- return OkStatus();
-}
-
-template <typename T>
-Status MappedMemory<T>::CalculateDataRange(
- device_size_t element_offset, device_size_t element_length,
- device_size_t* out_adjusted_element_offset,
- device_size_t* out_adjusted_element_length) const {
- RETURN_IF_ERROR(Buffer::CalculateLocalRange(
- element_size_ * sizeof(T), element_offset * sizeof(T),
- element_length == kWholeBuffer ? kWholeBuffer
- : element_length * sizeof(T),
- out_adjusted_element_offset, out_adjusted_element_length));
- *out_adjusted_element_offset /= sizeof(T);
- *out_adjusted_element_length /= sizeof(T);
- return OkStatus();
-}
-
-template <typename T>
-inline StatusOr<absl::Span<const T>> MappedMemory<T>::Subspan(
- device_size_t element_offset, device_size_t element_length) const noexcept {
- RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead));
- RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length,
- &element_offset, &element_length));
- return absl::Span<const T>(data_ + element_offset, element_length);
-}
-
-template <typename T>
-inline StatusOr<absl::Span<T>> MappedMemory<T>::MutableSubspan(
- device_size_t element_offset, device_size_t element_length) noexcept {
- RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite));
- RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length,
- &element_offset, &element_length));
- return absl::Span<T>(data_ + element_offset, element_length);
-}
-
-template <typename T>
-Status MappedMemory<T>::Invalidate(device_size_t element_offset,
- device_size_t element_length) const {
- RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead));
- RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length,
- &element_offset, &element_length));
- if (!element_length) return OkStatus();
- return buffer_->InvalidateMappedMemory(
- byte_offset_ + element_offset * sizeof(T), element_length * sizeof(T));
-}
-
-template <typename T>
-Status MappedMemory<T>::Flush(device_size_t element_offset,
- device_size_t element_length) {
- RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite));
- RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length,
- &element_offset, &element_length));
- if (!element_length) return OkStatus();
- return buffer_->FlushMappedMemory(byte_offset_ + element_offset * sizeof(T),
- element_length * sizeof(T));
-}
-
-template <typename T>
-void MappedMemory<T>::reset() {
- if (!buffer_) return;
- // TODO(benvanik): better handling of errors? may be fine to always warn.
- buffer_->UnmapMemory(byte_offset_, byte_length_, data_).IgnoreError();
- buffer_.reset();
- access_ = MemoryAccess::kNone;
- byte_offset_ = 0;
- byte_length_ = 0;
- element_size_ = 0;
- data_ = nullptr;
-}
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_BUFFER_H_
diff --git a/iree/hal/buffer_mapping_test.cc b/iree/hal/buffer_mapping_test.cc
deleted file mode 100644
index 9b7e9ec..0000000
--- a/iree/hal/buffer_mapping_test.cc
+++ /dev/null
@@ -1,539 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Tests for the MemoryMapping RAII wrapper.
-// This uses a mock buffer implementation such that it is only testing
-// MemoryMapping and not any real underlying memory mapping behavior.
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include "absl/types/span.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status.h"
-#include "iree/base/status_matchers.h"
-#include "iree/hal/buffer.h"
-
-namespace iree {
-namespace hal {
-class Allocator;
-
-namespace {
-
-using ::testing::_;
-using ::testing::DoAll;
-using ::testing::Return;
-using ::testing::SetArgPointee;
-
-static void* const kValidPtr = reinterpret_cast<void*>(0xBEEFCAFEF00D1234ull);
-
-class MockBuffer : public Buffer {
- public:
- using Buffer::MappingMode;
-
- MockBuffer(Allocator* allocator, MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
- device_size_t allocation_size)
- : Buffer(allocator, memory_type, allowed_access, usage, allocation_size,
- 0, allocation_size) {}
-
- MOCK_METHOD4(FillImpl,
- Status(device_size_t byte_offset, device_size_t byte_length,
- const void* pattern, device_size_t pattern_length));
-
- MOCK_METHOD3(ReadDataImpl, Status(device_size_t source_offset, void* data,
- device_size_t data_length));
- MOCK_METHOD3(WriteDataImpl,
- Status(device_size_t target_offset, const void* data,
- device_size_t data_length));
- MOCK_METHOD4(CopyDataImpl,
- Status(device_size_t target_offset, Buffer* source_buffer,
- device_size_t source_offset, device_size_t data_length));
-
- MOCK_METHOD5(MapMemoryImpl,
- Status(MappingMode mapping_mode,
- MemoryAccessBitfield memory_access,
- device_size_t local_byte_offset,
- device_size_t local_byte_length, void** out_data));
- MOCK_METHOD3(UnmapMemoryImpl,
- Status(device_size_t local_byte_offset,
- device_size_t local_byte_length, void* data));
- MOCK_METHOD2(InvalidateMappedMemoryImpl,
- Status(device_size_t local_byte_offset,
- device_size_t local_byte_length));
- MOCK_METHOD2(FlushMappedMemoryImpl, Status(device_size_t local_byte_offset,
- device_size_t local_byte_length));
-};
-
-TEST(MemoryMappingTest, MapWholeBuffer) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mapping,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mapping.reset();
-}
-
-TEST(MemoryMappingTest, MapPartialBuffer) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead, 4, 12, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mapping,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 4, 12));
- EXPECT_CALL(*buffer, UnmapMemoryImpl(4, 12, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mapping.reset();
-}
-
-TEST(MemoryMappingTest, EmptyHandle) {
- MappedMemory<uint8_t> mm_a;
- MappedMemory<uint8_t> mm_b;
- mm_a = std::move(mm_b);
- EXPECT_EQ(nullptr, mm_a.buffer());
- EXPECT_EQ(0, mm_a.byte_offset());
- EXPECT_EQ(0, mm_a.byte_length());
- EXPECT_TRUE(mm_a.empty());
- EXPECT_EQ(0, mm_a.size());
- EXPECT_EQ(nullptr, mm_a.data());
- EXPECT_EQ(nullptr, mm_a.mutable_data());
- EXPECT_TRUE(IsFailedPrecondition(mm_a.Subspan().status()));
- EXPECT_TRUE(IsFailedPrecondition(mm_a.MutableSubspan().status()));
- EXPECT_TRUE(IsFailedPrecondition(mm_a.Invalidate()));
- EXPECT_TRUE(IsFailedPrecondition(mm_a.Flush()));
- mm_a.reset();
-}
-
-TEST(MemoryMappingTest, MoveHandle) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
-
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_a,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
-
- // Should be able to move the handle around without having any calls.
- auto mm_b = std::move(mm_a);
- mm_a = std::move(mm_b);
- mm_b = std::move(mm_a);
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_b.reset();
-}
-
-TEST(MemoryMappingTest, ReadOnlyAccess) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kRead, BufferUsage::kAll, 128);
-
- // Should succeed to map for reading.
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_r,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
-
- // Non-mutable access is fine.
- EXPECT_EQ(kValidPtr, mm_r.data());
- ASSERT_OK_AND_ASSIGN(auto span, mm_r.Subspan());
- (void)span;
-
- // Read-only mappings should not be able to get mutable access.
- EXPECT_EQ(nullptr, mm_r.mutable_data());
- EXPECT_TRUE(IsPermissionDenied(mm_r.MutableSubspan().status()));
-
- // Read-only mappings should not be able to call Flush.
- EXPECT_TRUE(IsPermissionDenied(mm_r.Flush()));
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_r.reset();
-
- // Should fail to map for writing.
- EXPECT_TRUE(IsPermissionDenied(
- buffer->MapMemory<uint8_t>(MemoryAccess::kWrite).status()));
-}
-
-TEST(MemoryMappingTest, ReadWriteAccess) {
- auto buffer = std::make_shared<MockBuffer>(
- nullptr, MemoryType::kHostLocal,
- MemoryAccess::kRead | MemoryAccess::kWrite, BufferUsage::kAll, 128);
-
- // Should succeed to map for reading and/or writing.
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead | MemoryAccess::kWrite,
- 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(
- auto mm_rw,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead | MemoryAccess::kWrite));
-
- // Everything valid.
- EXPECT_EQ(kValidPtr, mm_rw.data());
- ASSERT_OK_AND_ASSIGN(auto span, mm_rw.Subspan());
- EXPECT_EQ(kValidPtr, mm_rw.mutable_data());
- ASSERT_OK_AND_ASSIGN(span, mm_rw.MutableSubspan());
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_rw.reset();
-
- // Should fail to map for discard.
- EXPECT_TRUE(IsPermissionDenied(
- buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status()));
-}
-
-TEST(MemoryMappingTest, WriteOnlyAccess) {
- auto buffer = std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kWrite,
- BufferUsage::kAll, 128);
-
- // Should succeed to map for writing.
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kWrite, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_w,
- buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
-
- // Mutable access is valid.
- EXPECT_EQ(kValidPtr, mm_w.mutable_data());
- ASSERT_OK_AND_ASSIGN(auto span, mm_w.MutableSubspan());
- (void)span;
-
- // Write-only mappings should not be able to get non-mutable access.
- EXPECT_EQ(nullptr, mm_w.data());
- EXPECT_TRUE(IsPermissionDenied(mm_w.Subspan().status()));
-
- // Write-only mappings should not be able to call Invalidate.
- EXPECT_TRUE(IsPermissionDenied(mm_w.Invalidate()));
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_w.reset();
-
- // Should fail to map for reading.
- EXPECT_TRUE(IsPermissionDenied(
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead).status()));
-
- // Should fail to map for discard.
- EXPECT_TRUE(IsPermissionDenied(
- buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status()));
-}
-
-TEST(MemoryMappingTest, WriteDiscardAccess) {
- auto buffer = std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kDiscardWrite,
- BufferUsage::kAll, 128);
-
- // Should succeed to map for writing with discard.
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kDiscardWrite, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_dw,
- buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite));
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_dw.reset();
-
- // Should also be ok to map for just writing.
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kWrite, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_w,
- buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_w.reset();
-
- // Should fail to map for reading.
- EXPECT_TRUE(IsPermissionDenied(
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead).status()));
-}
-
-TEST(MemoryMappingTest, Subspan) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_r,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
-
- // Request some valid ranges and ensure the byte offsets are correct.
- ASSERT_OK_AND_ASSIGN(auto ss, mm_r.Subspan());
- EXPECT_EQ(kValidPtr, ss.data());
- EXPECT_EQ(128, ss.size());
- ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, 2));
- EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data());
- EXPECT_EQ(2, ss.size());
- ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, kWholeBuffer));
- EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data());
- EXPECT_EQ(28, ss.size());
-
- // Zero length ranges are fine.
- ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(0, 0));
- EXPECT_EQ(kValidPtr, ss.data());
- EXPECT_TRUE(ss.empty());
- ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, 0));
- EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data());
- EXPECT_TRUE(ss.empty());
- ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, kWholeBuffer));
- EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data());
- EXPECT_TRUE(ss.empty());
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_r.reset();
-}
-
-TEST(MemoryMappingTest, SubspanOutOfRange) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_r,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
-
- // Try some invalid ranges that would overrun the span.
- EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 0).status()));
- EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 2).status()));
- EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, kWholeBuffer).status()));
- EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(100, 1234).status()));
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_r.reset();
-}
-
-TEST(MemoryMappingTest, MutableSubspan) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kWrite, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_w,
- buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
-
- // Request some valid ranges and ensure the byte offsets are correct.
- ASSERT_OK_AND_ASSIGN(auto ss, mm_w.MutableSubspan());
- EXPECT_EQ(kValidPtr, ss.data());
- EXPECT_EQ(128, ss.size());
- ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, 2));
- EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data());
- EXPECT_EQ(2, ss.size());
- ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, kWholeBuffer));
- EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data());
- EXPECT_EQ(28, ss.size());
-
- // Zero length ranges are fine.
- ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(0, 0));
- EXPECT_EQ(kValidPtr, ss.data());
- EXPECT_TRUE(ss.empty());
- ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, 0));
- EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data());
- EXPECT_TRUE(ss.empty());
- ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, kWholeBuffer));
- EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data());
- EXPECT_TRUE(ss.empty());
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_w.reset();
-}
-
-TEST(MemoryMappingTest, MutableSubspanOutOfRange) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kWrite, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_w,
- buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
-
- // Try some invalid ranges that would overrun the span.
- EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 0).status()));
- EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 2).status()));
- EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, kWholeBuffer).status()));
- EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(100, 1234).status()));
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_w.reset();
-}
-
-TEST(MemoryMappingTest, ElementOperator) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_r,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
-
- // Just verify we are getting the expected pointer back.
- EXPECT_EQ(kValidPtr, &mm_r[0]);
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_r.reset();
-}
-
-TEST(MemoryMappingTest, Invalidate) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_r,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
-
- // Invalidate a few ways.
- EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(0, 128))
- .WillOnce(Return(OkStatus()));
- EXPECT_OK(mm_r.Invalidate());
- EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 2))
- .WillOnce(Return(OkStatus()));
- EXPECT_OK(mm_r.Invalidate(100, 2));
- EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 28))
- .WillOnce(Return(OkStatus()));
- EXPECT_OK(mm_r.Invalidate(100, kWholeBuffer));
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_r.reset();
-}
-
-TEST(MemoryMappingTest, InvalidateOutOfRange) {
- auto buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_r,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
-
- // Try to invalidate invalid ranges.
- EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 0)));
- EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 12345)));
- EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, kWholeBuffer)));
- EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1, 1234)));
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_r.reset();
-}
-
-TEST(MemoryMappingTest, InvalidateBadMode) {
- // Invalidate is not required on coherent memory.
- auto coherent_buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*coherent_buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kRead, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(
- auto mm_r, coherent_buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
- EXPECT_TRUE(IsPermissionDenied(mm_r.Invalidate()));
- EXPECT_CALL(*coherent_buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_r.reset();
-}
-
-TEST(MemoryMappingTest, Flush) {
- auto buffer = std::make_shared<MockBuffer>(
- nullptr, MemoryType::kHostVisible | MemoryType::kHostCached,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kWrite, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_w,
- buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
-
- // Flush a few ways.
- EXPECT_CALL(*buffer, FlushMappedMemoryImpl(0, 128))
- .WillOnce(Return(OkStatus()));
- EXPECT_OK(mm_w.Flush());
- EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 2))
- .WillOnce(Return(OkStatus()));
- EXPECT_OK(mm_w.Flush(100, 2));
- EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 28))
- .WillOnce(Return(OkStatus()));
- EXPECT_OK(mm_w.Flush(100, kWholeBuffer));
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_w.reset();
-}
-
-TEST(MemoryMappingTest, FlushOutOfRange) {
- auto buffer = std::make_shared<MockBuffer>(
- nullptr, MemoryType::kHostVisible | MemoryType::kHostCached,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kWrite, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(auto mm_w,
- buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
-
- // Try to flush invalid ranges.
- EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 0)));
- EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 12345)));
- EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, kWholeBuffer)));
- EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1, 1234)));
-
- EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_w.reset();
-}
-
-TEST(MemoryMappingTest, FlushBadMode) {
- // Flush is not required on uncached memory.
- auto uncached_buffer =
- std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible,
- MemoryAccess::kAll, BufferUsage::kAll, 128);
- EXPECT_CALL(*uncached_buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped,
- MemoryAccess::kWrite, 0, 128, _))
- .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus())));
- ASSERT_OK_AND_ASSIGN(
- auto mm_w, uncached_buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
- EXPECT_TRUE(IsPermissionDenied(mm_w.Flush()));
- EXPECT_CALL(*uncached_buffer, UnmapMemoryImpl(0, 128, kValidPtr))
- .WillOnce(Return(OkStatus()));
- mm_w.reset();
-}
-
-} // namespace
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/buffer_test.cc b/iree/hal/buffer_test.cc
deleted file mode 100644
index 1d9fced..0000000
--- a/iree/hal/buffer_test.cc
+++ /dev/null
@@ -1,1000 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Tests for the shared buffer functionality and host heap buffers.
-// This does not test device-specific buffer implementations; see the device
-// code for associated tests.
-
-#include "iree/hal/buffer.h"
-
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status_matchers.h"
-#include "iree/hal/heap_buffer.h"
-
-namespace iree {
-namespace hal {
-namespace {
-
-using ::testing::_;
-using ::testing::ElementsAre;
-using ::testing::Eq;
-using ::testing::Not;
-
-TEST(BufferTest, Allocate) {
- auto buffer =
- HeapBuffer::Allocate(BufferUsage::kTransfer | BufferUsage::kMapping, 14);
- EXPECT_NE(nullptr, buffer->allocator());
- EXPECT_EQ(MemoryAccess::kAll, buffer->allowed_access());
- EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type());
- EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage());
-
- // We don't currently do any padding on the host.
- // Other implementations may differ.
- EXPECT_GE(14, buffer->allocation_size());
- EXPECT_EQ(0, buffer->byte_offset());
- EXPECT_EQ(14, buffer->byte_length());
-
- // Data should be zeroed by default.
- std::vector<uint8_t> zero_data(buffer->allocation_size());
- std::vector<uint8_t> actual_data(buffer->allocation_size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(zero_data));
-}
-
-TEST(BufferTest, AllocateZeroLength) {
- auto buffer =
- HeapBuffer::Allocate(BufferUsage::kTransfer | BufferUsage::kMapping, 0);
- EXPECT_NE(nullptr, buffer->allocator());
- EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type());
- EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage());
- EXPECT_EQ(0, buffer->allocation_size());
-}
-
-TEST(BufferTest, AllocateCopy) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
- EXPECT_NE(nullptr, buffer->allocator());
- EXPECT_GE(src_data.size(), buffer->allocation_size());
-
- // Data should have been copied.
- std::vector<uint8_t> actual_data(src_data.size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(src_data));
-
- // Modify the source data and ensure it is not reflected in the buffer.
- src_data[0] = 0x88;
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Not(Eq(src_data)));
-}
-
-TEST(BufferTest, AllocateCopyZeroLength) {
- std::vector<uint8_t> src_data;
- auto buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
- EXPECT_NE(nullptr, buffer->allocator());
- EXPECT_EQ(0, buffer->allocation_size());
-}
-
-TEST(BufferTest, AllocateCopyTyped) {
- std::vector<int32_t> src_data = {0, 1, 2, 3};
- auto buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- absl::MakeConstSpan(src_data));
- EXPECT_NE(nullptr, buffer->allocator());
- EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type());
- EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage());
- EXPECT_GE(src_data.size() * sizeof(int32_t), buffer->allocation_size());
-
- // Data should have been copied.
- std::vector<int32_t> actual_data(src_data.size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(),
- actual_data.size() * sizeof(int32_t)));
- EXPECT_THAT(actual_data, Eq(src_data));
-}
-
-TEST(BufferTest, WrapConstant) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto buffer = HeapBuffer::Wrap(MemoryType::kHostLocal,
- BufferUsage::kTransfer | BufferUsage::kMapping,
- absl::MakeConstSpan(src_data));
- EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type());
- EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage());
- EXPECT_EQ(src_data.size(), buffer->allocation_size());
-
- // src_data and buffer should match after the wrapping.
- std::vector<uint8_t> actual_data(src_data.size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(src_data));
-
- // Modify the source data directly.
- src_data[0] = 123;
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(src_data));
-
- // Attempts to modify the buffer should fail.
- std::vector<uint8_t> new_data = {3, 2, 1, 0};
- EXPECT_TRUE(IsPermissionDenied(
- buffer->WriteData(0, new_data.data(), new_data.size())));
-}
-
-TEST(BufferTest, WrapMutable) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto buffer = HeapBuffer::WrapMutable(
- MemoryType::kHostLocal, MemoryAccess::kAll,
- BufferUsage::kTransfer | BufferUsage::kMapping, absl::MakeSpan(src_data));
- EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type());
- EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage());
- EXPECT_EQ(src_data.size(), buffer->allocation_size());
-
- // src_data and buffer should match after the wrapping.
- std::vector<uint8_t> actual_data(src_data.size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(src_data));
-
- // Modify the source data directly.
- src_data[0] = 123;
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(src_data));
-
- // Modify the source data via the Buffer and ensure reflected in src_data.
- std::vector<uint8_t> new_data = {3, 2, 1, 0};
- EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size()));
- EXPECT_THAT(src_data, Eq(new_data));
-}
-
-TEST(BufferTest, WrapExternal) {
- // This is not fully supported yet, but does let us verify that the validation
- // of memory types is working.
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto buffer = HeapBuffer::Wrap(MemoryType::kDeviceLocal, BufferUsage::kAll,
- absl::MakeConstSpan(src_data));
- EXPECT_EQ(MemoryType::kDeviceLocal, buffer->memory_type());
-
- // Should fail (for now) as the buffer is not host visible.
- EXPECT_TRUE(IsPermissionDenied(buffer->Fill8(0, kWholeBuffer, 0x99u)));
-}
-
-TEST(BufferTest, DoesOverlap) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto parent_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
-
- // A buffer should overlap with itself.
- EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1,
- parent_buffer.get(), 1, 1));
- EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1,
- parent_buffer.get(), 0, 1));
-
- // Zero length buffers never overlap.
- EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 1, 1,
- parent_buffer.get(), 1, 0));
-
- // Subspans should offset within their allocation.
- ASSERT_OK_AND_ASSIGN(auto subspan_buffer_0,
- Buffer::Subspan(parent_buffer, 1, 2));
- ASSERT_OK_AND_ASSIGN(auto subspan_buffer_1,
- Buffer::Subspan(parent_buffer, 2, 2));
- EXPECT_FALSE(Buffer::DoesOverlap(subspan_buffer_0.get(), 0, 1,
- subspan_buffer_1.get(), 0, 1));
- EXPECT_TRUE(Buffer::DoesOverlap(subspan_buffer_0.get(), 1, 1,
- subspan_buffer_1.get(), 0, 1));
-
- // Mixing subspans and normal buffers.
- EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1,
- subspan_buffer_0.get(), 0, 1));
- EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 1, 2,
- subspan_buffer_0.get(), 1, 1));
-
- // Independent buffers should not be able to overlap.
- auto other_buffer = HeapBuffer::Allocate(BufferUsage::kAll, 128);
- EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, kWholeBuffer,
- other_buffer.get(), 0, kWholeBuffer));
-}
-
-TEST(BufferTest, Subspan) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto parent_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
- ASSERT_TRUE(parent_buffer);
-
- // Create a subspan of the buffer.
- ASSERT_OK_AND_ASSIGN(auto subspan_buffer,
- Buffer::Subspan(parent_buffer, 1, 2));
- ASSERT_TRUE(subspan_buffer);
- EXPECT_EQ(1, subspan_buffer->byte_offset());
- EXPECT_EQ(2, subspan_buffer->byte_length());
-
- // Modifications to either buffer should appear in the other.
- EXPECT_OK(subspan_buffer->Fill8(1, kWholeBuffer, 0xFFu));
- std::vector<uint8_t> actual_data(src_data.size());
- EXPECT_OK(parent_buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xFF, 3));
-
- // Subspans should be able to create subspans.
- // NOTE: offset is from the original buffer.
- ASSERT_OK_AND_ASSIGN(auto subsubspan_buffer,
- Buffer::Subspan(subspan_buffer, 1, 1));
- ASSERT_TRUE(subsubspan_buffer);
- EXPECT_EQ(2, subsubspan_buffer->byte_offset());
- EXPECT_EQ(1, subsubspan_buffer->byte_length());
-
- // Zero length subspans are fine.
- ASSERT_OK_AND_ASSIGN(auto zero_subspan_buffer,
- Buffer::Subspan(parent_buffer, 0, 0));
- ASSERT_TRUE(zero_subspan_buffer);
- EXPECT_EQ(0, zero_subspan_buffer->byte_offset());
- EXPECT_EQ(0, zero_subspan_buffer->byte_length());
-
- // Subspan with kWholeBuffer should get the remaining size (or zero).
- ASSERT_OK_AND_ASSIGN(auto whole_subspan_buffer,
- Buffer::Subspan(parent_buffer, 1, kWholeBuffer));
- ASSERT_TRUE(whole_subspan_buffer);
- EXPECT_EQ(1, whole_subspan_buffer->byte_offset());
- EXPECT_EQ(3, whole_subspan_buffer->byte_length());
-
- // Zero length subspans are fine.
- ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, 0));
- ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, kWholeBuffer));
-}
-
-TEST(BufferTest, SubspanIdentity) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto parent_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
-
- // Asking for a subspan of the entire buffer should return the same buffer.
- // Mostly an optimization.
- EXPECT_EQ(parent_buffer.get(),
- Buffer::Subspan(parent_buffer, 0, kWholeBuffer).ValueOrDie().get());
- EXPECT_EQ(parent_buffer.get(),
- Buffer::Subspan(parent_buffer, 0, 4).ValueOrDie().get());
-}
-
-TEST(BufferTest, SubspanOutOfRange) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto parent_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
- ASSERT_TRUE(parent_buffer);
-
- // Create a subspan of the buffer.
- ASSERT_OK_AND_ASSIGN(auto subspan_buffer,
- Buffer::Subspan(parent_buffer, 1, 2));
- ASSERT_TRUE(subspan_buffer);
- EXPECT_EQ(1, subspan_buffer->byte_offset());
- EXPECT_EQ(2, subspan_buffer->byte_length());
-
- // Try to make subspans from invalid ranges.
- EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 5, 0).status()));
- EXPECT_TRUE(
- IsOutOfRange(Buffer::Subspan(parent_buffer, 5, kWholeBuffer).status()));
- EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 4, 1).status()));
- EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 0, 123).status()));
- EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 1, 2).status()));
- EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 0, 44).status()));
-}
-
-TEST(BufferTest, Fill8) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5);
- ASSERT_TRUE(buffer);
-
- // Data should be zeroed by default.
- std::vector<uint8_t> actual_data(buffer->allocation_size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0));
-
- // Fill with a sentinel.
- EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u));
-
- // Verify data.
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33));
-
- // Zero fills are fine.
- EXPECT_OK(buffer->Fill8(0, 0, 0x44u));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33));
-
- // Fill the remaining parts of the buffer by using kWholeBuffer.
- EXPECT_OK(buffer->Fill8(2, kWholeBuffer, 0x55u));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x55, 0x55, 0x55));
-
- // Fill a small region of the buffer.
- EXPECT_OK(buffer->Fill8(1, 1, 0x66u));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0x33, 0x66, 0x55, 0x55, 0x55));
-
- // Whole buffer helper.
- EXPECT_OK(buffer->Fill8(0x99u));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0x99, 0x99, 0x99, 0x99, 0x99));
-}
-
-TEST(BufferTest, Fill8OutOfRange) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5);
- ASSERT_TRUE(buffer);
-
- // Fill with a sentinel.
- EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u));
-
- // Try to fill with invalid ranges.
- EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u)));
- EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 444, 0x44u)));
- EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 1, 0x44u)));
- EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u)));
-
- // Ensure nothing happened with the bad ranges.
- std::vector<uint8_t> actual_data(buffer->allocation_size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33));
-}
-
-TEST(BufferTest, Fill8BadMode) {
- // Fail to fill buffers not supporting mapping.
- auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
- EXPECT_TRUE(
- IsPermissionDenied(nonmapping_buffer->Fill8(0, kWholeBuffer, 0x99u)));
-
- // Fail to fill constant buffers.
- std::vector<uint8_t> const_data = {1, 2, 3};
- auto constant_buffer =
- HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping,
- absl::MakeConstSpan(const_data));
- EXPECT_TRUE(
- IsPermissionDenied(constant_buffer->Fill8(0, kWholeBuffer, 0x99u)));
-}
-
-TEST(BufferTest, Fill8Subspan) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5);
- ASSERT_TRUE(buffer);
-
- // Test on subspan.
- std::vector<uint8_t> actual_data(buffer->allocation_size());
- ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 3));
- EXPECT_OK(subspan_buffer->Fill8(2, kWholeBuffer, 0xDDu));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0xDD, 0));
-}
-
-TEST(BufferTest, Fill16) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
- ASSERT_TRUE(buffer);
-
- // Data should be zeroed by default.
- std::vector<uint8_t> actual_data(buffer->allocation_size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0));
-
- // Fill with a sentinel.
- EXPECT_OK(buffer->Fill16(0, 4, 0x1122u));
-
- // Verify data.
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0));
-
- // Zero fills are fine.
- EXPECT_OK(buffer->Fill16(0, 0, 0x5566u));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0));
-
- // Fill the remaining parts of the buffer by using kWholeBuffer.
- auto aligned_buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 8);
- EXPECT_OK(aligned_buffer->Fill16(4, kWholeBuffer, 0x5566u));
- std::vector<uint8_t> aligned_actual_data(aligned_buffer->allocation_size());
- EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(),
- aligned_actual_data.size()));
- EXPECT_THAT(aligned_actual_data,
- ElementsAre(0, 0, 0, 0, 0x66, 0x55, 0x66, 0x55));
-
- // Whole buffer helper.
- EXPECT_OK(aligned_buffer->Fill16(0x5566u));
- EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(),
- aligned_actual_data.size()));
- EXPECT_THAT(aligned_actual_data,
- ElementsAre(0x66, 0x55, 0x66, 0x55, 0x66, 0x55, 0x66, 0x55));
-}
-
-TEST(BufferTest, Fill16OutOfRange) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
- ASSERT_TRUE(buffer);
-
- // Try to fill with invalid ranges.
- EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u)));
- EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 444, 0x5566u)));
- EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 4, 0x5566u)));
- EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u)));
-}
-
-TEST(BufferTest, Fill16Unaligned) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
- ASSERT_TRUE(buffer);
-
- // Try to fill with unaligned ranges.
- EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(1, 4, 0x5566u)));
- EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(0, 5, 0x5566u)));
-}
-
-TEST(BufferTest, Fill16BadMode) {
- // Fail to fill buffers not supporting mapping.
- auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
- EXPECT_TRUE(
- IsPermissionDenied(nonmapping_buffer->Fill16(0, kWholeBuffer, 0x99AAu)));
-
- // Fail to fill constant buffers.
- std::vector<uint8_t> const_data = {1, 2, 3};
- auto constant_buffer =
- HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping,
- absl::MakeConstSpan(const_data));
- EXPECT_TRUE(
- IsPermissionDenied(constant_buffer->Fill16(0, kWholeBuffer, 0x99AAu)));
-}
-
-TEST(BufferTest, Fill16Subspan) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
- ASSERT_TRUE(buffer);
-
- // Fill with a sentinel.
- EXPECT_OK(buffer->Fill16(0, 4, 0x1122u));
-
- // Test on subspan.
- std::vector<uint8_t> actual_data(buffer->allocation_size());
- ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 2, 4));
- EXPECT_OK(subspan_buffer->Fill16(2, kWholeBuffer, 0xAABBu));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data,
- ElementsAre(0x22, 0x11, 0x22, 0x11, 0xBB, 0xAA, 0, 0, 0));
-}
-
-TEST(BufferTest, Fill32) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
- ASSERT_TRUE(buffer);
-
- // Data should be zeroed by default.
- std::vector<uint8_t> actual_data(buffer->allocation_size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0));
-
- // Fill with a sentinel.
- EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u));
-
- // Verify data.
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data,
- ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0));
-
- // Zero fills are fine.
- EXPECT_OK(buffer->Fill32(0, 0, 0x55667788u));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data,
- ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0));
-
- // Fill the remaining parts of the buffer by using kWholeBuffer.
- auto aligned_buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 8);
- EXPECT_OK(aligned_buffer->Fill32(4, kWholeBuffer, 0x55667788u));
- std::vector<uint8_t> aligned_actual_data(aligned_buffer->allocation_size());
- EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(),
- aligned_actual_data.size()));
- EXPECT_THAT(aligned_actual_data,
- ElementsAre(0, 0, 0, 0, 0x88, 0x77, 0x66, 0x55));
-
- // Whole buffer helper.
- EXPECT_OK(aligned_buffer->Fill32(0x55667788u));
- EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(),
- aligned_actual_data.size()));
- EXPECT_THAT(aligned_actual_data,
- ElementsAre(0x88, 0x77, 0x66, 0x55, 0x88, 0x77, 0x66, 0x55));
-}
-
-TEST(BufferTest, Fill32OutOfRange) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
- ASSERT_TRUE(buffer);
-
- // Try to fill with invalid ranges.
- EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u)));
- EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 444, 0x55667788u)));
- EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 4, 0x55667788u)));
- EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u)));
-}
-
-TEST(BufferTest, Fill32Unaligned) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
- ASSERT_TRUE(buffer);
-
- // Try to fill with unaligned ranges.
- EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(1, 4, 0x55667788u)));
- EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(0, 5, 0x55667788u)));
-}
-
-TEST(BufferTest, Fill32BadMode) {
- // Fail to fill buffers not supporting mapping.
- auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
- EXPECT_TRUE(IsPermissionDenied(
- nonmapping_buffer->Fill32(0, kWholeBuffer, 0x99AABBCCu)));
-
- // Fail to fill constant buffers.
- std::vector<uint8_t> const_data = {1, 2, 3};
- auto constant_buffer =
- HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping,
- absl::MakeConstSpan(const_data));
- EXPECT_TRUE(IsPermissionDenied(
- constant_buffer->Fill32(0, kWholeBuffer, 0x99AABBCCu)));
-}
-
-TEST(BufferTest, Fill32Subspan) {
- auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9);
- ASSERT_TRUE(buffer);
-
- // Fill with a sentinel.
- EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u));
-
- // Test on subspan.
- std::vector<uint8_t> actual_data(buffer->allocation_size());
- ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 4, 4));
- EXPECT_OK(subspan_buffer->Fill32(0, kWholeBuffer, 0xAABBCCDDu));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data,
- ElementsAre(0x44, 0x33, 0x22, 0x11, 0xDD, 0xCC, 0xBB, 0xAA, 0));
-}
-
-TEST(BufferTest, ReadData) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // Read the data back.
- std::vector<uint8_t> actual_data(src_data.size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(src_data));
-
- // Reading zero bytes is valid.
- std::vector<uint8_t> zero_data(0);
- EXPECT_OK(buffer->ReadData(1, zero_data.data(), 0));
-
- // Read a portion of the data.
- std::vector<uint8_t> partial_data(2);
- EXPECT_OK(buffer->ReadData(1, partial_data.data(), 2));
- EXPECT_THAT(partial_data, ElementsAre(1, 2));
-}
-
-TEST(BufferTest, ReadDataOutOfRange) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // Try to read out of range.
- std::vector<uint8_t> partial_data(2);
- EXPECT_TRUE(IsOutOfRange(buffer->ReadData(0, partial_data.data(), 444)));
- EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 444)));
- EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 1)));
- EXPECT_TRUE(IsInvalidArgument(
- buffer->ReadData(0, partial_data.data(), kWholeBuffer)));
-}
-
-TEST(BufferTest, ReadDataBadMode) {
- // Fail to read buffers not supporting mapping.
- std::vector<uint8_t> actual_data(1);
- auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
- EXPECT_TRUE(IsPermissionDenied(
- nonmapping_buffer->ReadData(0, actual_data.data(), 1)));
-}
-
-TEST(BufferTest, ReadDataSubspan) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // Test on subspan.
- std::vector<uint8_t> subspan_data(1);
- ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2));
- EXPECT_OK(subspan_buffer->ReadData(1, subspan_data.data(), 1));
- EXPECT_THAT(subspan_data, ElementsAre(2));
-}
-
-TEST(BufferTest, WriteData) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // Read the data back - should still match.
- std::vector<uint8_t> actual_data(src_data.size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(src_data));
-
- // Write over the entire buffer.
- std::vector<uint8_t> new_data = {10, 20, 30, 40};
- EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size()));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(new_data));
-
- // Writing zero bytes is valid.
- std::vector<uint8_t> zero_data;
- EXPECT_OK(buffer->WriteData(0, zero_data.data(), 0));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(new_data));
-
- // Write over a portion of the buffer.
- std::vector<uint8_t> partial_data = {99};
- EXPECT_OK(buffer->WriteData(1, partial_data.data(), partial_data.size()));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(10, 99, 30, 40));
-}
-
-TEST(BufferTest, WriteDataOutOfRange) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // Try to write out of range.
- std::vector<uint8_t> partial_data = {99};
- EXPECT_TRUE(IsOutOfRange(buffer->WriteData(0, partial_data.data(), 444)));
- EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 444)));
- EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 1)));
- EXPECT_TRUE(IsInvalidArgument(
- buffer->WriteData(0, partial_data.data(), kWholeBuffer)));
-}
-
-TEST(BufferTest, WriteDataBadMode) {
- std::vector<uint8_t> actual_data(4);
-
- // Fail to write buffers not supporting mapping.
- auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
- EXPECT_TRUE(IsPermissionDenied(
- nonmapping_buffer->WriteData(0, actual_data.data(), 1)));
-
- // Fail to write to constant buffers.
- std::vector<uint8_t> const_data = {1, 2, 3};
- auto constant_buffer =
- HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kTransfer,
- absl::MakeConstSpan(const_data));
- EXPECT_TRUE(
- IsPermissionDenied(constant_buffer->WriteData(0, actual_data.data(), 2)));
-}
-
-TEST(BufferTest, WriteDataSubspan) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // Test on subspan.
- std::vector<uint8_t> subspan_data = {0xAA};
- ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2));
- EXPECT_OK(subspan_buffer->WriteData(1, subspan_data.data(), 1));
- std::vector<uint8_t> actual_data(src_data.size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xAA, 3));
-}
-
-TEST(BufferTest, CopyData) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto src_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
- src_data.data(), src_data.size());
- ASSERT_TRUE(src_buffer);
- std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4};
- auto dst_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
- dst_data.data(), dst_data.size());
- ASSERT_TRUE(dst_buffer);
-
- // Copy of length 0 should not change the dest buffer.
- EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 0, 0));
- std::vector<uint8_t> actual_data(dst_data.size());
- EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, Eq(dst_data));
-
- // Copy a subrange of the buffer.
- EXPECT_OK(dst_buffer->CopyData(1, src_buffer.get(), 2, 2));
- EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 3, 4));
-
- // Copy the entire buffer using kWholeBuffer. This will adjust sizes
- // to ensure that the min buffer is taken. We test both src and dst buffer
- // offset/length calculations (note that some may end up as 0 copies).
- EXPECT_OK(dst_buffer->CopyData(3, src_buffer.get(), 0, kWholeBuffer));
- EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 0, 1));
- EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 2, kWholeBuffer));
- EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(2, 3, 3, 0, 1));
- EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 3, kWholeBuffer));
- EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 1));
- EXPECT_OK(dst_buffer->CopyData(4, src_buffer.get(), 0, kWholeBuffer));
- EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 0));
-}
-
-TEST(BufferTest, CopyDataOutOfRange) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto src_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
- src_data.data(), src_data.size());
- ASSERT_TRUE(src_buffer);
- std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4};
- auto dst_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
- dst_data.data(), dst_data.size());
- ASSERT_TRUE(dst_buffer);
-
- // Try to copy out of range of source and dest.
- EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 0, 1)));
- EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(4, src_buffer.get(), 0, 4)));
- EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 1)));
- EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 0, 123)));
- EXPECT_TRUE(
- IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 123, 123)));
- EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 0)));
-}
-
-TEST(BufferTest, CopyDataOverlapping) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto src_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
- src_data.data(), src_data.size());
- ASSERT_TRUE(src_buffer);
- std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4};
- auto dst_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
- dst_data.data(), dst_data.size());
- ASSERT_TRUE(dst_buffer);
-
- // Test overlap. Non-overlapping regions should be fine, otherwise fail.
- std::vector<uint8_t> actual_data(dst_data.size());
- EXPECT_OK(dst_buffer->CopyData(0, dst_buffer.get(), 4, 1));
- EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(4, 1, 2, 3, 4));
- EXPECT_TRUE(
- IsInvalidArgument(dst_buffer->CopyData(2, dst_buffer.get(), 0, 3)));
- EXPECT_TRUE(
- IsInvalidArgument(dst_buffer->CopyData(0, dst_buffer.get(), 0, 3)));
-}
-
-TEST(BufferTest, CopyDataBadMode) {
- // Both source and target buffers must support mapping.
- auto nonmapping_src_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
- auto nonmapping_dst_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4);
- EXPECT_TRUE(IsPermissionDenied(nonmapping_dst_buffer->CopyData(
- 0, nonmapping_src_buffer.get(), 0, kWholeBuffer)));
- EXPECT_TRUE(IsPermissionDenied(nonmapping_src_buffer->CopyData(
- 0, nonmapping_dst_buffer.get(), 0, kWholeBuffer)));
-
- // Fail to copy into to constant buffers.
- std::vector<uint8_t> const_data = {1, 2, 3};
- auto constant_buffer =
- HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kTransfer,
- absl::MakeConstSpan(const_data));
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto src_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
- src_data.data(), src_data.size());
- EXPECT_TRUE(IsPermissionDenied(
- constant_buffer->CopyData(0, src_buffer.get(), 0, kWholeBuffer)));
-}
-
-TEST(BufferTest, CopyDataSubspan) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- auto src_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
- src_data.data(), src_data.size());
- ASSERT_TRUE(src_buffer);
- std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4};
- auto dst_buffer =
- HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer,
- dst_data.data(), dst_data.size());
- ASSERT_TRUE(dst_buffer);
-
- // Test on subspan.
- std::vector<uint8_t> actual_data(dst_data.size());
- ASSERT_OK_AND_ASSIGN(auto subspan_src_buffer,
- Buffer::Subspan(src_buffer, 1, 3));
- ASSERT_OK_AND_ASSIGN(auto subspan_dst_buffer,
- Buffer::Subspan(dst_buffer, 2, 3));
- EXPECT_OK(subspan_dst_buffer->CopyData(1, subspan_src_buffer.get(), 1, 2));
- EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 1, 2, 2, 3));
-}
-
-// NOTE: more tests related specifically to MappedMemory are in
-// buffer_mapping_test.cc. This tests the MapMemory operation and enough to
-// ensure the memory was mapped to the correct range and the HostBuffer and
-// SubspanBuffer work as intended for basic usage.
-TEST(BufferTest, MapMemory) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
- auto buffer = HeapBuffer::AllocateCopy(
- BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // 0-length mappings are valid.
- ASSERT_OK_AND_ASSIGN(auto mapping,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 0, 0));
- EXPECT_TRUE(mapping.empty());
- EXPECT_EQ(0, mapping.size());
- EXPECT_EQ(0, mapping.byte_length());
- EXPECT_NE(nullptr, mapping.data());
- ASSERT_OK_AND_ASSIGN(auto span, mapping.Subspan());
- EXPECT_TRUE(span.empty());
- mapping.reset();
-
- // Map the whole buffer for reading.
- ASSERT_OK_AND_ASSIGN(mapping, buffer->MapMemory<uint8_t>(MemoryAccess::kRead,
- 0, kWholeBuffer));
- EXPECT_EQ(src_data.size(), mapping.size());
- ASSERT_OK_AND_ASSIGN(span, mapping.Subspan());
- EXPECT_THAT(span, ElementsAre(0, 1, 2, 3, 4, 5, 6));
- mapping.reset();
-
- // Map a portion of the buffer for reading.
- ASSERT_OK_AND_ASSIGN(mapping,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 1, 2));
- EXPECT_EQ(2, mapping.size());
- ASSERT_OK_AND_ASSIGN(span, mapping.Subspan());
- EXPECT_THAT(span, ElementsAre(1, 2));
- mapping.reset();
-}
-
-TEST(BufferTest, MapMemoryNonByte) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
- auto buffer = HeapBuffer::AllocateCopy(
- BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // Map the buffer as non-byte values.
- // Note that we'll round down to the number of valid elements at the
- // alignment.
- ASSERT_OK_AND_ASSIGN(auto mapping16,
- buffer->MapMemory<uint16_t>(MemoryAccess::kRead));
- EXPECT_EQ(3, mapping16.size());
- EXPECT_LE(6, mapping16.byte_length());
- ASSERT_OK_AND_ASSIGN(auto span16, mapping16.Subspan());
- EXPECT_THAT(span16, ElementsAre(0x0100, 0x0302, 0x0504));
- mapping16.reset();
-}
-
-TEST(BufferTest, MapMemoryOutOfRange) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
- auto buffer = HeapBuffer::AllocateCopy(
- BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // Test invalid mapping ranges.
- EXPECT_TRUE(IsOutOfRange(
- buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 0, 123).status()));
- EXPECT_TRUE(IsOutOfRange(
- buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 5, 1231).status()));
- EXPECT_TRUE(IsOutOfRange(
- buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 6, kWholeBuffer)
- .status()));
- EXPECT_TRUE(IsOutOfRange(
- buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 1236, 1).status()));
-}
-
-TEST(BufferTest, MapMemoryBadMode) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
- auto read_buffer = HeapBuffer::AllocateCopy(
- BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead,
- src_data.data(), src_data.size());
- ASSERT_TRUE(read_buffer);
-
- // Test mapping the read-only buffer for writing.
- EXPECT_TRUE(IsPermissionDenied(
- read_buffer->MapMemory<uint8_t>(MemoryAccess::kWrite).status()));
- EXPECT_TRUE(IsPermissionDenied(
- read_buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status()));
- EXPECT_TRUE(IsPermissionDenied(
- read_buffer
- ->MapMemory<uint8_t>(MemoryAccess::kRead | MemoryAccess::kDiscard)
- .status()));
- EXPECT_TRUE(IsInvalidArgument(
- read_buffer->MapMemory<uint8_t>(MemoryAccess::kNone).status()));
-}
-
-TEST(BufferTest, MapMemoryWrite) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
- auto buffer = HeapBuffer::AllocateCopy(
- BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // Map and modify the data. We should see it when we read back.
- ASSERT_OK_AND_ASSIGN(auto mapping,
- buffer->MapMemory<uint8_t>(MemoryAccess::kWrite, 1, 2));
- auto mutable_data = mapping.mutable_data();
- mutable_data[0] = 0xAA;
- mutable_data[1] = 0xBB;
- mapping.reset();
- std::vector<uint8_t> actual_data(src_data.size());
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 0xAA, 0xBB, 3, 4, 5, 6));
-}
-
-TEST(BufferTest, MapMemoryDiscard) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
- auto buffer = HeapBuffer::AllocateCopy(
- BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll,
- src_data.data(), src_data.size());
- ASSERT_TRUE(buffer);
-
- // Map for discard. Note that we can't really rely on the value of the data
- // so we just trust that it's been discarded. It's a hint, anyway. We can be
- // sure that the data we didn't want to discard is the same though.
- std::vector<uint8_t> actual_data(src_data.size());
- ASSERT_OK_AND_ASSIGN(auto mapping, buffer->MapMemory<uint8_t>(
- MemoryAccess::kDiscardWrite, 1, 2));
- EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, _, _, 3, 4, 5, 6));
- mapping.reset();
-}
-
-TEST(BufferTest, MapMemorySubspan) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6};
- auto parent_buffer = HeapBuffer::AllocateCopy(
- BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll,
- src_data.data(), src_data.size());
- ASSERT_TRUE(parent_buffer);
- ASSERT_OK_AND_ASSIGN(auto subspan_buffer,
- Buffer::Subspan(parent_buffer, 1, 3));
- ASSERT_OK_AND_ASSIGN(auto mapping, subspan_buffer->MapMemory<uint8_t>(
- MemoryAccess::kDiscardWrite, 1, 2));
- auto* mutable_data = mapping.mutable_data();
- mutable_data[0] = 0xCC;
- mutable_data[1] = 0xDD;
- mapping.reset();
-
- std::vector<uint8_t> actual_data(src_data.size());
- EXPECT_OK(parent_buffer->ReadData(0, actual_data.data(), actual_data.size()));
- EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xCC, 0xDD, 4, 5, 6));
-
- // Just here to make coverage happy; they are currently no-ops on the host.
- // buffer_mapping_test.cc contains tests that ensure they are called
- // correctly.
- std::vector<uint8_t> external_data = {0, 1, 2, 3, 4};
- auto external_buffer = HeapBuffer::WrapMutable(
- MemoryType::kHostVisible | MemoryType::kHostCached, MemoryAccess::kAll,
- BufferUsage::kAll, absl::MakeSpan(external_data));
- ASSERT_OK_AND_ASSIGN(auto external_subspan_buffer,
- Buffer::Subspan(external_buffer, 0, 1));
- ASSERT_OK_AND_ASSIGN(
- mapping, external_subspan_buffer->MapMemory<uint8_t>(MemoryAccess::kAll));
- EXPECT_OK(mapping.Invalidate());
- EXPECT_OK(mapping.Flush());
-}
-
-} // namespace
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/buffer_view.cc b/iree/hal/buffer_view.cc
deleted file mode 100644
index 75687cb..0000000
--- a/iree/hal/buffer_view.cc
+++ /dev/null
@@ -1,180 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/buffer_view.h"
-
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer.h"
-
-namespace iree {
-namespace hal {
-
-namespace {
-// Pretty prints an array, e.g. [1, 2, 3, 4]
-inline std::string PrettyPrint(absl::Span<const int32_t> arr) {
- return "[" + absl::StrJoin(arr, ",") + "]";
-}
-} // namespace
-
-// static
-bool BufferView::Equal(const BufferView& lhs, const BufferView& rhs) {
- return lhs.buffer.get() == rhs.buffer.get() &&
- lhs.element_size == rhs.element_size && lhs.shape == rhs.shape;
-}
-
-std::string BufferView::DebugStringShort() const {
- if (element_size == 0) {
- return "Ø";
- }
- return shape.empty() ? std::to_string(element_size)
- : absl::StrCat(absl::StrJoin(shape.subspan(), "x"), "x",
- element_size);
-}
-
-StatusOr<device_size_t> BufferView::CalculateOffset(
- absl::Span<const int32_t> indices) const {
- if (indices.empty()) {
- return 0;
- } else if (shape.empty() || indices.size() > shape.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Indices " << PrettyPrint(indices)
- << " out of bounds of the rank of buffer_view "
- << DebugStringShort();
- }
- device_size_t offset = 0;
- for (int i = 0; i < indices.size(); ++i) {
- if (indices[i] >= shape[i]) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Indices[" << i << "]=" << indices[i]
- << " out of bounds of buffer_view " << DebugStringShort();
- }
- device_size_t axis_offset = indices[i];
- for (int j = i + 1; j < shape.size(); ++j) {
- axis_offset *= shape[j];
- }
- offset += axis_offset;
- }
- offset *= element_size;
- return offset;
-}
-
-StatusOr<BufferView> BufferView::Slice(
- absl::Span<const int32_t> start_indices,
- absl::Span<const int32_t> lengths) const {
- if (start_indices.size() != shape.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Slice start_indices " << PrettyPrint(start_indices)
- << " do not match rank of buffer_view " << DebugStringShort();
- }
- if (start_indices.size() != lengths.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Slice start_indices " << PrettyPrint(start_indices)
- << " and lengths " << PrettyPrint(lengths)
- << " are not the same size";
- }
-
- // Buffer::Subspan only support contiguous memory. To ensure that this slice
- // only requests such, we validate that the offset in the buffer between the
- // start and end indices is the same as the requested size of the slice.
- absl::InlinedVector<int32_t, 6> end_indices(lengths.size());
- device_size_t subspan_length = element_size;
- for (int i = 0; i < lengths.size(); ++i) {
- subspan_length *= lengths[i];
- end_indices[i] = start_indices[i] + lengths[i] - 1;
- }
-
- ASSIGN_OR_RETURN(auto start_byte_offset, CalculateOffset(start_indices));
- // Also validates the ends are in bounds.
- ASSIGN_OR_RETURN(auto end_byte_offset, CalculateOffset(end_indices));
-
- auto offset_length = end_byte_offset - start_byte_offset + element_size;
- if (subspan_length != offset_length) {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Slice for non-contiguous region of memory unimplemented. "
- "start_indices: "
- << PrettyPrint(start_indices) << " lengths: " << PrettyPrint(lengths)
- << " " << subspan_length << " " << offset_length << " "
- << PrettyPrint(end_indices);
- }
-
- ASSIGN_OR_RETURN(auto new_buffer,
- Buffer::Subspan(buffer, start_byte_offset, subspan_length));
- return BufferView(std::move(new_buffer), Shape(lengths), element_size);
-}
-
-// static
-Status BufferView::Copy(BufferView* src,
- absl::Span<const int32_t> src_start_indices,
- BufferView* dst,
- absl::Span<const int32_t> dst_start_indices,
- absl::Span<const int32_t> lengths) {
- if (src_start_indices.size() != src->shape.size() ||
- dst_start_indices.size() != dst->shape.size() ||
- src_start_indices.size() != lengths.size() ||
- dst_start_indices.size() != lengths.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Src/dst shape/size mismatch: src=" << src->DebugStringShort()
- << ", dst=" << dst->DebugStringShort()
- << ", src_indices=" << PrettyPrint(src_start_indices)
- << ", dst_indices=" << PrettyPrint(dst_start_indices)
- << ", lengths=" << PrettyPrint(lengths);
- }
-
- // Copies only support contiguous memory. To ensure that this copy
- // only requests such, we validate that the offset in the buffer between the
- // start and end indices is the same as the requested size of the copy.
- absl::InlinedVector<int32_t, 4> src_end_indices(lengths.size());
- absl::InlinedVector<int32_t, 4> dst_end_indices(lengths.size());
- device_size_t total_length = src->element_size;
- for (int i = 0; i < lengths.size(); ++i) {
- total_length *= lengths[i];
- src_end_indices[i] = src_start_indices[i] + lengths[i] - 1;
- dst_end_indices[i] = dst_start_indices[i] + lengths[i] - 1;
- }
-
- ASSIGN_OR_RETURN(auto src_start_byte_offset,
- src->CalculateOffset(src_start_indices));
- ASSIGN_OR_RETURN(auto src_end_byte_offset,
- src->CalculateOffset(src_end_indices));
- ASSIGN_OR_RETURN(auto dst_start_byte_offset,
- dst->CalculateOffset(dst_start_indices));
- ASSIGN_OR_RETURN(auto dst_end_byte_offset,
- dst->CalculateOffset(dst_end_indices));
-
- auto src_length =
- src_end_byte_offset - src_start_byte_offset + src->element_size;
- auto dst_length =
- dst_end_byte_offset - dst_start_byte_offset + dst->element_size;
- if (src_length != dst_length || src_length != total_length) {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Copy for non-contiguous region of memory unimplemented: "
- << src->DebugStringShort() << ", dst=" << dst->DebugStringShort()
- << ", src_indices=" << PrettyPrint(src_start_indices)
- << ", dst_indices=" << PrettyPrint(dst_start_indices)
- << ", lengths=" << PrettyPrint(lengths);
- }
-
- RETURN_IF_ERROR(dst->buffer->CopyData(dst_start_byte_offset,
- src->buffer.get(),
- src_start_byte_offset, total_length));
-
- return OkStatus();
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/buffer_view.h b/iree/hal/buffer_view.h
deleted file mode 100644
index 308a0bb..0000000
--- a/iree/hal/buffer_view.h
+++ /dev/null
@@ -1,108 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_BUFFER_VIEW_H_
-#define IREE_HAL_BUFFER_VIEW_H_
-
-#include <memory>
-#include <ostream>
-
-#include "iree/base/shape.h"
-#include "iree/hal/buffer.h"
-
-namespace iree {
-namespace hal {
-
-struct BufferView {
- // Returns true if the given buffer_views are exactly equal.
- static bool Equal(const BufferView& lhs, const BufferView& rhs);
-
- BufferView() = default;
- BufferView(ref_ptr<Buffer> buffer, Shape shape, int8_t element_size) noexcept
- : buffer(std::move(buffer)), shape(shape), element_size(element_size) {}
-
- BufferView(const BufferView& other) noexcept
- : buffer(add_ref(other.buffer)),
- shape(other.shape),
- element_size(other.element_size) {}
- BufferView& operator=(const BufferView& other) noexcept {
- buffer = add_ref(other.buffer);
- shape = other.shape;
- element_size = other.element_size;
- return *this;
- }
- BufferView(BufferView&& other) noexcept
- : buffer(std::move(other.buffer)),
- shape(other.shape),
- element_size(other.element_size) {}
- BufferView& operator=(BufferView&& other) noexcept {
- buffer = std::move(other.buffer);
- shape = other.shape;
- element_size = other.element_size;
- return *this;
- }
-
- // Returns a string useful for printing debug messages.
- std::string DebugStringShort() const;
-
- // Total length of the valid view range in bytes.
- device_size_t byte_length() const {
- return shape.element_count() * element_size;
- }
-
- // TODO(b/134586626): remove this when byte ranges are encoded in IR.
- // Calculates a byte offset into the buffer_view at the given dimension
- // indices.
- StatusOr<device_size_t> CalculateOffset(
- absl::Span<const int32_t> indices) const;
-
- // TODO(b/134586626): remove this when byte ranges are encoded in IR.
- // Returns a view onto the given range of the buffer underlying this view. The
- // returned view starts at the offset indicated by |start_indices| and has a
- // shape of |lengths|.
- // Only contiguous regions of memory are supported at the moment.
- StatusOr<BufferView> Slice(absl::Span<const int32_t> start_indices,
- absl::Span<const int32_t> lengths) const;
-
- // TODO(b/134586626): remove this when byte ranges are encoded in IR.
- static Status Copy(BufferView* src,
- absl::Span<const int32_t> src_start_indices,
- BufferView* dst,
- absl::Span<const int32_t> dst_start_indices,
- absl::Span<const int32_t> lengths);
-
- ref_ptr<Buffer> buffer;
- Shape shape;
- int8_t element_size;
- // TODO(benvanik): strides.
-};
-
-inline bool operator==(const BufferView& a, const BufferView& b) {
- return BufferView::Equal(a, b);
-}
-
-inline bool operator!=(const BufferView& a, const BufferView& b) {
- return !(a == b);
-}
-
-inline std::ostream& operator<<(std::ostream& stream,
- const BufferView& buffer_view) {
- stream << buffer_view.DebugStringShort();
- return stream;
-}
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_BUFFER_VIEW_H_
diff --git a/iree/hal/buffer_view_string_util.cc b/iree/hal/buffer_view_string_util.cc
deleted file mode 100644
index ad95743..0000000
--- a/iree/hal/buffer_view_string_util.cc
+++ /dev/null
@@ -1,542 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/buffer_view_string_util.h"
-
-#include <functional>
-#include <sstream>
-#include <type_traits>
-
-#include "absl/strings/ascii.h"
-#include "absl/strings/numbers.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/str_split.h"
-#include "absl/strings/strip.h"
-#include "absl/types/optional.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/hal/heap_buffer.h"
-
-namespace iree {
-namespace hal {
-
-namespace {
-
-/* clang-format off */
-constexpr char kHexValue[256] = {
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, // '0'..'9'
- 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'A'..'F'
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'a'..'f'
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
-};
-/* clang-format on */
-
-template <typename T>
-void HexStringToBytes(const char* from, T to, ptrdiff_t num) {
- for (int i = 0; i < num; i++) {
- to[i] = (kHexValue[from[i * 2] & 0xFF] << 4) +
- (kHexValue[from[i * 2 + 1] & 0xFF]);
- }
-}
-
-constexpr char kHexTable[513] =
- "000102030405060708090a0b0c0d0e0f"
- "101112131415161718191a1b1c1d1e1f"
- "202122232425262728292a2b2c2d2e2f"
- "303132333435363738393a3b3c3d3e3f"
- "404142434445464748494a4b4c4d4e4f"
- "505152535455565758595a5b5c5d5e5f"
- "606162636465666768696a6b6c6d6e6f"
- "707172737475767778797a7b7c7d7e7f"
- "808182838485868788898a8b8c8d8e8f"
- "909192939495969798999a9b9c9d9e9f"
- "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf"
- "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf"
- "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf"
- "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf"
- "e0e1e2e3e4e5e6e7e8e9eaebecedeeef"
- "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff";
-
-template <typename T>
-void BytesToHexString(const unsigned char* src, T dest, ptrdiff_t num) {
- auto dest_ptr = &dest[0];
- for (auto src_ptr = src; src_ptr != (src + num); ++src_ptr, dest_ptr += 2) {
- const char* hex_p = &kHexTable[*src_ptr * 2];
- std::copy(hex_p, hex_p + 2, dest_ptr);
- }
-}
-
-// Returns true if the given type is represented as binary hex data.
-bool IsBinaryType(absl::string_view type_str) {
- return !type_str.empty() && absl::ascii_isdigit(type_str[0]);
-}
-
-// Parses binary hex data.
-Status ParseBinaryData(absl::string_view data_str, Buffer* buffer) {
- data_str = absl::StripAsciiWhitespace(data_str);
- ASSIGN_OR_RETURN(auto mapping,
- buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite));
- auto contents = mapping.mutable_contents();
- size_t dst_i = 0;
- size_t src_i = 0;
- while (src_i < data_str.size() && dst_i < contents.size()) {
- char c = data_str[src_i];
- if (absl::ascii_isspace(c) || c == ',') {
- ++src_i;
- continue;
- }
- if (src_i + 1 >= data_str.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Invalid input hex data (offset=" << src_i << ")";
- }
- HexStringToBytes(data_str.data() + src_i, contents.data() + dst_i, 1);
- src_i += 2;
- ++dst_i;
- }
- if (dst_i < contents.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Too few elements to fill type; expected " << contents.size()
- << " but only read " << dst_i;
- } else if (data_str.size() - src_i > 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Input data string contains more elements than the underlying "
- "buffer ("
- << contents.size() << ")";
- }
- return OkStatus();
-}
-
-// Prints binary hex data.
-Status PrintBinaryData(int element_size, Buffer* buffer, size_t max_entries,
- std::ostream* stream) {
- max_entries *= element_size; // Counting bytes, but treat them as elements.
- ASSIGN_OR_RETURN(auto mapping,
- buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
- auto contents = mapping.contents();
- char hex_buffer[8 * 2];
- for (size_t i = 0; i < std::min(max_entries, mapping.size());
- i += element_size) {
- if (i > 0) *stream << " ";
- BytesToHexString(contents.data() + i, hex_buffer, element_size);
- *stream << hex_buffer;
- }
- if (mapping.size() > max_entries) *stream << "...";
- return OkStatus();
-}
-
-template <typename ElementType, typename Enabled = void>
-struct SimpleStrToValue {
- absl::optional<ElementType> operator()(absl::string_view text) const = delete;
-};
-
-template <typename IntegerType>
-struct SimpleStrToValue<
- IntegerType,
- typename std::enable_if<(sizeof(IntegerType) < 4), void>::type> {
- absl::optional<IntegerType> operator()(absl::string_view text) const {
- int32_t value;
- return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value}
- : absl::nullopt;
- }
-};
-
-template <typename IntegerType>
-struct SimpleStrToValue<
- IntegerType,
- typename std::enable_if<(sizeof(IntegerType) >= 4), void>::type> {
- absl::optional<IntegerType> operator()(absl::string_view text) const {
- IntegerType value;
- return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value}
- : absl::nullopt;
- }
-};
-
-template <>
-struct SimpleStrToValue<float, void> {
- absl::optional<float> operator()(absl::string_view text) const {
- float value;
- return absl::SimpleAtof(text, &value) ? absl::optional<float>{value}
- : absl::nullopt;
- }
-};
-
-template <>
-struct SimpleStrToValue<double, void> {
- absl::optional<double> operator()(absl::string_view text) const {
- double value;
- return absl::SimpleAtod(text, &value) ? absl::optional<double>{value}
- : absl::nullopt;
- }
-};
-
-template <typename T>
-Status ParseNumericalDataElement(absl::string_view data_str, size_t token_start,
- size_t token_end, absl::Span<T> contents,
- int dst_i) {
- if (dst_i >= contents.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Input data string contains more elements than the underlying "
- "buffer ("
- << contents.size() << ")";
- }
- auto element_str = data_str.substr(token_start, token_end - token_start + 1);
- auto element = SimpleStrToValue<T>()(element_str);
- if (!element.has_value()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Unable to parse element " << dst_i << " = '" << element_str
- << "'";
- }
- contents[dst_i] = element.value();
- return OkStatus();
-}
-
-template <typename T>
-Status ParseNumericalDataAsType(absl::string_view data_str, Buffer* buffer) {
- ASSIGN_OR_RETURN(auto mapping,
- buffer->MapMemory<T>(MemoryAccess::kDiscardWrite));
- auto contents = mapping.mutable_contents();
- size_t src_i = 0;
- size_t dst_i = 0;
- size_t token_start = std::string::npos;
- while (src_i < data_str.size()) {
- char c = data_str[src_i++];
- bool is_separator =
- absl::ascii_isspace(c) || c == ',' || c == '[' || c == ']';
- if (token_start == std::string::npos) {
- if (!is_separator) {
- token_start = src_i - 1;
- }
- continue;
- } else if (token_start != std::string::npos && !is_separator) {
- continue;
- }
- RETURN_IF_ERROR(ParseNumericalDataElement<T>(data_str, token_start,
- src_i - 2, contents, dst_i++));
- token_start = std::string::npos;
- }
- if (token_start != std::string::npos) {
- RETURN_IF_ERROR(ParseNumericalDataElement<T>(
- data_str, token_start, data_str.size() - 1, contents, dst_i++));
- }
- if (dst_i < contents.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Input data string contains fewer elements than the underlying "
- "buffer (expected "
- << contents.size() << ")";
- }
- return OkStatus();
-}
-
-// Parses numerical data (ints, floats, etc) in some typed form.
-Status ParseNumericalData(absl::string_view type_str,
- absl::string_view data_str, Buffer* buffer) {
- if (type_str == "i8") {
- return ParseNumericalDataAsType<int8_t>(data_str, buffer);
- } else if (type_str == "u8") {
- return ParseNumericalDataAsType<uint8_t>(data_str, buffer);
- } else if (type_str == "i16") {
- return ParseNumericalDataAsType<int16_t>(data_str, buffer);
- } else if (type_str == "u16") {
- return ParseNumericalDataAsType<uint16_t>(data_str, buffer);
- } else if (type_str == "i32") {
- return ParseNumericalDataAsType<int32_t>(data_str, buffer);
- } else if (type_str == "u32") {
- return ParseNumericalDataAsType<uint32_t>(data_str, buffer);
- } else if (type_str == "i64") {
- return ParseNumericalDataAsType<int64_t>(data_str, buffer);
- } else if (type_str == "u64") {
- return ParseNumericalDataAsType<uint64_t>(data_str, buffer);
- } else if (type_str == "f32") {
- return ParseNumericalDataAsType<float>(data_str, buffer);
- } else if (type_str == "f64") {
- return ParseNumericalDataAsType<double>(data_str, buffer);
- } else {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Unsupported type: " << type_str;
- }
-}
-
-template <typename T>
-void PrintElementList(const Shape& shape, absl::Span<const T> data,
- size_t* max_entries, std::ostream* stream) {
- if (shape.empty()) {
- // Scalar value.
- PrintElementList({1}, data, max_entries, stream);
- return;
- } else if (shape.size() == 1) {
- // Leaf dimension; output data.
- size_t max_count = std::min(*max_entries, static_cast<size_t>(shape[0]));
- *stream << absl::StrJoin(data.subspan(0, max_count), " ");
- if (max_count < shape[0]) {
- *stream << "...";
- }
- *max_entries -= max_count;
- } else {
- // Nested; recurse into next dimension.
- Shape nested_shape = Shape(shape.subspan(1));
- size_t length = nested_shape.element_count();
- size_t offset = 0;
- for (int i = 0; i < shape[0]; ++i) {
- *stream << "[";
- PrintElementList<T>(nested_shape, data.subspan(offset, length),
- max_entries, stream);
- offset += length;
- *stream << "]";
- }
- }
-}
-
-template <typename T>
-Status PrintNumericalDataAsType(const Shape& shape, Buffer* buffer,
- size_t max_entries, std::ostream* stream) {
- ASSIGN_OR_RETURN(auto mapping, buffer->MapMemory<T>(MemoryAccess::kRead));
- PrintElementList(shape, mapping.contents(), &max_entries, stream);
- return OkStatus();
-}
-
-// Prints numerical data (ints, floats, etc) from some typed form.
-Status PrintNumericalData(const Shape& shape, absl::string_view type_str,
- Buffer* buffer, size_t max_entries,
- std::ostream* stream) {
- if (type_str == "i8") {
- return PrintNumericalDataAsType<int8_t>(shape, buffer, max_entries, stream);
- } else if (type_str == "u8") {
- return PrintNumericalDataAsType<uint8_t>(shape, buffer, max_entries,
- stream);
- } else if (type_str == "i16") {
- return PrintNumericalDataAsType<int16_t>(shape, buffer, max_entries,
- stream);
- } else if (type_str == "u16") {
- return PrintNumericalDataAsType<uint16_t>(shape, buffer, max_entries,
- stream);
- } else if (type_str == "i32") {
- return PrintNumericalDataAsType<int32_t>(shape, buffer, max_entries,
- stream);
- } else if (type_str == "u32") {
- return PrintNumericalDataAsType<uint32_t>(shape, buffer, max_entries,
- stream);
- } else if (type_str == "i64") {
- return PrintNumericalDataAsType<int64_t>(shape, buffer, max_entries,
- stream);
- } else if (type_str == "u64") {
- return PrintNumericalDataAsType<uint64_t>(shape, buffer, max_entries,
- stream);
- } else if (type_str == "f32") {
- return PrintNumericalDataAsType<float>(shape, buffer, max_entries, stream);
- } else if (type_str == "f64") {
- return PrintNumericalDataAsType<double>(shape, buffer, max_entries, stream);
- } else {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Unsupported type: " << type_str;
- }
-}
-
-} // namespace
-
-StatusOr<int> GetTypeElementSize(absl::string_view type_str) {
- if (type_str.empty()) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "Type is empty";
- } else if (IsBinaryType(type_str)) {
- // If the first character is a digit then we are dealign with binary data.
- // The type is just the number of bytes per element.
- int element_size = 0;
- if (!absl::SimpleAtoi(type_str, &element_size)) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Unable to parse element size type '" << type_str << "'";
- }
- return element_size;
- }
- // We know that our types are single characters followed by bit counts.
- // If we start to support other types we may need to do something more clever.
- int bit_count = 0;
- if (!absl::SimpleAtoi(type_str.substr(1), &bit_count)) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Unable to parse type bit count from '" << type_str
- << "'; expecting something like 'i32'";
- }
- return bit_count / 8;
-}
-
-StatusOr<Shape> ParseShape(absl::string_view shape_str) {
- std::vector<int> dims;
- for (auto dim_str : absl::StrSplit(shape_str, 'x', absl::SkipWhitespace())) {
- int dim_value = 0;
- if (!absl::SimpleAtoi(dim_str, &dim_value)) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Invalid shape dimension '" << dim_str
- << "' while parsing shape '" << shape_str << "'";
- }
- dims.push_back(dim_value);
- }
- return Shape{dims};
-}
-
-StatusOr<BufferView> ParseBufferViewFromString(
- absl::string_view buffer_view_str, hal::Allocator* allocator) {
- // Strip whitespace that may come along (linefeeds/etc).
- buffer_view_str = absl::StripAsciiWhitespace(buffer_view_str);
- if (buffer_view_str.empty()) {
- // Empty lines denote empty buffer_views.
- return BufferView{};
- }
-
- // Split into the components we can work with: shape, type, and data.
- absl::string_view shape_and_type_str;
- absl::string_view data_str;
- auto equal_index = buffer_view_str.find('=');
- if (equal_index == std::string::npos) {
- // Treat a lack of = as defaulting the data to zeros.
- shape_and_type_str = buffer_view_str;
- } else {
- shape_and_type_str = buffer_view_str.substr(0, equal_index);
- data_str = buffer_view_str.substr(equal_index + 1);
- }
- absl::string_view shape_str;
- absl::string_view type_str;
- auto last_x_index = shape_and_type_str.rfind('x');
- if (last_x_index == std::string::npos) {
- // Scalar.
- type_str = shape_and_type_str;
- } else {
- // Has a shape.
- shape_str = shape_and_type_str.substr(0, last_x_index);
- type_str = shape_and_type_str.substr(last_x_index + 1);
- }
-
- // Populate BufferView metadata required for allocation.
- BufferView result;
- ASSIGN_OR_RETURN(result.element_size, GetTypeElementSize(type_str));
- ASSIGN_OR_RETURN(result.shape, ParseShape(shape_str));
-
- // Allocate the host buffer.
- size_t allocation_size = result.shape.element_count() * result.element_size;
- if (allocator) {
- ASSIGN_OR_RETURN(
- result.buffer,
- allocator->Allocate(MemoryType::kHostLocal | MemoryType::kDeviceVisible,
- BufferUsage::kAll | BufferUsage::kConstant,
- allocation_size));
- } else {
- result.buffer = HeapBuffer::Allocate(
- MemoryType::kHostLocal, BufferUsage::kAll | BufferUsage::kConstant,
- allocation_size);
- }
-
- if (!data_str.empty()) {
- // Parse the data from the string right into the buffer.
- if (IsBinaryType(type_str)) {
- // Parse as binary hex.
- RETURN_IF_ERROR(ParseBinaryData(data_str, result.buffer.get()));
- } else {
- // Parse as some nicely formatted type.
- RETURN_IF_ERROR(
- ParseNumericalData(type_str, data_str, result.buffer.get()));
- }
- }
-
- return result;
-}
-
-StatusOr<BufferViewPrintMode> ParseBufferViewPrintMode(absl::string_view str) {
- char str_char = str.empty() ? '?' : str[0];
- switch (str_char) {
- case 'b':
- return BufferViewPrintMode::kBinary;
- case 'i':
- return BufferViewPrintMode::kSignedInteger;
- case 'u':
- return BufferViewPrintMode::kUnsignedInteger;
- case 'f':
- return BufferViewPrintMode::kFloatingPoint;
- default:
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Unsupported output type '" << str << "'";
- }
-}
-
-StatusOr<std::string> PrintBufferViewToString(const BufferView& buffer_view,
- BufferViewPrintMode print_mode,
- size_t max_entries) {
- std::string result;
- RETURN_IF_ERROR(
- PrintBufferViewToString(buffer_view, print_mode, max_entries, &result));
- return result;
-}
-
-Status PrintBufferViewToString(const BufferView& buffer_view,
- BufferViewPrintMode print_mode,
- size_t max_entries, std::string* out_result) {
- std::ostringstream stream;
- RETURN_IF_ERROR(
- PrintBufferViewToStream(buffer_view, print_mode, max_entries, &stream));
- *out_result = stream.str();
- return OkStatus();
-}
-
-Status PrintBufferViewToStream(const BufferView& buffer_view,
- BufferViewPrintMode print_mode,
- size_t max_entries, std::ostream* stream) {
- if (!buffer_view.buffer) {
- // No buffer means the buffer_view is empty. We use the empty string to
- // denote this (as we have no useful information).
- return OkStatus();
- }
-
- // Pick a type based on the element size and the printing mode.
- std::string type_str;
- switch (print_mode) {
- case BufferViewPrintMode::kBinary:
- type_str = std::to_string(buffer_view.element_size);
- break;
- case BufferViewPrintMode::kSignedInteger:
- absl::StrAppend(&type_str, "i", buffer_view.element_size * 8);
- break;
- case BufferViewPrintMode::kUnsignedInteger:
- absl::StrAppend(&type_str, "u", buffer_view.element_size * 8);
- break;
- case BufferViewPrintMode::kFloatingPoint:
- absl::StrAppend(&type_str, "f", buffer_view.element_size * 8);
- break;
- }
-
- // [shape]x[type]= prefix (taking into account scalar values).
- *stream << absl::StrJoin(buffer_view.shape.begin(), buffer_view.shape.end(),
- "x");
- if (!buffer_view.shape.empty()) *stream << "x";
- *stream << type_str;
- *stream << "=";
-
- if (IsBinaryType(type_str)) {
- return PrintBinaryData(buffer_view.element_size, buffer_view.buffer.get(),
- max_entries, stream);
- } else {
- return PrintNumericalData(buffer_view.shape, type_str,
- buffer_view.buffer.get(), max_entries, stream);
- }
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/buffer_view_string_util.h b/iree/hal/buffer_view_string_util.h
deleted file mode 100644
index 6e9ca2c..0000000
--- a/iree/hal/buffer_view_string_util.h
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Utilities for working with BufferView data, mostly useful for testing.
-// These functions allow for conversion between types, parsing and printing, and
-// basic comparisons.
-//
-// The canonical BufferView string format is:
-// [shape]x[type]=value,value,...
-// For example:
-// 2x2xi32=0,1,2,3
-// Characters like [] are optional and will be ignored during parsing:
-// 2x2xi32=[[0 1][2 3]]
-//
-// The type may be one of the following:
-// * 1/2/4/8 = 1/2/4/8 byte elements in binary hex format.
-// * i8/u8 = signed/unsigned 8-bit integers.
-// * i16/u16 = signed/unsigned 16-bit integers.
-// * i32/u32 = signed/unsigned 32-bit integers.
-// * i64/u64 = signed/unsigned 64-bit integers.
-// * f32 = 32-bit floating-point number.
-// * f64 = 64-bit floating-point number.
-
-#ifndef IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_
-#define IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_
-
-#include <ostream>
-#include <string>
-
-#include "absl/strings/string_view.h"
-#include "iree/base/status.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/buffer_view.h"
-
-namespace iree {
-namespace hal {
-
-// Returns the size, in bytes, of the given type.
-StatusOr<int> GetTypeElementSize(absl::string_view type_str);
-
-// Returns a Shape parsed from the given NxMx... string.
-StatusOr<Shape> ParseShape(absl::string_view shape_str);
-
-// Parses a BufferView encoded in a string.
-// If an |allocator| is provided the buffer will be allocated as host-local and
-// device-visible. Otherwise, buffers will be host-local.
-// The format accepted matches that produced by PrintBufferViewToString.
-StatusOr<BufferView> ParseBufferViewFromString(
- absl::string_view buffer_view_str, hal::Allocator* allocator = nullptr);
-
-// Defines how the elements within a BufferView are interpreted during printing.
-enum class BufferViewPrintMode {
- // Interpret the data as if it were serialized bytes.
- // In this mode no conversion is performed and the bytes in memory are printed
- // as hex in groupings based on the element size. Shortened to 'b'.
- kBinary,
- // Interpret elements as signed integers; shortened to 'i'.
- kSignedInteger,
- // Interpret elements as unsigned integers; shortened to 'u'.
- kUnsignedInteger,
- // Interpret elements as floating-point values; shortened to 'f'.
- kFloatingPoint,
-};
-
-// Returns the BufferViewPrintMode based on the shortened char in |str|.
-StatusOr<BufferViewPrintMode> ParseBufferViewPrintMode(absl::string_view str);
-
-// Prints a BufferView to a string encoded in the canonical format.
-StatusOr<std::string> PrintBufferViewToString(const BufferView& buffer_view,
- BufferViewPrintMode print_mode,
- size_t max_entries);
-Status PrintBufferViewToString(const BufferView& buffer_view,
- BufferViewPrintMode print_mode,
- size_t max_entries, std::string* out_result);
-
-// Prints a BufferView to a string stream encoded in the canonical format.
-Status PrintBufferViewToStream(const BufferView& buffer_view,
- BufferViewPrintMode print_mode,
- size_t max_entries, std::ostream* stream);
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_
diff --git a/iree/hal/buffer_view_string_util_test.cc b/iree/hal/buffer_view_string_util_test.cc
deleted file mode 100644
index 11902ba..0000000
--- a/iree/hal/buffer_view_string_util_test.cc
+++ /dev/null
@@ -1,186 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/buffer_view_string_util.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status.h"
-#include "iree/base/status_matchers.h"
-
-namespace iree {
-namespace hal {
-namespace {
-
-using ::testing::ElementsAre;
-
-template <typename T>
-StatusOr<std::vector<T>> ReadBuffer(const ref_ptr<Buffer>& buffer) {
- std::vector<T> result;
- result.resize(buffer->byte_length() / sizeof(T));
- RETURN_IF_ERROR(
- buffer->ReadData(0, result.data(), result.size() * sizeof(T)));
- return result;
-}
-
-TEST(BufferViewUtilTest, GetTypeElementSize) {
- EXPECT_EQ(1, GetTypeElementSize("1").ValueOrDie());
- EXPECT_EQ(7, GetTypeElementSize("7").ValueOrDie());
- EXPECT_EQ(4, GetTypeElementSize("i32").ValueOrDie());
- EXPECT_EQ(8, GetTypeElementSize("f64").ValueOrDie());
-
- EXPECT_FALSE(GetTypeElementSize("").ok());
- EXPECT_FALSE(GetTypeElementSize(" ").ok());
- EXPECT_FALSE(GetTypeElementSize("a").ok());
- EXPECT_FALSE(GetTypeElementSize("ib").ok());
- EXPECT_FALSE(GetTypeElementSize("i").ok());
- EXPECT_FALSE(GetTypeElementSize("i543ff").ok());
-}
-
-TEST(BufferViewUtilTest, ParseShape) {
- EXPECT_EQ((Shape{}), ParseShape("").ValueOrDie());
- EXPECT_EQ((Shape{1}), ParseShape("1").ValueOrDie());
- EXPECT_EQ((Shape{1, 2}), ParseShape("1x2").ValueOrDie());
- EXPECT_EQ((Shape{1, 2}), ParseShape(" 1 x 2 ").ValueOrDie());
-
- EXPECT_FALSE(ParseShape("abc").ok());
- EXPECT_FALSE(ParseShape("1xf").ok());
- EXPECT_FALSE(ParseShape("1xff23").ok());
-}
-
-TEST(BufferViewUtilTest, ParseBufferViewFromStringEmpty) {
- // Empty string = empty buffer_view.
- ASSERT_OK_AND_ASSIGN(auto m0, ParseBufferViewFromString(""));
- EXPECT_EQ(nullptr, m0.buffer.get());
- EXPECT_EQ(Shape{}, m0.shape);
- EXPECT_EQ(0, m0.element_size);
-
- // No = means no data.
- ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString("4x2xf32"));
- EXPECT_EQ(4 * 2 * 4, m1.buffer->allocation_size());
- EXPECT_EQ(Shape({4, 2}), m1.shape);
- EXPECT_EQ(4, m1.element_size);
- EXPECT_THAT(ReadBuffer<float>(m1.buffer).ValueOrDie(),
- ElementsAre(0, 0, 0, 0, 0, 0, 0, 0));
-
- // No data after = means no data.
- ASSERT_OK_AND_ASSIGN(auto m2, ParseBufferViewFromString("4x2xf32="));
- EXPECT_EQ(4 * 2 * 4, m2.buffer->allocation_size());
- EXPECT_EQ(Shape({4, 2}), m2.shape);
- EXPECT_EQ(4, m2.element_size);
- EXPECT_THAT(ReadBuffer<float>(m2.buffer).ValueOrDie(),
- ElementsAre(0, 0, 0, 0, 0, 0, 0, 0));
-}
-
-TEST(BufferViewUtilTest, ParseBufferViewFromStringBinary) {
- ASSERT_OK_AND_ASSIGN(auto m0, ParseBufferViewFromString("4x1=00 01 02 03"));
- EXPECT_EQ(Shape({4}), m0.shape);
- EXPECT_EQ(1, m0.element_size);
- EXPECT_THAT(ReadBuffer<uint8_t>(m0.buffer).ValueOrDie(),
- ElementsAre(0, 1, 2, 3));
-
- // Whitespace shouldn't matter.
- ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString("4x1=00,010203"));
- EXPECT_EQ(Shape({4}), m1.shape);
- EXPECT_EQ(1, m1.element_size);
- EXPECT_THAT(ReadBuffer<uint8_t>(m1.buffer).ValueOrDie(),
- ElementsAre(0, 1, 2, 3));
-
- // Should fail on malformed hex bytes.
- EXPECT_FALSE(ParseBufferViewFromString("4x1=1").ok());
- EXPECT_FALSE(ParseBufferViewFromString("4x1=00003").ok());
- EXPECT_FALSE(ParseBufferViewFromString("4x1=%0123%\1").ok());
- EXPECT_FALSE(ParseBufferViewFromString("4x1=00010203040506").ok());
-}
-
-TEST(BufferViewUtilTest, ParseBufferViewFromStringAllowBrackets) {
- ASSERT_OK_AND_ASSIGN(auto m0,
- ParseBufferViewFromString("4xi16=[[0][ 1 ][2]][3]"));
- EXPECT_EQ(Shape({4}), m0.shape);
- EXPECT_EQ(2, m0.element_size);
- EXPECT_THAT(ReadBuffer<int16_t>(m0.buffer).ValueOrDie(),
- ElementsAre(0, 1, 2, 3));
-}
-
-TEST(BufferViewUtilTest, ParseBufferViewFromStringInteger) {
- // Signed int16.
- ASSERT_OK_AND_ASSIGN(auto m0,
- ParseBufferViewFromString("4xi16=0 12345 65535 -2"));
- EXPECT_EQ(Shape({4}), m0.shape);
- EXPECT_EQ(2, m0.element_size);
- EXPECT_THAT(ReadBuffer<int16_t>(m0.buffer).ValueOrDie(),
- ElementsAre(0, 12345, -1, -2));
-
- // Unsigned int16.
- ASSERT_OK_AND_ASSIGN(auto m1,
- ParseBufferViewFromString("4xu16=0 12345 65535 -2"));
- EXPECT_EQ(Shape({4}), m1.shape);
- EXPECT_EQ(2, m1.element_size);
- EXPECT_THAT(ReadBuffer<uint16_t>(m1.buffer).ValueOrDie(),
- ElementsAre(0, 12345, 65535, 65534));
-
- // Mixing separator types is ok.
- ASSERT_OK_AND_ASSIGN(auto m2,
- ParseBufferViewFromString("4xu16=0, 12345, 65535, -2"));
- EXPECT_EQ(Shape({4}), m2.shape);
- EXPECT_EQ(2, m2.element_size);
- EXPECT_THAT(ReadBuffer<uint16_t>(m2.buffer).ValueOrDie(),
- ElementsAre(0, 12345, 65535, 65534));
-
- // Should fail on malformed integers bytes and out of bounds values.
- EXPECT_FALSE(ParseBufferViewFromString("4xi32=asodfj").ok());
- EXPECT_FALSE(ParseBufferViewFromString("4xi32=0 1 2 3 4").ok());
-}
-
-TEST(BufferViewUtilTest, ParseBufferViewFromStringFloat) {
- // Float.
- ASSERT_OK_AND_ASSIGN(auto m0,
- ParseBufferViewFromString("4xf32=0 1.0 1234 -2.0e-5"));
- EXPECT_EQ(Shape({4}), m0.shape);
- EXPECT_EQ(4, m0.element_size);
- EXPECT_THAT(ReadBuffer<float>(m0.buffer).ValueOrDie(),
- ElementsAre(0.0f, 1.0f, 1234.0f, -2.0e-5f));
-
- // Double.
- ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString(
- "4xf64=0 1.0 123456789012345 -2.0e-5"));
- EXPECT_EQ(Shape({4}), m1.shape);
- EXPECT_EQ(8, m1.element_size);
- EXPECT_THAT(ReadBuffer<double>(m1.buffer).ValueOrDie(),
- ElementsAre(0.0, 1.0, 123456789012345.0, -2.0e-5));
-
- // Should fail on malformed floats and out of bounds values.
- EXPECT_FALSE(ParseBufferViewFromString("4xf32=asodfj").ok());
- EXPECT_FALSE(ParseBufferViewFromString("4xf32=0").ok());
- EXPECT_FALSE(ParseBufferViewFromString("4xf32=0 1 2 3 4").ok());
-}
-
-TEST(BufferViewUtilTest, ParseBufferViewPrintMode) {
- EXPECT_EQ(BufferViewPrintMode::kBinary,
- ParseBufferViewPrintMode("b").ValueOrDie());
- EXPECT_EQ(BufferViewPrintMode::kSignedInteger,
- ParseBufferViewPrintMode("i").ValueOrDie());
- EXPECT_EQ(BufferViewPrintMode::kUnsignedInteger,
- ParseBufferViewPrintMode("u").ValueOrDie());
- EXPECT_EQ(BufferViewPrintMode::kFloatingPoint,
- ParseBufferViewPrintMode("f").ValueOrDie());
-
- EXPECT_FALSE(ParseBufferViewPrintMode("").ok());
- EXPECT_FALSE(ParseBufferViewPrintMode("s").ok());
- EXPECT_FALSE(ParseBufferViewPrintMode("asdfasdf").ok());
-}
-
-} // namespace
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/buffer_view_test.cc b/iree/hal/buffer_view_test.cc
deleted file mode 100644
index 18dcbcf..0000000
--- a/iree/hal/buffer_view_test.cc
+++ /dev/null
@@ -1,285 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/buffer_view.h"
-
-#include <numeric>
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status.h"
-#include "iree/base/status_matchers.h"
-#include "iree/hal/buffer.h"
-#include "iree/hal/heap_buffer.h"
-
-namespace iree {
-namespace hal {
-namespace {
-
-template <typename T>
-BufferView MakeView(const std::vector<T> src_data, Shape shape) {
- auto parent_buffer = HeapBuffer::AllocateCopy(
- BufferUsage::kTransfer | BufferUsage::kMapping, absl::MakeSpan(src_data));
-
- return BufferView(std::move(parent_buffer), shape, sizeof(T));
-}
-
-template <typename T>
-std::vector<T> ReadData(BufferView view) {
- std::vector<T> data(view.shape.element_count());
- EXPECT_OK(view.buffer->ReadData(0, data.data(), data.size() * sizeof(T)));
- return data;
-}
-
-TEST(BufferViewTest, SliceWholeBuffer) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- Shape shape = {2, 2};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {0, 0};
- std::vector<int32_t> lengths = {2, 2};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_TRUE(BufferView::Equal(parent_view, slice))
- << "original parent_view " << parent_view.DebugStringShort()
- << " and whole slice " << slice.DebugStringShort() << " are not equal";
-}
-
-TEST(BufferViewTest, SliceSingleRow) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- Shape shape = {2, 2};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 0};
- std::vector<int32_t> lengths = {1, 2};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({2, 3}));
-}
-
-TEST(BufferViewTest, SliceRowStart) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7};
- Shape shape = {2, 4};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 0};
- std::vector<int32_t> lengths = {1, 3};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({4, 5, 6}));
-}
-
-TEST(BufferViewTest, SliceRowEnd) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7};
- Shape shape = {2, 4};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 1};
- std::vector<int32_t> lengths = {1, 3};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({5, 6, 7}));
-}
-
-TEST(BufferViewTest, SliceRowMiddle) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7};
- Shape shape = {2, 4};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 1};
- std::vector<int32_t> lengths = {1, 2};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({5, 6}));
-}
-
-TEST(BufferViewTest, SliceMultiRow) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8};
- Shape shape = {3, 3};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 0};
- std::vector<int32_t> lengths = {2, 3};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({3, 4, 5, 6, 7, 8}));
-}
-
-TEST(BufferViewTest, SliceHighRank) {
- std::vector<uint8_t> src_data(81);
- std::iota(src_data.begin(), src_data.end(), 0);
- Shape shape = {3, 3, 3, 3};
-
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 2, 2, 1};
- std::vector<int32_t> lengths = {1, 1, 1, 2};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({52, 53}));
-}
-
-TEST(BufferViewTest, SliceModifySlice) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- Shape shape = {2, 2};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 0};
- std::vector<int32_t> lengths = {1, 2};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_OK(slice.buffer->Fill8(0, kWholeBuffer, 0xFFu));
-
- auto parent_data = ReadData<uint8_t>(parent_view);
- EXPECT_EQ(parent_data, std::vector<uint8_t>({0, 1, 0xFFu, 0xFFu}));
-}
-
-TEST(BufferViewTest, SliceModifyParent) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- Shape shape = {2, 2};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 0};
- std::vector<int32_t> lengths = {1, 2};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_OK(parent_view.buffer->Fill8(0, kWholeBuffer, 0xFFu));
-
- EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({0xFFu, 0xFFu}));
-}
-
-TEST(BufferViewTest, SliceMultiByteElementWholeBuffer) {
- const std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3};
-
- Shape shape = {2, 2};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {0, 0};
- std::vector<int32_t> lengths = {2, 2};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_TRUE(BufferView::Equal(parent_view, slice))
- << "original parent_view " << parent_view.DebugStringShort()
- << " and whole slice " << slice.DebugStringShort() << " are not equal";
-}
-
-TEST(BufferViewTest, SliceShapeAndElementSize) {
- std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3};
- Shape shape = {2, 2};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 0};
- std::vector<int32_t> lengths = {1, 2};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
- EXPECT_EQ(slice.shape, Shape(lengths));
- EXPECT_EQ(slice.element_size, 4);
-}
-
-TEST(BufferViewTest, SliceMultiByteElement) {
- std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3};
- Shape shape = {2, 2};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 0};
- std::vector<int32_t> lengths = {1, 2};
- ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths));
-
- EXPECT_EQ(ReadData<int32_t>(slice), std::vector<int32_t>({2, 3}));
-}
-
-TEST(BufferViewTest, SliceIndexBadRank) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- Shape shape = {2, 2};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {0};
- std::vector<int32_t> lengths = {2};
- EXPECT_TRUE(
- IsInvalidArgument(parent_view.Slice(start_indices, lengths).status()));
-}
-
-TEST(BufferViewTest, SliceIndexLengthMismatch) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- Shape shape = {2, 2};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {0, 0};
- std::vector<int32_t> lengths = {2};
- EXPECT_TRUE(
- IsInvalidArgument(parent_view.Slice(start_indices, lengths).status()));
-}
-
-TEST(BufferViewTest, SliceIndicesOutOfBounds) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- Shape shape = {2, 2};
-
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {0, 3};
- std::vector<int32_t> lengths = {1, 1};
- EXPECT_TRUE(
- IsInvalidArgument(parent_view.Slice(start_indices, lengths).status()));
-}
-
-TEST(BufferViewTest, SliceLengthsOutOfBounds) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3};
- Shape shape = {2, 2};
-
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {0, 0};
- std::vector<int32_t> lengths = {1, 3};
- EXPECT_TRUE(
- IsInvalidArgument(parent_view.Slice(start_indices, lengths).status()));
-}
-
-TEST(BufferViewTest, SliceNonContiguous) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8};
- Shape shape = {3, 3};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 1};
- std::vector<int32_t> lengths = {2, 2};
- EXPECT_TRUE(
- IsUnimplemented(parent_view.Slice(start_indices, lengths).status()));
-}
-
-TEST(BufferViewTest, SliceNonContiguousMultiRowLeft) {
- std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8};
- Shape shape = {3, 3};
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 0};
- std::vector<int32_t> lengths = {2, 1};
- EXPECT_TRUE(
- IsUnimplemented(parent_view.Slice(start_indices, lengths).status()));
-}
-
-TEST(BufferViewTest, SliceHighRankNonContiguous) {
- std::vector<uint8_t> src_data(81);
- std::iota(src_data.begin(), src_data.end(), 0);
- Shape shape = {3, 3, 3, 3};
-
- auto parent_view = MakeView(src_data, shape);
-
- std::vector<int32_t> start_indices = {1, 0, 2, 1};
- std::vector<int32_t> lengths = {1, 2, 1, 2};
- EXPECT_TRUE(
- IsUnimplemented(parent_view.Slice(start_indices, lengths).status()));
-}
-
-} // namespace
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/command_buffer.cc b/iree/hal/command_buffer.cc
deleted file mode 100644
index d83810a..0000000
--- a/iree/hal/command_buffer.cc
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/command_buffer.h"
-
-namespace iree {
-namespace hal {
-
-std::string CommandCategoryString(CommandCategoryBitfield categories) {
- return FormatBitfieldValue(categories,
- {
- {CommandCategory::kTransfer, "kTransfer"},
- {CommandCategory::kDispatch, "kDispatch"},
- });
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/command_buffer.h b/iree/hal/command_buffer.h
deleted file mode 100644
index 9b32232..0000000
--- a/iree/hal/command_buffer.h
+++ /dev/null
@@ -1,383 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_COMMAND_BUFFER_H_
-#define IREE_HAL_COMMAND_BUFFER_H_
-
-#include <cstdint>
-
-#include "iree/base/bitfield.h"
-#include "iree/base/shape.h"
-#include "iree/base/status.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/buffer.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/event.h"
-#include "iree/hal/executable.h"
-#include "iree/hal/resource.h"
-
-namespace iree {
-namespace hal {
-
-// A bitfield specifying the mode of operation for a command buffer.
-enum class CommandBufferMode : uint32_t {
- // Command buffer will be submitted once and never used again.
- // This may enable in-place patching of command buffers that reduce overhead
- // when it's known that command buffers will not be reused.
- kOneShot = 1 << 0,
-};
-IREE_BITFIELD(CommandBufferMode);
-using CommandBufferModeBitfield = CommandBufferMode;
-std::string CommandBufferModeString(CommandBufferModeBitfield mode);
-
-// A bitfield specifying the category of commands in a command queue.
-enum class CommandCategory : uint32_t {
- // Command is considered a transfer operation (memcpy, etc).
- kTransfer = 1 << 0,
- // Command is considered a dispatch operation (dispatch/execute).
- kDispatch = 1 << 1,
-};
-IREE_BITFIELD(CommandCategory);
-using CommandCategoryBitfield = CommandCategory;
-std::string CommandCategoryString(CommandCategoryBitfield categories);
-
-// Bitfield specifying which execution stage a brarrier should start/end at.
-//
-// Maps to VkPipelineStageFlagBits.
-enum class ExecutionStage : uint32_t {
- // Top of the pipeline when commands are initially issued by the device.
- kCommandIssue = 1 << 0,
- // Stage of the pipeline when dispatch parameter data is consumed.
- kCommandProcess = 1 << 1,
- // Stage where dispatch commands execute.
- kDispatch = 1 << 2,
- // Stage where transfer (copy/clear/fill/etc) commands execute.
- kTransfer = 1 << 3,
- // Final stage in the pipeline when commands are retired on the device.
- kCommandRetire = 1 << 4,
- // Pseudo-stage for read/writes by the host. Not executed on device.
- kHost = 1 << 5,
-};
-IREE_BITFIELD(ExecutionStage);
-using ExecutionStageBitfield = ExecutionStage;
-
-// Bitfield specifying which scopes will access memory and how.
-//
-// Maps to VkAccessFlagBits.
-enum class AccessScope : uint32_t {
- // Read access to indirect command data as part of an indirect dispatch.
- kIndirectCommandRead = 1 << 0,
- // Constant uniform buffer reads by the device.
- kConstantRead = 1 << 1,
- // Storage buffer reads by dispatch commands.
- kDispatchRead = 1 << 2,
- // Storage buffer writes by dispatch commands.
- kDispatchWrite = 1 << 3,
- // Source of a transfer operation.
- kTransferRead = 1 << 4,
- // Target of a transfer operation.
- kTransferWrite = 1 << 5,
- // Read operation by the host through mapped memory.
- kHostRead = 1 << 6,
- // Write operation by the host through mapped memory.
- kHostWrite = 1 << 7,
- // External/non-specific read.
- kMemoryRead = 1 << 8,
- // External/non-specific write.
- kMemoryWrite = 1 << 9,
-};
-IREE_BITFIELD(AccessScope);
-using AccessScopeBitfield = AccessScope;
-
-// Defines a global memory barrier.
-// These are cheaper to encode than buffer-specific barriers but may cause
-// stalls and bubbles in device pipelines if applied too broadly. Prefer them
-// over equivalently large sets of buffer-specific barriers (such as when
-// completely changing execution contexts).
-//
-// Maps to VkMemoryBarrier.
-struct MemoryBarrier {
- // All access scopes prior-to the barrier (inclusive).
- AccessScopeBitfield source_scope;
- // All access scopes following the barrier (inclusive).
- AccessScopeBitfield target_scope;
-};
-
-// Defines a memory barrier that applies to a range of a specific buffer.
-// Use of these (vs. global memory barriers) provides fine-grained execution
-// ordering to device command processors and allows for more aggressive
-// reordering.
-//
-// Maps to VkBufferMemoryBarrier.
-struct BufferBarrier {
- // All access scopes prior-to the barrier (inclusive).
- AccessScopeBitfield source_scope;
- // All access scopes following the barrier (inclusive).
- AccessScopeBitfield target_scope;
- // Buffer the barrier is restricted to.
- // The barrier will apply to the entire physical device allocation.
- Buffer* buffer = nullptr;
- // Relative offset/length within |buffer| (which may itself be mapped into the
- // device allocation at an offset).
- device_size_t offset = 0;
- device_size_t length = kWholeBuffer;
-};
-
-// Represents a binding to a buffer with a set of attributes.
-// This may be used by drivers to validate alignment.
-struct BufferBinding {
- // Access rights of the buffer contents by the executable.
- MemoryAccessBitfield access = MemoryAccess::kAll;
-
- // The buffer this binding references.
- // The buffer is not retained by the binding and must be kept alive externally
- // for the duration it is in use by the queue.
- Buffer* buffer = nullptr;
-
- // Shape of the buffer contents.
- Shape shape;
-
- // Size of each element within the buffer, in bytes.
- int8_t element_size = 0;
-
- BufferBinding() = default;
- BufferBinding(MemoryAccessBitfield access, Buffer* buffer)
- : access(access), buffer(buffer) {}
- BufferBinding(MemoryAccessBitfield access, Buffer* buffer, Shape shape,
- int8_t element_size)
- : access(access),
- buffer(buffer),
- shape(shape),
- element_size(element_size) {}
- BufferBinding(MemoryAccessBitfield access, const BufferView& buffer_view)
- : access(access),
- buffer(buffer_view.buffer.get()),
- shape(buffer_view.shape),
- element_size(buffer_view.element_size) {}
-};
-
-// Wraps parameters for a Dispatch request.
-struct DispatchRequest {
- // Executable prepared for use on the device.
- // The executable must remain alive until all in-flight dispatch requests
- // that use it have completed.
- Executable* executable = nullptr;
-
- // Executable entry point ordinal.
- int entry_point = 0;
-
- // TODO(benvanik): predication.
-
- // Static workload parameters defining the X, Y, and Z workgroup counts.
- std::array<int32_t, 3> workload;
-
- // An optional buffer containing the dynamic workload to dispatch.
- // The contents need not be available at the time of recording but must be
- // made visible prior to execution of the dispatch command.
- //
- // Buffer contents are expected to be 3 int32 values defining the X, Y, and Z
- // workgroup counts.
- //
- // The buffer must have been allocated with BufferUsage::kDispatch and be
- // of MemoryType::kDeviceVisible.
- Buffer* workload_buffer = nullptr;
-
- // A list of buffers that contain the execution inputs/outputs.
- // Order is dependent on executable arg layout.
- //
- // Buffers must have been allocated with BufferUsage::kDispatch and be
- // of MemoryType::kDeviceVisible.
- absl::Span<const BufferBinding> bindings;
-
- // TODO(benvanik): push-constant equivalent (uniforms, etc).
-};
-
-// Asynchronous command buffer recording interface.
-// Commands are recorded by the implementation for later submission to command
-// queues.
-//
-// Buffers and synchronization objects referenced must remain valid and not be
-// modified or read while there are commands in-flight. The usual flow is to
-// populate input buffers, Dispatch using those buffers, wait on a Fence until
-// the buffers are guaranteed to no longer be in use, and then reuse or release
-// the buffers.
-//
-// Errors that can be recognized when operations are enqueued will be returned
-// immediately, such as invalid argument errors. Errors that can only be
-// determined at execution time will be returned on fences. Once a failure
-// occurs the device queue will enter an error state that invalidates all
-// operations on the device queue (as ordering is not strict and any may still
-// be in-flight). In this case the user of the device queue should treat all
-// in-flight operations as cancelled and fully reset themselves. Other device
-// queues that may be waiting on events from the device queue will also enter
-// error states. Only once a user has acknowledged and cleared the error state
-// with a Reset the queue will become usable, and otherwise all operations will
-// return errors.
-//
-// Command buffers are thread-compatible. Use multiple command buffers if trying
-// to record commands from multiple threads. Command buffers must not be mutated
-// between when they have are submitted for execution on a queue and when the
-// fence fires indicating the completion of their execution.
-class CommandBuffer : public Resource {
- public:
- virtual CommandBuffer* impl() { return this; }
-
- // Device allocator that commands encoded into the buffer share compatibility
- // with.
- Allocator* allocator() const { return allocator_; }
-
- // Command buffer operation mode.
- CommandBufferModeBitfield mode() const { return mode_; }
-
- // Command categories that may be recorded into the buffer.
- CommandCategoryBitfield command_categories() const {
- return command_categories_;
- }
-
- // True if the command buffer is between a Begin/End recording block.
- virtual bool is_recording() const = 0;
-
- // Resets and begins recording into the command buffer, clearing all
- // previously recorded contents.
- // The command buffer must not be in-flight.
- virtual Status Begin() = 0;
-
- // Ends recording into the command buffer.
- // This must be called prior to submitting the command buffer for execution.
- virtual Status End() = 0;
-
- // TODO(benvanik): annotations for debugging and tracing:
- // enter/exit
- // stack frame manipulation
- // explicit timers? or profiling buffer?
-
- // TODO(b/138719910): cross-queue and external acquire/release.
- // virtual Status AcquireBuffer() = 0;
- // virtual Status ReleaseBuffer() = 0;
-
- // Defines a memory dependency between commands recorded before and after the
- // barrier. One or more memory or buffer barriers can be specified to indicate
- // between which stages or buffers the dependencies exist.
- virtual Status ExecutionBarrier(
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) = 0;
-
- // Sets an event to the signaled state.
- // |source_stage_mask| specifies when the event is signaled.
- //
- // Events are only valid within a single command buffer. Events can only be
- // used on non-transfer queues.
- virtual Status SignalEvent(Event* event,
- ExecutionStageBitfield source_stage_mask) = 0;
-
- // Resets an event to the non-signaled state.
- // |source_stage_mask| specifies when the event is unsignaled.
- //
- // Events are only valid within a single command buffer. Events can only be
- // used on non-transfer queues.
- virtual Status ResetEvent(Event* event,
- ExecutionStageBitfield source_stage_mask) = 0;
-
- // Waits for one or more events to be signaled and defines a memory dependency
- // between the synchronization scope of the signal operations and the commands
- // following the wait.
- //
- // |source_stage_mask| must include ExecutionStage::kHost for Event::Signal to
- // be visibile.
- //
- // Events are only valid within a single command buffer. Events remain
- // signaled even after waiting and must be reset to be reused. Events can only
- // be used on non-transfer queues.
- virtual Status WaitEvents(
- absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) = 0;
-
- // Fills the target buffer with the given repeating value.
- // Expects that value_length is one of 1, 2, or 4 and that the offset and
- // length are aligned to the natural alignment of the value.
- // The target buffer must be compatible with the devices owned by this
- // device queue and be allocated with BufferUsage::kTransfer.
- virtual Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
- device_size_t length, const void* pattern,
- size_t pattern_length) = 0;
-
- // Hints to the device queue that the given buffer will not be used again.
- // After encoding a discard the buffer contents will be considered undefined.
- // This is because the discard may be used to elide write backs to host memory
- // or aggressively reuse the allocation for other purposes.
- //
- // For buffers allocated with MemoryType::kTransient this may allow
- // the device queue to reclaim the memory used by the buffer earlier than
- // otherwise possible.
- virtual Status DiscardBuffer(Buffer* buffer) = 0;
-
- // Updates a range of the given target buffer from the source host memory.
- // The source host memory is copied immediately into the command buffer and
- // occupies command buffer space. It is strongly recommended that large buffer
- // updates are performed via CopyBuffer where there is the possibility of a
- // zero-copy path.
- // The |source_buffer| may be releaed by the caller immediately after this
- // call returns.
- // The |target_buffer| must be compatible with the devices owned by this
- // device queue and be allocated with BufferUsage::kTransfer.
- virtual Status UpdateBuffer(const void* source_buffer,
- device_size_t source_offset,
- Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length) = 0;
-
- // Copies a range of one buffer to another.
- // Both buffers must be compatible with the devices owned by this device
- // queue and be allocated with BufferUsage::kTransfer. Though the source and
- // target buffer may be the same the ranges must not overlap (as with memcpy).
- //
- // This can be used to perform device->host, host->device, and device->device
- // copies.
- virtual Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length) = 0;
-
- // Dispatches an execution request.
- // The request may execute overlapped with any other transfer operation or
- // dispatch made within the same barrier-defined sequence.
- //
- // The executable specified must be registered for use with the device driver
- // owning this queue. It must not be unregistered until all requests that use
- // it have completed.
- //
- // Fails if the queue does not support dispatch operations (as indicated by
- // can_dispatch).
- virtual Status Dispatch(const DispatchRequest& dispatch_request) = 0;
-
- protected:
- CommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories)
- : allocator_(allocator),
- mode_(mode),
- command_categories_(command_categories) {}
-
- private:
- Allocator* const allocator_;
- const CommandBufferModeBitfield mode_;
- const CommandCategoryBitfield command_categories_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_COMMAND_BUFFER_H_
diff --git a/iree/hal/command_buffer_validation.cc b/iree/hal/command_buffer_validation.cc
deleted file mode 100644
index c78580f..0000000
--- a/iree/hal/command_buffer_validation.cc
+++ /dev/null
@@ -1,403 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/command_buffer_validation.h"
-
-#include "iree/base/logging.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace hal {
-
-namespace {
-
-// Command buffer validation shim.
-// Wraps an existing command buffer to provide in-depth validation during
-// recording. This should be enabled whenever the command buffer is being driven
-// by unsafe code or when early and readable diagnostics are needed.
-class ValidatingCommandBuffer : public CommandBuffer {
- public:
- explicit ValidatingCommandBuffer(ref_ptr<CommandBuffer> impl);
- ~ValidatingCommandBuffer() override;
-
- CommandBuffer* impl() override { return impl_.get(); }
-
- bool is_recording() const override;
-
- Status Begin() override;
- Status End() override;
-
- Status ExecutionBarrier(
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) override;
- Status SignalEvent(Event* event,
- ExecutionStageBitfield source_stage_mask) override;
- Status ResetEvent(Event* event,
- ExecutionStageBitfield source_stage_mask) override;
- Status WaitEvents(absl::Span<Event*> events,
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) override;
- Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
- device_size_t length, const void* pattern,
- size_t pattern_length) override;
- Status DiscardBuffer(Buffer* buffer) override;
- Status UpdateBuffer(const void* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length) override;
- Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length) override;
- Status Dispatch(const DispatchRequest& dispatch_request) override;
-
- private:
- // Returns a failure if the queue does not support the given caps.
- Status ValidateCategories(CommandCategoryBitfield required_categories) const;
- // Returns a failure if the memory type the buffer was allocated from is not
- // compatible with the given type.
- Status ValidateCompatibleMemoryType(Buffer* buffer,
- MemoryTypeBitfield memory_type) const;
- // Returns a failure if the buffer memory type or usage disallows the given
- // access type.
- Status ValidateAccess(Buffer* buffer,
- MemoryAccessBitfield memory_access) const;
- // Returns a failure if the buffer was not allocated for the given usage.
- Status ValidateUsage(Buffer* buffer, BufferUsageBitfield usage) const;
- // Validates that the range provided is within the given buffer.
- Status ValidateRange(Buffer* buffer, device_size_t byte_offset,
- device_size_t byte_length) const;
-
- ref_ptr<CommandBuffer> impl_;
-};
-
-ValidatingCommandBuffer::ValidatingCommandBuffer(ref_ptr<CommandBuffer> impl)
- : CommandBuffer(impl->allocator(), impl->mode(),
- impl->command_categories()),
- impl_(std::move(impl)) {}
-
-ValidatingCommandBuffer::~ValidatingCommandBuffer() = default;
-
-bool ValidatingCommandBuffer::is_recording() const {
- return impl_->is_recording();
-}
-
-Status ValidatingCommandBuffer::Begin() {
- DVLOG(3) << "CommandBuffer::Begin()";
- if (impl_->is_recording()) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Command buffer is already recording";
- }
- return impl_->Begin();
-}
-
-Status ValidatingCommandBuffer::End() {
- DVLOG(3) << "CommandBuffer::End()";
- if (!impl_->is_recording()) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Command buffer is not recording";
- }
- return impl_->End();
-}
-
-Status ValidatingCommandBuffer::ValidateCategories(
- CommandCategoryBitfield required_categories) const {
- if (!AllBitsSet(command_categories(), required_categories)) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Operation requires categories "
- << CommandCategoryString(required_categories)
- << " but buffer only supports "
- << CommandCategoryString(command_categories());
- }
- return OkStatus();
-}
-
-Status ValidatingCommandBuffer::ValidateCompatibleMemoryType(
- Buffer* buffer, MemoryTypeBitfield memory_type) const {
- if ((buffer->memory_type() & memory_type) != memory_type) {
- // Missing one or more bits.
- return PermissionDeniedErrorBuilder(IREE_LOC)
- << "Buffer memory type is not compatible with the requested "
- "operation; buffer has "
- << MemoryTypeString(buffer->memory_type()) << ", operation requires "
- << MemoryTypeString(memory_type);
- }
- return OkStatus();
-}
-
-Status ValidatingCommandBuffer::ValidateAccess(
- Buffer* buffer, MemoryAccessBitfield memory_access) const {
- if ((buffer->allowed_access() & memory_access) != memory_access) {
- // Bits must match exactly.
- return PermissionDeniedErrorBuilder(IREE_LOC)
- << "The buffer does not support the requested access type; buffer "
- "allows "
- << MemoryAccessString(buffer->allowed_access())
- << ", operation requires " << MemoryAccessString(memory_access);
- }
- return OkStatus();
-}
-
-// Returns a failure if the buffer was not allocated for the given usage.
-Status ValidatingCommandBuffer::ValidateUsage(Buffer* buffer,
- BufferUsageBitfield usage) const {
- if (!allocator()->CanUseBuffer(buffer, usage)) {
- // Buffer cannot be used on the queue for the given usage.
- return PermissionDeniedErrorBuilder(IREE_LOC)
- << "Requested usage of " << buffer->DebugString()
- << " is not supported for the buffer on this queue; "
- "buffer allows "
- << BufferUsageString(buffer->usage()) << ", queue requires "
- << BufferUsageString(usage);
- }
-
- if ((buffer->usage() & usage) != usage) {
- // Missing one or more bits.
- return PermissionDeniedErrorBuilder(IREE_LOC)
- << "Requested usage was not specified when the buffer was "
- "allocated; buffer allows "
- << BufferUsageString(buffer->usage()) << ", operation requires "
- << BufferUsageString(usage);
- }
-
- return OkStatus();
-}
-
-// Validates that the range provided is within the given buffer.
-Status ValidatingCommandBuffer::ValidateRange(Buffer* buffer,
- device_size_t byte_offset,
- device_size_t byte_length) const {
- // Check if the start of the range runs off the end of the buffer.
- if (byte_offset > buffer->byte_length()) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Attempted to access an address off the end of the valid buffer "
- "range (offset="
- << byte_offset << ", length=" << byte_length
- << ", buffer byte_length=" << buffer->byte_length() << ")";
- }
-
- if (byte_length == 0) {
- // Fine to have a zero length.
- return OkStatus();
- }
-
- // Check if the end runs over the allocation.
- device_size_t end = byte_offset + byte_length;
- if (end > buffer->byte_length()) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Attempted to access an address outside of the valid buffer "
- "range (offset="
- << byte_offset << ", length=" << byte_length
- << ", end(inc)=" << (end - 1)
- << ", buffer byte_length=" << buffer->byte_length() << ")";
- }
-
- return OkStatus();
-}
-
-Status ValidatingCommandBuffer::ExecutionBarrier(
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) {
- DVLOG(3) << "CommandBuffer::ExecutionBarrier(...)";
-
- // TODO(benvanik): additional synchronization validation.
- RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer |
- CommandCategory::kDispatch));
-
- return impl_->ExecutionBarrier(source_stage_mask, target_stage_mask,
- memory_barriers, buffer_barriers);
-}
-
-Status ValidatingCommandBuffer::SignalEvent(
- Event* event, ExecutionStageBitfield source_stage_mask) {
- DVLOG(3) << "CommandBuffer::SignalEvent(...)";
-
- // TODO(benvanik): additional synchronization validation.
- RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
-
- return impl_->SignalEvent(event, source_stage_mask);
-}
-
-Status ValidatingCommandBuffer::ResetEvent(
- Event* event, ExecutionStageBitfield source_stage_mask) {
- DVLOG(3) << "CommandBuffer::ResetEvent(...)";
-
- // TODO(benvanik): additional synchronization validation.
- RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
-
- return impl_->ResetEvent(event, source_stage_mask);
-}
-
-Status ValidatingCommandBuffer::WaitEvents(
- absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) {
- DVLOG(3) << "CommandBuffer::WaitEvents(...)";
-
- // TODO(benvanik): additional synchronization validation.
- RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
-
- return impl_->WaitEvents(events, source_stage_mask, target_stage_mask,
- memory_barriers, buffer_barriers);
-}
-
-Status ValidatingCommandBuffer::FillBuffer(Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length,
- const void* pattern,
- size_t pattern_length) {
- DVLOG(3) << "CommandBuffer::FillBuffer(" << target_buffer->DebugString()
- << ", " << target_offset << ", " << length << ", ??, "
- << pattern_length << ")";
-
- RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
- RETURN_IF_ERROR(
- ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible));
- RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite));
- RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer));
- RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length));
-
- // Ensure the value length is supported.
- if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Fill value length is not one of the supported values "
- "(pattern_length="
- << pattern_length << ")";
- }
-
- // Ensure the offset and length have an alignment matching the value length.
- if ((target_offset % pattern_length) != 0 || (length % pattern_length) != 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Fill offset and/or length do not match the natural alignment of "
- "the fill value (target_offset="
- << target_offset << ", length=" << length
- << ", pattern_length=" << pattern_length << ")";
- }
-
- return impl_->FillBuffer(target_buffer, target_offset, length, pattern,
- pattern_length);
-}
-
-Status ValidatingCommandBuffer::DiscardBuffer(Buffer* buffer) {
- DVLOG(3) << "CommandBuffer::DiscardBuffer(" << buffer->DebugString() << ")";
-
- RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
- RETURN_IF_ERROR(
- ValidateCompatibleMemoryType(buffer, MemoryType::kDeviceVisible));
- RETURN_IF_ERROR(ValidateUsage(buffer, BufferUsage::kNone));
-
- return impl_->DiscardBuffer(buffer);
-}
-
-Status ValidatingCommandBuffer::UpdateBuffer(const void* source_buffer,
- device_size_t source_offset,
- Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length) {
- DVLOG(3) << "CommandBuffer::UpdateBuffer(" << source_buffer << ", "
- << source_offset << ", " << target_buffer->DebugString() << ", "
- << target_offset << ", " << length << ")";
-
- RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
- RETURN_IF_ERROR(
- ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible));
- RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite));
- RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer));
- RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length));
-
- return impl_->UpdateBuffer(source_buffer, source_offset, target_buffer,
- target_offset, length);
-}
-
-Status ValidatingCommandBuffer::CopyBuffer(Buffer* source_buffer,
- device_size_t source_offset,
- Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length) {
- DVLOG(3) << "CommandBuffer::CopyBuffer(" << source_buffer->DebugString()
- << ", " << source_offset << ", " << target_buffer->DebugString()
- << ", " << target_offset << ", " << length << ")";
-
- RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
-
- // At least source or destination must be device-visible to enable
- // host->device, device->host, and device->device.
- // TODO(b/117338171): host->host copies.
- if (!AnyBitSet(source_buffer->memory_type() & MemoryType::kDeviceVisible) &&
- !AnyBitSet(target_buffer->memory_type() & MemoryType::kDeviceVisible)) {
- return PermissionDeniedErrorBuilder(IREE_LOC)
- << "At least one buffer must be device-visible for a copy; "
- "source_buffer="
- << MemoryTypeString(source_buffer->memory_type())
- << ", target_buffer="
- << MemoryTypeString(target_buffer->memory_type());
- }
-
- RETURN_IF_ERROR(ValidateAccess(source_buffer, MemoryAccess::kRead));
- RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite));
- RETURN_IF_ERROR(ValidateUsage(source_buffer, BufferUsage::kTransfer));
- RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer));
- RETURN_IF_ERROR(ValidateRange(source_buffer, source_offset, length));
- RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length));
-
- // Check for overlap - just like memcpy we don't handle that.
- if (Buffer::TestOverlap(source_buffer, source_offset, length, target_buffer,
- target_offset,
- length) != Buffer::Overlap::kDisjoint) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Source and target ranges overlap within the same buffer";
- }
-
- return impl_->CopyBuffer(source_buffer, source_offset, target_buffer,
- target_offset, length);
-}
-
-Status ValidatingCommandBuffer::Dispatch(
- const DispatchRequest& dispatch_request) {
- DVLOG(3) << "CommandBuffer::Dispatch(?)";
-
- RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
-
- // Validate all buffers referenced have compatible memory types, access
- // rights, and usage.
- for (const auto& binding : dispatch_request.bindings) {
- RETURN_IF_ERROR(ValidateCompatibleMemoryType(binding.buffer,
- MemoryType::kDeviceVisible))
- << "input buffer: " << MemoryAccessString(binding.access) << " "
- << binding.buffer->DebugStringShort();
- RETURN_IF_ERROR(ValidateAccess(binding.buffer, binding.access));
- RETURN_IF_ERROR(ValidateUsage(binding.buffer, BufferUsage::kDispatch));
- // TODO(benvanik): validate it matches the executable expectations.
- // TODO(benvanik): validate buffer contains enough data for shape+size.
- }
-
- // TODO(benvanik): validate no aliasing?
-
- return impl_->Dispatch(dispatch_request);
-}
-
-} // namespace
-
-ref_ptr<CommandBuffer> WrapCommandBufferWithValidation(
- ref_ptr<CommandBuffer> impl) {
- return make_ref<ValidatingCommandBuffer>(std::move(impl));
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/command_buffer_validation.h b/iree/hal/command_buffer_validation.h
deleted file mode 100644
index 036f132..0000000
--- a/iree/hal/command_buffer_validation.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_COMMAND_BUFFER_VALIDATION_H_
-#define IREE_HAL_COMMAND_BUFFER_VALIDATION_H_
-
-#include "iree/hal/command_buffer.h"
-
-namespace iree {
-namespace hal {
-
-// Wraps an existing command buffer to provide in-depth validation during
-// recording. This should be enabled whenever the command buffer is being driven
-// by unsafe code or when early and readable diagnostics are needed.
-ref_ptr<CommandBuffer> WrapCommandBufferWithValidation(
- ref_ptr<CommandBuffer> impl);
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_COMMAND_BUFFER_VALIDATION_H_
diff --git a/iree/hal/command_queue.h b/iree/hal/command_queue.h
deleted file mode 100644
index 0e8eb5d..0000000
--- a/iree/hal/command_queue.h
+++ /dev/null
@@ -1,119 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_COMMAND_QUEUE_H_
-#define IREE_HAL_COMMAND_QUEUE_H_
-
-#include <cstdint>
-#include <string>
-
-#include "absl/time/clock.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "iree/base/bitfield.h"
-#include "iree/base/status.h"
-#include "iree/base/time.h"
-#include "iree/hal/command_buffer.h"
-#include "iree/hal/fence.h"
-#include "iree/hal/semaphore.h"
-
-namespace iree {
-namespace hal {
-
-// A batch of command buffers with synchronization information for submission.
-struct SubmissionBatch {
- // Semaphores that must be signaled prior to the execution of any command
- // buffer in this submission. For TimelineSemaphores the specified payload
- // must be reached or exceeded.
- absl::Span<const SemaphoreValue> wait_semaphores;
-
- // Command buffers that will execute in this batch.
- // The command buffers will begin execution in order but may complete out of
- // order.
- absl::Span<CommandBuffer* const> command_buffers;
-
- // Semaphores to signal after execution of all command buffers complete.
- // TimelineSemaphores will be set to the maximum of the specified payload or
- // their current payload.
- absl::Span<const SemaphoreValue> signal_semaphores;
-};
-
-// Asynchronous command execution queue.
-//
-// CommandQueues may capture device status at Fence barriers, including
-// information about device state such as thermal throttling. This information
-// is a snapshot of the state at the time the fence was signaled and not
-// necessarily live at the time of the application query.
-//
-// Command queues are thread-safe and submissions may occur from multiple
-// threads.
-class CommandQueue {
- public:
- virtual ~CommandQueue() = default;
-
- // Name of the queue used for logging purposes.
- // Try to keep at 4 characters total for prettier logging.
- const std::string& name() const { return name_; }
-
- // Capabilities of the command queue.
- CommandCategoryBitfield supported_categories() const {
- return supported_categories_;
- }
-
- // Whether this queue may be used for transfer commands.
- bool can_transfer() const {
- return AllBitsSet(supported_categories_, CommandCategory::kTransfer);
- }
-
- // Whether this queue may be used for dispatch commands.
- bool can_dispatch() const {
- return AllBitsSet(supported_categories_, CommandCategory::kDispatch);
- }
-
- // Submits one or more command batches for execution on the queue.
- // Dependencies between |batches| on BinarySemaphores must be sorted in order
- // such that all semaphores are signaled prior to any waits on them.
- // Dependencies between TimelineSemaphores may occur in any order.
- //
- // The provided |fence| will be signaled when all |batches| have retired.
- virtual Status Submit(absl::Span<const SubmissionBatch> batches,
- FenceValue fence) = 0;
- inline Status Submit(const SubmissionBatch& batch, FenceValue fence) {
- return Submit(absl::MakeConstSpan(&batch, 1), std::move(fence));
- }
-
- // Blocks until all outstanding requests have been completed.
- // This is equivalent to having waited on all outstanding fences.
- // Implicitly calls Flush to ensure delayed requests are scheduled.
- //
- // If the command queue has encountered an error during submission at any
- // point it will be returned here (repeatedly).
- virtual Status WaitIdle(absl::Time deadline) = 0;
- inline Status WaitIdle(absl::Duration timeout) {
- return WaitIdle(RelativeTimeoutToDeadline(timeout));
- }
- inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); }
-
- protected:
- CommandQueue(std::string name, CommandCategoryBitfield supported_categories)
- : name_(std::move(name)), supported_categories_(supported_categories) {}
-
- const std::string name_;
- const CommandCategoryBitfield supported_categories_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_COMMAND_QUEUE_H_
diff --git a/iree/hal/dawn/BUILD b/iree/hal/dawn/BUILD
deleted file mode 100644
index 2c6293a..0000000
--- a/iree/hal/dawn/BUILD
+++ /dev/null
@@ -1,72 +0,0 @@
-# HAL implementation using Dawn and SPIR-V executables.
-# https://dawn.googlesource.com/dawn
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "dawn_device",
- srcs = ["dawn_device.cc"],
- hdrs = ["dawn_device.h"],
- deps = [
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_queue",
- "//iree/hal:device",
- "//iree/hal:executable_cache",
- "//iree/hal:fence",
- "//iree/hal/host:host_local_allocator",
- "//third_party/dawn:dawn_headers",
- "//third_party/dawn:dawn_native",
- "//third_party/dawn:dawn_static_proc",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "dawn_driver",
- srcs = ["dawn_driver.cc"],
- hdrs = ["dawn_driver.h"],
- deps = [
- ":dawn_device",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- "//third_party/dawn:dawn_headers",
- "//third_party/dawn:dawn_native",
- "//third_party/dawn:dawn_static_proc",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- ],
-)
-
-# TODO(scotttodd): Use SwiftShader to test Vulkan backend
-cc_test(
- name = "dawn_driver_test",
- srcs = ["dawn_driver_test.cc"],
- deps = [
- ":dawn_driver",
- "//iree/base:status_matchers",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "dawn_driver_module",
- srcs = ["dawn_driver_module.cc"],
- deps = [
- ":dawn_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:driver_registry",
- "@com_google_absl//absl/flags:flag",
- ],
- alwayslink = 1,
-)
diff --git a/iree/hal/dawn/dawn_device.cc b/iree/hal/dawn/dawn_device.cc
deleted file mode 100644
index 7888a72..0000000
--- a/iree/hal/dawn/dawn_device.cc
+++ /dev/null
@@ -1,139 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/dawn/dawn_device.h"
-
-#include "absl/memory/memory.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/executable_cache.h"
-#include "iree/hal/fence.h"
-
-namespace iree {
-namespace hal {
-namespace dawn {
-
-namespace {
-
-// ExecutableCache implementation that compiles but does nothing.
-// This will be replaced with something functional soon.
-class NoopExecutableCache final : public ExecutableCache {
- public:
- explicit NoopExecutableCache() {}
- ~NoopExecutableCache() override = default;
-
- bool CanPrepareFormat(ExecutableFormat format) const override {
- return false;
- }
-
- StatusOr<ref_ptr<Executable>> PrepareExecutable(
- ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override {
- return UnimplementedErrorBuilder(IREE_LOC) << "PrepareExecutable NYI";
- }
-};
-
-} // namespace
-
-DawnDevice::DawnDevice(const DeviceInfo& device_info,
- ::dawn::Device backend_device)
- : Device(device_info), backend_device_(backend_device) {
- IREE_TRACE_SCOPE0("DawnDevice::ctor");
-
- // TODO(scotttodd): construct command queues, perform other initialization
-
- // Log some basic device info.
- std::string backend_type_str;
- auto* adapter =
- static_cast<dawn_native::Adapter*>(device_info.driver_handle());
- switch (adapter->GetBackendType()) {
- case dawn_native::BackendType::D3D12:
- backend_type_str = "D3D12";
- break;
- case dawn_native::BackendType::Metal:
- backend_type_str = "Metal";
- break;
- case dawn_native::BackendType::Null:
- backend_type_str = "Null";
- break;
- case dawn_native::BackendType::OpenGL:
- backend_type_str = "OpenGL";
- break;
- case dawn_native::BackendType::Vulkan:
- backend_type_str = "Vulkan";
- break;
- }
- LOG(INFO) << "Created DawnDevice '" << device_info.name() << "' ("
- << backend_type_str << ")";
-}
-
-DawnDevice::~DawnDevice() = default;
-
-std::shared_ptr<ExecutableCache> DawnDevice::CreateExecutableCache() {
- return std::make_shared<NoopExecutableCache>();
-}
-
-StatusOr<ref_ptr<CommandBuffer>> DawnDevice::CreateCommandBuffer(
- CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories) {
- return UnimplementedErrorBuilder(IREE_LOC) << "CreateCommandBuffer NYI";
-}
-
-StatusOr<ref_ptr<Event>> DawnDevice::CreateEvent() {
- return UnimplementedErrorBuilder(IREE_LOC) << "CreateEvent NYI";
-}
-
-StatusOr<ref_ptr<BinarySemaphore>> DawnDevice::CreateBinarySemaphore(
- bool initial_value) {
- IREE_TRACE_SCOPE0("DawnDevice::CreateBinarySemaphore");
-
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Binary semaphores not yet implemented";
-}
-
-StatusOr<ref_ptr<TimelineSemaphore>> DawnDevice::CreateTimelineSemaphore(
- uint64_t initial_value) {
- IREE_TRACE_SCOPE0("DawnDevice::CreateTimelineSemaphore");
-
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Timeline semaphores not yet implemented";
-}
-
-StatusOr<ref_ptr<Fence>> DawnDevice::CreateFence(uint64_t initial_value) {
- IREE_TRACE_SCOPE0("DawnDevice::CreateFence");
-
- return UnimplementedErrorBuilder(IREE_LOC) << "CreateFence NYI";
-}
-
-Status DawnDevice::WaitAllFences(absl::Span<const FenceValue> fences,
- absl::Time deadline) {
- IREE_TRACE_SCOPE0("DawnDevice::WaitAllFences");
-
- return UnimplementedErrorBuilder(IREE_LOC) << "WaitAllFences NYI";
-}
-
-StatusOr<int> DawnDevice::WaitAnyFence(absl::Span<const FenceValue> fences,
- absl::Time deadline) {
- IREE_TRACE_SCOPE0("DawnDevice::WaitAnyFence");
-
- return UnimplementedErrorBuilder(IREE_LOC) << "WaitAnyFence NYI";
-}
-
-Status DawnDevice::WaitIdle(absl::Time deadline) {
- return UnimplementedErrorBuilder(IREE_LOC) << "WaitIdle";
-}
-
-} // namespace dawn
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/dawn/dawn_device.h b/iree/hal/dawn/dawn_device.h
deleted file mode 100644
index 117c352..0000000
--- a/iree/hal/dawn/dawn_device.h
+++ /dev/null
@@ -1,78 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_DAWN_DAWN_DEVICE_H_
-#define IREE_HAL_DAWN_DAWN_DEVICE_H_
-
-#include "absl/container/inlined_vector.h"
-#include "absl/types/span.h"
-#include "iree/base/memory.h"
-#include "iree/hal/device.h"
-#include "iree/hal/host/host_local_allocator.h"
-#include "third_party/dawn/src/include/dawn/dawncpp.h"
-#include "third_party/dawn/src/include/dawn_native/DawnNative.h"
-
-namespace iree {
-namespace hal {
-namespace dawn {
-
-class DawnDevice final : public Device {
- public:
- explicit DawnDevice(const DeviceInfo& device_info,
- ::dawn::Device backend_device);
- ~DawnDevice() override;
-
- Allocator* allocator() const override { return &allocator_; }
-
- absl::Span<CommandQueue*> dispatch_queues() const override {
- return RawPtrSpan(absl::MakeSpan(command_queues_));
- }
-
- absl::Span<CommandQueue*> transfer_queues() const override {
- return RawPtrSpan(absl::MakeSpan(command_queues_));
- }
-
- std::shared_ptr<ExecutableCache> CreateExecutableCache() override;
-
- StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer(
- CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories) override;
-
- StatusOr<ref_ptr<Event>> CreateEvent() override;
-
- StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore(
- bool initial_value) override;
- StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore(
- uint64_t initial_value) override;
-
- StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override;
- Status WaitAllFences(absl::Span<const FenceValue> fences,
- absl::Time deadline) override;
- StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences,
- absl::Time deadline) override;
-
- Status WaitIdle(absl::Time deadline) override;
-
- private:
- mutable HostLocalAllocator allocator_;
- mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 1> command_queues_;
-
- ::dawn::Device backend_device_;
-};
-
-} // namespace dawn
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_DAWN_DAWN_DEVICE_H_
diff --git a/iree/hal/dawn/dawn_driver.cc b/iree/hal/dawn/dawn_driver.cc
deleted file mode 100644
index 9e9067a..0000000
--- a/iree/hal/dawn/dawn_driver.cc
+++ /dev/null
@@ -1,120 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/dawn/dawn_driver.h"
-
-#include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/dawn/dawn_device.h"
-#include "iree/hal/device_info.h"
-
-namespace iree {
-namespace hal {
-namespace dawn {
-
-namespace {
-
-// Populates device information from the given dawn_native::Adapter.
-StatusOr<DeviceInfo> PopulateDeviceInfo(dawn_native::Adapter* adapter) {
- // TODO(scotttodd): Query these for each backend or implement?
- DeviceFeatureBitfield supported_features = DeviceFeature::kNone;
- // supported_features |= DeviceFeature::kDebugging;
- // supported_features |= DeviceFeature::kCoverage;
- // supported_features |= DeviceFeature::kProfiling;
-
- // TODO(scotttodd): more clever/sanitized device naming.
- std::string device_name = absl::StrCat("dawn-", adapter->GetPCIInfo().name);
-
- return DeviceInfo(device_name, supported_features,
- reinterpret_cast<void*>(adapter));
-}
-
-} // namespace
-
-DawnDriver::DawnDriver() : Driver("dawn") {
- dawn_instance_ = absl::make_unique<dawn_native::Instance>();
-}
-
-DawnDriver::~DawnDriver() = default;
-
-StatusOr<std::vector<DeviceInfo>> DawnDriver::EnumerateAvailableDevices() {
- IREE_TRACE_SCOPE0("DawnDriver::EnumerateAvailableDevices");
-
- if (dawn_backend_adapters_.empty()) {
- // Discover adapters (i.e. devices and their associated backend APIs).
- // Retain the list of adapters so pointers are valid for the lifetime of
- // this object.
- dawn_instance_->DiscoverDefaultAdapters();
- dawn_backend_adapters_ = dawn_instance_->GetAdapters();
- } else {
- // Assume that the list of adapters does not change. This is not guaranteed
- // to be true, but we also don't want to invalidate pointers by requesting
- // a new list each time. If the list of available devices would change,
- // tearing down and creating a new DawnDriver may be your best option.
- }
-
- // Convert to our HAL structure.
- std::vector<DeviceInfo> device_infos;
- device_infos.reserve(dawn_backend_adapters_.size());
- for (auto& adapter : dawn_backend_adapters_) {
- // TODO(scotttodd): if we fail should we just ignore the device in the list?
- ASSIGN_OR_RETURN(auto device_info, PopulateDeviceInfo(&adapter));
- device_infos.push_back(std::move(device_info));
- }
- return device_infos;
-}
-
-StatusOr<std::shared_ptr<Device>> DawnDriver::CreateDefaultDevice() {
- IREE_TRACE_SCOPE0("DawnDriver::CreateDefaultDevice");
-
- // Query available devices.
- ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices());
- if (available_devices.empty()) {
- return NotFoundErrorBuilder(IREE_LOC) << "No devices are available";
- }
-
- // Create the first non-null device, if any.
- for (const auto& device : available_devices) {
- auto* adapter = static_cast<dawn_native::Adapter*>(device.driver_handle());
- if (adapter->GetBackendType() != dawn_native::BackendType::Null) {
- return CreateDevice(device);
- }
- }
-
- // Otherwise create the first null device.
- return CreateDevice(available_devices.front());
-}
-
-StatusOr<std::shared_ptr<Device>> DawnDriver::CreateDevice(
- const DeviceInfo& device_info) {
- IREE_TRACE_SCOPE0("DawnDriver::CreateDevice");
-
- auto* adapter =
- static_cast<dawn_native::Adapter*>(device_info.driver_handle());
- ::DawnDevice c_backend_device = adapter->CreateDevice();
- if (!c_backend_device) {
- return InternalErrorBuilder(IREE_LOC) << "Failed to create a Dawn device";
- }
- DawnProcTable backend_procs = dawn_native::GetProcs();
- dawnSetProcs(&backend_procs);
- ::dawn::Device backend_device = ::dawn::Device::Acquire(c_backend_device);
-
- return std::make_shared<DawnDevice>(device_info, backend_device);
-}
-
-} // namespace dawn
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/dawn/dawn_driver.h b/iree/hal/dawn/dawn_driver.h
deleted file mode 100644
index 4c79c25..0000000
--- a/iree/hal/dawn/dawn_driver.h
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_DAWN_DAWN_DRIVER_H_
-#define IREE_HAL_DAWN_DAWN_DRIVER_H_
-
-#include <memory>
-#include <vector>
-
-#include "iree/hal/driver.h"
-#include "third_party/dawn/src/include/dawn/dawncpp.h"
-#include "third_party/dawn/src/include/dawn_native/DawnNative.h"
-
-namespace iree {
-namespace hal {
-namespace dawn {
-
-class DawnDriver final : public Driver {
- public:
- DawnDriver();
- ~DawnDriver() override;
-
- StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override;
-
- StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override;
-
- StatusOr<std::shared_ptr<Device>> CreateDevice(
- const DeviceInfo& device_info) override;
-
- private:
- std::unique_ptr<dawn_native::Instance> dawn_instance_;
- std::vector<dawn_native::Adapter> dawn_backend_adapters_;
-};
-
-} // namespace dawn
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_DAWN_DAWN_DRIVER_H_
diff --git a/iree/hal/dawn/dawn_driver_module.cc b/iree/hal/dawn/dawn_driver_module.cc
deleted file mode 100644
index 2fce4bf..0000000
--- a/iree/hal/dawn/dawn_driver_module.cc
+++ /dev/null
@@ -1,41 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <memory>
-
-#include "iree/base/init.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/dawn/dawn_driver.h"
-#include "iree/hal/driver_registry.h"
-
-namespace iree {
-namespace hal {
-namespace dawn {
-namespace {
-
-StatusOr<std::shared_ptr<Driver>> CreateDawnDriver() {
- return std::make_shared<DawnDriver>();
-}
-
-} // namespace
-} // namespace dawn
-} // namespace hal
-} // namespace iree
-
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_dawn_driver, {
- QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "dawn", ::iree::hal::dawn::CreateDawnDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_dawn_driver);
diff --git a/iree/hal/dawn/dawn_driver_test.cc b/iree/hal/dawn/dawn_driver_test.cc
deleted file mode 100644
index 07a1561..0000000
--- a/iree/hal/dawn/dawn_driver_test.cc
+++ /dev/null
@@ -1,45 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/dawn/dawn_driver.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status_matchers.h"
-
-namespace iree {
-namespace hal {
-namespace dawn {
-namespace {
-
-TEST(DawnDriverTest, CreateDefaultDevice) {
- DawnDriver dawn_driver;
- ASSERT_OK_AND_ASSIGN(auto default_device, dawn_driver.CreateDefaultDevice());
-}
-
-TEST(DawnDriverTest, EnumerateDevicesAndCreate) {
- DawnDriver dawn_driver;
-
- ASSERT_OK_AND_ASSIGN(auto available_devices,
- dawn_driver.EnumerateAvailableDevices());
- ASSERT_GT(available_devices.size(), 0);
-
- ASSERT_OK_AND_ASSIGN(auto first_device,
- dawn_driver.CreateDevice(available_devices[0]));
-}
-
-} // namespace
-} // namespace dawn
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/deferred_buffer.cc b/iree/hal/deferred_buffer.cc
deleted file mode 100644
index 3414436..0000000
--- a/iree/hal/deferred_buffer.cc
+++ /dev/null
@@ -1,162 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/deferred_buffer.h"
-
-#include "iree/base/status.h"
-
-namespace iree {
-namespace hal {
-
-DeferredBuffer::DeferredBuffer(Allocator* allocator,
- MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield usage,
- device_size_t byte_length)
- : Buffer(allocator, memory_type, allowed_access, usage, 0, 0, byte_length) {
-}
-
-DeferredBuffer::~DeferredBuffer() = default;
-
-Status DeferredBuffer::GrowByteLength(device_size_t new_byte_length) {
- if (parent_buffer_) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Attempting to set min allocation size while bound to an "
- "allocation";
- }
- if (byte_length_ != kWholeBuffer && new_byte_length < byte_length_) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Attempting to shrink a buffer to " << new_byte_length
- << " when it has a minimum size of " << byte_length_;
- }
- byte_length_ = new_byte_length;
- return OkStatus();
-}
-
-Status DeferredBuffer::BindAllocation(ref_ptr<Buffer> allocated_buffer,
- device_size_t byte_offset,
- device_size_t byte_length) {
- // We can only be bound to allocations that are compatible with our specified
- // allocator and usage.
- if (!allocator_->CanUseBuffer(allocated_buffer.get(), usage())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Allocation is not compatible with the allocator specified for "
- "the deferred buffer";
- }
-
- // Calculate the range in the allocated_buffer that we are interested in.
- RETURN_IF_ERROR(Buffer::CalculateRange(0, allocated_buffer->byte_length(),
- byte_offset, byte_length, &byte_offset,
- &byte_length));
-
- // Verify that we have enough bytes for what we've promised.
- if (byte_length < byte_length_) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Allocation range is too small; min_allocation_size="
- << byte_length_ << " but the range of " << byte_offset << "-"
- << (byte_offset + byte_length - 1) << " (" << byte_length
- << "b) is too small";
- }
-
- allocated_buffer_ = allocated_buffer.get();
- parent_buffer_ = std::move(allocated_buffer);
- byte_offset_ = byte_offset;
- return OkStatus();
-}
-
-void DeferredBuffer::ResetAllocation() {
- allocated_buffer_ = this;
- parent_buffer_.reset();
- byte_offset_ = 0;
-}
-
-StatusOr<Buffer*> DeferredBuffer::ResolveAllocation() const {
- // If you get errors here then someone allocated the buffer with
- // MemoryType::kTransient and you are trying to use it outside of the time
- // it is actually allocated (such as during CommandBuffer evaluation). If
- // you need to use the buffer in non-transient ways then allocate the buffer
- // without the MemoryType::kTransient flag.
- if (!parent_buffer_) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Attempting to use a transient buffer prior to allocation: "
- << DebugString();
- }
- return parent_buffer_.get();
-}
-
-Status DeferredBuffer::FillImpl(device_size_t byte_offset,
- device_size_t byte_length, const void* pattern,
- device_size_t pattern_length) {
- ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
- return allocated_buffer->FillImpl(byte_offset, byte_length, pattern,
- pattern_length);
-}
-
-Status DeferredBuffer::ReadDataImpl(device_size_t source_offset, void* data,
- device_size_t data_length) {
- ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
- return allocated_buffer->ReadDataImpl(source_offset, data, data_length);
-}
-
-Status DeferredBuffer::WriteDataImpl(device_size_t target_offset,
- const void* data,
- device_size_t data_length) {
- ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
- return allocated_buffer->WriteDataImpl(target_offset, data, data_length);
-}
-
-Status DeferredBuffer::CopyDataImpl(device_size_t target_offset,
- Buffer* source_buffer,
- device_size_t source_offset,
- device_size_t data_length) {
- ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
- return allocated_buffer->CopyDataImpl(target_offset, source_buffer,
- source_offset, data_length);
-}
-
-Status DeferredBuffer::MapMemoryImpl(MappingMode mapping_mode,
- MemoryAccessBitfield memory_access,
- device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void** out_data) {
- ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
- return allocated_buffer->MapMemoryImpl(mapping_mode, memory_access,
- local_byte_offset, local_byte_length,
- out_data);
-}
-
-Status DeferredBuffer::UnmapMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void* data) {
- ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
- return allocated_buffer->UnmapMemoryImpl(local_byte_offset, local_byte_length,
- data);
-}
-
-Status DeferredBuffer::InvalidateMappedMemoryImpl(
- device_size_t local_byte_offset, device_size_t local_byte_length) {
- ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
- return allocated_buffer->InvalidateMappedMemoryImpl(local_byte_offset,
- local_byte_length);
-}
-
-Status DeferredBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) {
- ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation());
- return allocated_buffer->FlushMappedMemoryImpl(local_byte_offset,
- local_byte_length);
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/deferred_buffer.h b/iree/hal/deferred_buffer.h
deleted file mode 100644
index aeb19ad..0000000
--- a/iree/hal/deferred_buffer.h
+++ /dev/null
@@ -1,106 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_DEFERRED_BUFFER_H_
-#define IREE_HAL_DEFERRED_BUFFER_H_
-
-#include <cstddef>
-#include <memory>
-#include <utility>
-
-#include "iree/base/status.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/buffer.h"
-
-namespace iree {
-namespace hal {
-
-// A Buffer that can have its underlying allocation changed at runtime.
-// Unbound buffers act as a way to logically group dependent ranges of memory
-// without needing to have allocated that memory yet.
-//
-// Usage:
-// // Setup two spans referencing ranges of a deferred buffer.
-// auto deferred_buffer = std::make_shared<DeferredBuffer>(..., 200);
-// ASSIGN_OR_RETURN(auto span0, Buffer::Subspan(deferred_buffer, 0, 100));
-// ASSIGN_OR_RETURN(auto span1, Buffer::Subspan(deferred_buffer, 100, 100));
-//
-// // Attempting to access |deferred_buffer| or |span0| or |span1| will fail.
-// // ERROR: span0->Fill(false);
-//
-// // Now allocate a real buffer to serve as storage for the data.
-// ASSIGN_OR_RETURN(auto allocated_buffer, Buffer::Allocate(..., 200));
-// RETURN_IF_ERROR(deferred_buffer->BindAllocation(
-// allocated_buffer, 0, kWholeBuffer));
-//
-// // And now we can use the spans.
-// RETURN_IF_ERROR(span0->Fill(false));
-//
-// // If at some point we want to detach the buffer from the allocation (so we
-// // can use a different allocation, reuse the memory, etc).
-// deferred_buffer->ResetAllocation();
-//
-// Thread-compatible. Attempting to rebind the allocation while other threads
-// are using the buffer will lead to undefined behavior.
-class DeferredBuffer : public Buffer {
- public:
- DeferredBuffer(Allocator* allocator, MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
- device_size_t byte_length);
- ~DeferredBuffer() override;
-
- // Grows the minimum allocation size of the buffer to |new_byte_length|.
- // Attempting to bind an allocation less than this size will fail. This must
- // only be called when the buffer is not bound to an allocation.
- Status GrowByteLength(device_size_t new_byte_length);
-
- // Binds or rebinds the deferred buffer to an allocated buffer.
- Status BindAllocation(ref_ptr<Buffer> allocated_buffer,
- device_size_t byte_offset, device_size_t byte_length);
-
- // Resets the deferred buffer to have no binding.
- void ResetAllocation();
-
- private:
- // Resolves the allocated buffer that this subspan references into.
- // This will fail if the buffer has not yet been bound to an allocation or
- // the allocated buffer has not been committed.
- StatusOr<Buffer*> ResolveAllocation() const;
-
- Status FillImpl(device_size_t byte_offset, device_size_t byte_length,
- const void* pattern, device_size_t pattern_length) override;
- Status ReadDataImpl(device_size_t source_offset, void* data,
- device_size_t data_length) override;
- Status WriteDataImpl(device_size_t target_offset, const void* data,
- device_size_t data_length) override;
- Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer,
- device_size_t source_offset,
- device_size_t data_length) override;
- Status MapMemoryImpl(MappingMode mapping_mode,
- MemoryAccessBitfield memory_access,
- device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void** out_data) override;
- Status UnmapMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length, void* data) override;
- Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) override;
- Status FlushMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) override;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_DEFERRED_BUFFER_H_
diff --git a/iree/hal/deferred_buffer_test.cc b/iree/hal/deferred_buffer_test.cc
deleted file mode 100644
index dd8b968..0000000
--- a/iree/hal/deferred_buffer_test.cc
+++ /dev/null
@@ -1,174 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/deferred_buffer.h"
-
-#include "absl/memory/memory.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status_matchers.h"
-#include "iree/hal/heap_buffer.h"
-#include "iree/hal/testing/mock_allocator.h"
-
-namespace iree {
-namespace hal {
-namespace {
-
-using ::iree::hal::testing::MockAllocator;
-using ::testing::_;
-using ::testing::Return;
-
-// Tests properties of unbound buffers.
-TEST(DeferredBufferTest, Unbound) {
- MockAllocator allocator;
- auto deferred_buffer = absl::make_unique<DeferredBuffer>(
- &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
- 100);
- EXPECT_EQ(&allocator, deferred_buffer->allocator());
- EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer());
- EXPECT_EQ(0, deferred_buffer->allocation_size());
- EXPECT_EQ(0, deferred_buffer->byte_offset());
- EXPECT_EQ(100, deferred_buffer->byte_length());
-}
-
-// Tests that binding verifies allocators are compatible.
-TEST(DeferredBufferTest, AllocatorCheck) {
- MockAllocator allocator;
- auto deferred_buffer = absl::make_unique<DeferredBuffer>(
- &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
- 100);
- auto real_buffer =
- HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256);
- EXPECT_CALL(
- allocator,
- CanUseBufferLike(real_buffer->allocator(), real_buffer->memory_type(),
- real_buffer->usage(), BufferUsage::kAll))
- .WillOnce(Return(false));
- EXPECT_TRUE(IsInvalidArgument(
- deferred_buffer->BindAllocation(std::move(real_buffer), 0, 100)));
-}
-
-// Tests that binding verifies allocation sizes.
-TEST(DeferredBufferTest, SizeCheck) {
- MockAllocator allocator;
- auto deferred_buffer = absl::make_unique<DeferredBuffer>(
- &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
- 100);
- auto real_buffer =
- HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256);
- EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _))
- .WillRepeatedly(Return(true));
-
- EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 10, 100));
- EXPECT_EQ(256, deferred_buffer->allocation_size());
- EXPECT_EQ(10, deferred_buffer->byte_offset());
- EXPECT_EQ(100, deferred_buffer->byte_length());
- EXPECT_OK(
- deferred_buffer->BindAllocation(add_ref(real_buffer), 10, kWholeBuffer));
- EXPECT_EQ(256, deferred_buffer->allocation_size());
- EXPECT_EQ(10, deferred_buffer->byte_offset());
- EXPECT_EQ(100, deferred_buffer->byte_length());
-
- EXPECT_TRUE(IsOutOfRange(
- deferred_buffer->BindAllocation(add_ref(real_buffer), 200, 100)));
- EXPECT_TRUE(IsOutOfRange(deferred_buffer->BindAllocation(add_ref(real_buffer),
- 200, kWholeBuffer)));
- EXPECT_TRUE(IsOutOfRange(
- deferred_buffer->BindAllocation(add_ref(real_buffer), 10, 10)));
-}
-
-// Tests resizing buffers after they have been allocated.
-TEST(DeferredBufferTest, Resizing) {
- MockAllocator allocator;
- auto deferred_buffer = absl::make_unique<DeferredBuffer>(
- &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
- 100);
- auto real_buffer =
- HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256);
- EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _))
- .WillRepeatedly(Return(true));
-
- // Grow.
- EXPECT_EQ(100, deferred_buffer->byte_length());
- EXPECT_OK(deferred_buffer->GrowByteLength(150));
- EXPECT_EQ(150, deferred_buffer->byte_length());
-
- // Shrinking should fail.
- EXPECT_TRUE(IsInvalidArgument(deferred_buffer->GrowByteLength(5)));
-
- // Growing should fail if bound.
- EXPECT_OK(deferred_buffer->BindAllocation(std::move(real_buffer), 0, 150));
- EXPECT_TRUE(IsFailedPrecondition(deferred_buffer->GrowByteLength(100)));
-}
-
-// Tests binding and rebinding behavior.
-TEST(DeferredBufferTest, Rebinding) {
- MockAllocator allocator;
- auto deferred_buffer = absl::make_unique<DeferredBuffer>(
- &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
- 100);
- auto real_buffer =
- HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256);
- EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _))
- .WillRepeatedly(Return(true));
-
- // Safe to reset when not bound.
- deferred_buffer->ResetAllocation();
- EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer());
- EXPECT_EQ(0, deferred_buffer->allocation_size());
-
- EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 0, 100));
- EXPECT_EQ(real_buffer.get(), deferred_buffer->allocated_buffer());
- EXPECT_EQ(256, deferred_buffer->allocation_size());
- deferred_buffer->ResetAllocation();
- EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer());
- EXPECT_EQ(0, deferred_buffer->allocation_size());
- EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 0, 100));
- EXPECT_EQ(real_buffer.get(), deferred_buffer->allocated_buffer());
- EXPECT_EQ(256, deferred_buffer->allocation_size());
-}
-
-// Tests normal usage of bound buffers.
-TEST(DeferredBufferTest, BoundUsage) {
- MockAllocator allocator;
- auto deferred_buffer = absl::make_unique<DeferredBuffer>(
- &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
- 100);
- auto real_buffer =
- HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256);
- EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _))
- .WillRepeatedly(Return(true));
- EXPECT_OK(deferred_buffer->BindAllocation(std::move(real_buffer), 0, 100));
-
- EXPECT_FALSE(deferred_buffer->DebugString().empty());
- EXPECT_FALSE(deferred_buffer->DebugStringShort().empty());
-
- EXPECT_OK(deferred_buffer->Fill8(0, 10, 0xFF));
-}
-
-// Tests that unbound buffers fail to perform any buffer actions.
-TEST(DeferredBufferTest, UnboundUsage) {
- MockAllocator allocator;
- auto deferred_buffer = absl::make_unique<DeferredBuffer>(
- &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll,
- 100);
- EXPECT_FALSE(deferred_buffer->DebugString().empty());
- EXPECT_FALSE(deferred_buffer->DebugStringShort().empty());
-
- EXPECT_TRUE(IsFailedPrecondition(deferred_buffer->Fill8(0, 10, 0xFF)));
-}
-
-} // namespace
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/device.h b/iree/hal/device.h
deleted file mode 100644
index ff97a8a..0000000
--- a/iree/hal/device.h
+++ /dev/null
@@ -1,165 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_DEVICE_H_
-#define IREE_HAL_DEVICE_H_
-
-#include <memory>
-
-#include "absl/time/clock.h"
-#include "absl/time/time.h"
-#include "iree/base/status.h"
-#include "iree/base/time.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/buffer.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/device_info.h"
-#include "iree/hal/event.h"
-#include "iree/hal/executable_cache.h"
-#include "iree/hal/semaphore.h"
-
-namespace iree {
-namespace hal {
-
-class Device {
- public:
- virtual ~Device() = default;
-
- // Information about device capabilities.
- const DeviceInfo& info() const { return device_info_; }
-
- // TODO(benvanik): status (thermal, power mode, etc).
-
- // TODO(benvanik): throttling adjustment/power profile.
-
- // TODO(benvanik): control (suspend/resume, delay, etc).
-
- // An allocator providing buffers usable by the device.
- // This allocator may be shared with other devices in the same family.
- virtual Allocator* allocator() const = 0;
-
- // Returns a list of all general-purpose dispatch queues provided by the
- // device. In general these map 1:1 with independent execution contexts,
- // though some devices may hide that and expose only a single queue that is
- // scheduled internally.
- virtual absl::Span<CommandQueue*> dispatch_queues() const = 0;
-
- // Returns a list of transfer queues provided by the device. These queues may
- // perform transfer operations asynchronously with respect to execution on the
- // dispatch queues. For large sequences of transfer operations always prefer
- // using one of these queues.
- // Note that if the device does not support a dedicated transfer queue this
- // list may be the same as (or a subset of) dispatch_queues.
- virtual absl::Span<CommandQueue*> transfer_queues() const = 0;
-
- // TODO(b/137153339): accept initial cache data.
- // Creates a device-specific cache for executables prepared for dispatch.
- // The cache manages executable compilation, caching (on disk or in memory),
- // and lifetime. Users can decide to use one or more caches to allow differing
- // lifetimes (such as unloading modules), persistent on disk caching of only
- // specific hot executables, etc.
- //
- // Returns a thread-safe cache that must remain alive until all executables
- // using the cache are no longer in-flight.
- virtual std::shared_ptr<ExecutableCache> CreateExecutableCache() = 0;
-
- // Creates a command buffer for recording commands to submit to queues owned
- // by this device. The command buffer may come from a pool but will be reset
- // prior to being returned to the caller.
- virtual StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer(
- CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories) = 0;
-
- // Creates an event for recording into command buffers.
- // The returned event object is only usable with this device and events must
- // only be used to synchronize within the same queue.
- virtual StatusOr<ref_ptr<Event>> CreateEvent() = 0;
-
- // Creates a binary semaphore that can be used with command queues owned by
- // this device. To use the semaphores with other devices or instances they
- // must first be exported.
- virtual StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore(
- bool initial_value) = 0;
-
- // Creates a timeline semaphore that can be used with command queues owned by
- // this device. To use the semaphores with other devices or instances they
- // must first be exported.
- virtual StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore(
- uint64_t initial_value) = 0;
-
- // Creates a fence that can be used with command queues owned by this device.
- // To use the fences with other devices or instances they must first be
- // exported.
- virtual StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) = 0;
-
- // TODO(benvanik): import/export semaphore utilities.
- // TODO(benvanik): import/export fence utilities.
- // TODO(benvanik): fences to wait handles.
-
- // Blocks the caller until all passed |fences| reach or exceed the specified
- // payload values or the |deadline| elapses. All |fences| must be created from
- // this device (or be imported into it).
- //
- // Returns success if the wait is successful and all fences have been
- // signaled.
- //
- // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all fences
- // having been signaled. Note that a subset of the |fences| may have been
- // signaled and each can be queried to see which ones.
- virtual Status WaitAllFences(absl::Span<const FenceValue> fences,
- absl::Time deadline) = 0;
- inline Status WaitAllFences(absl::Span<const FenceValue> fences,
- absl::Duration timeout) {
- return WaitAllFences(fences, RelativeTimeoutToDeadline(timeout));
- }
-
- // Blocks the caller until at least one of the |fences| reaches or exceeds the
- // specified payload value or the |deadline| elapses. All |fences| must be
- // created from this device (or be imported into it).
- //
- // Returns an arbitrary index into |fences| of a fence that was signaled. Note
- // that more than one fence may have been signaled and all of the other
- // |fences| should be queried or waited on again until waits for them
- // succeed.
- //
- // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any fences
- // having been signaled.
- virtual StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences,
- absl::Time deadline) = 0;
- inline StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences,
- absl::Duration timeout) {
- return WaitAnyFence(fences, RelativeTimeoutToDeadline(timeout));
- }
-
- // Blocks until all outstanding requests on all queues have been
- // completed. This is equivalent to having waited on all outstanding
- // fences.
- virtual Status WaitIdle(absl::Time deadline) = 0;
- inline Status WaitIdle(absl::Duration timeout) {
- return WaitIdle(RelativeTimeoutToDeadline(timeout));
- }
- inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); }
-
- protected:
- explicit Device(DeviceInfo device_info)
- : device_info_(std::move(device_info)) {}
-
- private:
- const DeviceInfo device_info_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_DEVICE_H_
diff --git a/iree/hal/device_info.h b/iree/hal/device_info.h
deleted file mode 100644
index 7a6c6d3..0000000
--- a/iree/hal/device_info.h
+++ /dev/null
@@ -1,90 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_DEVICE_INFO_H_
-#define IREE_HAL_DEVICE_INFO_H_
-
-#include <cstdint>
-#include <string>
-#include <utility>
-
-#include "iree/base/bitfield.h"
-
-namespace iree {
-namespace hal {
-
-// Describes features supported by the device.
-// These flags indicate the availability of features that may be enabled at the
-// request of the calling application. Note that certain features may disable
-// runtime optimizations or require compilation flags to ensure the required
-// metadata is present in executables.
-enum class DeviceFeature : uint32_t {
- kNone = 0,
-
- // Device supports executable debugging.
- // When present executables *may* be compiled with
- // ExecutableCachingMode::kEnableDebugging and will have usable debugging
- // related methods. Note that if the input executables do not have embedded
- // debugging information they still may not be able to perform disassembly or
- // fine-grained breakpoint insertion.
- kDebugging = 1 << 0,
-
- // Device supports executable coverage information.
- // When present executables *may* be compiled with
- // ExecutableCachingMode::kEnableCoverage and will produce coverage buffers
- // during dispatch. Note that input executables must have partial embedded
- // debug information to allow mapping back to source offsets.
- kCoverage = 1 << 1,
-
- // Device supports executable and command queue profiling.
- // When present executables *may* be compiled with
- // ExecutableCachingMode::kEnableProfiling and will produce profiling buffers
- // during dispatch. Note that input executables must have partial embedded
- // debug information to allow mapping back to source offsets.
- kProfiling = 1 << 2,
-};
-IREE_BITFIELD(DeviceFeature);
-using DeviceFeatureBitfield = DeviceFeature;
-
-// TODO(benvanik): device info (caps, physical mappings, etc).
-class DeviceInfo {
- public:
- DeviceInfo(std::string name, DeviceFeatureBitfield supported_features,
- void* driver_handle = nullptr)
- : name_(std::move(name)),
- supported_features_(supported_features),
- driver_handle_(driver_handle) {}
-
- const std::string& name() const { return name_; }
-
- // Features supported by the device.
- DeviceFeatureBitfield supported_features() const {
- return supported_features_;
- }
-
- // Opaque handle used by drivers to correlate this device with their internal
- // listing. This handle will not be valid across driver instances or outside
- // of the current process.
- void* driver_handle() const { return driver_handle_; }
-
- private:
- const std::string name_;
- const DeviceFeatureBitfield supported_features_;
- void* driver_handle_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_DEVICE_INFO_H_
diff --git a/iree/hal/device_manager.cc b/iree/hal/device_manager.cc
deleted file mode 100644
index 6f8d9f4..0000000
--- a/iree/hal/device_manager.cc
+++ /dev/null
@@ -1,201 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/device_manager.h"
-
-#include <algorithm>
-
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/heap_buffer.h"
-
-namespace iree {
-namespace hal {
-
-DeviceManager::DeviceManager() = default;
-
-DeviceManager::~DeviceManager() {
- IREE_TRACE_SCOPE0("DeviceManager::dtor");
- WaitIdle().IgnoreError();
-}
-
-Status DeviceManager::RegisterDevice(std::shared_ptr<Device> device) {
- IREE_TRACE_SCOPE0("DeviceManager::RegisterDevice");
- absl::MutexLock lock(&device_mutex_);
- if (std::find(devices_.begin(), devices_.end(), device) != devices_.end()) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Device already registered";
- }
- devices_.push_back(std::move(device));
- return OkStatus();
-}
-
-Status DeviceManager::UnregisterDevice(Device* device) {
- IREE_TRACE_SCOPE0("DeviceManager::UnregisterDevice");
- absl::MutexLock lock(&device_mutex_);
- auto it = std::find_if(devices_.begin(), devices_.end(),
- [device](const std::shared_ptr<Device>& other_device) {
- return device == other_device.get();
- });
- if (it == devices_.end()) {
- return NotFoundErrorBuilder(IREE_LOC) << "Device not registered";
- }
- devices_.erase(it);
- return OkStatus();
-}
-
-StatusOr<DevicePlacement> DeviceManager::ResolvePlacement(
- const PlacementSpec& placement_spec) const {
- IREE_TRACE_SCOPE0("DeviceManager::ResolvePlacement");
- absl::MutexLock lock(&device_mutex_);
- if (devices_.empty()) {
- return NotFoundErrorBuilder(IREE_LOC) << "No devices registered";
- }
-
- // TODO(benvanik): multiple devices and placement.
- QCHECK_EQ(devices_.size(), 1)
- << "Multiple devices not yet supported (need placement)";
- DevicePlacement device_placement;
- device_placement.device = devices_.front();
-
- return device_placement;
-}
-
-StatusOr<Allocator*> DeviceManager::FindCompatibleAllocator(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- absl::Span<const DevicePlacement> device_placements) const {
- IREE_TRACE_SCOPE0("DeviceManager::FindCompatibleAllocator");
- if (device_placements.empty()) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No placements provided";
- }
-
- // Find the first allocator. As we only return an allocator if all placements
- // are compatible we'll compare allocator[0] against allocator[1,N].
- Allocator* some_allocator = nullptr;
- for (const auto& device_placement : device_placements) {
- auto* allocator = device_placement.device->allocator();
- if (!some_allocator) {
- some_allocator = allocator;
- continue;
- }
- // NOTE: as there can be asymmetry between usage restrictions (A can use B
- // but B cannot use A) we have to compare both directions.
- if (!some_allocator->CanUseBufferLike(allocator, memory_type, buffer_usage,
- buffer_usage) ||
- !allocator->CanUseBufferLike(some_allocator, memory_type, buffer_usage,
- buffer_usage)) {
- // Allocators are not compatible.
- return NotFoundErrorBuilder(IREE_LOC)
- << "No single allocator found that is compatible with all "
- "placements";
- }
- }
- return some_allocator;
-}
-
-StatusOr<ref_ptr<Buffer>> DeviceManager::TryAllocateDeviceVisibleBuffer(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- device_size_t allocation_size,
- absl::Span<const DevicePlacement> device_placements) {
- IREE_TRACE_SCOPE("DeviceManager::TryAllocateDeviceVisibleBuffer:size", int)
- (static_cast<int>(allocation_size));
- if (!AnyBitSet(memory_type & MemoryType::kHostLocal)) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Host-local buffers require the kHostLocal bit: "
- << MemoryTypeString(memory_type);
- }
-
- // Strip kDeviceVisible as we conditionally add it based on support.
- memory_type &= ~MemoryType::kDeviceVisible;
-
- // Find an allocator that works for device-visible buffers.
- // If this fails we'll fall back to allocation a non-device-visible buffer.
- auto allocator_or =
- FindCompatibleAllocator(memory_type | MemoryType::kDeviceVisible,
- buffer_usage, device_placements);
- if (allocator_or.ok()) {
- return allocator_or.ValueOrDie()->Allocate(
- memory_type | MemoryType::kDeviceVisible, buffer_usage,
- allocation_size);
- }
-
- // Fallback to allocating a host-local buffer.
- return HeapBuffer::Allocate(memory_type, buffer_usage, allocation_size);
-}
-
-StatusOr<ref_ptr<Buffer>> DeviceManager::AllocateDeviceVisibleBuffer(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- device_size_t allocation_size,
- absl::Span<const DevicePlacement> device_placements) {
- IREE_TRACE_SCOPE("DeviceManager::AllocateDeviceVisibleBuffer:size", int)
- (static_cast<int>(allocation_size));
- if (!AnyBitSet(memory_type & MemoryType::kHostLocal)) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Host-local buffers require the kHostLocal bit: "
- << MemoryTypeString(memory_type);
- }
-
- // Always use device-visible.
- memory_type |= MemoryType::kDeviceVisible;
-
- // Find an allocator that works for device-visible buffers.
- ASSIGN_OR_RETURN(
- auto* allocator,
- FindCompatibleAllocator(memory_type, buffer_usage, device_placements));
- return allocator->Allocate(memory_type, buffer_usage, allocation_size);
-}
-
-StatusOr<ref_ptr<Buffer>> DeviceManager::AllocateDeviceLocalBuffer(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- device_size_t allocation_size,
- absl::Span<const DevicePlacement> device_placements) {
- IREE_TRACE_SCOPE("DeviceManager::AllocateDeviceLocalBuffer:size", int)
- (static_cast<int>(allocation_size));
- if (!AnyBitSet(memory_type & MemoryType::kDeviceLocal)) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Device-local buffers require the kDeviceLocal bit: "
- << MemoryTypeString(memory_type);
- }
-
- // Find an allocator that works for device-local buffers.
- ASSIGN_OR_RETURN(
- auto* allocator,
- FindCompatibleAllocator(memory_type, buffer_usage, device_placements));
- return allocator->Allocate(memory_type, buffer_usage, allocation_size);
-}
-
-Status DeviceManager::Submit(Device* device, CommandQueue* command_queue,
- absl::Span<const SubmissionBatch> batches,
- absl::Time deadline, FenceValue fence) {
- IREE_TRACE_SCOPE0("DeviceManager::Submit");
- return command_queue->Submit(batches, fence);
-}
-
-Status DeviceManager::Flush() {
- IREE_TRACE_SCOPE0("DeviceManager::Flush");
- return OkStatus();
-}
-
-Status DeviceManager::WaitIdle(absl::Time deadline) {
- IREE_TRACE_SCOPE0("DeviceManager::WaitIdle");
- absl::MutexLock lock(&device_mutex_);
- for (const auto& device : devices_) {
- RETURN_IF_ERROR(device->WaitIdle(deadline));
- }
- return OkStatus();
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/device_manager.h b/iree/hal/device_manager.h
deleted file mode 100644
index b6a8783..0000000
--- a/iree/hal/device_manager.h
+++ /dev/null
@@ -1,209 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_DEVICE_MANAGER_H_
-#define IREE_HAL_DEVICE_MANAGER_H_
-
-#include <vector>
-
-#include "absl/synchronization/mutex.h"
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/base/time.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/buffer.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/device.h"
-#include "iree/hal/device_placement.h"
-#include "iree/hal/executable_format.h"
-#include "iree/hal/fence.h"
-
-namespace iree {
-namespace hal {
-
-// Specifies how devices should be resolved to DevicePlacements.
-// Most fields are optional and when not included will be ignored.
-struct PlacementSpec {
- // TODO(benvanik): other requirements (features/caps, power, etc).
-
- // A list of executable formats that the placement should support.
- // If more than one format is provided any device satisfying at least one
- // will be considered for placement. The formats can be sorted in descending
- // priority order to prefer the first available format in the case of ties.
- absl::Span<const ExecutableFormat> available_formats;
-};
-
-// Manages device lifetime and placement resolution.
-// Optionally the DeviceManager may be used for automatic device selection for
-// allocations or batched submissions, however this is not required if specific
-// devices and scheduling behavior are known to the caller.
-//
-// Thread-safe. Note that callers must ensure that unregistered devices are kept
-// alive for as long as any commands are in-flight that may be using them.
-class DeviceManager final {
- public:
- DeviceManager();
- ~DeviceManager();
-
- // Registers a device with the manager.
- // The device will be used to resolve placements. Any placements resolved
- // prior to the addition of the device will need to be refreshed by the caller
- // if they want to make use of the new device.
- Status RegisterDevice(std::shared_ptr<Device> device);
-
- // Unregisters a device with the manager.
- // Placements that resolved to the device prior to unregistering will remain
- // valid for that device. Callers will need to refresh the placements to
- // ensure the device stops being used.
- Status UnregisterDevice(Device* device);
-
- // TODO(benvanik): dispatch info + requirements + etc -> DevicePlacement.
-
- // Resolves a placement spec to a device placement based on the registered
- // devices.
- // If the placement is not fully specified the device and queue may be chosen
- // at random. See PlacementSpec for more information about resolution and
- // ranking.
- StatusOr<DevicePlacement> ResolvePlacement(
- const PlacementSpec& placement_spec) const;
-
- // Finds an allocator that can allocate buffers of the given |memory_type| and
- // |buffer_usage| such that the buffers can be used interchangebly.
- // Fails if there is no Allocator that can satisfy that requirement.
- StatusOr<Allocator*> FindCompatibleAllocator(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- absl::Span<const DevicePlacement> device_placements) const;
-
- // Tries to allocate a host-local buffer that _may_ be optimal for use with
- // the given |device_placements| and _may_ be device-visible. The buffer can
- // be used for staging uploads to device-local buffers and is useful for times
- // when the buffer will be used more on the host than the device. If a buffer
- // never needs to be used with a device prefer instead
- // Allocator::host_local()::Allocate.
- //
- // Returns a buffer even if it's not possible to satisfy the requested
- // |buffer_usage| for the |device_placements| at the cost of a run-time
- // performance hit.
- StatusOr<ref_ptr<Buffer>> TryAllocateDeviceVisibleBuffer(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- device_size_t allocation_size,
- absl::Span<const DevicePlacement> device_placements);
- StatusOr<ref_ptr<Buffer>> TryAllocateDeviceVisibleBuffer(
- BufferUsageBitfield buffer_usage, device_size_t allocation_size,
- absl::Span<const DevicePlacement> device_placements) {
- return TryAllocateDeviceVisibleBuffer(
- MemoryType::kHostLocal | MemoryType::kDeviceVisible, buffer_usage,
- allocation_size, device_placements);
- }
-
- // Allocates a host-local buffer that is optimal for use on the host but is
- // usable by the given |device_placements| (at a possible performance
- // penalty). The buffer can be used for staging uploads to device-local
- // buffers and is useful for times when the buffer will be used more on the
- // host than the device. If a buffer never needs to be used with a device
- // prefer instead HeapBuffer::Allocate.
- //
- // Fails if it is not possible to allocate and satisfy all |device_placements|
- // for the requested |buffer_usage|.
- StatusOr<ref_ptr<Buffer>> AllocateDeviceVisibleBuffer(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- device_size_t allocation_size,
- absl::Span<const DevicePlacement> device_placements);
- StatusOr<ref_ptr<Buffer>> AllocateDeviceVisibleBuffer(
- BufferUsageBitfield buffer_usage, device_size_t allocation_size,
- absl::Span<const DevicePlacement> device_placements) {
- return AllocateDeviceVisibleBuffer(
- MemoryType::kHostLocal | MemoryType::kDeviceVisible, buffer_usage,
- allocation_size, device_placements);
- }
-
- // Allocates a device-local buffer that is optimal for use with the given
- // |device_placements|. The buffer will not be host-visible and can only be
- // used from compatible device queues.
- //
- // Fails if it is not possible to allocate and satisfy all |device_placements|
- // for the requested |buffer_usage|.
- StatusOr<ref_ptr<Buffer>> AllocateDeviceLocalBuffer(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- device_size_t allocation_size,
- absl::Span<const DevicePlacement> device_placements);
- StatusOr<ref_ptr<Buffer>> AllocateDeviceLocalBuffer(
- BufferUsageBitfield buffer_usage, device_size_t allocation_size,
- absl::Span<const DevicePlacement> device_placements) {
- return AllocateDeviceLocalBuffer(MemoryType::kDeviceLocal, buffer_usage,
- allocation_size, device_placements);
- }
-
- // Enqueues a submission against the given target |device| |command_queue|.
- // The provided |deadline| is used to determine how long the submission can
- // stay waiting in the queue prior to flushing, with absl::InfinitePast
- // indicating immediate submission and absl::InfiniteFuture indicating that
- // Flush must be called.
- //
- // If a |fence| is provided it will be signaled when the submission has
- // completed and otherwise the caller must use WaitIdle to ensure completion.
- // If a sequence of submissions are performed then the semaphore relationships
- // can be used to elide waits. Submit(A)+Submit(B, fence) where there is a
- // dependency from A->B is safe.
- //
- // All provided resources must remain alive until the provided |fence|
- // resolves or Scheduler::WaitIdle succeeds.
- //
- // Submissions may be made from any thread. Behavior is undefined
- // if a thread is performing a WaitIdle while another thread submits work.
- Status Submit(Device* device, CommandQueue* command_queue,
- absl::Span<const SubmissionBatch> batches, absl::Time deadline,
- FenceValue fence = {});
- Status Submit(Device* device, CommandQueue* command_queue,
- absl::Span<const SubmissionBatch> batches,
- absl::Duration timeout, FenceValue fence = {}) {
- return Submit(device, command_queue, batches,
- RelativeTimeoutToDeadline(timeout), fence);
- }
- Status Submit(Device* device, CommandQueue* command_queue,
- absl::Span<const SubmissionBatch> batches,
- FenceValue fence = {}) {
- return Submit(device, command_queue, batches, absl::InfinitePast(), fence);
- }
-
- // Flushes any requests that are pending in the scheduler and ensures they
- // begin executing ASAP regardless of policy.
- //
- // If any used device has encountered an error during submission at any
- // point it will be returned here (repeatedly).
- Status Flush();
-
- // Blocks until all outstanding requests have been completed.
- // This is equivalent to having waited on all outstanding fences.
- // Implicitly calls Flush to ensure delayed requests are scheduled.
- // Work submitted from other threads during a wait may not be included in the
- // wait set.
- //
- // If any used device has encountered an error during submission at any
- // point it will be returned here (repeatedly).
- Status WaitIdle(absl::Time deadline);
- inline Status WaitIdle(absl::Duration timeout) {
- return WaitIdle(RelativeTimeoutToDeadline(timeout));
- }
- inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); }
-
- private:
- mutable absl::Mutex device_mutex_;
- std::vector<std::shared_ptr<Device>> devices_ ABSL_GUARDED_BY(device_mutex_);
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_DEVICE_MANAGER_H_
diff --git a/iree/hal/driver.h b/iree/hal/driver.h
deleted file mode 100644
index 023660c..0000000
--- a/iree/hal/driver.h
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_DRIVER_H_
-#define IREE_HAL_DRIVER_H_
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "iree/base/status.h"
-#include "iree/hal/device.h"
-#include "iree/hal/device_info.h"
-
-namespace iree {
-namespace hal {
-
-class Driver {
- public:
- virtual ~Driver() = default;
-
- // Driver name used during registration.
- const std::string& name() const { return name_; }
-
- // TODO(benvanik): info/query (version number, etc).
-
- // Enumerates devices available for creation from the driver.
- // This may fail if the driver is in an invalid state but otherwise will
- // return an empty list if no devices are available.
- virtual StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() = 0;
-
- // Creates the driver-defined 'default' device.
- // This may simply be the first device enumerated.
- virtual StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() = 0;
-
- // Creates a device as queried with the given |device_info|.
- virtual StatusOr<std::shared_ptr<Device>> CreateDevice(
- const DeviceInfo& device_info) = 0;
-
- protected:
- explicit Driver(std::string name) : name_(std::move(name)) {}
-
- private:
- const std::string name_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_DRIVER_H_
diff --git a/iree/hal/driver_registry.cc b/iree/hal/driver_registry.cc
deleted file mode 100644
index 21ab5ea..0000000
--- a/iree/hal/driver_registry.cc
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/driver_registry.h"
-
-#include "iree/base/status.h"
-
-namespace iree {
-namespace hal {
-
-// static
-DriverRegistry* DriverRegistry::shared_registry() {
- static auto* singleton = new DriverRegistry();
- return singleton;
-}
-
-DriverRegistry::DriverRegistry() = default;
-
-DriverRegistry::~DriverRegistry() = default;
-
-Status DriverRegistry::Register(std::string driver_name, FactoryFn factory_fn) {
- absl::MutexLock lock(&mutex_);
- for (const auto& pair : driver_factory_fns_) {
- if (pair.first == driver_name) {
- return AlreadyExistsErrorBuilder(IREE_LOC)
- << "Driver already registered: " << driver_name;
- }
- }
- driver_factory_fns_.emplace_back(driver_name, std::move(factory_fn));
- return OkStatus();
-}
-
-bool DriverRegistry::HasDriver(absl::string_view driver_name) const {
- absl::MutexLock lock(&mutex_);
- for (const auto& pair : driver_factory_fns_) {
- if (pair.first == driver_name) {
- return true;
- }
- }
- return false;
-}
-
-std::vector<std::string> DriverRegistry::EnumerateAvailableDrivers() const {
- absl::MutexLock lock(&mutex_);
- std::vector<std::string> driver_names;
- driver_names.reserve(driver_factory_fns_.size());
- for (const auto& pair : driver_factory_fns_) {
- driver_names.push_back(pair.first);
- }
- return driver_names;
-}
-
-StatusOr<std::shared_ptr<Driver>> DriverRegistry::Create(
- absl::string_view driver_name) const {
- FactoryFn factory_fn;
- {
- absl::MutexLock lock(&mutex_);
- for (const auto& pair : driver_factory_fns_) {
- if (pair.first == driver_name) {
- factory_fn = pair.second;
- break;
- }
- }
- if (!factory_fn) {
- return NotFoundErrorBuilder(IREE_LOC)
- << "Driver " << driver_name << " not found";
- }
- }
- return factory_fn();
-}
-
-} // namespace hal
-} // namespace iree
-
-IREE_REGISTER_MODULE_INITIALIZER(
- iree_hal, ::iree::hal::DriverRegistry::shared_registry());
diff --git a/iree/hal/driver_registry.h b/iree/hal/driver_registry.h
deleted file mode 100644
index 26b05fc..0000000
--- a/iree/hal/driver_registry.h
+++ /dev/null
@@ -1,83 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_DRIVER_REGISTRY_H_
-#define IREE_HAL_DRIVER_REGISTRY_H_
-
-#include <memory>
-#include <vector>
-
-#include "absl/base/thread_annotations.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/base/init.h"
-#include "iree/base/status.h"
-#include "iree/hal/driver.h"
-
-namespace iree {
-namespace hal {
-
-// Driver registry and factory.
-// Factory functions for available drivers are registered with a given name and
-// can be invoked with a call to Create. The configuration of the drivers is
-// generally contained within the factory function and consumers of the drivers
-// don't need to fiddle with things.
-//
-// This is used for dynamic *safe* link-time driver module registration.
-// Roughly: driver_registry provides the shared registry and a way to create
-// drivers and *_driver_module.cc files register drivers when linked in.
-// Remember to alwayslink=1 on cc_libraries providing modules.
-//
-// If link-time driver registration is not desired (or possible) it's also
-// possible to explicitly register drivers via this registry. This is useful
-// when programmatically enabling drivers.
-//
-// Thread-safe.
-class DriverRegistry final {
- public:
- using FactoryFn = std::function<StatusOr<std::shared_ptr<Driver>>()>;
-
- // The shared driver registry singleton that modules use when linked in.
- static DriverRegistry* shared_registry();
-
- DriverRegistry();
- ~DriverRegistry();
-
- // Registers a driver and its factory function.
- // The function will be called to create a new driver whenever it is requested
- // via Create.
- Status Register(std::string driver_name, FactoryFn factory_fn);
-
- // Returns true if there is a driver registered with the given name.
- bool HasDriver(absl::string_view driver_name) const;
-
- // Returns a list of registered drivers.
- std::vector<std::string> EnumerateAvailableDrivers() const;
-
- // TODO(benvanik): flags for enabling debug validation/control/etc.
- // Creates a driver by name.
- StatusOr<std::shared_ptr<Driver>> Create(absl::string_view driver_name) const;
-
- private:
- mutable absl::Mutex mutex_;
- std::vector<std::pair<std::string, FactoryFn>> driver_factory_fns_
- ABSL_GUARDED_BY(mutex_);
-};
-
-} // namespace hal
-} // namespace iree
-
-IREE_DECLARE_MODULE_INITIALIZER(iree_hal);
-IREE_REQUIRE_MODULE_LINKED(iree_hal);
-
-#endif // IREE_HAL_DRIVER_REGISTRY_H_
diff --git a/iree/hal/event.h b/iree/hal/event.h
deleted file mode 100644
index c7786f4..0000000
--- a/iree/hal/event.h
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_EVENT_H_
-#define IREE_HAL_EVENT_H_
-
-#include "iree/hal/resource.h"
-
-namespace iree {
-namespace hal {
-
-// Events are used for defining synchronization scopes within CommandBuffers.
-// An event only exists within a single CommandBuffer and must not be used
-// across CommandBuffers from the same device or others.
-//
-// See CommandBuffer::SignalEvent and CommandBuffer::WaitEvents for more info.
-class Event : public Resource {
- public:
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_EVENT_H_
diff --git a/iree/hal/executable.h b/iree/hal/executable.h
deleted file mode 100644
index d724d01..0000000
--- a/iree/hal/executable.h
+++ /dev/null
@@ -1,57 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_EXECUTABLE_H_
-#define IREE_HAL_EXECUTABLE_H_
-
-#include "iree/hal/resource.h"
-
-namespace iree {
-namespace hal {
-
-class Executable : public Resource {
- public:
- ~Executable() override = default;
-
- // True if the executable was prepared with debugging enabled and the device
- // and input data support debugging (symbols present, etc).
- virtual bool supports_debugging() const = 0;
-
- // TODO(benvanik): disassembly methods.
-
- // TODO(benvanik): relative offset calculation:
- // - step once
- // - step over
- // - step out
-
- // TODO(benvanik): create executable split on breakpoint.
- // Executable should return when the breakpoint is hit without any future
- // modifications to output buffers. If the breakpoint is not hit the
- // executable should run to completion as normal.
-
- // TODO(benvanik): retrieve coverage info.
- // Returns a buffer containing offset -> coverage metrics. Note that depending
- // on the device this may only contain a single coverage metric for the entire
- // executable or some subset of the available offsets.
-
- // TODO(benvanik): retrieve profiling info.
-
- protected:
- Executable() = default;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_EXECUTABLE_H_
diff --git a/iree/hal/executable_cache.cc b/iree/hal/executable_cache.cc
deleted file mode 100644
index 26ce40c..0000000
--- a/iree/hal/executable_cache.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/executable_cache.h"
-
-namespace iree {
-namespace hal {
-
-ExecutableCache::ExecutableCache() = default;
-
-ExecutableCache::~ExecutableCache() = default;
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/executable_cache.h b/iree/hal/executable_cache.h
deleted file mode 100644
index 3e31831..0000000
--- a/iree/hal/executable_cache.h
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_EXECUTABLE_CACHE_H_
-#define IREE_HAL_EXECUTABLE_CACHE_H_
-
-#include "iree/base/bitfield.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-#include "iree/hal/executable.h"
-#include "iree/hal/executable_format.h"
-#include "iree/hal/executable_spec.h"
-
-namespace iree {
-namespace hal {
-
-// Defines how the executable cache performs preparation.
-enum class ExecutableCachingMode : uint32_t {
- // Allows the cache to reference the provided executable_data after it has
- // prepared the executable. Callers must ensure the data remains valid for the
- // lifetime of the cache. If memory mapping constant executable data from
- // disk this can be used to avoid copies.
- kAliasProvidedData = 1 << 0,
-
- // Allows the prepared executable to be cached persistently (on disk/etc).
- // Enable for any executable that is likely to be used in future runs.
- // Note that not all caches support persistent serialization and this is just
- // a hint.
- kAllowPersistentCaching = 1 << 1,
-
- // Allows the cache to optimize the executable as much as it can.
- // This may cause preparation to take significantly longer while (hopefully)
- // improving runtime performance. Avoid for one-shot executables.
- kAllowOptimization = 1 << 2,
-
- // Enables Executable debugging methods if supported by the device and
- // executable. This may disable certain optimizations or retain additional
- // data to allow disassembly, stepping, etc.
- //
- // Device must support the DeviceFeature::kDebugging feature and executables
- // must support the ExecutableFeature::kDebugging feature.
- kEnableDebugging = 1 << 3,
-
- // Enables Executable coverage if supported by the device and executable.
- // Depending on the optimization mode this may produce partial coverage
- // results (for example, when certain source operations were optimized away).
- //
- // Device must support the DeviceFeature::kCoverage feature and executables
- // must support the ExecutableFeature::kCoverage feature.
- kEnableCoverage = 1 << 4,
-
- // Enables Executable profiling if supported by the device and executable.
- // Depending on the optimization mode this may produce partial profiling
- // results. Profiling attribution (whether to the entire executable or
- // specific operations) depends on the implementation.
- //
- // Device must support the DeviceFeature::kProfiling feature and executables
- // must support the ExecutableFeature::kProfiling feature.
- kEnableProfiling = 1 << 5,
-
- // Default caching mode.
- kDefault = kAllowPersistentCaching | kAllowOptimization,
-};
-IREE_BITFIELD(ExecutableCachingMode);
-using ExecutableCachingModeBitfield = ExecutableCachingMode;
-
-// A cache of prepared executables for a particular device.
-// Caches may be shared across multiple devices from the same driver or specific
-// to individual devices. Caches may persist prepared executables across process
-// launches or reprepare them each run. Callers should assume that the cache is
-// a no-op and the returned Executables only live for as long as the cache does.
-//
-// The term 'cache' here is rather optimistic - it's perfectly acceptable for
-// implementations to not cache at all and return new Executables for each
-// PrepareExecutable called (even for the same executable). Callers should
-// expect such behavior and try to retain the results of the PrepareExecutable
-// calls to reduce overhead in re-preparing executables.
-//
-// Thread-safe - multiple threads may prepare executables (including the *same*
-// executable) simultaneously.
-class ExecutableCache {
- public:
- virtual ~ExecutableCache();
-
- // TODO(benvanik): status/queries (size, etc).
-
- // TODO(b/137153339): serialization/deserialization.
-
- // Returns true if the executable cache can prepare the given executable input
- // format. Perparation may still fail if the particular version or features
- // required by the executable are not supported.
- virtual bool CanPrepareFormat(ExecutableFormat format) const = 0;
-
- // Prepares an executable for use.
- // The provided |spec| and |executable_data| will be used to either lookup a
- // previously prepared executable in the cache or prepare a new one.
- //
- // Depending on the driver preparation may take a non-trivial amount of time
- // (such as when JITing/etc). As the cache is internally synchronized callers
- // can issue preparation requests from multiple threads - even for the same
- // executables - and calls will block until preparation completes.
- //
- // When preparing a large number of executables it's recommended to use the
- // PrepareExecutables method to batch and wait on the results.
- virtual StatusOr<ref_ptr<Executable>> PrepareExecutable(
- ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) = 0;
-
- protected:
- ExecutableCache();
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_EXECUTABLE_CACHE_H_
diff --git a/iree/hal/executable_spec.h b/iree/hal/executable_spec.h
deleted file mode 100644
index a88553f..0000000
--- a/iree/hal/executable_spec.h
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_EXECUTABLE_SPEC_H_
-#define IREE_HAL_EXECUTABLE_SPEC_H_
-
-#include "absl/types/span.h"
-#include "iree/hal/executable_format.h"
-
-namespace iree {
-namespace hal {
-
-// Defines an executable specification used by a cache to prepare an executable.
-struct ExecutableSpec {
- // TODO(benvanik): pre-populated hash_code/key to avoid calculation.
-
- // Format of the executable input data.
- ExecutableFormat format = kExecutableFormatUnspecified;
-
- // A reference to the executable data as input to the cache.
- // If ExecutableCachingMode::kAliasProvidedData is set then this reference
- // may be retained by the cache and the backing buffer must be kept valid for
- // the lifetime of the cache.
- absl::Span<const uint8_t> executable_data;
-
- // TODO(benvanik): add specialization info (constants/defines).
- // TODO(benvanik): add compiler flags? could treat as opaque.
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_EXECUTABLE_SPEC_H_
diff --git a/iree/hal/fence.h b/iree/hal/fence.h
deleted file mode 100644
index b395e91..0000000
--- a/iree/hal/fence.h
+++ /dev/null
@@ -1,72 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_FENCE_H_
-#define IREE_HAL_FENCE_H_
-
-#include <cstdint>
-
-#include "iree/base/status.h"
-#include "iree/hal/resource.h"
-
-namespace iree {
-namespace hal {
-
-// Synchronization mechanism for device->host notification.
-// Fences behave like timeline semaphores and contain a monotonically increasing
-// uint64_t payload. They may be waited on any number of times - even if they
-// have already been signaled.
-//
-// A fence is updated to its new value after all prior commands have completed
-// but the delay between completion and the host being woken varies. Some
-// implementations may coalesce fences to avoid spurious waking while others
-// will immediately synchronize with the host.
-//
-// The primary use of fences is for resource lifetime management: all resources
-// used by a set of submission batches must be considered live until the fence
-// attached to the submission has signaled.
-//
-// Fences may be set to a permanently failed state by implementations when
-// errors occur during asynchronous execution. Users are expected to propagate
-// the failures and possibly reset the entire device that produced the error.
-//
-// For more information on fences see the following docs describing how
-// timelines are generally used (specifically in the device->host case):
-// https://www.youtube.com/watch?v=SpE--Rf516Y
-// https://www.khronos.org/assets/uploads/developers/library/2018-xdc/Vulkan-Timeline-Semaphores-Part-1_Sep18.pdf
-// https://docs.microsoft.com/en-us/windows/win32/direct3d12/user-mode-heap-synchronization
-class Fence : public Resource {
- public:
- // Returns a permanent failure status if the fence is indicating an
- // asynchronous failure.
- //
- // Returns the status at the time the method is called without blocking and as
- // such is only valid after a fence has been signaled. The same failure status
- // will be returned regardless of when in the timeline the error occurred.
- virtual Status status() const = 0;
-
- // Queries the current payload of the fence. As the payload is monotonically
- // increasing it is guaranteed that the value is at least equal to the
- // previous result of a QueryValue call and coherent with any waits for a
- // specified value via Device::WaitAllFences.
- virtual StatusOr<uint64_t> QueryValue() = 0;
-};
-
-// A reference to a fence and associated payload value.
-using FenceValue = std::pair<Fence*, uint64_t>;
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_FENCE_H_
diff --git a/iree/hal/heap_buffer.cc b/iree/hal/heap_buffer.cc
deleted file mode 100644
index 8c78935..0000000
--- a/iree/hal/heap_buffer.cc
+++ /dev/null
@@ -1,190 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/heap_buffer.h"
-
-#include <cstdint>
-#include <cstdlib>
-#include <string>
-#include <utility>
-
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/host/host_buffer.h"
-
-namespace iree {
-namespace hal {
-
-namespace {
-
-// An allocator that allocates or wraps host-only buffers.
-// The resulting buffers are not usable by most devices without a copy and
-// using a device allocator is strongly preferred.
-class HeapAllocator : public Allocator {
- public:
- // Returns a singleton heap allocator that can provide buffers that have
- // MemoryType::kHostLocal and are allocated with malloc/free.
- // These buffers will not be usable by devices directly and may incur
- // additional copies.
- static Allocator* std_heap();
-
- // TODO(benvanik): specify custom allocator (not malloc/free).
- HeapAllocator();
- ~HeapAllocator() override;
-
- bool CanUseBufferLike(Allocator* source_allocator,
- MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- BufferUsageBitfield intended_usage) const override;
-
- bool CanAllocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) const override;
-
- StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) override;
-
- StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield buffer_usage,
- void* data,
- size_t data_length) override;
-};
-
-// static
-Allocator* HeapAllocator::std_heap() {
- static Allocator* std_heap_allocator = new HeapAllocator();
- return std_heap_allocator;
-}
-
-HeapAllocator::HeapAllocator() = default;
-
-HeapAllocator::~HeapAllocator() = default;
-
-bool HeapAllocator::CanUseBufferLike(Allocator* source_allocator,
- MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- BufferUsageBitfield intended_usage) const {
- // The host can use anything with kHostVisible.
- if (!AnyBitSet(memory_type & MemoryType::kHostVisible)) {
- return false;
- }
-
- // Host currently uses mapping to copy buffers, which is done a lot.
- if (!AnyBitSet(buffer_usage & BufferUsage::kMapping)) {
- return false;
- }
-
- return true;
-}
-
-bool HeapAllocator::CanAllocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) const {
- // This host only allocator cannot serve device visible allocation as we
- // can't know which devices these buffers will be used with.
- return (memory_type & MemoryType::kHostLocal) == MemoryType::kHostLocal &&
- !AnyBitSet(memory_type & MemoryType::kDeviceLocal) &&
- !AnyBitSet(memory_type & MemoryType::kDeviceVisible);
-}
-
-StatusOr<ref_ptr<Buffer>> HeapAllocator::Allocate(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- size_t allocation_size) {
- IREE_TRACE_SCOPE0("HeapAllocator::Allocate");
-
- if (!CanAllocate(memory_type, buffer_usage, allocation_size)) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Allocation not supported; memory_type="
- << MemoryTypeString(memory_type)
- << ", buffer_usage=" << BufferUsageString(buffer_usage)
- << ", allocation_size=" << allocation_size;
- }
-
- void* malloced_data = std::calloc(1, allocation_size);
- if (!malloced_data) {
- return ResourceExhaustedErrorBuilder(IREE_LOC)
- << "Failed to malloc " << allocation_size << " bytes";
- }
-
- auto buffer =
- make_ref<HostBuffer>(this, memory_type, MemoryAccess::kAll, buffer_usage,
- allocation_size, malloced_data, true);
- return buffer;
-}
-
-StatusOr<ref_ptr<Buffer>> HeapAllocator::WrapMutable(
- MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access,
- BufferUsageBitfield buffer_usage, void* data, size_t data_length) {
- auto buffer = make_ref<HostBuffer>(this, memory_type, allowed_access,
- buffer_usage, data_length, data, false);
- return buffer;
-}
-
-} // namespace
-
-// static
-ref_ptr<Buffer> HeapBuffer::Allocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield usage,
- size_t allocation_size) {
- auto buffer_or =
- HeapAllocator::std_heap()->Allocate(memory_type, usage, allocation_size);
- return std::move(buffer_or.ValueOrDie());
-}
-
-// static
-ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage,
- const void* data, size_t data_length) {
- return AllocateCopy(usage, MemoryAccess::kAll, data, data_length);
-}
-
-// static
-ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage,
- MemoryAccessBitfield allowed_access,
- const void* data, size_t data_length) {
- IREE_TRACE_SCOPE0("HeapBuffer::AllocateCopy");
- // Ensure we can map so that we can copy into it.
- usage |= BufferUsage::kMapping;
- auto buffer_or = HeapAllocator::std_heap()->Allocate(MemoryType::kHostLocal,
- usage, data_length);
- auto buffer = std::move(buffer_or.ValueOrDie());
- buffer->WriteData(0, data, data_length).IgnoreError();
- buffer->set_allowed_access(allowed_access);
- return buffer;
-}
-
-// static
-ref_ptr<Buffer> HeapBuffer::Wrap(MemoryTypeBitfield memory_type,
- BufferUsageBitfield usage, const void* data,
- size_t data_length) {
- auto buffer_or =
- HeapAllocator::std_heap()->Wrap(memory_type, usage, data, data_length);
- return std::move(buffer_or.ValueOrDie());
-}
-
-// static
-ref_ptr<Buffer> HeapBuffer::WrapMutable(MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield usage, void* data,
- size_t data_length) {
- auto buffer_or = HeapAllocator::std_heap()->WrapMutable(
- memory_type, allowed_access, usage, data, data_length);
- return std::move(buffer_or.ValueOrDie());
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/heap_buffer.h b/iree/hal/heap_buffer.h
deleted file mode 100644
index eba9b72..0000000
--- a/iree/hal/heap_buffer.h
+++ /dev/null
@@ -1,117 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_HEAP_BUFFER_H_
-#define IREE_HAL_HEAP_BUFFER_H_
-
-#include <memory>
-
-#include "iree/base/status.h"
-#include "iree/hal/buffer.h"
-
-namespace iree {
-namespace hal {
-
-// Factory for buffers that are allocated from the host heap (malloc/free).
-// These buffers cannot be used by devices and will incur copies/transfers when
-// used. Prefer device-specific allocators instead.
-class HeapBuffer {
- public:
- // Allocates a zeroed host heap buffer of the given size.
- // Returns a buffer allocated with malloc and have MemoryType::kHostLocal
- // and will not be usable by devices without copies.
- static ref_ptr<Buffer> Allocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield usage,
- size_t allocation_size);
- static ref_ptr<Buffer> Allocate(BufferUsageBitfield usage,
- size_t allocation_size) {
- return Allocate(MemoryType::kHostLocal, usage, allocation_size);
- }
-
- // Allocates a host heap buffer with a copy of the given data.
- // Returns a buffer allocated with malloc and have MemoryType::kHostLocal
- // and will not be usable by devices without copies.
- static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage,
- const void* data, size_t data_length);
- static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage,
- MemoryAccessBitfield allowed_access,
- const void* data, size_t data_length);
- template <typename T>
- static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage,
- absl::Span<const T> data);
- template <typename T>
- static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage,
- MemoryAccessBitfield allowed_access,
- absl::Span<const T> data);
-
- // Wraps an existing host heap allocation in a buffer.
- // Ownership of the host allocation remains with the caller and the memory
- // must remain valid for so long as the Buffer may be in use.
- // Will have MemoryType::kHostLocal in most cases and may not be usable
- // by the device.
- static ref_ptr<Buffer> Wrap(MemoryTypeBitfield memory_type,
- BufferUsageBitfield usage, const void* data,
- size_t data_length);
- static ref_ptr<Buffer> WrapMutable(MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield usage, void* data,
- size_t data_length);
- template <typename T>
- static ref_ptr<Buffer> Wrap(MemoryTypeBitfield memory_type,
- BufferUsageBitfield usage,
- absl::Span<const T> data);
- template <typename T>
- static ref_ptr<Buffer> WrapMutable(MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield usage,
- absl::Span<T> data);
-};
-
-// Inline functions and template definitions follow:
-
-template <typename T>
-ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage,
- absl::Span<const T> data) {
- return HeapBuffer::AllocateCopy(usage, MemoryAccess::kAll, data);
-}
-
-template <typename T>
-ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage,
- MemoryAccessBitfield allowed_access,
- absl::Span<const T> data) {
- return HeapBuffer::AllocateCopy(usage, allowed_access, data.data(),
- data.size() * sizeof(T));
-}
-
-template <typename T>
-ref_ptr<Buffer> HeapBuffer::Wrap(MemoryTypeBitfield memory_type,
- BufferUsageBitfield usage,
- absl::Span<const T> data) {
- return HeapBuffer::Wrap(memory_type, usage, data.data(),
- data.size() * sizeof(T));
-}
-
-template <typename T>
-ref_ptr<Buffer> HeapBuffer::WrapMutable(MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield usage,
- absl::Span<T> data) {
- return HeapBuffer::WrapMutable(memory_type, allowed_access, usage,
- data.data(), data.size() * sizeof(T));
-}
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_HEAP_BUFFER_H_
diff --git a/iree/hal/host/BUILD b/iree/hal/host/BUILD
deleted file mode 100644
index 6e6c51f..0000000
--- a/iree/hal/host/BUILD
+++ /dev/null
@@ -1,155 +0,0 @@
-# Default implementations for HAL types that use the host resources.
-# These are generally just wrappers around host heap memory and host threads.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "async_command_queue",
- srcs = ["async_command_queue.cc"],
- hdrs = ["async_command_queue.h"],
- deps = [
- ":host_submission_queue",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_queue",
- "//iree/hal:fence",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/synchronization",
- ],
-)
-
-cc_test(
- name = "async_command_queue_test",
- srcs = ["async_command_queue_test.cc"],
- deps = [
- ":async_command_queue",
- ":host_submission_queue",
- "//iree/base:status",
- "//iree/base:status_matchers",
- "//iree/base:time",
- "//iree/hal:command_queue",
- "//iree/hal/testing:mock_command_buffer",
- "//iree/hal/testing:mock_command_queue",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "host_buffer",
- srcs = ["host_buffer.cc"],
- hdrs = ["host_buffer.h"],
- deps = [
- "//iree/base:logging",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/hal:buffer",
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "host_event",
- srcs = ["host_event.cc"],
- hdrs = ["host_event.h"],
- deps = [
- "//iree/hal:event",
- ],
-)
-
-cc_library(
- name = "host_fence",
- srcs = ["host_fence.cc"],
- hdrs = ["host_fence.h"],
- deps = [
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:fence",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_test(
- name = "host_fence_test",
- srcs = ["host_fence_test.cc"],
- deps = [
- ":host_fence",
- "//iree/base:status",
- "//iree/base:status_matchers",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "host_local_allocator",
- srcs = ["host_local_allocator.cc"],
- hdrs = ["host_local_allocator.h"],
- deps = [
- ":host_buffer",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:buffer",
- ],
-)
-
-cc_library(
- name = "host_local_command_processor",
- srcs = ["host_local_command_processor.cc"],
- hdrs = ["host_local_command_processor.h"],
- deps = [
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer",
- ],
-)
-
-cc_library(
- name = "host_submission_queue",
- srcs = ["host_submission_queue.cc"],
- hdrs = ["host_submission_queue.h"],
- deps = [
- ":host_fence",
- "//iree/base:intrusive_list",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_queue",
- "//iree/hal:fence",
- "//iree/hal:semaphore",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- ],
-)
-
-cc_test(
- name = "host_submission_queue_test",
- srcs = ["host_submission_queue_test.cc"],
- deps = [
- ":host_submission_queue",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "inproc_command_buffer",
- srcs = ["inproc_command_buffer.cc"],
- hdrs = ["inproc_command_buffer.h"],
- deps = [
- "//iree/base:arena",
- "//iree/base:intrusive_list",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer",
- ],
-)
diff --git a/iree/hal/host/async_command_queue.cc b/iree/hal/host/async_command_queue.cc
deleted file mode 100644
index f7e549f..0000000
--- a/iree/hal/host/async_command_queue.cc
+++ /dev/null
@@ -1,127 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/async_command_queue.h"
-
-#include "absl/base/thread_annotations.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-
-namespace iree {
-namespace hal {
-
-AsyncCommandQueue::AsyncCommandQueue(std::unique_ptr<CommandQueue> target_queue)
- : CommandQueue(target_queue->name(), target_queue->supported_categories()),
- target_queue_(std::move(target_queue)) {
- IREE_TRACE_SCOPE0("AsyncCommandQueue::ctor");
- thread_ = std::thread([this]() { ThreadMain(); });
-}
-
-AsyncCommandQueue::~AsyncCommandQueue() {
- IREE_TRACE_SCOPE0("AsyncCommandQueue::dtor");
- {
- // Signal to thread that we want to stop. Note that the thread may have
- // already been stopped and that's ok (as we'll Join right away).
- // The thread will finish processing any queued submissions.
- absl::MutexLock lock(&submission_mutex_);
- submission_queue_.SignalShutdown();
- }
- thread_.join();
-
- // Ensure we shut down OK.
- {
- absl::MutexLock lock(&submission_mutex_);
- CHECK(submission_queue_.empty())
- << "Dirty shutdown of async queue (unexpected thread exit?)";
- }
-}
-
-void AsyncCommandQueue::ThreadMain() {
- // TODO(benvanik): make this safer (may die if trace is flushed late).
- IREE_TRACE_THREAD_ENABLE(target_queue_->name().c_str());
-
- bool is_exiting = false;
- while (!is_exiting) {
- // Block until we are either requested to exit or there are pending
- // submissions.
- submission_mutex_.Lock();
- submission_mutex_.Await(absl::Condition(
- +[](HostSubmissionQueue* queue) {
- return queue->has_shutdown() || !queue->empty();
- },
- &submission_queue_));
- if (!submission_queue_.empty()) {
- // Run all ready submissions (this may be called many times).
- submission_mutex_.AssertHeld();
- submission_queue_
- .ProcessBatches(
- [this](absl::Span<CommandBuffer* const> command_buffers)
- ABSL_EXCLUSIVE_LOCKS_REQUIRED(submission_mutex_) {
- // Release the lock while we perform the processing so that
- // other threads can submit more work.
- submission_mutex_.AssertHeld();
- submission_mutex_.Unlock();
-
- // Relay the command buffers to the target queue.
- // Since we are taking care of all synchronization they
- // don't need any waiters or fences.
- auto status = target_queue_->Submit(
- {{}, command_buffers, {}}, {nullptr, 0u});
-
- // Take back the lock so we can manipulate the queue safely.
- submission_mutex_.Lock();
- submission_mutex_.AssertHeld();
-
- return status;
- })
- .IgnoreError();
- submission_mutex_.AssertHeld();
- }
- if (submission_queue_.has_shutdown()) {
- // Exit when there are no more submissions to process and an exit was
- // requested (or we errored out).
- is_exiting = true;
- }
- submission_mutex_.Unlock();
- }
-}
-
-Status AsyncCommandQueue::Submit(absl::Span<const SubmissionBatch> batches,
- FenceValue fence) {
- IREE_TRACE_SCOPE0("AsyncCommandQueue::Submit");
- absl::MutexLock lock(&submission_mutex_);
- return submission_queue_.Enqueue(batches, fence);
-}
-
-Status AsyncCommandQueue::WaitIdle(absl::Time deadline) {
- IREE_TRACE_SCOPE0("AsyncCommandQueue::WaitIdle");
-
- // Wait until the deadline, the thread exits, or there are no more pending
- // submissions.
- absl::MutexLock lock(&submission_mutex_);
- if (!submission_mutex_.AwaitWithDeadline(
- absl::Condition(
- +[](HostSubmissionQueue* queue) {
- return queue->empty() || !queue->permanent_error().ok();
- },
- &submission_queue_),
- deadline)) {
- return DeadlineExceededErrorBuilder(IREE_LOC)
- << "Deadline exceeded waiting for submission thread to go idle";
- }
- return submission_queue_.permanent_error();
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/async_command_queue.h b/iree/hal/host/async_command_queue.h
deleted file mode 100644
index 5716d55..0000000
--- a/iree/hal/host/async_command_queue.h
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_
-#define IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_
-
-#include <memory>
-#include <thread> // NOLINT
-
-#include "absl/base/thread_annotations.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/fence.h"
-#include "iree/hal/host/host_submission_queue.h"
-
-namespace iree {
-namespace hal {
-
-// Asynchronous command queue wrapper.
-// This creates a single thread to perform all CommandQueue operations. Any
-// submitted CommandBuffer is dispatched in FIFO order on the queue thread
-// against the provided |target_queue|.
-//
-// Target queues will receive submissions containing only command buffers as
-// all semaphore synchronization is handled by the wrapper. Fences will also be
-// omitted and code should safely handle nullptr.
-//
-// AsyncCommandQueue (as with CommandQueue) is thread-safe. Multiple threads
-// may submit command buffers concurrently, though the order of execution in
-// such a case depends entirely on the synchronization primitives provided.
-class AsyncCommandQueue final : public CommandQueue {
- public:
- explicit AsyncCommandQueue(std::unique_ptr<CommandQueue> target_queue);
- ~AsyncCommandQueue() override;
-
- Status Submit(absl::Span<const SubmissionBatch> batches,
- FenceValue fence) override;
-
- Status WaitIdle(absl::Time deadline) override;
-
- private:
- // Thread entry point for the async worker thread.
- // Waits for submissions to be queued up and processes them eagerly.
- void ThreadMain();
-
- // CommandQueue that the async queue relays submissions into.
- std::unique_ptr<CommandQueue> target_queue_;
-
- // Thread that runs the ThreadMain() function and processes submissions.
- std::thread thread_;
-
- // Queue that manages submission ordering.
- mutable absl::Mutex submission_mutex_;
- HostSubmissionQueue submission_queue_ ABSL_GUARDED_BY(submission_mutex_);
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_
diff --git a/iree/hal/host/async_command_queue_test.cc b/iree/hal/host/async_command_queue_test.cc
deleted file mode 100644
index 307ed3b..0000000
--- a/iree/hal/host/async_command_queue_test.cc
+++ /dev/null
@@ -1,232 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/async_command_queue.h"
-
-#include <cstdint>
-#include <memory>
-#include <utility>
-
-#include "absl/memory/memory.h"
-#include "absl/time/time.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status.h"
-#include "iree/base/status_matchers.h"
-#include "iree/base/time.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/host/host_submission_queue.h"
-#include "iree/hal/testing/mock_command_buffer.h"
-#include "iree/hal/testing/mock_command_queue.h"
-
-namespace iree {
-namespace hal {
-namespace {
-
-using ::testing::_;
-
-using testing::MockCommandBuffer;
-using testing::MockCommandQueue;
-
-struct AsyncCommandQueueTest : public ::testing::Test {
- MockCommandQueue* mock_target_queue;
- std::unique_ptr<CommandQueue> command_queue;
-
- void SetUp() override {
- auto mock_queue = absl::make_unique<MockCommandQueue>(
- "mock", CommandCategory::kTransfer | CommandCategory::kDispatch);
- mock_target_queue = mock_queue.get();
- command_queue = absl::make_unique<AsyncCommandQueue>(std::move(mock_queue));
- }
-
- void TearDown() override {
- command_queue.reset();
- mock_target_queue = nullptr;
- }
-};
-
-// Tests that submitting a command buffer and immediately waiting will not
-// deadlock.
-TEST_F(AsyncCommandQueueTest, BlockingSubmit) {
- ::testing::InSequence sequence;
-
- auto cmd_buffer = make_ref<MockCommandBuffer>(
- nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
-
- EXPECT_CALL(*mock_target_queue, Submit(_, _))
- .WillOnce(
- [&](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
- CHECK_EQ(1, batches.size());
- CHECK_EQ(1, batches[0].command_buffers.size());
- CHECK_EQ(cmd_buffer.get(), batches[0].command_buffers[0]);
- CHECK_EQ(nullptr, fence.first);
- return OkStatus();
- });
- HostFence fence(0u);
- ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u}));
- ASSERT_OK(HostFence::WaitForFences({{&fence, 1u}}, /*wait_all=*/true,
- absl::InfiniteFuture()));
-}
-
-// Tests that failure is propagated along the fence from the target queue.
-TEST_F(AsyncCommandQueueTest, PropagateSubmitFailure) {
- ::testing::InSequence sequence;
-
- auto cmd_buffer = make_ref<MockCommandBuffer>(
- nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
-
- EXPECT_CALL(*mock_target_queue, Submit(_, _))
- .WillOnce(
- [](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
- return DataLossErrorBuilder(IREE_LOC);
- });
- HostFence fence(0u);
- ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u}));
- EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences(
- {{&fence, 1u}}, /*wait_all=*/true, absl::InfiniteFuture())));
-}
-
-// Tests that waiting for idle is a no-op when nothing is queued.
-TEST_F(AsyncCommandQueueTest, WaitIdleWhileIdle) {
- ASSERT_OK(command_queue->WaitIdle());
-}
-
-// Tests that waiting for idle will block when work is pending/in-flight.
-TEST_F(AsyncCommandQueueTest, WaitIdleWithPending) {
- ::testing::InSequence sequence;
-
- auto cmd_buffer = make_ref<MockCommandBuffer>(
- nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
-
- EXPECT_CALL(*mock_target_queue, Submit(_, _))
- .WillOnce(
- [](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
- Sleep(absl::Milliseconds(100));
- return OkStatus();
- });
- HostFence fence(0u);
- ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u}));
-
- // This should block for a sec or two.
- ASSERT_OK(command_queue->WaitIdle());
-
- // Should have already expired.
- ASSERT_OK_AND_ASSIGN(uint64_t value, fence.QueryValue());
- ASSERT_EQ(1u, value);
-}
-
-// Tests that waiting for idle with multiple pending submissions will wait until
-// all of them complete while still allowing incremental progress.
-TEST_F(AsyncCommandQueueTest, WaitIdleAndProgress) {
- ::testing::InSequence sequence;
-
- EXPECT_CALL(*mock_target_queue, Submit(_, _))
- .WillRepeatedly(
- [](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
- Sleep(absl::Milliseconds(100));
- return OkStatus();
- });
-
- auto cmd_buffer_0 = make_ref<MockCommandBuffer>(
- nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
- auto cmd_buffer_1 = make_ref<MockCommandBuffer>(
- nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
-
- HostFence fence_0(0u);
- ASSERT_OK(
- command_queue->Submit({{}, {cmd_buffer_0.get()}, {}}, {&fence_0, 1u}));
- HostFence fence_1(0u);
- ASSERT_OK(
- command_queue->Submit({{}, {cmd_buffer_1.get()}, {}}, {&fence_1, 1u}));
-
- // This should block for a sec or two.
- ASSERT_OK(command_queue->WaitIdle());
-
- // Both should have already expired.
- ASSERT_OK_AND_ASSIGN(uint64_t value_0, fence_0.QueryValue());
- ASSERT_EQ(1u, value_0);
- ASSERT_OK_AND_ASSIGN(uint64_t value_1, fence_1.QueryValue());
- ASSERT_EQ(1u, value_1);
-}
-
-// Tests that failures are sticky.
-TEST_F(AsyncCommandQueueTest, StickyFailures) {
- ::testing::InSequence sequence;
-
- // Fail.
- EXPECT_CALL(*mock_target_queue, Submit(_, _))
- .WillOnce(
- [](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
- Sleep(absl::Milliseconds(100));
- return DataLossErrorBuilder(IREE_LOC);
- });
- auto cmd_buffer_0 = make_ref<MockCommandBuffer>(
- nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
- HostFence fence_0(0u);
- ASSERT_OK(
- command_queue->Submit({{}, {cmd_buffer_0.get()}, {}}, {&fence_0, 1u}));
- EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences(
- {{&fence_0, 1u}}, /*wait_all=*/true, absl::InfiniteFuture())));
-
- // Future flushes/waits/etc should also fail.
- EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle()));
-
- // Future submits should fail asynchronously.
- auto cmd_buffer_1 = make_ref<MockCommandBuffer>(
- nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
- HostFence fence_1(0u);
- EXPECT_TRUE(IsDataLoss(
- command_queue->Submit({{}, {cmd_buffer_1.get()}, {}}, {&fence_1, 1u})));
-}
-
-// Tests that a failure with two submissions pending causes the second to
-// bail as well.
-TEST_F(AsyncCommandQueueTest, FailuresCascadeAcrossSubmits) {
- ::testing::InSequence sequence;
-
- // Fail.
- EXPECT_CALL(*mock_target_queue, Submit(_, _))
- .WillOnce(
- [](absl::Span<const SubmissionBatch> batches, FenceValue fence) {
- Sleep(absl::Milliseconds(100));
- return DataLossErrorBuilder(IREE_LOC);
- });
-
- auto cmd_buffer_0 = make_ref<MockCommandBuffer>(
- nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
- auto cmd_buffer_1 = make_ref<MockCommandBuffer>(
- nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer);
-
- HostBinarySemaphore semaphore_0_1(false);
- HostFence fence_0(0u);
- ASSERT_OK(command_queue->Submit({{}, {cmd_buffer_0.get()}, {&semaphore_0_1}},
- {&fence_0, 1u}));
- HostFence fence_1(0u);
- ASSERT_OK(command_queue->Submit({{&semaphore_0_1}, {cmd_buffer_1.get()}, {}},
- {&fence_1, 1u}));
-
- EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle()));
-
- EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences(
- {{&fence_0, 1u}}, /*wait_all=*/true, absl::InfiniteFuture())));
- EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences(
- {{&fence_1, 1u}}, /*wait_all=*/true, absl::InfiniteFuture())));
-
- // Future flushes/waits/etc should also fail.
- EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle()));
-}
-
-} // namespace
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/host_buffer.cc b/iree/hal/host/host_buffer.cc
deleted file mode 100644
index 555f6ac..0000000
--- a/iree/hal/host/host_buffer.cc
+++ /dev/null
@@ -1,148 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/host_buffer.h"
-
-#include <cstdint>
-#include <cstdlib>
-#include <cstring>
-
-#include "iree/base/logging.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace hal {
-
-class Allocator;
-
-HostBuffer::HostBuffer(Allocator* allocator, MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield usage, device_size_t allocation_size,
- void* data, bool owns_data)
- : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, 0,
- allocation_size),
- data_(data),
- owns_data_(owns_data) {}
-
-HostBuffer::~HostBuffer() {
- if (owns_data_ && data_) {
- std::free(data_);
- data_ = nullptr;
- }
-}
-
-Status HostBuffer::FillImpl(device_size_t byte_offset,
- device_size_t byte_length, const void* pattern,
- device_size_t pattern_length) {
- auto data_ptr = data_;
- switch (pattern_length) {
- case 1: {
- uint8_t* data = static_cast<uint8_t*>(data_ptr);
- uint8_t value_bits = *static_cast<const uint8_t*>(pattern);
- std::fill_n(data + byte_offset, byte_length, value_bits);
- break;
- }
- case 2: {
- uint16_t* data = static_cast<uint16_t*>(data_ptr);
- uint16_t value_bits = *static_cast<const uint16_t*>(pattern);
- std::fill_n(data + byte_offset / sizeof(uint16_t),
- byte_length / sizeof(uint16_t), value_bits);
- break;
- }
- case 4: {
- uint32_t* data = static_cast<uint32_t*>(data_ptr);
- uint32_t value_bits = *static_cast<const uint32_t*>(pattern);
- std::fill_n(data + byte_offset / sizeof(uint32_t),
- byte_length / sizeof(uint32_t), value_bits);
- break;
- }
- default:
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Unsupported scalar data size: " << pattern_length;
- }
- return OkStatus();
-}
-
-Status HostBuffer::ReadDataImpl(device_size_t source_offset, void* data,
- device_size_t data_length) {
- auto data_ptr = static_cast<uint8_t*>(data_);
- std::memcpy(data, data_ptr + source_offset, data_length);
- return OkStatus();
-}
-
-Status HostBuffer::WriteDataImpl(device_size_t target_offset, const void* data,
- device_size_t data_length) {
- auto data_ptr = static_cast<uint8_t*>(data_);
- std::memcpy(data_ptr + target_offset, data, data_length);
- return OkStatus();
-}
-
-Status HostBuffer::CopyDataImpl(device_size_t target_offset,
- Buffer* source_buffer,
- device_size_t source_offset,
- device_size_t data_length) {
- // This is pretty terrible. Let's not do this.
- // TODO(benvanik): a way for allocators to indicate transfer compat.
- ASSIGN_OR_RETURN(auto source_data,
- source_buffer->MapMemory<uint8_t>(
- MemoryAccess::kRead, source_offset, data_length));
- CHECK_EQ(data_length, source_data.size());
- auto data_ptr = static_cast<uint8_t*>(data_);
- std::memcpy(data_ptr + target_offset, source_data.data(), data_length);
- return OkStatus();
-}
-
-Status HostBuffer::MapMemoryImpl(MappingMode mapping_mode,
- MemoryAccessBitfield memory_access,
- device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void** out_data) {
- auto data_ptr = static_cast<uint8_t*>(data_);
- *out_data = data_ptr + local_byte_offset;
-
- // If we mapped for discard scribble over the bytes. This is not a mandated
- // behavior but it will make debugging issues easier. Alternatively for
- // heap buffers we could reallocate them such that ASAN yells, but that
- // would only work if the entire buffer was discarded.
-#ifndef NDEBUG
- if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) {
- std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length);
- }
-#endif // !NDEBUG
-
- return OkStatus();
-}
-
-Status HostBuffer::UnmapMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void* data) {
- // No-op? We still want error checking to make finding misuse easier.
- return OkStatus();
-}
-
-Status HostBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) {
- // No-op? We still want error checking to make finding misuse easier.
- return OkStatus();
-}
-
-Status HostBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) {
- // No-op? We still want error checking to make finding misuse easier.
- return OkStatus();
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/host_buffer.h b/iree/hal/host/host_buffer.h
deleted file mode 100644
index 7a52758..0000000
--- a/iree/hal/host/host_buffer.h
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_HOST_BUFFER_H_
-#define IREE_HAL_HOST_BUFFER_H_
-
-#include <cstdint>
-
-#include "iree/base/status.h"
-#include "iree/hal/buffer.h"
-
-namespace iree {
-namespace hal {
-
-// A buffer type that operates on host pointers.
-// This can be used by Allocator implementations when they support operating
-// on host memory (or mapping their memory to host memory).
-class HostBuffer : public Buffer {
- public:
- HostBuffer(Allocator* allocator, MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
- device_size_t allocation_size, void* data, bool owns_data);
-
- ~HostBuffer() override;
-
- protected:
- Status FillImpl(device_size_t byte_offset, device_size_t byte_length,
- const void* pattern, device_size_t pattern_length) override;
- Status ReadDataImpl(device_size_t source_offset, void* data,
- device_size_t data_length) override;
- Status WriteDataImpl(device_size_t target_offset, const void* data,
- device_size_t data_length) override;
- Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer,
- device_size_t source_offset,
- device_size_t data_length) override;
- Status MapMemoryImpl(MappingMode mapping_mode,
- MemoryAccessBitfield memory_access,
- device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void** out_data) override;
- Status UnmapMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length, void* data) override;
- Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) override;
- Status FlushMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) override;
-
- private:
- void* data_ = nullptr;
- bool owns_data_ = false;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_HOST_BUFFER_H_
diff --git a/iree/hal/host/host_event.cc b/iree/hal/host/host_event.cc
deleted file mode 100644
index 9ffac59..0000000
--- a/iree/hal/host/host_event.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/host_event.h"
-
-namespace iree {
-namespace hal {
-
-HostEvent::HostEvent() = default;
-
-HostEvent::~HostEvent() = default;
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/host_event.h b/iree/hal/host/host_event.h
deleted file mode 100644
index c5fb33a..0000000
--- a/iree/hal/host/host_event.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_HOST_HOST_EVENT_H_
-#define IREE_HAL_HOST_HOST_EVENT_H_
-
-#include "iree/hal/event.h"
-
-namespace iree {
-namespace hal {
-
-class HostEvent final : public Event {
- public:
- HostEvent();
- ~HostEvent() override;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_HOST_HOST_EVENT_H_
diff --git a/iree/hal/host/host_fence.cc b/iree/hal/host/host_fence.cc
deleted file mode 100644
index 6932b20..0000000
--- a/iree/hal/host/host_fence.cc
+++ /dev/null
@@ -1,110 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/host_fence.h"
-
-#include <atomic>
-#include <cstdint>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-
-namespace iree {
-namespace hal {
-
-HostFence::HostFence(uint64_t initial_value) : value_(initial_value) {}
-
-HostFence::~HostFence() = default;
-
-Status HostFence::status() const {
- absl::MutexLock lock(&mutex_);
- return status_;
-}
-
-StatusOr<uint64_t> HostFence::QueryValue() {
- return value_.load(std::memory_order_acquire);
-}
-
-Status HostFence::Signal(uint64_t value) {
- absl::MutexLock lock(&mutex_);
- if (!status_.ok()) {
- return status_;
- }
- if (value_.exchange(value) >= value) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Fence values must be monotonically increasing";
- }
- return OkStatus();
-}
-
-Status HostFence::Fail(Status status) {
- absl::MutexLock lock(&mutex_);
- status_ = status;
- value_.store(UINT64_MAX, std::memory_order_release);
- return OkStatus();
-}
-
-// static
-Status HostFence::WaitForFences(absl::Span<const FenceValue> fences,
- bool wait_all, absl::Time deadline) {
- IREE_TRACE_SCOPE0("HostFence::WaitForFences");
-
- // Some of the fences may already be signaled; we only need to wait for those
- // that are not yet at the expected value.
- using HostFenceValue = std::pair<HostFence*, uint64_t>;
- absl::InlinedVector<HostFenceValue, 4> waitable_fences;
- waitable_fences.reserve(fences.size());
- for (auto& fence_value : fences) {
- auto* fence = reinterpret_cast<HostFence*>(fence_value.first);
- ASSIGN_OR_RETURN(uint64_t current_value, fence->QueryValue());
- if (current_value == UINT64_MAX) {
- // Fence has failed. Return the error.
- return fence->status();
- } else if (current_value < fence_value.second) {
- // Fence has not yet hit the required value; wait for it.
- waitable_fences.push_back({fence, fence_value.second});
- }
- }
-
- // TODO(benvanik): maybe sort fences by value in case we are waiting on
- // multiple values from the same fence.
-
- // Loop over the fences and wait for them to complete.
- // TODO(b/140026716): add WaitHandle support for !wait_all (wait any).
- for (auto& fence_value : waitable_fences) {
- auto* fence = fence_value.first;
- absl::MutexLock lock(&fence->mutex_);
- if (!fence->mutex_.AwaitWithDeadline(
- absl::Condition(
- +[](HostFenceValue* fence_value) {
- return fence_value->first->value_.load(
- std::memory_order_acquire) >= fence_value->second;
- },
- &fence_value),
- deadline)) {
- return DeadlineExceededErrorBuilder(IREE_LOC)
- << "Deadline exceeded waiting for fences";
- }
- if (!fence->status_.ok()) {
- return fence->status_;
- }
- }
-
- return OkStatus();
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/host_fence.h b/iree/hal/host/host_fence.h
deleted file mode 100644
index 80792e7..0000000
--- a/iree/hal/host/host_fence.h
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_HOST_HOST_FENCE_H_
-#define IREE_HAL_HOST_HOST_FENCE_H_
-
-#include <atomic>
-#include <cstdint>
-
-#include "absl/base/thread_annotations.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/hal/fence.h"
-
-namespace iree {
-namespace hal {
-
-// TODO(b/140026716): add WaitHandle support for better multi-wait.
-// Simple host-only fence semaphore implemented with a mutex.
-//
-// Thread-safe (as instances may be imported and used by others).
-class HostFence final : public Fence {
- public:
- // Waits for one or more (or all) fences to reach or exceed the given values.
- static Status WaitForFences(absl::Span<const FenceValue> fences,
- bool wait_all, absl::Time deadline);
-
- explicit HostFence(uint64_t initial_value);
- ~HostFence() override;
-
- Status status() const override;
- StatusOr<uint64_t> QueryValue() override;
-
- Status Signal(uint64_t value);
- Status Fail(Status status);
-
- private:
- // The mutex is not required to query the value; this lets us quickly check if
- // a required value has been exceeded. The mutex is only used to update and
- // notify waiters.
- std::atomic<uint64_t> value_{0};
-
- // We have a full mutex here so that we can perform condvar waits on value
- // changes.
- mutable absl::Mutex mutex_;
- Status status_ ABSL_GUARDED_BY(mutex_);
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_HOST_HOST_FENCE_H_
diff --git a/iree/hal/host/host_fence_test.cc b/iree/hal/host/host_fence_test.cc
deleted file mode 100644
index a843923..0000000
--- a/iree/hal/host/host_fence_test.cc
+++ /dev/null
@@ -1,148 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/host_fence.h"
-
-#include <cstdint>
-#include <thread> // NOLINT
-
-#include "absl/time/time.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status.h"
-#include "iree/base/status_matchers.h"
-
-namespace iree {
-namespace hal {
-namespace {
-
-// Tests that a fence that is unused properly cleans itself up.
-TEST(HostFenceTest, NoOp) {
- HostFence fence(123u);
- EXPECT_TRUE(fence.status().ok());
- ASSERT_OK_AND_ASSIGN(uint64_t value, fence.QueryValue());
- EXPECT_EQ(123u, value);
-}
-
-// Tests that a fence will accept new values as it is signaled.
-TEST(HostFenceTest, NormalSignaling) {
- HostFence fence(2u);
- EXPECT_EQ(2u, fence.QueryValue().ValueOrDie());
- EXPECT_OK(fence.Signal(3u));
- EXPECT_EQ(3u, fence.QueryValue().ValueOrDie());
- EXPECT_OK(fence.Signal(40u));
- EXPECT_EQ(40u, fence.QueryValue().ValueOrDie());
-}
-
-// Tests that a fence will fail to set non-increasing values.
-TEST(HostFenceTest, RequireIncreasingValues) {
- HostFence fence(2u);
- EXPECT_EQ(2u, fence.QueryValue().ValueOrDie());
- // Same value.
- EXPECT_TRUE(IsInvalidArgument(fence.Signal(2u)));
- // Decreasing.
- EXPECT_TRUE(IsInvalidArgument(fence.Signal(1u)));
-}
-
-// Tests that a fence that has failed will remain in a failed state.
-TEST(HostFenceTest, StickyFailure) {
- HostFence fence(2u);
- // Signal to 3.
- EXPECT_OK(fence.Signal(3u));
- EXPECT_TRUE(fence.status().ok());
- EXPECT_EQ(3u, fence.QueryValue().ValueOrDie());
-
- // Fail now.
- EXPECT_OK(fence.Fail(UnknownErrorBuilder(IREE_LOC)));
- EXPECT_TRUE(IsUnknown(fence.status()));
- EXPECT_EQ(UINT64_MAX, fence.QueryValue().ValueOrDie());
-
- // Unable to signal again (it'll return the sticky failure).
- EXPECT_TRUE(IsUnknown(fence.Signal(4u)));
- EXPECT_TRUE(IsUnknown(fence.status()));
- EXPECT_EQ(UINT64_MAX, fence.QueryValue().ValueOrDie());
-}
-
-// Tests waiting on no fences.
-TEST(HostFenceTest, EmptyWait) {
- EXPECT_OK(
- HostFence::WaitForFences({}, /*wait_all=*/true, absl::InfiniteFuture()));
-}
-
-// Tests waiting on a fence that has already been signaled.
-TEST(HostFenceTest, WaitAlreadySignaled) {
- HostFence fence(2u);
- // Test both previous and current values.
- EXPECT_OK(HostFence::WaitForFences({{&fence, 1u}}, /*wait_all=*/true,
- absl::InfiniteFuture()));
- EXPECT_OK(HostFence::WaitForFences({{&fence, 2u}}, /*wait_all=*/true,
- absl::InfiniteFuture()));
-}
-
-// Tests waiting on a fence that has not been signaled.
-TEST(HostFenceTest, WaitUnsignaled) {
- HostFence fence(2u);
- // NOTE: we don't actually block here because otherwise we'd lock up.
- EXPECT_TRUE(IsDeadlineExceeded(HostFence::WaitForFences(
- {{&fence, 3u}}, /*wait_all=*/true, absl::InfinitePast())));
-}
-
-// Tests waiting on a failed fence (it should return the error on the fence).
-TEST(HostFenceTest, WaitAlreadyFailed) {
- HostFence fence(2u);
- EXPECT_OK(fence.Fail(UnknownErrorBuilder(IREE_LOC)));
- EXPECT_TRUE(IsUnknown(HostFence::WaitForFences(
- {{&fence, 2u}}, /*wait_all=*/true, absl::InfinitePast())));
-}
-
-// Tests threading behavior by ping-ponging between the test main thread and
-// a little thread.
-TEST(HostFenceTest, PingPong) {
- HostFence a2b(0u);
- HostFence b2a(0u);
- std::thread thread([&]() {
- // Should advance right past this because the value is already set.
- ASSERT_OK(HostFence::WaitForFences({{&a2b, 0u}}, /*wait_all=*/true,
- absl::InfiniteFuture()));
- ASSERT_OK(b2a.Signal(1u));
- // Jump ahead.
- ASSERT_OK(HostFence::WaitForFences({{&a2b, 4u}}, /*wait_all=*/true,
- absl::InfiniteFuture()));
- });
- ASSERT_OK(HostFence::WaitForFences({{&b2a, 1u}}, /*wait_all=*/true,
- absl::InfiniteFuture()));
- ASSERT_OK(a2b.Signal(4u));
- thread.join();
-}
-
-// Tests that failure still wakes waiters and propagates the error.
-TEST(HostFenceTest, FailNotifies) {
- HostFence a2b(0u);
- HostFence b2a(0u);
- bool got_failure = false;
- std::thread thread([&]() {
- ASSERT_OK(b2a.Signal(1u));
- got_failure = IsUnknown(HostFence::WaitForFences(
- {{&a2b, 1u}}, /*wait_all=*/true, absl::InfiniteFuture()));
- });
- ASSERT_OK(HostFence::WaitForFences({{&b2a, 1u}}, /*wait_all=*/true,
- absl::InfiniteFuture()));
- ASSERT_OK(a2b.Fail(UnknownErrorBuilder(IREE_LOC)));
- thread.join();
- ASSERT_TRUE(got_failure);
-}
-
-} // namespace
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/host_local_allocator.cc b/iree/hal/host/host_local_allocator.cc
deleted file mode 100644
index 10b898d..0000000
--- a/iree/hal/host/host_local_allocator.cc
+++ /dev/null
@@ -1,111 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/host_local_allocator.h"
-
-#include <cstdlib>
-#include <string>
-#include <utility>
-
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/host/host_buffer.h"
-
-namespace iree {
-namespace hal {
-
-HostLocalAllocator::HostLocalAllocator() = default;
-
-HostLocalAllocator::~HostLocalAllocator() = default;
-
-bool HostLocalAllocator::CanUseBufferLike(
- Allocator* source_allocator, MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- BufferUsageBitfield intended_usage) const {
- // Must always have visibility to the device, which ensures we can test
- // against the host but have things work on devices with separate address
- // spaces.
- if (!AnyBitSet(memory_type & MemoryType::kDeviceVisible)) {
- return false;
- }
-
- // kHostVisible is required for mapping.
- if (AnyBitSet(intended_usage & BufferUsage::kMapping) &&
- !AnyBitSet(memory_type & MemoryType::kHostVisible)) {
- return false;
- }
-
- // Dispatch needs to be specified if we intend to dispatch.
- if (AnyBitSet(intended_usage & BufferUsage::kDispatch) &&
- !AnyBitSet(buffer_usage & BufferUsage::kDispatch)) {
- return false;
- }
-
- return true;
-}
-
-bool HostLocalAllocator::CanAllocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) const {
- // Host allows everything, pretty much, so long as it is device-visible (as
- // the host is the device here).
- return AnyBitSet(memory_type & MemoryType::kDeviceVisible);
-}
-
-Status HostLocalAllocator::MakeCompatible(
- MemoryTypeBitfield* memory_type, BufferUsageBitfield* buffer_usage) const {
- // Always ensure we are host-visible.
- *memory_type |= MemoryType::kHostVisible;
-
- // Host currently uses mapping to copy buffers, which is done a lot.
- // We could probably remove this restriction somehow.
- *buffer_usage |= BufferUsage::kMapping;
-
- // TODO(b/111372612): tensorflow needs transfer too, but shouldn't.
- *buffer_usage |= BufferUsage::kTransfer;
-
- return OkStatus();
-}
-
-StatusOr<ref_ptr<Buffer>> HostLocalAllocator::Allocate(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- size_t allocation_size) {
- IREE_TRACE_SCOPE0("HostLocalAllocator::Allocate");
-
- if (!CanAllocate(memory_type, buffer_usage, allocation_size)) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Allocation not supported; memory_type="
- << MemoryTypeString(memory_type)
- << ", buffer_usage=" << BufferUsageString(buffer_usage)
- << ", allocation_size=" << allocation_size;
- }
-
- // Make compatible with our requirements.
- RETURN_IF_ERROR(MakeCompatible(&memory_type, &buffer_usage));
-
- void* malloced_data = std::calloc(1, allocation_size);
- if (!malloced_data) {
- return ResourceExhaustedErrorBuilder(IREE_LOC)
- << "Failed to malloc " << allocation_size << " bytes";
- }
-
- auto buffer =
- make_ref<HostBuffer>(this, memory_type, MemoryAccess::kAll, buffer_usage,
- allocation_size, malloced_data, true);
- return buffer;
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/host_local_allocator.h b/iree/hal/host/host_local_allocator.h
deleted file mode 100644
index 020fbea..0000000
--- a/iree/hal/host/host_local_allocator.h
+++ /dev/null
@@ -1,60 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_HOST_LOCAL_ALLOCATOR_H_
-#define IREE_HAL_HOST_LOCAL_ALLOCATOR_H_
-
-#include <cstddef>
-#include <memory>
-
-#include "iree/base/status.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/buffer.h"
-
-namespace iree {
-namespace hal {
-
-// An allocator implementation that allocates buffers from host memory.
-// This can be used for drivers that do not have a memory space of their own.
-//
-// Buffers allocated will have be MemoryType::kHostLocal | kDeviceVisible as
-// the 'device' in the case of a host-local queue *is* the host. To keep code
-// written initially for a host-local queue working when other queues are used
-// the allocator only works with buffers that are kDeviceVisible.
-class HostLocalAllocator : public Allocator {
- public:
- HostLocalAllocator();
- ~HostLocalAllocator() override;
-
- bool CanUseBufferLike(Allocator* source_allocator,
- MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- BufferUsageBitfield intended_usage) const override;
-
- bool CanAllocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) const override;
-
- Status MakeCompatible(MemoryTypeBitfield* memory_type,
- BufferUsageBitfield* buffer_usage) const override;
-
- StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) override;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_HOST_LOCAL_ALLOCATOR_H_
diff --git a/iree/hal/host/host_local_command_processor.cc b/iree/hal/host/host_local_command_processor.cc
deleted file mode 100644
index a2c94ec..0000000
--- a/iree/hal/host/host_local_command_processor.cc
+++ /dev/null
@@ -1,120 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/host_local_command_processor.h"
-
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-
-namespace iree {
-namespace hal {
-
-HostLocalCommandProcessor::HostLocalCommandProcessor(
- Allocator* allocator, CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories)
- : CommandBuffer(allocator, mode, command_categories) {}
-
-HostLocalCommandProcessor::~HostLocalCommandProcessor() = default;
-
-Status HostLocalCommandProcessor::Begin() {
- IREE_TRACE_SCOPE0("HostLocalCommandProcessor::Begin");
- is_recording_ = true;
- return OkStatus();
-}
-
-Status HostLocalCommandProcessor::End() {
- IREE_TRACE_SCOPE0("HostLocalCommandProcessor::End");
- is_recording_ = false;
- return OkStatus();
-}
-
-Status HostLocalCommandProcessor::ExecutionBarrier(
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) {
- IREE_TRACE_SCOPE0("HostLocalCommandProcessor::ExecutionBarrier");
- // No-op.
- return OkStatus();
-}
-
-Status HostLocalCommandProcessor::SignalEvent(
- Event* event, ExecutionStageBitfield source_stage_mask) {
- IREE_TRACE_SCOPE0("HostLocalCommandProcessor::SignalEvent");
- // No-op.
- return OkStatus();
-}
-
-Status HostLocalCommandProcessor::ResetEvent(
- Event* event, ExecutionStageBitfield source_stage_mask) {
- IREE_TRACE_SCOPE0("HostLocalCommandProcessor::ResetEvent");
- // No-op.
- return OkStatus();
-}
-
-Status HostLocalCommandProcessor::WaitEvents(
- absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) {
- IREE_TRACE_SCOPE0("HostLocalCommandProcessor::WaitEvents");
- // No-op.
- return OkStatus();
-}
-
-Status HostLocalCommandProcessor::FillBuffer(Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length,
- const void* pattern,
- size_t pattern_length) {
- IREE_TRACE_SCOPE0("HostLocalCommandProcessor::FillBuffer");
- return target_buffer->Fill(target_offset, length, pattern, pattern_length);
-}
-
-Status HostLocalCommandProcessor::DiscardBuffer(Buffer* buffer) {
- IREE_TRACE_SCOPE0("HostLocalCommandProcessor::DiscardBuffer");
- // No-op as we don't support lazily allocated buffers.
- return OkStatus();
-}
-
-Status HostLocalCommandProcessor::UpdateBuffer(const void* source_buffer,
- device_size_t source_offset,
- Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length) {
- IREE_TRACE_SCOPE0("HostLocalCommandProcessor::UpdateBuffer");
- return target_buffer->WriteData(
- target_offset, static_cast<const uint8_t*>(source_buffer) + source_offset,
- length);
-}
-
-Status HostLocalCommandProcessor::CopyBuffer(Buffer* source_buffer,
- device_size_t source_offset,
- Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length) {
- IREE_TRACE_SCOPE0("HostLocalCommandProcessor::CopyBuffer");
- return target_buffer->CopyData(target_offset, source_buffer, source_offset,
- length);
-}
-
-Status HostLocalCommandProcessor::Dispatch(
- const DispatchRequest& dispatch_request) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Command processor does not support dispatch operations";
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/host_local_command_processor.h b/iree/hal/host/host_local_command_processor.h
deleted file mode 100644
index f60d4c2..0000000
--- a/iree/hal/host/host_local_command_processor.h
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_
-#define IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_
-
-#include "iree/hal/command_buffer.h"
-
-namespace iree {
-namespace hal {
-
-// Host-local command processor for dispatching transfer operations against
-// buffers allocated from the HostLocalAllocator.
-// This assumes that all buffers are host-visible (if not local) and that all
-// buffers can be mapped for access.
-//
-// Subclasses may implement Dispatch, otherwise the default implementation just
-// returns failure.
-//
-// Thread-compatible (as with CommandBuffer itself).
-class HostLocalCommandProcessor : public CommandBuffer {
- public:
- HostLocalCommandProcessor(Allocator* allocator,
- CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories);
- ~HostLocalCommandProcessor() override;
-
- bool is_recording() const override { return is_recording_; }
-
- Status Begin() override;
- Status End() override;
-
- Status ExecutionBarrier(
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) override;
-
- Status SignalEvent(Event* event,
- ExecutionStageBitfield source_stage_mask) override;
-
- Status ResetEvent(Event* event,
- ExecutionStageBitfield source_stage_mask) override;
-
- Status WaitEvents(absl::Span<Event*> events,
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) override;
-
- Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
- device_size_t length, const void* pattern,
- size_t pattern_length) override;
-
- Status DiscardBuffer(Buffer* buffer) override;
-
- Status UpdateBuffer(const void* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length) override;
-
- Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length) override;
-
- Status Dispatch(const DispatchRequest& dispatch_request) override;
-
- private:
- bool is_recording_ = false;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_
diff --git a/iree/hal/host/host_submission_queue.cc b/iree/hal/host/host_submission_queue.cc
deleted file mode 100644
index a86d442..0000000
--- a/iree/hal/host/host_submission_queue.cc
+++ /dev/null
@@ -1,295 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/host_submission_queue.h"
-
-#include <atomic>
-#include <cstdint>
-
-#include "absl/synchronization/mutex.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-
-namespace iree {
-namespace hal {
-
-HostBinarySemaphore::HostBinarySemaphore(bool initial_value) {
- State state = {0};
- state.signaled = initial_value ? 1 : 0;
- state_ = state;
-}
-
-bool HostBinarySemaphore::is_signaled() const {
- return state_.load(std::memory_order_acquire).signaled == 1;
-}
-
-Status HostBinarySemaphore::BeginSignaling() {
- State old_state = state_.load(std::memory_order_acquire);
- if (old_state.signal_pending != 0) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "A signal operation on a binary semaphore is already pending";
- }
- State new_state = old_state;
- new_state.signal_pending = 1;
- state_.compare_exchange_strong(old_state, new_state);
- return OkStatus();
-}
-
-Status HostBinarySemaphore::EndSignaling() {
- State old_state = state_.load(std::memory_order_acquire);
- DCHECK_EQ(old_state.signal_pending, 1)
- << "A signal operation on a binary semaphore was not pending";
- if (old_state.signaled != 0) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "A binary semaphore cannot be signaled multiple times";
- }
- State new_state = old_state;
- new_state.signal_pending = 0;
- new_state.signaled = 1;
- state_.compare_exchange_strong(old_state, new_state);
- return OkStatus();
-}
-
-Status HostBinarySemaphore::BeginWaiting() {
- State old_state = state_.load(std::memory_order_acquire);
- if (old_state.wait_pending != 0) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "A wait operation on a binary semaphore is already pending";
- }
- State new_state = old_state;
- new_state.wait_pending = 1;
- state_.compare_exchange_strong(old_state, new_state);
- return OkStatus();
-}
-
-Status HostBinarySemaphore::EndWaiting() {
- State old_state = state_.load(std::memory_order_acquire);
- DCHECK_EQ(old_state.wait_pending, 1)
- << "A wait operation on a binary semaphore was not pending";
- if (old_state.signaled != 1) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "A binary semaphore cannot be reset multiple times";
- }
- State new_state = old_state;
- new_state.wait_pending = 0;
- new_state.signaled = 0;
- state_.compare_exchange_strong(old_state, new_state);
- return OkStatus();
-}
-
-HostSubmissionQueue::HostSubmissionQueue() = default;
-
-HostSubmissionQueue::~HostSubmissionQueue() = default;
-
-bool HostSubmissionQueue::IsBatchReady(const PendingBatch& batch) const {
- for (auto& wait_point : batch.wait_semaphores) {
- if (wait_point.index() == 0) {
- auto* binary_semaphore =
- reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(wait_point));
- if (!binary_semaphore->is_signaled()) {
- return false;
- }
- } else {
- // TODO(b/140141417): implement timeline semaphores.
- return false;
- }
- }
- return true;
-}
-
-Status HostSubmissionQueue::Enqueue(absl::Span<const SubmissionBatch> batches,
- FenceValue fence) {
- IREE_TRACE_SCOPE0("HostSubmissionQueue::Enqueue");
-
- if (has_shutdown_) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Cannot enqueue new submissions; queue is exiting";
- } else if (!permanent_error_.ok()) {
- return permanent_error_;
- }
-
- // Verify waiting/signaling behavior on semaphores and prepare them all.
- // We need to track this to ensure that we are modeling the Vulkan behavior
- // and are consistent across HAL implementations.
- for (auto& batch : batches) {
- for (auto& semaphore_value : batch.wait_semaphores) {
- if (semaphore_value.index() == 0) {
- auto* binary_semaphore = reinterpret_cast<HostBinarySemaphore*>(
- absl::get<0>(semaphore_value));
- RETURN_IF_ERROR(binary_semaphore->BeginWaiting());
- } else {
- // TODO(b/140141417): implement timeline semaphores.
- return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI";
- }
- }
- for (auto& semaphore_value : batch.signal_semaphores) {
- if (semaphore_value.index() == 0) {
- auto* binary_semaphore = reinterpret_cast<HostBinarySemaphore*>(
- absl::get<0>(semaphore_value));
- RETURN_IF_ERROR(binary_semaphore->BeginSignaling());
- } else {
- // TODO(b/140141417): implement timeline semaphores.
- return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI";
- }
- }
- }
-
- // Add to list - order does not matter as Process evaluates semaphores.
- auto submission = absl::make_unique<Submission>();
- submission->fence = std::move(fence);
- submission->pending_batches.resize(batches.size());
- for (int i = 0; i < batches.size(); ++i) {
- submission->pending_batches[i] = PendingBatch{
- {batches[i].wait_semaphores.begin(), batches[i].wait_semaphores.end()},
- {batches[i].command_buffers.begin(), batches[i].command_buffers.end()},
- {batches[i].signal_semaphores.begin(),
- batches[i].signal_semaphores.end()},
- };
- }
- list_.push_back(std::move(submission));
-
- return OkStatus();
-}
-
-Status HostSubmissionQueue::ProcessBatches(ExecuteFn execute_fn) {
- IREE_TRACE_SCOPE0("HostSubmissionQueue::ProcessBatches");
-
- if (!permanent_error_.ok()) {
- // Sticky failure state.
- return permanent_error_;
- }
-
- // Repeated try to run things until we quiesce or are blocked.
- while (permanent_error_.ok() && !list_.empty()) {
- // NOTE: to support re-entrancy where |execute_fn| may modify the submission
- // list we need to always start from the beginning. If we wanted we could
- // track a list of ready submissions however that's a lot of bookkeeping and
- // the list is usually short.
- bool restart_iteration = false;
- for (auto* submission : list_) {
- for (int i = 0; i < submission->pending_batches.size(); ++i) {
- auto& batch = submission->pending_batches[i];
- if (!IsBatchReady(batch)) {
- // Try the next batch in the submission until we find one that is
- // ready. If none are ready we'll return to the caller.
- continue;
- }
-
- // Batch can run! Process now and remove it from the list so we don't
- // try to run it again.
- auto batch_status = ProcessBatch(batch, execute_fn);
- submission->pending_batches.erase(submission->pending_batches.begin() +
- i);
- if (batch_status.ok()) {
- // Batch succeeded. Since we want to preserve submission order we'll
- // break out of the loop and try from the first submission again.
- if (submission->pending_batches.empty()) {
- // All work for this submission completed successfully. Signal the
- // fence and remove the submission from the list.
- RETURN_IF_ERROR(CompleteSubmission(submission, OkStatus()));
- list_.take(submission).reset();
- }
- } else {
- // Batch failed; set the permanent error flag and abort so we don't
- // try to process anything else.
- permanent_error_ = batch_status;
- RETURN_IF_ERROR(CompleteSubmission(submission, batch_status));
- list_.take(submission).reset();
- }
- restart_iteration = true;
- break;
- }
- if (restart_iteration) break;
- }
- }
-
- if (!permanent_error_.ok()) {
- // If the sticky error got set while processing we need to abort all
- // remaining submissions (simulating a device loss).
- FailAllPending(permanent_error_);
- return permanent_error_;
- }
-
- return OkStatus();
-}
-
-Status HostSubmissionQueue::ProcessBatch(const PendingBatch& batch,
- const ExecuteFn& execute_fn) {
- IREE_TRACE_SCOPE0("HostSubmissionQueue::ProcessBatch");
-
- // Complete the waits on all semaphores and reset them.
- for (auto& semaphore_value : batch.wait_semaphores) {
- if (semaphore_value.index() == 0) {
- auto* binary_semaphore =
- reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(semaphore_value));
- RETURN_IF_ERROR(binary_semaphore->EndWaiting());
- } else {
- // TODO(b/140141417): implement timeline semaphores.
- return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI";
- }
- }
-
- // Let the caller handle execution of the command buffers.
- RETURN_IF_ERROR(execute_fn(batch.command_buffers));
-
- // Signal all semaphores to allow them to unblock waiters.
- for (auto& semaphore_value : batch.signal_semaphores) {
- if (semaphore_value.index() == 0) {
- auto* binary_semaphore =
- reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(semaphore_value));
- RETURN_IF_ERROR(binary_semaphore->EndSignaling());
- } else {
- // TODO(b/140141417): implement timeline semaphores.
- return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI";
- }
- }
-
- return OkStatus();
-}
-
-Status HostSubmissionQueue::CompleteSubmission(Submission* submission,
- Status status) {
- IREE_TRACE_SCOPE0("HostSubmissionQueue::CompleteSubmission");
-
- // It's safe to drop any remaining batches - their semaphores will never be
- // signaled but that's fine as we should be the only thing relying on them.
- submission->pending_batches.clear();
-
- // Signal the fence.
- auto* fence = static_cast<HostFence*>(submission->fence.first);
- if (status.ok()) {
- RETURN_IF_ERROR(fence->Signal(submission->fence.second));
- } else {
- RETURN_IF_ERROR(fence->Fail(std::move(status)));
- }
-
- return OkStatus();
-}
-
-void HostSubmissionQueue::FailAllPending(Status status) {
- IREE_TRACE_SCOPE0("HostSubmissionQueue::FailAllPending");
- while (!list_.empty()) {
- auto submission = list_.take(list_.front());
- CompleteSubmission(submission.get(), status).IgnoreError();
- submission.reset();
- }
-}
-
-void HostSubmissionQueue::SignalShutdown() {
- IREE_TRACE_SCOPE0("HostSubmissionQueue::SignalShutdown");
- has_shutdown_ = true;
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/host_submission_queue.h b/iree/hal/host/host_submission_queue.h
deleted file mode 100644
index 14eae7a..0000000
--- a/iree/hal/host/host_submission_queue.h
+++ /dev/null
@@ -1,163 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_
-#define IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_
-
-#include "absl/base/thread_annotations.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/base/intrusive_list.h"
-#include "iree/base/status.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/host/host_fence.h"
-#include "iree/hal/semaphore.h"
-
-namespace iree {
-namespace hal {
-
-class HostSubmissionQueue;
-
-// Simple host-only binary semaphore implemented with a mutex.
-// To match the expected HAL behavior (mostly dictated by Vulkan) we can only
-// have a single waiter and waits can only occur once a signal has been
-// enqueued.
-//
-// Thread-safe (as instances may be imported and used by others).
-class HostBinarySemaphore final : public BinarySemaphore {
- public:
- explicit HostBinarySemaphore(bool initial_value);
-
- // Returns true if the semaphore has been signaled.
- bool is_signaled() const;
-
- private:
- friend class HostSubmissionQueue;
-
- // Begins a signal operation and ensures no other signal operation is pending.
- Status BeginSignaling();
- // Ends a signal operation by setting the semaphore to the signaled state.
- Status EndSignaling();
-
- // Begins a wait operation and ensures no other wait operation is pending.
- Status BeginWaiting();
- // Ends a wait operation by resetting the semaphore to the unsignaled state.
- Status EndWaiting();
-
- // A single 32-bit int for lock-free semaphore behavior. We need to do this
- // extra tracking so that we get consistent behavior across HAL
- // implementations that have strict semaphore semantics.
- struct State {
- uint32_t signal_pending : 1;
- uint32_t wait_pending : 1;
- uint32_t signaled : 1;
- };
- std::atomic<State> state_{{0, 0, 0}};
-};
-
-// Simple host-only timeline semaphore implemented with a mutex.
-//
-// Thread-safe (as instances may be imported and used by others).
-class HostTimelineSemaphore final : public TimelineSemaphore {
- public:
- // TODO(b/140141417): implement timeline semaphores.
-};
-
-// A queue managing CommandQueue submissions that uses host-local
-// synchronization primitives. Evaluates submission order by respecting the
-// wait and signal semaphores defined per batch and notifies fences upon
-// submission completion.
-//
-// Note that it's possible for HAL users to deadlock themselves; we don't try to
-// avoid that as in device backends it may not be possible and we want to have
-// some kind of warning in the host implementation that TSAN can catch.
-//
-// Thread-compatible. Const methods may be called from any thread.
-class HostSubmissionQueue {
- public:
- using ExecuteFn =
- std::function<Status(absl::Span<CommandBuffer* const> command_buffers)>;
-
- HostSubmissionQueue();
- ~HostSubmissionQueue();
-
- // Returns true if the queue is currently empty.
- bool empty() const { return list_.empty(); }
- // Returns true if SignalShutdown has been called.
- bool has_shutdown() const { return has_shutdown_; }
- // The sticky error status, if an error has occurred.
- Status permanent_error() const { return permanent_error_; }
-
- // Enqueues a new submission.
- // No work will be performed until Process is called.
- Status Enqueue(absl::Span<const SubmissionBatch> batches, FenceValue fence);
-
- // Processes all ready batches using the provided |execute_fn|.
- // The function may be called several times if new batches become ready due to
- // prior batches in the sequence completing during processing.
- //
- // Returns any errors returned by |execute_fn| (which will be the same as
- // permanent_error()). When an error occurs all in-flight submissions are
- // aborted, the permanent_error() is set, and the queue is shutdown.
- Status ProcessBatches(ExecuteFn execute_fn);
-
- // Marks the queue as having shutdown. All pending submissions will be allowed
- // to complete but future enqueues will fail.
- void SignalShutdown();
-
- private:
- // A submitted command buffer batch and its synchronization information.
- struct PendingBatch {
- absl::InlinedVector<SemaphoreValue, 4> wait_semaphores;
- absl::InlinedVector<CommandBuffer*, 4> command_buffers;
- absl::InlinedVector<SemaphoreValue, 4> signal_semaphores;
- };
- struct Submission : public IntrusiveLinkBase<void> {
- absl::InlinedVector<PendingBatch, 4> pending_batches;
- FenceValue fence;
- };
-
- // Returns true if all wait semaphores in the |batch| are signaled.
- bool IsBatchReady(const PendingBatch& batch) const;
-
- // Processes a batch by resetting semaphores, dispatching the command buffers
- // to the specified |execute_fn|, and signaling semaphores.
- //
- // Preconditions: IsBatchReady(batch) == true
- Status ProcessBatch(const PendingBatch& batch, const ExecuteFn& execute_fn);
-
- // Completes a submission by signaling the fence with the given |status|.
- Status CompleteSubmission(Submission* submission, Status status);
-
- // Fails all pending submissions with the given status.
- // Errors that occur during this process are silently ignored.
- void FailAllPending(Status status);
-
- // True to exit the thread after all submissions complete.
- bool has_shutdown_ = false;
-
- // A sticky error that is set on the first failed submit. All future
- // submissions will be skipped except for fences, which will receive this
- // error.
- Status permanent_error_;
-
- // Pending submissions in submission order.
- // Note that we may evaluate batches within the list out of order.
- IntrusiveList<std::unique_ptr<Submission>> list_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_
diff --git a/iree/hal/host/host_submission_queue_test.cc b/iree/hal/host/host_submission_queue_test.cc
deleted file mode 100644
index 227131a..0000000
--- a/iree/hal/host/host_submission_queue_test.cc
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/host_submission_queue.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-namespace iree {
-namespace hal {
-namespace {
-
-TEST(HostSubmissionQueueTest, TBD) {
- // TODO(benvanik): test!
-}
-
-} // namespace
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/inproc_command_buffer.cc b/iree/hal/host/inproc_command_buffer.cc
deleted file mode 100644
index 0833843..0000000
--- a/iree/hal/host/inproc_command_buffer.cc
+++ /dev/null
@@ -1,264 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/host/inproc_command_buffer.h"
-
-#include "iree/base/tracing.h"
-
-namespace iree {
-namespace hal {
-
-InProcCommandBuffer::InProcCommandBuffer(
- Allocator* allocator, CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories)
- : CommandBuffer(allocator, mode, command_categories) {}
-
-InProcCommandBuffer::~InProcCommandBuffer() { Reset(); }
-
-Status InProcCommandBuffer::Begin() {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::Begin");
- is_recording_ = true;
- Reset();
- return OkStatus();
-}
-
-Status InProcCommandBuffer::End() {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::End");
- is_recording_ = false;
- return OkStatus();
-}
-
-Status InProcCommandBuffer::ExecutionBarrier(
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::ExecutionBarrier");
- ASSIGN_OR_RETURN(auto* cmd, AppendCmd<ExecutionBarrierCmd>());
- cmd->source_stage_mask = source_stage_mask;
- cmd->target_stage_mask = target_stage_mask;
- cmd->memory_barriers = AppendStructSpan(memory_barriers);
- cmd->buffer_barriers = AppendStructSpan(buffer_barriers);
- return OkStatus();
-}
-
-Status InProcCommandBuffer::SignalEvent(
- Event* event, ExecutionStageBitfield source_stage_mask) {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::SignalEvent");
- ASSIGN_OR_RETURN(auto* cmd, AppendCmd<SignalEventCmd>());
- cmd->event = event;
- cmd->source_stage_mask = source_stage_mask;
- return OkStatus();
-}
-
-Status InProcCommandBuffer::ResetEvent(
- Event* event, ExecutionStageBitfield source_stage_mask) {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::ResetEvent");
- ASSIGN_OR_RETURN(auto* cmd, AppendCmd<ResetEventCmd>());
- cmd->event = event;
- cmd->source_stage_mask = source_stage_mask;
- return OkStatus();
-}
-
-Status InProcCommandBuffer::WaitEvents(
- absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::WaitEvents");
- ASSIGN_OR_RETURN(auto* cmd, AppendCmd<WaitEventsCmd>());
- cmd->events = AppendStructSpan(events);
- cmd->source_stage_mask = source_stage_mask;
- cmd->target_stage_mask = target_stage_mask;
- cmd->memory_barriers = AppendStructSpan(memory_barriers);
- cmd->buffer_barriers = AppendStructSpan(buffer_barriers);
- return OkStatus();
-}
-
-Status InProcCommandBuffer::FillBuffer(Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length,
- const void* pattern,
- size_t pattern_length) {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::FillBuffer");
- ASSIGN_OR_RETURN(auto* cmd, AppendCmd<FillBufferCmd>());
- cmd->target_buffer = target_buffer;
- cmd->target_offset = target_offset;
- cmd->length = length;
- std::memcpy(cmd->pattern, pattern, pattern_length);
- cmd->pattern_length = pattern_length;
- return OkStatus();
-}
-
-Status InProcCommandBuffer::DiscardBuffer(Buffer* buffer) {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::DiscardBuffer");
- ASSIGN_OR_RETURN(auto* cmd, AppendCmd<DiscardBufferCmd>());
- cmd->buffer = buffer;
- return OkStatus();
-}
-
-Status InProcCommandBuffer::UpdateBuffer(const void* source_buffer,
- device_size_t source_offset,
- Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length) {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::UpdateBuffer");
- ASSIGN_OR_RETURN(auto* cmd, AppendCmd<UpdateBufferCmd>());
- cmd->source_buffer = AppendCmdData(source_buffer, source_offset, length);
- cmd->target_buffer = target_buffer;
- cmd->target_offset = target_offset;
- cmd->length = length;
- return OkStatus();
-}
-
-Status InProcCommandBuffer::CopyBuffer(Buffer* source_buffer,
- device_size_t source_offset,
- Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length) {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::CopyBuffer");
- ASSIGN_OR_RETURN(auto* cmd, AppendCmd<CopyBufferCmd>());
- cmd->source_buffer = source_buffer;
- cmd->source_offset = source_offset;
- cmd->target_buffer = target_buffer;
- cmd->target_offset = target_offset;
- cmd->length = length;
- return OkStatus();
-}
-
-Status InProcCommandBuffer::Dispatch(const DispatchRequest& dispatch_request) {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::Dispatch");
- ASSIGN_OR_RETURN(auto* cmd, AppendCmd<DispatchCmd>());
- cmd->request.executable = dispatch_request.executable;
- cmd->request.entry_point = dispatch_request.entry_point;
- cmd->request.workload = dispatch_request.workload;
- cmd->request.workload_buffer = dispatch_request.workload_buffer;
- cmd->request.bindings = AppendStructSpan(dispatch_request.bindings);
- return OkStatus();
-}
-
-void InProcCommandBuffer::Reset() {
- auto* cmd_list = ¤t_cmd_list_;
- cmd_list->head = cmd_list->tail = nullptr;
- cmd_list->arena.Reset();
-}
-
-InProcCommandBuffer::CmdHeader* InProcCommandBuffer::AppendCmdHeader(
- CmdType type, size_t cmd_size) {
- auto* cmd_list = ¤t_cmd_list_;
- auto* cmd_header = reinterpret_cast<CmdHeader*>(
- cmd_list->arena.AllocateBytes(sizeof(CmdHeader) + cmd_size));
- cmd_header->next = nullptr;
- cmd_header->type = type;
- if (!cmd_list->head) {
- cmd_list->head = cmd_header;
- } else if (cmd_list->tail) {
- cmd_list->tail->next = cmd_header;
- }
- cmd_list->tail = cmd_header;
- return cmd_header;
-}
-
-void* InProcCommandBuffer::AppendCmdData(const void* source_buffer,
- device_size_t source_offset,
- device_size_t source_length) {
- auto* cmd_list = ¤t_cmd_list_;
-
- uint8_t* allocated_bytes = cmd_list->arena.AllocateBytes(source_length);
- std::memcpy(allocated_bytes,
- static_cast<const uint8_t*>(source_buffer) + source_offset,
- source_length);
- return allocated_bytes;
-}
-
-Status InProcCommandBuffer::Process(CommandBuffer* command_processor) const {
- IREE_TRACE_SCOPE0("InProcCommandBuffer::Process");
-
- RETURN_IF_ERROR(command_processor->Begin());
-
- // Process each command in the order they were recorded.
- auto* cmd_list = ¤t_cmd_list_;
- for (CmdHeader* cmd_header = cmd_list->head; cmd_header != nullptr;
- cmd_header = cmd_header->next) {
- auto command_status = ProcessCmd(cmd_header, command_processor);
- if (!command_status.ok()) {
- LOG(ERROR) << "DeviceQueue failure while executing command; permanently "
- "failing all future commands: "
- << command_status;
- }
- }
-
- RETURN_IF_ERROR(command_processor->End());
-
- return OkStatus();
-}
-
-Status InProcCommandBuffer::ProcessCmd(CmdHeader* cmd_header,
- CommandBuffer* command_processor) const {
- switch (cmd_header->type) {
- case CmdType::kExecutionBarrier: {
- auto* cmd = reinterpret_cast<ExecutionBarrierCmd*>(cmd_header + 1);
- return command_processor->ExecutionBarrier(
- cmd->source_stage_mask, cmd->target_stage_mask, cmd->memory_barriers,
- cmd->buffer_barriers);
- }
- case CmdType::kSignalEvent: {
- auto* cmd = reinterpret_cast<SignalEventCmd*>(cmd_header + 1);
- return command_processor->SignalEvent(cmd->event, cmd->source_stage_mask);
- }
- case CmdType::kResetEvent: {
- auto* cmd = reinterpret_cast<ResetEventCmd*>(cmd_header + 1);
- return command_processor->ResetEvent(cmd->event, cmd->source_stage_mask);
- }
- case CmdType::kWaitEvents: {
- auto* cmd = reinterpret_cast<WaitEventsCmd*>(cmd_header + 1);
- return command_processor->WaitEvents(
- cmd->events, cmd->source_stage_mask, cmd->target_stage_mask,
- cmd->memory_barriers, cmd->buffer_barriers);
- }
- case CmdType::kFillBuffer: {
- auto* cmd = reinterpret_cast<FillBufferCmd*>(cmd_header + 1);
- return command_processor->FillBuffer(cmd->target_buffer,
- cmd->target_offset, cmd->length,
- cmd->pattern, cmd->pattern_length);
- }
- case CmdType::kDiscardBuffer: {
- auto* cmd = reinterpret_cast<DiscardBufferCmd*>(cmd_header + 1);
- return command_processor->DiscardBuffer(cmd->buffer);
- }
- case CmdType::kUpdateBuffer: {
- auto* cmd = reinterpret_cast<UpdateBufferCmd*>(cmd_header + 1);
- return command_processor->UpdateBuffer(cmd->source_buffer, 0,
- cmd->target_buffer,
- cmd->target_offset, cmd->length);
- }
- case CmdType::kCopyBuffer: {
- auto* cmd = reinterpret_cast<CopyBufferCmd*>(cmd_header + 1);
- return command_processor->CopyBuffer(
- cmd->source_buffer, cmd->source_offset, cmd->target_buffer,
- cmd->target_offset, cmd->length);
- }
- case CmdType::kDispatch: {
- auto* cmd = reinterpret_cast<DispatchCmd*>(cmd_header + 1);
- return command_processor->Dispatch(cmd->request);
- }
- default:
- return DataLossErrorBuilder(IREE_LOC)
- << "Unrecognized command type "
- << static_cast<int>(cmd_header->type) << "; corrupt buffer?";
- }
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/host/inproc_command_buffer.h b/iree/hal/host/inproc_command_buffer.h
deleted file mode 100644
index a9d2bf6..0000000
--- a/iree/hal/host/inproc_command_buffer.h
+++ /dev/null
@@ -1,241 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_
-#define IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_
-
-#include "iree/base/arena.h"
-#include "iree/base/intrusive_list.h"
-#include "iree/base/status.h"
-#include "iree/hal/command_buffer.h"
-
-namespace iree {
-namespace hal {
-
-// In-process command buffer with support for recording and playback.
-// Commands are recorded into heap-allocated arenas with pointers to used
-// resources (Buffer*, etc). To replay a command buffer against a real
-// implementation use Process to call each command method as it was originally
-// recorded.
-//
-// Thread-compatible (as with CommandBuffer itself).
-class InProcCommandBuffer final : public CommandBuffer {
- public:
- InProcCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories);
- ~InProcCommandBuffer() override;
-
- bool is_recording() const override { return is_recording_; }
-
- Status Begin() override;
- Status End() override;
-
- Status ExecutionBarrier(
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) override;
-
- Status SignalEvent(Event* event,
- ExecutionStageBitfield source_stage_mask) override;
-
- Status ResetEvent(Event* event,
- ExecutionStageBitfield source_stage_mask) override;
-
- Status WaitEvents(absl::Span<Event*> events,
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) override;
-
- Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
- device_size_t length, const void* pattern,
- size_t pattern_length) override;
-
- Status DiscardBuffer(Buffer* buffer) override;
-
- Status UpdateBuffer(const void* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length) override;
-
- Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length) override;
-
- Status Dispatch(const DispatchRequest& dispatch_request) override;
-
- // Processes all commands in the buffer using the given |command_processor|.
- // The commands are issued in the order they were recorded.
- Status Process(CommandBuffer* command_processor) const;
-
- private:
- // Type of Cmd, used by CmdHeader to identify the command payload.
- enum class CmdType {
- kExecutionBarrier,
- kSignalEvent,
- kResetEvent,
- kWaitEvents,
- kFillBuffer,
- kDiscardBuffer,
- kUpdateBuffer,
- kCopyBuffer,
- kDispatch,
- };
-
- // Prefix for commands encoded into the CmdList.
- // This is used to identify the type of a command as well as connect commands
- // in the list sequence. Command data immediately follows the header in
- // memory.
- struct CmdHeader {
- // Optional next command in the list.
- CmdHeader* next;
- // Type of the command.
- CmdType type;
- };
-
- // A lightweight linked list of commands and an arena that stores them.
- // CmdLists are designed to be reused so that the arena allocations are
- // amortized across multiple uses.
- //
- // Note that this and the CmdHeader/Cmd types include raw pointers and as
- // such are *not* portable across processes. It'd be possible, though, to
- // extend this for cross-process use if a shared-memory Buffer was also
- // implemented. For YAGNI we avoid that here.
- struct CmdList : public IntrusiveLinkBase<void> {
- static constexpr size_t kArenaBlockSize = 64 * 1024;
-
- Arena arena{kArenaBlockSize};
- CmdHeader* head = nullptr;
- CmdHeader* tail = nullptr;
- };
-
- // Defines an execution barrier.
- struct ExecutionBarrierCmd {
- static constexpr CmdType kType = CmdType::kExecutionBarrier;
- ExecutionStageBitfield source_stage_mask;
- ExecutionStageBitfield target_stage_mask;
- absl::Span<const MemoryBarrier> memory_barriers;
- absl::Span<const BufferBarrier> buffer_barriers;
- };
-
- // Signals an event.
- struct SignalEventCmd {
- static constexpr CmdType kType = CmdType::kSignalEvent;
- Event* event;
- ExecutionStageBitfield source_stage_mask;
- };
-
- // Resets an event.
- struct ResetEventCmd {
- static constexpr CmdType kType = CmdType::kResetEvent;
- Event* event;
- ExecutionStageBitfield source_stage_mask;
- };
-
- // Waits for one or more events.
- struct WaitEventsCmd {
- static constexpr CmdType kType = CmdType::kWaitEvents;
- absl::Span<Event*> events;
- ExecutionStageBitfield source_stage_mask;
- ExecutionStageBitfield target_stage_mask;
- absl::Span<const MemoryBarrier> memory_barriers;
- absl::Span<const BufferBarrier> buffer_barriers;
- };
-
- // Fills the target buffer with the given repeating value.
- struct FillBufferCmd {
- static constexpr CmdType kType = CmdType::kFillBuffer;
- Buffer* target_buffer;
- device_size_t target_offset;
- device_size_t length;
- uint8_t pattern[4];
- size_t pattern_length;
- };
-
- // Hints to the device queue that the given buffer will not be used again.
- struct DiscardBufferCmd {
- static constexpr CmdType kType = CmdType::kDiscardBuffer;
- Buffer* buffer;
- };
-
- // Writes a range of the given target buffer from the embedded memory.
- // The source buffer contents immediately follow the command in the arena.
- struct UpdateBufferCmd {
- static constexpr CmdType kType = CmdType::kUpdateBuffer;
- const void* source_buffer;
- Buffer* target_buffer;
- device_size_t target_offset;
- device_size_t length;
- };
-
- // Copies a range of one buffer to another.
- struct CopyBufferCmd {
- static constexpr CmdType kType = CmdType::kCopyBuffer;
- Buffer* source_buffer;
- device_size_t source_offset;
- Buffer* target_buffer;
- device_size_t target_offset;
- device_size_t length;
- };
-
- // Dispatches an execution request.
- struct DispatchCmd {
- static constexpr CmdType kType = CmdType::kDispatch;
- DispatchRequest request;
- };
-
- // Resets the command list.
- void Reset();
-
- // Allocates a command and appends it to the current command list.
- // The caller must populate the fields in the returned pointer.
- template <typename T>
- StatusOr<T*> AppendCmd() {
- return reinterpret_cast<T*>(AppendCmdHeader(T::kType, sizeof(T)) + 1);
- }
-
- // Appends a command with the given |type| and payload |cmd_size| prefixed
- // with a CmdHeader. Returns a pointer to the CmdHeader that is followed
- // immediately by |cmd_size| zero bytes.
- CmdHeader* AppendCmdHeader(CmdType type, size_t cmd_size);
-
- // Appends a byte buffer to the command buffer and returns a pointer to the
- // copied data within the command buffer arena.
- void* AppendCmdData(const void* source_buffer, device_size_t source_offset,
- device_size_t source_length);
-
- // Appends a span of POD structs to the current CmdList and returns a span
- // pointing into the CmdList arena.
- template <typename T>
- absl::Span<T> AppendStructSpan(absl::Span<T> value) {
- static_assert(std::is_standard_layout<T>::value,
- "Struct must be a POD type");
- void* data_ptr = AppendCmdData(value.data(), 0, value.size() * sizeof(T));
- return absl::MakeSpan(static_cast<T*>(data_ptr), value.size());
- }
-
- // Processes a single command.
- Status ProcessCmd(CmdHeader* cmd_header,
- CommandBuffer* command_processor) const;
-
- bool is_recording_ = false;
-
- // NOTE: not synchronized. Expected to be used from a single thread.
- CmdList current_cmd_list_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_
diff --git a/iree/hal/interpreter/BUILD b/iree/hal/interpreter/BUILD
deleted file mode 100644
index 204eb0a..0000000
--- a/iree/hal/interpreter/BUILD
+++ /dev/null
@@ -1,190 +0,0 @@
-# HAL implementation running on the CPU using the IREE bytecode.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "bytecode_cache",
- srcs = ["bytecode_cache.cc"],
- hdrs = ["bytecode_cache.h"],
- deps = [
- ":bytecode_executable",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_format",
- "//iree/rt",
- ],
-)
-
-cc_library(
- name = "bytecode_dispatch",
- srcs = [
- "bytecode_dispatch.cc",
- "bytecode_dispatch_conversion.h",
- "bytecode_dispatch_util.cc",
- "bytecode_dispatch_util.h",
- ],
- hdrs = ["bytecode_dispatch.h"],
- deps = [
- ":bytecode_kernels",
- "//iree/base:logging",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/hal:allocator",
- "//iree/hal:buffer_view",
- "//iree/hal:heap_buffer",
- "//iree/rt",
- "//iree/schemas/bytecode:interpreter_bytecode_v0",
- "//iree/vm:bytecode_module",
- "//iree/vm:bytecode_reader",
- "//iree/vm:bytecode_tables_interpreter",
- "//iree/vm:bytecode_util",
- "//iree/vm:opcode_info",
- "//iree/vm:type",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "bytecode_executable",
- srcs = ["bytecode_executable.cc"],
- hdrs = ["bytecode_executable.h"],
- deps = [
- ":interpreter_module",
- "//iree/base:status",
- "//iree/hal:allocator",
- "//iree/hal:executable",
- "//iree/hal:executable_spec",
- "//iree/rt",
- "//iree/vm:bytecode_tables_interpreter",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "bytecode_kernels",
- hdrs = ["bytecode_kernels.h"],
- textual_hdrs = [
- # TODO(benvanik): SIMD variants.
- "bytecode_kernels_generic.h",
- "bytecode_kernels_ruy.h",
- ],
- deps = [
- "//iree/base:shape",
- "//iree/base:status",
- "//iree/hal:buffer_view",
- "@com_google_absl//absl/algorithm",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/types:span",
- "@org_tensorflow//tensorflow/lite/experimental/ruy",
- "@org_tensorflow//tensorflow/lite/experimental/ruy:context",
- ],
-)
-
-cc_test(
- name = "bytecode_kernels_test",
- srcs = ["bytecode_kernels_test.cc"],
- deps = [
- ":bytecode_kernels",
- "//iree/base:memory",
- "//iree/base:status_matchers",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
- name = "interpreter_command_processor",
- srcs = ["interpreter_command_processor.cc"],
- hdrs = ["interpreter_command_processor.h"],
- deps = [
- ":bytecode_executable",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:buffer_view",
- "//iree/hal/host:host_local_command_processor",
- "//iree/rt",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "interpreter_device",
- srcs = ["interpreter_device.cc"],
- hdrs = ["interpreter_device.h"],
- deps = [
- ":bytecode_cache",
- ":bytecode_kernels",
- ":interpreter_command_processor",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer_validation",
- "//iree/hal:command_queue",
- "//iree/hal:device",
- "//iree/hal:fence",
- "//iree/hal/host:async_command_queue",
- "//iree/hal/host:host_event",
- "//iree/hal/host:host_local_allocator",
- "//iree/hal/host:host_submission_queue",
- "//iree/hal/host:inproc_command_buffer",
- "//iree/rt",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "interpreter_driver",
- srcs = ["interpreter_driver.cc"],
- hdrs = ["interpreter_driver.h"],
- deps = [
- ":interpreter_device",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- ],
-)
-
-cc_library(
- name = "interpreter_driver_module",
- srcs = ["interpreter_driver_module.cc"],
- deps = [
- ":interpreter_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/hal:driver_registry",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "interpreter_module",
- srcs = ["interpreter_module.cc"],
- hdrs = ["interpreter_module.h"],
- deps = [
- ":bytecode_dispatch",
- ":bytecode_kernels",
- "//iree/base:flatbuffer_util",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:buffer_view",
- "//iree/rt",
- "//iree/vm:bytecode_module",
- "//iree/vm:bytecode_tables_interpreter",
- "@com_google_absl//absl/types:span",
- ],
-)
diff --git a/iree/hal/interpreter/bytecode_cache.cc b/iree/hal/interpreter/bytecode_cache.cc
deleted file mode 100644
index b42db5a..0000000
--- a/iree/hal/interpreter/bytecode_cache.cc
+++ /dev/null
@@ -1,55 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/interpreter/bytecode_cache.h"
-
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/executable_format.h"
-#include "iree/hal/interpreter/bytecode_executable.h"
-
-namespace iree {
-namespace hal {
-
-BytecodeCache::BytecodeCache(ref_ptr<rt::Instance> instance,
- hal::Allocator* allocator)
- : instance_(std::move(instance)), allocator_(allocator) {}
-
-BytecodeCache::~BytecodeCache() = default;
-
-bool BytecodeCache::CanPrepareFormat(ExecutableFormat format) const {
- return format == kExecutableFormatIreeBytecode;
-}
-
-StatusOr<ref_ptr<Executable>> BytecodeCache::PrepareExecutable(
- ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) {
- IREE_TRACE_SCOPE0("BytecodeCache::PrepareExecutable");
- if (!CanPrepareFormat(spec.format)) {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unsupported format: " << spec.format;
- }
-
- // Wrap the data (or copy it).
- bool allow_aliasing_data =
- AllBitsSet(mode, ExecutableCachingMode::kAliasProvidedData);
- ASSIGN_OR_RETURN(auto executable,
- BytecodeExecutable::Load(add_ref(instance_), allocator_,
- spec, !allow_aliasing_data));
-
- return executable;
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/interpreter/bytecode_cache.h b/iree/hal/interpreter/bytecode_cache.h
deleted file mode 100644
index 4da59ec..0000000
--- a/iree/hal/interpreter/bytecode_cache.h
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_
-#define IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_
-
-#include "iree/hal/allocator.h"
-#include "iree/hal/executable.h"
-#include "iree/hal/executable_cache.h"
-#include "iree/rt/instance.h"
-
-namespace iree {
-namespace hal {
-
-class BytecodeCache final : public ExecutableCache {
- public:
- BytecodeCache(ref_ptr<rt::Instance> instance, hal::Allocator* allocator);
- ~BytecodeCache() override;
-
- bool CanPrepareFormat(ExecutableFormat format) const override;
-
- StatusOr<ref_ptr<Executable>> PrepareExecutable(
- ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override;
-
- private:
- ref_ptr<rt::Instance> instance_;
- hal::Allocator* allocator_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_
diff --git a/iree/hal/interpreter/bytecode_dispatch.cc b/iree/hal/interpreter/bytecode_dispatch.cc
deleted file mode 100644
index f0e7040..0000000
--- a/iree/hal/interpreter/bytecode_dispatch.cc
+++ /dev/null
@@ -1,850 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Implements a full bytecode dispatch system.
-// Currently this is verbose and object oriented, but future revisions
-// (once we have interesting benchmarks) will likely simplify and inline
-// a lot of the checks to make things faster. Consider this to be as
-// experimental an implementation as the entire rest of the project :)
-
-#include "iree/hal/interpreter/bytecode_dispatch.h"
-
-#include <algorithm>
-
-#include "absl/base/attributes.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/types/span.h"
-#include "iree/base/logging.h"
-#include "iree/base/memory.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/heap_buffer.h"
-#include "iree/hal/interpreter/bytecode_dispatch_conversion.h"
-#include "iree/hal/interpreter/bytecode_dispatch_util.h"
-#include "iree/hal/interpreter/bytecode_kernels.h"
-#include "iree/rt/function.h"
-#include "iree/schemas/bytecode/interpreter_bytecode_v0.h"
-#include "iree/vm/bytecode_module.h"
-#include "iree/vm/bytecode_reader.h"
-#include "iree/vm/bytecode_tables_interpreter.h"
-#include "iree/vm/bytecode_util.h"
-#include "iree/vm/opcode_info.h"
-
-namespace iree {
-namespace hal {
-
-namespace {
-
-using ::iree::rt::Stack;
-using ::iree::rt::StackFrame;
-using ::iree::vm::BytecodeReader;
-
-} // namespace
-
-Status Dispatch(hal::Allocator* allocator,
- kernels::RuntimeState* kernel_runtime_state, Stack* stack,
- StackFrame* entry_stack_frame,
- absl::Span<BufferView> entry_results) {
- // Dispatch table mapping 1:1 with bytecode ops.
- // Each entry is a label within this function that can be used for computed
- // goto. You can find more information on computed goto here:
- // https://eli.thegreenplace.net/2012/07/12/computed-goto-for-efficient-dispatch-tables
- //
- // Note that we ensure the table is 256 elements long exactly to make sure
- // that unused opcodes are handled gracefully.
- static const void* kDispatchTable[256] = {
-#define DECLARE_DISPATCH(ordinal, name, ...) &&_dispatch_##name,
-#define DECLARE_DISPATCH_RESERVED(ordinal, name, ...) &&_dispatch_unhandled,
- IREE_INTERPRETER_OPCODE_LIST(DECLARE_DISPATCH, DECLARE_DISPATCH_RESERVED)
-#undef DECLARE_DISPATCH
-#undef DECLARE_DISPATCH_RESERVED
- };
-
- // Primary dispatch state. This is our 'native stack frame' and really just
- // enough to make dereferencing common addresses (like the current offset)
- // faster. You can think of this like CPU state (like PC).
- //
- // We hope that LLVM decides to keep these in registers (as they are touched
- // for every instruction executed). The stack_frame will change as we call
- // into different functions.
- BytecodeReader reader(stack);
- RETURN_IF_ERROR(reader.SwitchStackFrame(entry_stack_frame));
-
-#define DISPATCH_NEXT() \
- { \
- uint8_t opcode = *reader.AdvanceOffset().ValueOrDie(); \
- DVLOG(1) \
- << "Interpreter dispatching op code: " \
- << GetOpcodeInfo(vm::interpreter_opcode_table(), opcode).mnemonic; \
- goto* kDispatchTable[opcode]; \
- }
-
-#define DISPATCH_CORE_OPCODE(opcode, body) \
- _dispatch_##opcode : {body} DISPATCH_NEXT()
-#if defined(IREE_SUPPORT_F32) || defined(IREE_SUPPORT_F64)
-#define DISPATCH_FLOAT_OPCODE(opcode, body) \
- _dispatch_##opcode : {body} DISPATCH_NEXT()
-#else
-#define DISPATCH_FLOAT_OPCODE(...)
-#endif // IREE_SUPPORT_F32 || IREE_SUPPORT_F64
-
- DISPATCH_NEXT();
-
- DISPATCH_CORE_OPCODE(kConstant, {
- ASSIGN_OR_RETURN(auto value, reader.ReadConstant());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- *dst_local = std::move(value);
- });
-
- DISPATCH_CORE_OPCODE(kCall, {
- auto* old_stack_frame = stack->current_frame();
- ASSIGN_OR_RETURN(const auto& target_function, reader.ReadFunction());
- // TODO(benvanik): rework register storage interface.
- ASSIGN_OR_RETURN(
- const auto* function_def,
- static_cast<const vm::BytecodeModule*>(target_function.module())
- ->GetFunctionDef(target_function.linkage(),
- target_function.ordinal()));
- ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function));
- new_stack_frame->mutable_registers()->buffer_views.resize(
- function_def->bytecode()->local_count());
- RETURN_IF_ERROR(
- reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame));
- DVLOG(1) << "Call; stack now: " << stack->DebugString();
- });
-
- DISPATCH_CORE_OPCODE(kCallImport, {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Non-module imports not supported";
- });
-
- DISPATCH_CORE_OPCODE(kCallIndirect, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented call_indirect";
- });
-
- DISPATCH_CORE_OPCODE(kReturn, {
- auto* old_stack_frame = stack->current_frame();
- auto* new_stack_frame = stack->caller_frame();
- if (old_stack_frame == entry_stack_frame) {
- // Returning from entry function. Marshal results from the return stmt.
- ASSIGN_OR_RETURN(int32_t src_count, reader.ReadCount());
- for (int i = 0; i < src_count; ++i) {
- ASSIGN_OR_RETURN(
- auto* src_local,
- reader.ReadLocal(old_stack_frame->mutable_registers()));
- entry_results[i] = std::move(*src_local);
- }
- DVLOG(1) << "Returning to entry";
- return OkStatus();
- } else if (!new_stack_frame) {
- return FailedPreconditionErrorBuilder(IREE_LOC) << "Stack underflow";
- }
- RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame,
- new_stack_frame));
- RETURN_IF_ERROR(stack->PopFrame());
- DVLOG(1) << "Return; stack now: " << stack->DebugString();
- });
-
- DISPATCH_CORE_OPCODE(kBranch, {
- ASSIGN_OR_RETURN(int32_t offset, reader.ReadBlockOffset());
- RETURN_IF_ERROR(reader.CopySlots());
- RETURN_IF_ERROR(reader.BranchToOffset(offset));
- });
-
- DISPATCH_CORE_OPCODE(kCondBranch, {
- // Evaluate condition first so we can do the copies as we read them for
- // which side of the branch we take.
- ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal());
- bool cond_value = BufferViewIsTrue(*cond_local);
- ASSIGN_OR_RETURN(int32_t true_offset, reader.ReadBlockOffset());
- if (cond_value) {
- RETURN_IF_ERROR(reader.CopySlots());
- RETURN_IF_ERROR(reader.BranchToOffset(true_offset));
- } else {
- ASSIGN_OR_RETURN(int32_t true_op_count, reader.ReadCount());
- RETURN_IF_ERROR(reader.SkipLocals(2 * true_op_count));
- ASSIGN_OR_RETURN(int32_t false_offset, reader.ReadBlockOffset());
- RETURN_IF_ERROR(reader.CopySlots());
- RETURN_IF_ERROR(reader.BranchToOffset(false_offset));
- }
- });
-
- DISPATCH_CORE_OPCODE(kCmpI, {
- ASSIGN_OR_RETURN(uint8_t predicate, reader.ReadUint8_t());
- ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
-
- switch (static_cast<CmpIPredicate>(predicate)) {
- case CmpIPredicate::kEq:
- RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareEQ>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpIPredicate::kNe:
- RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareNE>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpIPredicate::kSlt:
- RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareLT>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpIPredicate::kSle:
- RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareLE>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpIPredicate::kSgt:
- RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareGT>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpIPredicate::kSge:
- RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareGE>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpIPredicate::kUlt:
- RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareLT>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpIPredicate::kUle:
- RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareLE>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpIPredicate::kUgt:
- RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareGT>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpIPredicate::kUge:
- RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareGE>(
- lhs_local, rhs_local, dst_local));
- break;
- }
- });
-
- DISPATCH_FLOAT_OPCODE(kCmpF, {
- ASSIGN_OR_RETURN(uint8_t p, reader.ReadUint8_t());
- ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
-
- auto predicate = static_cast<CmpFPredicate>(p);
- switch (predicate) {
- case CmpFPredicate::kOeq:
- RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareEQ>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpFPredicate::kUne:
- RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareNE>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpFPredicate::kOlt:
- RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareLT>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpFPredicate::kOle:
- RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareLE>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpFPredicate::kOgt:
- RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareGT>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpFPredicate::kOge:
- RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareGE>(
- lhs_local, rhs_local, dst_local));
- break;
- case CmpFPredicate::kFalse:
- case CmpFPredicate::kOne:
- case CmpFPredicate::kOrd:
- case CmpFPredicate::kUeq:
- case CmpFPredicate::kUgt:
- case CmpFPredicate::kUge:
- case CmpFPredicate::kUlt:
- case CmpFPredicate::kUle:
- case CmpFPredicate::kUno:
- case CmpFPredicate::kTrue:
- // TODO(b/132183250) support these if we ever need them.
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unsupported comparison predicate value "
- << static_cast<int>(p) << " ("
- << vm::PredicateToString(predicate) << ")";
- }
- });
-
- DISPATCH_CORE_OPCODE(kAllocStatic, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_static";
- });
-
- DISPATCH_CORE_OPCODE(kAllocStack, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_stack";
- });
-
- DISPATCH_CORE_OPCODE(kAllocStackInit, {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented alloc_stack_init";
- });
-
- DISPATCH_CORE_OPCODE(kAllocHeap, {
- ASSIGN_OR_RETURN(auto heap_type, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto type, reader.ReadType());
- size_t element_size = type.element_size();
-
- // TODO(benvanik): more efficient reading and storage.
- size_t element_count = 0;
- ASSIGN_OR_RETURN(auto shape, reader.ReadShapePieces(&element_count));
- size_t allocation_size = element_size * element_count;
-
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- dst_local->element_size = element_size;
- dst_local->shape = shape;
-
- // TODO(benvanik): properly allocate with attributes from op.
- CHECK_EQ(heap_type, 0);
- ASSIGN_OR_RETURN(
- dst_local->buffer,
- allocator->Allocate(MemoryType::kHostLocal | MemoryType::kDeviceVisible,
- BufferUsage::kAll, allocation_size));
- });
-
- DISPATCH_CORE_OPCODE(kDiscard, {
- // NOTE: if we were an encoder we would actually discard the buffer.
- ASSIGN_OR_RETURN(auto* local, reader.ReadLocal());
- *local = {};
- });
-
- DISPATCH_CORE_OPCODE(kRank, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- int32_t rank = src_local->shape.size();
- RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &rank, sizeof(int32_t)));
- });
-
- DISPATCH_CORE_OPCODE(kDim, {
- ASSIGN_OR_RETURN(int32_t axis, reader.ReadUint8_t());
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(int32_t dim, src_local->shape.ResolveAxis(axis));
- RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &dim, sizeof(int32_t)));
- });
-
- DISPATCH_CORE_OPCODE(kShape, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- RETURN_IF_ERROR(dst_local->buffer->WriteData(
- 0, src_local->shape.subspan().data(),
- src_local->shape.subspan().size() * sizeof(int32_t)));
- });
-
- DISPATCH_CORE_OPCODE(kLength, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- int32_t length = src_local->shape.element_count();
- RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &length, sizeof(int32_t)));
- });
-
- DISPATCH_CORE_OPCODE(kDynamicSlice, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto indices, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(*dst_local, src_local->Slice(indices, lengths));
- });
-
- DISPATCH_CORE_OPCODE(kStaticSlice, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto indices, reader.ReadIndexList());
- ASSIGN_OR_RETURN(auto lengths, reader.ReadIndexList());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(*dst_local, src_local->Slice(indices, lengths));
- });
-
- DISPATCH_CORE_OPCODE(kDynamicCopy, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto src_indices, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dst_indices, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>());
- RETURN_IF_ERROR(
- ApplyCopy(src_local, src_indices, dst_local, dst_indices, lengths));
- });
-
- DISPATCH_CORE_OPCODE(kStaticCopy, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto src_indices, reader.ReadIndexList());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dst_indices, reader.ReadIndexList());
- ASSIGN_OR_RETURN(auto lengths, reader.ReadIndexList());
- RETURN_IF_ERROR(
- ApplyCopy(src_local, src_indices, dst_local, dst_indices, lengths));
- });
-
- DISPATCH_CORE_OPCODE(kClone, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- dst_local->element_size = src_local->element_size;
- dst_local->shape = src_local->shape;
- dst_local->buffer = HeapBuffer::Allocate(src_local->buffer->usage(),
- src_local->buffer->byte_length());
- RETURN_IF_ERROR(dst_local->buffer->CopyData(0, src_local->buffer.get()));
- });
-
- DISPATCH_CORE_OPCODE(kSplit, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented split";
- });
-
- DISPATCH_CORE_OPCODE(kAssign, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- *dst_local = *src_local;
- });
-
- DISPATCH_CORE_OPCODE(kCondAssign, {
- ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- *dst_local = BufferViewIsTrue(*cond_local) ? *lhs_local : *rhs_local;
- });
-
- DISPATCH_CORE_OPCODE(kReshape, {
- // TODO(benvanik): more logic required if strides differ.
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- Shape new_shape = Shape{shape_data};
- if (src_local->shape.element_count() != new_shape.element_count()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "New element count " << new_shape.element_count()
- << " != source element count " << src_local->shape.element_count();
- }
- dst_local->shape = new_shape;
- dst_local->buffer = add_ref(src_local->buffer);
- dst_local->element_size = src_local->element_size;
- });
-
- DISPATCH_CORE_OPCODE(kSelect, {
- ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto cond_buffer, cond_local->buffer->MapMemory<uint8_t>(
- MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto lhs_buffer, lhs_local->buffer->MapMemory<uint8_t>(
- MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto rhs_buffer, rhs_local->buffer->MapMemory<uint8_t>(
- MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<uint8_t>(
- MemoryAccess::kDiscardWrite));
- if (cond_local->element_size != 1) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "Select cond must be i8";
- } else if (lhs_buffer.size() != rhs_buffer.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "LHS " << lhs_buffer.size() << "b != RHS " << rhs_buffer.size()
- << "b; both arguments must match";
- } else if (lhs_buffer.size() != dst_buffer.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Dest " << dst_buffer.size() << "b != LHS/RHS "
- << lhs_buffer.size() << "b; dest must match inputs";
- }
- switch (lhs_local->element_size) {
- case 1:
- RETURN_IF_ERROR(kernels::Select::Execute<uint8_t>(
- cond_buffer.contents(), lhs_buffer.contents(),
- rhs_buffer.contents(), dst_buffer.mutable_contents()));
- break;
- case 2:
- RETURN_IF_ERROR(kernels::Select::Execute<uint16_t>(
- cond_buffer.contents(),
- ReinterpretSpan<uint16_t>(lhs_buffer.contents()),
- ReinterpretSpan<uint16_t>(rhs_buffer.contents()),
- ReinterpretSpan<uint16_t>(dst_buffer.mutable_contents())));
- break;
- case 4:
- RETURN_IF_ERROR(kernels::Select::Execute<uint32_t>(
- cond_buffer.contents(),
- ReinterpretSpan<uint32_t>(lhs_buffer.contents()),
- ReinterpretSpan<uint32_t>(rhs_buffer.contents()),
- ReinterpretSpan<uint32_t>(dst_buffer.mutable_contents())));
- break;
- case 8:
- RETURN_IF_ERROR(kernels::Select::Execute<uint64_t>(
- cond_buffer.contents(),
- ReinterpretSpan<uint64_t>(lhs_buffer.contents()),
- ReinterpretSpan<uint64_t>(rhs_buffer.contents()),
- ReinterpretSpan<uint64_t>(dst_buffer.mutable_contents())));
- break;
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << lhs_local->element_size;
- }
- });
-
- DISPATCH_CORE_OPCODE(kTranspose, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto perm_data, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Transpose>(
- src_local, dst_local, src_local->shape,
- absl::MakeConstSpan(perm_data)));
- });
-
- DISPATCH_CORE_OPCODE(kReverse, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto perm_data, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- RETURN_IF_ERROR(
- ApplyUnaryOpIU<kernels::Reverse>(src_local, dst_local, src_local->shape,
- absl::MakeConstSpan(perm_data)));
- });
-
- DISPATCH_CORE_OPCODE(kPad, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* padding_value, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto edge_padding_low, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto edge_padding_high,
- reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto interior_padding, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
-
- RETURN_IF_ERROR(ApplyBinaryOpIU<kernels::Pad>(
- src_local, padding_value, dst_local, src_local->shape, dst_local->shape,
- absl::MakeConstSpan(edge_padding_low),
- absl::MakeConstSpan(edge_padding_high),
- absl::MakeConstSpan(interior_padding)));
- });
-
- DISPATCH_CORE_OPCODE(kBroadcast, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- dst_local->shape = Shape{shape_data};
- RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Broadcast>(src_local, dst_local));
- });
-
- DISPATCH_CORE_OPCODE(kTile, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- dst_local->shape = Shape{shape_data};
- RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Tile>(
- src_local, dst_local, src_local->shape, dst_local->shape));
- });
-
- DISPATCH_CORE_OPCODE(kNot, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpIU<kernels::Not>(&reader));
- });
- DISPATCH_CORE_OPCODE(kAnd, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::And>(&reader));
- });
- DISPATCH_CORE_OPCODE(kOr, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Or>(&reader));
- });
- DISPATCH_CORE_OPCODE(kXor, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Xor>(&reader));
- });
- DISPATCH_CORE_OPCODE(kShiftLeft, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::ShiftLeft>(&reader));
- });
- DISPATCH_CORE_OPCODE(kShiftRightLogical, {
- RETURN_IF_ERROR(
- DispatchElementwiseBinaryOpIU<kernels::ShiftRight>(&reader));
- });
- DISPATCH_CORE_OPCODE(kShiftRightArithmetic, {
- RETURN_IF_ERROR(
- DispatchElementwiseBinaryOpIS<kernels::ShiftRight>(&reader));
- });
-
- DISPATCH_CORE_OPCODE(kAddI, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Add>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kAddF, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Add>(&reader));
- });
-
- DISPATCH_CORE_OPCODE(kSubI, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Sub>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kSubF, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Sub>(&reader));
- });
-
- DISPATCH_CORE_OPCODE(kAbsI, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpIS<kernels::Abs>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kAbsF, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Abs>(&reader));
- });
-
- DISPATCH_CORE_OPCODE(kMulI, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Mul>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kMulF, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Mul>(&reader));
- });
-
- DISPATCH_CORE_OPCODE(kDivIS, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Div>(&reader));
- });
- DISPATCH_CORE_OPCODE(kDivIU, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Div>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kDivF, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Div>(&reader));
- });
-
- DISPATCH_CORE_OPCODE(kMulAddI, {
- RETURN_IF_ERROR(DispatchElementwiseTernaryOpIU<kernels::MulAdd>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kMulAddF, {
- RETURN_IF_ERROR(DispatchElementwiseTernaryOpF<kernels::MulAdd>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kExpF, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Exp>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kLogF, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Log>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kRsqrtF, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Rsqrt>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kCosF, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Cos>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kSinF, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Sin>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kTanhF, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Tanh>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kAtan2F, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Atan2>(&reader));
- });
-
- DISPATCH_CORE_OPCODE(kMinIS, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Min>(&reader));
- });
- DISPATCH_CORE_OPCODE(kMinIU, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Min>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kMinF, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Min>(&reader));
- });
-
- DISPATCH_CORE_OPCODE(kMaxIS, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Max>(&reader));
- });
- DISPATCH_CORE_OPCODE(kMaxIU, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Max>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kMaxF, {
- RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Max>(&reader));
- });
-
- DISPATCH_CORE_OPCODE(kClampIS, {
- RETURN_IF_ERROR(DispatchElementwiseTernaryOpIS<kernels::Clamp>(&reader));
- });
- DISPATCH_CORE_OPCODE(kClampIU, {
- RETURN_IF_ERROR(DispatchElementwiseTernaryOpIS<kernels::Clamp>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kClampF, {
- RETURN_IF_ERROR(DispatchElementwiseTernaryOpF<kernels::Clamp>(&reader));
- });
-
- DISPATCH_FLOAT_OPCODE(kFloorF, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Floor>(&reader));
- });
- DISPATCH_FLOAT_OPCODE(kCeilF, {
- RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Ceil>(&reader));
- });
-
- DISPATCH_CORE_OPCODE(kConvertSS, {
- ASSIGN_OR_RETURN(auto src_type, reader.ReadType());
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dst_type, reader.ReadType());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- RETURN_IF_ERROR(
- ApplyConvertSS::Apply(src_type, src_local, dst_type, dst_local));
- });
- DISPATCH_CORE_OPCODE(kConvertUU, {
- ASSIGN_OR_RETURN(auto src_type, reader.ReadType());
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dst_type, reader.ReadType());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- RETURN_IF_ERROR(
- ApplyConvertUU::Apply(src_type, src_local, dst_type, dst_local));
- });
- DISPATCH_CORE_OPCODE(kConvertSU, {
- ASSIGN_OR_RETURN(auto src_type, reader.ReadType());
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dst_type, reader.ReadType());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- RETURN_IF_ERROR(
- ApplyConvertSU::Apply(src_type, src_local, dst_type, dst_local));
- });
- DISPATCH_CORE_OPCODE(kConvertUS, {
- ASSIGN_OR_RETURN(auto src_type, reader.ReadType());
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dst_type, reader.ReadType());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- RETURN_IF_ERROR(
- ApplyConvertUS::Apply(src_type, src_local, dst_type, dst_local));
- });
-
- DISPATCH_CORE_OPCODE(kMatMulI, {
- ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
- // TODO(benvanik): add fused matmul-with-bias op in MLIR and lower to this.
- BufferView* bias_local = nullptr;
- ASSIGN_OR_RETURN(auto* multiplier_mantissa_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* multiplier_exponent_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- RETURN_IF_ERROR(ValidateMatMulOpI(lhs_local, rhs_local, bias_local,
- multiplier_mantissa_local,
- multiplier_exponent_local, dst_local));
- auto* mat_mul_state = kernel_runtime_state->mat_mul_state.get();
- // TODO(benvanik): define as a matrix of supported types to enable 8*8=16,
- // accumulator options, and other precision modes.
- switch (lhs_local->element_size) {
- case 1:
- RETURN_IF_ERROR(ApplyMatMulOpI<int8_t>(
- mat_mul_state, lhs_local, rhs_local, bias_local,
- multiplier_mantissa_local, multiplier_exponent_local, dst_local));
- break;
- case 2:
- RETURN_IF_ERROR(ApplyMatMulOpI<int16_t>(
- mat_mul_state, lhs_local, rhs_local, bias_local,
- multiplier_mantissa_local, multiplier_exponent_local, dst_local));
- break;
- case 4:
- RETURN_IF_ERROR(ApplyMatMulOpI<int32_t>(
- mat_mul_state, lhs_local, rhs_local, bias_local,
- multiplier_mantissa_local, multiplier_exponent_local, dst_local));
- break;
- case 8:
- RETURN_IF_ERROR(ApplyMatMulOpI<int64_t>(
- mat_mul_state, lhs_local, rhs_local, bias_local,
- multiplier_mantissa_local, multiplier_exponent_local, dst_local));
- break;
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << lhs_local->element_size;
- }
- });
-
- DISPATCH_FLOAT_OPCODE(kMatMulF, {
- ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
- BufferView* bias_local = nullptr;
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- RETURN_IF_ERROR(
- ValidateMatMulOpF(lhs_local, rhs_local, bias_local, dst_local));
- auto* mat_mul_state = kernel_runtime_state->mat_mul_state.get();
- switch (lhs_local->element_size) {
- case 4:
- RETURN_IF_ERROR(ApplyMatMulOpF<float>(
- mat_mul_state, lhs_local, rhs_local, bias_local, dst_local));
- break;
- case 8:
- RETURN_IF_ERROR(ApplyMatMulOpF<double>(
- mat_mul_state, lhs_local, rhs_local, bias_local, dst_local));
- break;
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << lhs_local->element_size;
- }
- });
-
- DISPATCH_CORE_OPCODE(kReduceSumI, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- // TODO(scotttodd): validate
- RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceSum>(
- src_local, init_local, dst_local, dimension, src_local->shape,
- dst_local->shape));
- });
-
- DISPATCH_FLOAT_OPCODE(kReduceSumF, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- // TODO(scotttodd): validate
- RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceSum>(
- src_local, init_local, dst_local, dimension, src_local->shape,
- dst_local->shape));
- });
-
- DISPATCH_CORE_OPCODE(kReduceMinI, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- // TODO(scotttodd): validate
- RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceMin>(
- src_local, init_local, dst_local, dimension, src_local->shape,
- dst_local->shape));
- });
-
- DISPATCH_FLOAT_OPCODE(kReduceMinF, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- // TODO(scotttodd): validate
- RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceMin>(
- src_local, init_local, dst_local, dimension, src_local->shape,
- dst_local->shape));
- });
-
- DISPATCH_CORE_OPCODE(kReduceMaxI, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- // TODO(scotttodd): validate
- RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceMax>(
- src_local, init_local, dst_local, dimension, src_local->shape,
- dst_local->shape));
- });
-
- DISPATCH_FLOAT_OPCODE(kReduceMaxF, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- // TODO(scotttodd): validate
- RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceMax>(
- src_local, init_local, dst_local, dimension, src_local->shape,
- dst_local->shape));
- });
-
- DISPATCH_CORE_OPCODE(kTrace, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented trace";
- });
-
- DISPATCH_CORE_OPCODE(kBreak, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented break";
- });
-
- DISPATCH_CORE_OPCODE(kCondBreak, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented cond_break";
- });
-
-_dispatch_unhandled:
- // TODO(benvanik): better tracing.
- return UnimplementedErrorBuilder(IREE_LOC) << "Unknown dispatch opcode";
-} // NOLINT(readability/fn_size)
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/interpreter/bytecode_dispatch.h b/iree/hal/interpreter/bytecode_dispatch.h
deleted file mode 100644
index edd92d2..0000000
--- a/iree/hal/interpreter/bytecode_dispatch.h
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_
-#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_
-
-#include "iree/base/status.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/interpreter/bytecode_kernels.h"
-#include "iree/rt/stack.h"
-#include "iree/rt/stack_frame.h"
-
-namespace iree {
-namespace hal {
-
-Status Dispatch(hal::Allocator* allocator,
- kernels::RuntimeState* kernel_runtime_state, rt::Stack* stack,
- rt::StackFrame* entry_stack_frame,
- absl::Span<BufferView> entry_results);
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_
diff --git a/iree/hal/interpreter/bytecode_dispatch_conversion.h b/iree/hal/interpreter/bytecode_dispatch_conversion.h
deleted file mode 100644
index 8973250..0000000
--- a/iree/hal/interpreter/bytecode_dispatch_conversion.h
+++ /dev/null
@@ -1,395 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Conversion helper tables.
-
-#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_
-#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_
-
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/interpreter/bytecode_dispatch_util.h"
-#include "iree/schemas/bytecode/interpreter_bytecode_v0.h"
-#include "iree/vm/type.h"
-
-namespace iree {
-namespace hal {
-
-template <typename KERNEL, bool src_signed, bool dst_signed, typename... ARGS>
-struct ApplyConversionOp {
- static Status Apply(const vm::Type& src_type, BufferView* src_local,
- const vm::Type& dst_type, BufferView* dst_local,
- ARGS... args) {
- // Validate ranges so that we cannot go out of bounds on thunk table.
- int src_type_index = src_type.type_index();
- int dst_type_index = dst_type.type_index();
- if (src_type_index < 0 || src_type_index >= kBuiltinTypeCount) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Conversion from invalid source builtin type "
- << src_type_index;
- } else if (dst_type_index < 0 || dst_type_index >= kBuiltinTypeCount) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Conversion to invalid dest builtin type " << dst_type_index;
- }
-
- // All possible combinations of conversions.
- using KernelFn = Status (*)(BufferView * src_local, BufferView * dst_local,
- ARGS... args);
- KernelFn fn = nullptr;
- if (src_signed && dst_signed) {
- // Signed -> signed.
- static const KernelFn
- kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
- // src_type = kI8:
- /* kI8 */ Thunk<int8_t, int8_t>::Apply,
- /* kI16 */ Thunk<int8_t, int16_t>::Apply,
- /* kI32 */ Thunk<int8_t, int32_t>::Apply,
- /* kI64 */ Thunk<int8_t, int64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ Thunk<int8_t, float>::Apply,
- /* kF64 */ Thunk<int8_t, double>::Apply,
-
- // src_type = kI16:
- /* kI8 */ Thunk<int16_t, int8_t>::Apply,
- /* kI16 */ Thunk<int16_t, int16_t>::Apply,
- /* kI32 */ Thunk<int16_t, int32_t>::Apply,
- /* kI64 */ Thunk<int16_t, int64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ Thunk<int16_t, float>::Apply,
- /* kF64 */ Thunk<int16_t, double>::Apply,
-
- // src_type = kI32:
- /* kI8 */ Thunk<int32_t, int8_t>::Apply,
- /* kI16 */ Thunk<int32_t, int16_t>::Apply,
- /* kI32 */ Thunk<int32_t, int32_t>::Apply,
- /* kI64 */ Thunk<int32_t, int64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ Thunk<int32_t, float>::Apply,
- /* kF64 */ Thunk<int32_t, double>::Apply,
-
- // src_type = kI64:
- /* kI8 */ Thunk<int64_t, int8_t>::Apply,
- /* kI16 */ Thunk<int64_t, int16_t>::Apply,
- /* kI32 */ Thunk<int64_t, int32_t>::Apply,
- /* kI64 */ Thunk<int64_t, int64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ Thunk<int64_t, float>::Apply,
- /* kF64 */ Thunk<int64_t, double>::Apply,
-
- // src_type = kF16:
- /* kI8 */ nullptr,
- /* kI16 */ nullptr,
- /* kI32 */ nullptr,
- /* kI64 */ nullptr,
- /* kF16 */ Thunk<uint16_t, uint16_t>::Apply,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kF32:
- /* kI8 */ Thunk<float, int8_t>::Apply,
- /* kI16 */ Thunk<float, int16_t>::Apply,
- /* kI32 */ Thunk<float, int32_t>::Apply,
- /* kI64 */ Thunk<float, int64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ Thunk<float, float>::Apply,
- /* kF64 */ Thunk<float, double>::Apply,
-
- // src_type = kF64:
- /* kI8 */ Thunk<double, int8_t>::Apply,
- /* kI16 */ Thunk<double, int16_t>::Apply,
- /* kI32 */ Thunk<double, int32_t>::Apply,
- /* kI64 */ Thunk<double, int64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ Thunk<double, float>::Apply,
- /* kF64 */ Thunk<double, double>::Apply,
- };
- fn =
- kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
- } else if (src_signed && !dst_signed) {
- // Signed -> unsigned.
- static const KernelFn
- kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
- // src_type = kI8:
- /* kI8 */ Thunk<int8_t, uint8_t>::Apply,
- /* kI16 */ Thunk<int8_t, uint16_t>::Apply,
- /* kI32 */ Thunk<int8_t, uint32_t>::Apply,
- /* kI64 */ Thunk<int8_t, uint64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kI16:
- /* kI8 */ Thunk<int16_t, uint8_t>::Apply,
- /* kI16 */ Thunk<int16_t, uint16_t>::Apply,
- /* kI32 */ Thunk<int16_t, uint32_t>::Apply,
- /* kI64 */ Thunk<int16_t, uint64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kI32:
- /* kI8 */ Thunk<int32_t, uint8_t>::Apply,
- /* kI16 */ Thunk<int32_t, uint16_t>::Apply,
- /* kI32 */ Thunk<int32_t, uint32_t>::Apply,
- /* kI64 */ Thunk<int32_t, uint64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kI64:
- /* kI8 */ Thunk<int64_t, uint8_t>::Apply,
- /* kI16 */ Thunk<int64_t, uint16_t>::Apply,
- /* kI32 */ Thunk<int64_t, uint32_t>::Apply,
- /* kI64 */ Thunk<int64_t, uint64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kF16:
- /* kI8 */ nullptr,
- /* kI16 */ nullptr,
- /* kI32 */ nullptr,
- /* kI64 */ nullptr,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kF32:
- /* kI8 */ Thunk<float, uint8_t>::Apply,
- /* kI16 */ Thunk<float, uint16_t>::Apply,
- /* kI32 */ Thunk<float, uint32_t>::Apply,
- /* kI64 */ Thunk<float, uint64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kF64:
- /* kI8 */ Thunk<double, uint8_t>::Apply,
- /* kI16 */ Thunk<double, uint16_t>::Apply,
- /* kI32 */ Thunk<double, uint32_t>::Apply,
- /* kI64 */ Thunk<double, uint64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
- };
- fn =
- kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
- } else if (!src_signed && dst_signed) {
- // Unsigned -> signed.
- static const KernelFn
- kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
- // src_type = kI8:
- /* kI8 */ Thunk<uint8_t, int8_t>::Apply,
- /* kI16 */ Thunk<uint8_t, int16_t>::Apply,
- /* kI32 */ Thunk<uint8_t, int32_t>::Apply,
- /* kI64 */ Thunk<uint8_t, int64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ Thunk<uint8_t, float>::Apply,
- /* kF64 */ Thunk<uint8_t, double>::Apply,
-
- // src_type = kI16:
- /* kI8 */ Thunk<uint16_t, int8_t>::Apply,
- /* kI16 */ Thunk<uint16_t, int16_t>::Apply,
- /* kI32 */ Thunk<uint16_t, int32_t>::Apply,
- /* kI64 */ Thunk<uint16_t, int64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ Thunk<uint16_t, float>::Apply,
- /* kF64 */ Thunk<uint16_t, double>::Apply,
-
- // src_type = kI32:
- /* kI8 */ Thunk<uint32_t, int8_t>::Apply,
- /* kI16 */ Thunk<uint32_t, int16_t>::Apply,
- /* kI32 */ Thunk<uint32_t, int32_t>::Apply,
- /* kI64 */ Thunk<uint32_t, int64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ Thunk<uint32_t, float>::Apply,
- /* kF64 */ Thunk<uint32_t, double>::Apply,
-
- // src_type = kI64:
- /* kI8 */ Thunk<uint64_t, int8_t>::Apply,
- /* kI16 */ Thunk<uint64_t, int16_t>::Apply,
- /* kI32 */ Thunk<uint64_t, int32_t>::Apply,
- /* kI64 */ Thunk<uint64_t, int64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ Thunk<uint64_t, float>::Apply,
- /* kF64 */ Thunk<uint64_t, double>::Apply,
-
- // src_type = kF16:
- /* kI8 */ nullptr,
- /* kI16 */ nullptr,
- /* kI32 */ nullptr,
- /* kI64 */ nullptr,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kF32:
- /* kI8 */ nullptr,
- /* kI16 */ nullptr,
- /* kI32 */ nullptr,
- /* kI64 */ nullptr,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kF64:
- /* kI8 */ nullptr,
- /* kI16 */ nullptr,
- /* kI32 */ nullptr,
- /* kI64 */ nullptr,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
- };
- fn =
- kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
- } else if (!src_signed && !dst_signed) {
- // Unsigned -> unsigned.
- static const KernelFn
- kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
- // src_type = kI8:
- /* kI8 */ Thunk<uint8_t, uint8_t>::Apply,
- /* kI16 */ Thunk<uint8_t, uint16_t>::Apply,
- /* kI32 */ Thunk<uint8_t, uint32_t>::Apply,
- /* kI64 */ Thunk<uint8_t, uint64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kI16:
- /* kI8 */ Thunk<uint16_t, uint8_t>::Apply,
- /* kI16 */ Thunk<uint16_t, uint16_t>::Apply,
- /* kI32 */ Thunk<uint16_t, uint32_t>::Apply,
- /* kI64 */ Thunk<uint16_t, uint64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kI32:
- /* kI8 */ Thunk<uint32_t, uint8_t>::Apply,
- /* kI16 */ Thunk<uint32_t, uint16_t>::Apply,
- /* kI32 */ Thunk<uint32_t, uint32_t>::Apply,
- /* kI64 */ Thunk<uint32_t, uint64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kI64:
- /* kI8 */ Thunk<uint64_t, uint8_t>::Apply,
- /* kI16 */ Thunk<uint64_t, uint16_t>::Apply,
- /* kI32 */ Thunk<uint64_t, uint32_t>::Apply,
- /* kI64 */ Thunk<uint64_t, uint64_t>::Apply,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kF16:
- /* kI8 */ nullptr,
- /* kI16 */ nullptr,
- /* kI32 */ nullptr,
- /* kI64 */ nullptr,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kF32:
- /* kI8 */ nullptr,
- /* kI16 */ nullptr,
- /* kI32 */ nullptr,
- /* kI64 */ nullptr,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
-
- // src_type = kF64:
- /* kI8 */ nullptr,
- /* kI16 */ nullptr,
- /* kI32 */ nullptr,
- /* kI64 */ nullptr,
- /* kF16 */ nullptr,
- /* kF32 */ nullptr,
- /* kF64 */ nullptr,
- };
- fn =
- kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
- }
- if (!fn) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Unsupported conversion from " << src_type_index << " to "
- << dst_type_index;
- }
- return fn(src_local, dst_local, args...);
- }
-
- template <typename SRC, typename DST>
- struct Thunk {
- static Status Apply(BufferView* src_local, BufferView* dst_local,
- ARGS... args) {
- ASSIGN_OR_RETURN(auto src_buffer,
- src_local->buffer->MapMemory<SRC>(MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<DST>(
- MemoryAccess::kDiscardWrite));
- return KERNEL::Execute(src_buffer.contents(),
- dst_buffer.mutable_contents(), args...);
- }
- };
-
-// Disable F32/F64 conversions if they are not supported.
-#if !defined(IREE_SUPPORT_F32)
- template <typename DST>
- struct Thunk<float, DST> {
- static Status Apply(BufferView* src_local, BufferView* dst_local,
- ARGS... args) {
- return UnimplementedErrorBuilder(IREE_LOC) << "F32 not supported";
- }
- };
- template <typename SRC>
- struct Thunk<SRC, float> {
- static Status Apply(BufferView* src_local, BufferView* dst_local,
- ARGS... args) {
- return UnimplementedErrorBuilder(IREE_LOC) << "F32 not supported";
- }
- };
-#endif // !IREE_SUPPORT_F32
-#if !defined(IREE_SUPPORT_F64)
- template <typename DST>
- struct Thunk<double, DST> {
- static Status Apply(BufferView* src_local, BufferView* dst_local,
- ARGS... args) {
- return UnimplementedErrorBuilder(IREE_LOC) << "F64 not supported";
- }
- };
- template <typename SRC>
- struct Thunk<SRC, double> {
- static Status Apply(BufferView* src_local, BufferView* dst_local,
- ARGS... args) {
- return UnimplementedErrorBuilder(IREE_LOC) << "F64 not supported";
- }
- };
-#endif // !IREE_SUPPORT_F64
-};
-
-using ApplyConvertSS = ApplyConversionOp<kernels::Convert, /*src_signed=*/true,
- /*dst_signed=*/true>;
-using ApplyConvertUU = ApplyConversionOp<kernels::Convert, /*src_signed=*/false,
- /*dst_signed=*/false>;
-using ApplyConvertSU = ApplyConversionOp<kernels::Convert, /*src_signed=*/true,
- /*dst_signed=*/false>;
-using ApplyConvertUS = ApplyConversionOp<kernels::Convert, /*src_signed=*/false,
- /*dst_signed=*/true>;
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_
diff --git a/iree/hal/interpreter/bytecode_dispatch_util.cc b/iree/hal/interpreter/bytecode_dispatch_util.cc
deleted file mode 100644
index 40937e1..0000000
--- a/iree/hal/interpreter/bytecode_dispatch_util.cc
+++ /dev/null
@@ -1,107 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/interpreter/bytecode_dispatch_util.h"
-
-namespace iree {
-namespace hal {
-
-bool BufferViewIsTrue(const BufferView& buffer_view) {
- if (buffer_view.element_size == 0 || !buffer_view.buffer ||
- buffer_view.byte_length() == 0) {
- return false;
- }
- // TODO(benvanik): map more efficiently (based on element size?).
- auto mapping =
- buffer_view.buffer->MapMemory<uint8_t>(hal::MemoryAccess::kRead);
- if (!mapping.ok()) {
- return false;
- }
- for (uint8_t value : mapping.ValueOrDie().contents()) {
- if (value) return true;
- }
- return false;
-}
-
-Status ValidateElementwiseUnaryOp(BufferView* src_local,
- BufferView* dst_local) {
- // TODO(benvanik): validate shapes.
- return OkStatus();
-}
-
-Status ValidateElementwiseBinaryOp(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* dst_local) {
- // TODO(benvanik): validate shapes.
- return OkStatus();
-}
-
-Status ValidateElementwiseTernaryOp(BufferView* a_local, BufferView* b_local,
- BufferView* c_local,
- BufferView* dst_local) {
- // TODO(benvanik): validate shapes.
- return OkStatus();
-}
-
-Status ValidateMatMulOpI(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* bias_local,
- BufferView* multiplier_mantissa_local,
- BufferView* multiplier_exponent_local,
- BufferView* dst_local) {
- // TODO(benvanik): validate shapes.
- return OkStatus();
-}
-
-Status ValidateMatMulOpF(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* bias_local, BufferView* dst_local) {
- // TODO(benvanik): validate shapes.
- return OkStatus();
-}
-
-Status ApplyCopy(BufferView* src_local, absl::Span<const int32_t> src_indices,
- BufferView* dst_local, absl::Span<const int32_t> dst_indices,
- absl::Span<const int32_t> lengths) {
- ASSIGN_OR_RETURN(auto src_buffer,
- src_local->buffer->MapMemory<uint8_t>(MemoryAccess::kRead));
- // TODO(benvanik): discard if overwriting the entire buffer.
- ASSIGN_OR_RETURN(auto dst_buffer,
- dst_local->buffer->MapMemory<uint8_t>(MemoryAccess::kWrite));
- switch (src_local->element_size) {
- case 1:
- return kernels::Copy::Execute<1>(src_buffer.contents(), src_local->shape,
- src_indices,
- dst_buffer.mutable_contents(),
- dst_local->shape, dst_indices, lengths);
- case 2:
- return kernels::Copy::Execute<2>(src_buffer.contents(), src_local->shape,
- src_indices,
- dst_buffer.mutable_contents(),
- dst_local->shape, dst_indices, lengths);
- case 4:
- return kernels::Copy::Execute<4>(src_buffer.contents(), src_local->shape,
- src_indices,
- dst_buffer.mutable_contents(),
- dst_local->shape, dst_indices, lengths);
- case 8:
- return kernels::Copy::Execute<8>(src_buffer.contents(), src_local->shape,
- src_indices,
- dst_buffer.mutable_contents(),
- dst_local->shape, dst_indices, lengths);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << src_local->element_size;
- }
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/interpreter/bytecode_dispatch_util.h b/iree/hal/interpreter/bytecode_dispatch_util.h
deleted file mode 100644
index cb8d7d9..0000000
--- a/iree/hal/interpreter/bytecode_dispatch_util.h
+++ /dev/null
@@ -1,513 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Utilities used by the bytecode_dispatch routines to aid in working with the
-// bytecode stream and kernel dispatch.
-
-#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_
-#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_
-
-#include "absl/base/attributes.h"
-#include "absl/container/inlined_vector.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/heap_buffer.h"
-#include "iree/hal/interpreter/bytecode_kernels.h"
-#include "iree/rt/function.h"
-#include "iree/rt/stack.h"
-#include "iree/schemas/bytecode/interpreter_bytecode_v0.h"
-#include "iree/vm/bytecode_reader.h"
-#include "iree/vm/type.h"
-
-// TODO(benvanik): move to dedicated config file/build flags.
-#define IREE_SUPPORT_F32 1
-#define IREE_SUPPORT_F64 1
-
-namespace iree {
-namespace hal {
-
-// Returns true if the contents of the BufferView are bitwise non-zero.
-// Returns false if there is no buffer, the buffer is empty, or the contents are
-// bitwise zero.
-bool BufferViewIsTrue(const BufferView& buffer_view);
-
-Status ValidateElementwiseUnaryOp(BufferView* src_local, BufferView* dst_local);
-Status ValidateElementwiseBinaryOp(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* dst_local);
-Status ValidateElementwiseTernaryOp(BufferView* a_local, BufferView* b_local,
- BufferView* c_local, BufferView* dst_local);
-Status ValidateMatMulOpI(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* bias_local,
- BufferView* multiplier_mantissa_local,
- BufferView* multiplier_exponent_local,
- BufferView* dst_local);
-Status ValidateMatMulOpF(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* bias_local, BufferView* dst_local);
-
-template <typename KERNEL, typename T, typename... ARGS>
-Status ApplyUnaryOp(BufferView* src_local, BufferView* dst_local,
- ARGS... args) {
- // TODO(benvanik): avoid mapping by changing buffer type?
- ASSIGN_OR_RETURN(auto src_buffer,
- src_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>(
- MemoryAccess::kDiscardWrite));
- return KERNEL::Execute(src_buffer.contents(), dst_buffer.mutable_contents(),
- args...);
-}
-
-template <typename KERNEL, typename T, typename... ARGS>
-Status ApplyBinaryOp(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* dst_local, ARGS... args) {
- ASSIGN_OR_RETURN(auto lhs_buffer,
- lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto rhs_buffer,
- rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>(
- MemoryAccess::kDiscardWrite));
- return KERNEL::Execute(lhs_buffer.contents(), rhs_buffer.contents(),
- dst_buffer.mutable_contents(), args...);
-}
-
-template <typename KERNEL, typename T, typename... ARGS>
-Status ApplyTernaryOp(BufferView* a_local, BufferView* b_local,
- BufferView* c_local, BufferView* dst_local,
- ARGS... args) {
- ASSIGN_OR_RETURN(auto a_buffer,
- a_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto b_buffer,
- b_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto c_buffer,
- c_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>(
- MemoryAccess::kDiscardWrite));
- return KERNEL::Execute(a_buffer.contents(), b_buffer.contents(),
- c_buffer.contents(), dst_buffer.mutable_contents(),
- args...);
-}
-
-template <typename KERNEL, typename T>
-Status ApplyComparisonOp(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* dst_local) {
- ASSIGN_OR_RETURN(auto lhs_buffer,
- lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto rhs_buffer,
- rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<uint8_t>(
- MemoryAccess::kDiscardWrite));
- return KERNEL::Execute(lhs_buffer.contents(), rhs_buffer.contents(),
- dst_buffer.mutable_contents());
-}
-
-template <typename KERNEL, typename... ARGS>
-Status ApplyUnaryOpIS(BufferView* src_local, BufferView* dst_local,
- ARGS... args) {
- switch (src_local->element_size) {
- case 1:
- return ApplyUnaryOp<KERNEL, int8_t>(src_local, dst_local, args...);
- case 2:
- return ApplyUnaryOp<KERNEL, int16_t>(src_local, dst_local, args...);
- case 4:
- return ApplyUnaryOp<KERNEL, int32_t>(src_local, dst_local, args...);
- case 8:
- return ApplyUnaryOp<KERNEL, int64_t>(src_local, dst_local, args...);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << src_local->element_size;
- }
-}
-
-template <typename KERNEL, typename... ARGS>
-Status ApplyUnaryOpIU(BufferView* src_local, BufferView* dst_local,
- ARGS... args) {
- switch (src_local->element_size) {
- case 1:
- return ApplyUnaryOp<KERNEL, uint8_t>(src_local, dst_local, args...);
- case 2:
- return ApplyUnaryOp<KERNEL, uint16_t>(src_local, dst_local, args...);
- case 4:
- return ApplyUnaryOp<KERNEL, uint32_t>(src_local, dst_local, args...);
- case 8:
- return ApplyUnaryOp<KERNEL, uint64_t>(src_local, dst_local, args...);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << src_local->element_size;
- }
-}
-
-template <typename KERNEL, typename... ARGS>
-Status ApplyUnaryOpF(BufferView* src_local, BufferView* dst_local,
- ARGS... args) {
- switch (src_local->element_size) {
-#if defined(IREE_SUPPORT_F32)
- case 4:
- return ApplyUnaryOp<KERNEL, float>(src_local, dst_local, args...);
-#endif // IREE_SUPPORT_F32
-#if defined(IREE_SUPPORT_F64)
- case 8:
- return ApplyUnaryOp<KERNEL, double>(src_local, dst_local, args...);
-#endif // IREE_SUPPORT_F64
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << src_local->element_size;
- }
-}
-
-template <typename KERNEL, typename... ARGS>
-Status ApplyBinaryOpIS(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* dst_local, ARGS... args) {
- switch (lhs_local->element_size) {
- case 1:
- return ApplyBinaryOp<KERNEL, int8_t>(lhs_local, rhs_local, dst_local,
- args...);
- case 2:
- return ApplyBinaryOp<KERNEL, int16_t>(lhs_local, rhs_local, dst_local,
- args...);
- case 4:
- return ApplyBinaryOp<KERNEL, int32_t>(lhs_local, rhs_local, dst_local,
- args...);
- case 8:
- return ApplyBinaryOp<KERNEL, int64_t>(lhs_local, rhs_local, dst_local,
- args...);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << lhs_local->element_size;
- }
-}
-
-template <typename KERNEL, typename... ARGS>
-Status ApplyBinaryOpIU(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* dst_local, ARGS... args) {
- switch (lhs_local->element_size) {
- case 1:
- return ApplyBinaryOp<KERNEL, uint8_t>(lhs_local, rhs_local, dst_local,
- args...);
- case 2:
- return ApplyBinaryOp<KERNEL, uint16_t>(lhs_local, rhs_local, dst_local,
- args...);
- case 4:
- return ApplyBinaryOp<KERNEL, uint32_t>(lhs_local, rhs_local, dst_local,
- args...);
- case 8:
- return ApplyBinaryOp<KERNEL, uint64_t>(lhs_local, rhs_local, dst_local,
- args...);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << lhs_local->element_size;
- }
-}
-
-template <typename KERNEL, typename... ARGS>
-Status ApplyBinaryOpF(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* dst_local, ARGS... args) {
- switch (lhs_local->element_size) {
-#if defined(IREE_SUPPORT_F32)
- case 4:
- return ApplyBinaryOp<KERNEL, float>(lhs_local, rhs_local, dst_local,
- args...);
-#endif // IREE_SUPPORT_F32
-#if defined(IREE_SUPPORT_F64)
- case 8:
- return ApplyBinaryOp<KERNEL, double>(lhs_local, rhs_local, dst_local,
- args...);
-#endif // IREE_SUPPORT_F64
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << lhs_local->element_size;
- }
-}
-
-template <typename KERNEL, typename... ARGS>
-Status ApplyTernaryOpIS(BufferView* a_local, BufferView* b_local,
- BufferView* c_local, BufferView* dst_local,
- ARGS... args) {
- switch (a_local->element_size) {
- case 1:
- return ApplyTernaryOp<KERNEL, int8_t>(a_local, b_local, c_local,
- dst_local, args...);
- case 2:
- return ApplyTernaryOp<KERNEL, int16_t>(a_local, b_local, c_local,
- dst_local, args...);
- case 4:
- return ApplyTernaryOp<KERNEL, int32_t>(a_local, b_local, c_local,
- dst_local, args...);
- case 8:
- return ApplyTernaryOp<KERNEL, int64_t>(a_local, b_local, c_local,
- dst_local, args...);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << a_local->element_size;
- }
-}
-
-template <typename KERNEL, typename... ARGS>
-Status ApplyTernaryOpIU(BufferView* a_local, BufferView* b_local,
- BufferView* c_local, BufferView* dst_local,
- ARGS... args) {
- switch (a_local->element_size) {
- case 1:
- return ApplyTernaryOp<KERNEL, uint8_t>(a_local, b_local, c_local,
- dst_local, args...);
- case 2:
- return ApplyTernaryOp<KERNEL, uint16_t>(a_local, b_local, c_local,
- dst_local, args...);
- case 4:
- return ApplyTernaryOp<KERNEL, uint32_t>(a_local, b_local, c_local,
- dst_local, args...);
- case 8:
- return ApplyTernaryOp<KERNEL, uint64_t>(a_local, b_local, c_local,
- dst_local, args...);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << a_local->element_size;
- }
-}
-
-template <typename KERNEL, typename... ARGS>
-Status ApplyTernaryOpF(BufferView* a_local, BufferView* b_local,
- BufferView* c_local, BufferView* dst_local,
- ARGS... args) {
- switch (a_local->element_size) {
-#if defined(IREE_SUPPORT_F32)
- case 4:
- return ApplyTernaryOp<KERNEL, float>(a_local, b_local, c_local, dst_local,
- args...);
-#endif // IREE_SUPPORT_F32
-#if defined(IREE_SUPPORT_F64)
- case 8:
- return ApplyTernaryOp<KERNEL, double>(a_local, b_local, c_local,
- dst_local, args...);
-#endif // IREE_SUPPORT_F64
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << a_local->element_size;
- }
-}
-
-template <typename KERNEL>
-Status ApplyComparisonOpIS(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* dst_local) {
- switch (lhs_local->element_size) {
- case 1:
- return ApplyComparisonOp<KERNEL, int8_t>(lhs_local, rhs_local, dst_local);
- case 2:
- return ApplyComparisonOp<KERNEL, int16_t>(lhs_local, rhs_local,
- dst_local);
- case 4:
- return ApplyComparisonOp<KERNEL, int32_t>(lhs_local, rhs_local,
- dst_local);
- case 8:
- return ApplyComparisonOp<KERNEL, int64_t>(lhs_local, rhs_local,
- dst_local);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << lhs_local->element_size;
- }
-}
-
-template <typename KERNEL>
-Status ApplyComparisonOpIU(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* dst_local) {
- switch (lhs_local->element_size) {
- case 1:
- return ApplyComparisonOp<KERNEL, uint8_t>(lhs_local, rhs_local,
- dst_local);
- case 2:
- return ApplyComparisonOp<KERNEL, uint16_t>(lhs_local, rhs_local,
- dst_local);
- case 4:
- return ApplyComparisonOp<KERNEL, uint32_t>(lhs_local, rhs_local,
- dst_local);
- case 8:
- return ApplyComparisonOp<KERNEL, uint64_t>(lhs_local, rhs_local,
- dst_local);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << lhs_local->element_size;
- }
-}
-
-template <typename KERNEL>
-Status ApplyComparisonOpF(BufferView* lhs_local, BufferView* rhs_local,
- BufferView* dst_local) {
- switch (lhs_local->element_size) {
- case 4:
- return ApplyComparisonOp<KERNEL, float>(lhs_local, rhs_local, dst_local);
- case 8:
- return ApplyComparisonOp<KERNEL, double>(lhs_local, rhs_local, dst_local);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented element size: " << lhs_local->element_size;
- }
-}
-
-template <typename T, typename ACC = int32_t>
-Status ApplyMatMulOpI(kernels::MatMul::RuntimeState* runtime_state,
- BufferView* lhs_local, BufferView* rhs_local,
- BufferView* bias_local,
- BufferView* multiplier_mantissa_local,
- BufferView* multiplier_exponent_local,
- BufferView* dst_local) {
- kernels::MatMul::Buffers<T, ACC> buffers;
- ASSIGN_OR_RETURN(auto lhs_buffer,
- lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- buffers.lhs_buffer = lhs_buffer.contents();
- buffers.lhs_shape = lhs_local->shape;
- ASSIGN_OR_RETURN(auto rhs_buffer,
- rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- buffers.rhs_buffer = rhs_buffer.contents();
- buffers.rhs_shape = rhs_local->shape;
- MappedMemory<ACC> bias_buffer;
- if (bias_local && bias_local->buffer && !bias_local->shape.empty()) {
- if (bias_local->element_size != sizeof(ACC)) {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Only " << sizeof(ACC) << "b biases are supported right now";
- }
- ASSIGN_OR_RETURN(bias_buffer,
- bias_local->buffer->MapMemory<ACC>(MemoryAccess::kRead));
- buffers.bias_buffer = bias_buffer.contents();
- }
- ASSIGN_OR_RETURN(
- auto multiplier_mantissa_buffer,
- multiplier_mantissa_local->buffer->MapMemory<ACC>(MemoryAccess::kRead));
- buffers.multiplier_mantissa_buffer = multiplier_mantissa_buffer.contents();
- ASSIGN_OR_RETURN(auto multiplier_exponent_buffer,
- multiplier_exponent_local->buffer->MapMemory<int32_t>(
- MemoryAccess::kRead));
- buffers.multiplier_exponent_buffer = multiplier_exponent_buffer.contents();
- ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>(
- MemoryAccess::kDiscardWrite));
- buffers.dst_buffer = dst_buffer.mutable_contents();
- buffers.dst_shape = dst_local->shape;
- return kernels::MatMul::Execute(runtime_state, buffers);
-}
-
-template <typename T>
-Status ApplyMatMulOpF(kernels::MatMul::RuntimeState* runtime_state,
- BufferView* lhs_local, BufferView* rhs_local,
- BufferView* bias_local, BufferView* dst_local) {
- kernels::MatMul::Buffers<T, T> buffers;
- ASSIGN_OR_RETURN(auto lhs_buffer,
- lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- buffers.lhs_buffer = lhs_buffer.contents();
- buffers.lhs_shape = lhs_local->shape;
- ASSIGN_OR_RETURN(auto rhs_buffer,
- rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- buffers.rhs_buffer = rhs_buffer.contents();
- buffers.rhs_shape = rhs_local->shape;
- MappedMemory<T> bias_buffer;
- if (bias_local && bias_local->buffer && !bias_local->shape.empty()) {
- ASSIGN_OR_RETURN(bias_buffer,
- bias_local->buffer->MapMemory<T>(MemoryAccess::kRead));
- buffers.bias_buffer = bias_buffer.contents();
- }
- ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>(
- MemoryAccess::kDiscardWrite));
- buffers.dst_buffer = dst_buffer.mutable_contents();
- buffers.dst_shape = dst_local->shape;
- return kernels::MatMul::Execute(runtime_state, buffers);
-}
-
-template <typename KERNEL>
-Status DispatchElementwiseUnaryOpIS(vm::BytecodeReader* reader) {
- ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
- RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local));
- return ApplyUnaryOpIS<KERNEL>(src_local, dst_local);
-}
-
-template <typename KERNEL>
-Status DispatchElementwiseUnaryOpIU(vm::BytecodeReader* reader) {
- ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
- RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local));
- return ApplyUnaryOpIU<KERNEL>(src_local, dst_local);
-}
-
-template <typename KERNEL>
-Status DispatchElementwiseUnaryOpF(vm::BytecodeReader* reader) {
- ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
- RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local));
- return ApplyUnaryOpF<KERNEL>(src_local, dst_local);
-}
-
-template <typename KERNEL>
-Status DispatchElementwiseBinaryOpIS(vm::BytecodeReader* reader) {
- ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
- RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local));
- return ApplyBinaryOpIS<KERNEL>(lhs_local, rhs_local, dst_local);
-}
-
-template <typename KERNEL>
-Status DispatchElementwiseBinaryOpIU(vm::BytecodeReader* reader) {
- ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
- RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local));
- return ApplyBinaryOpIU<KERNEL>(lhs_local, rhs_local, dst_local);
-}
-
-template <typename KERNEL>
-Status DispatchElementwiseBinaryOpF(vm::BytecodeReader* reader) {
- ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
- RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local));
- return ApplyBinaryOpF<KERNEL>(lhs_local, rhs_local, dst_local);
-}
-
-template <typename KERNEL>
-Status DispatchElementwiseTernaryOpIS(vm::BytecodeReader* reader) {
- ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
- RETURN_IF_ERROR(
- ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local));
- return ApplyTernaryOpIS<KERNEL>(a_local, b_local, c_local, dst_local);
-}
-
-template <typename KERNEL>
-Status DispatchElementwiseTernaryOpIU(vm::BytecodeReader* reader) {
- ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
- RETURN_IF_ERROR(
- ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local));
- return ApplyTernaryOpIU<KERNEL>(a_local, b_local, c_local, dst_local);
-}
-
-template <typename KERNEL>
-Status DispatchElementwiseTernaryOpF(vm::BytecodeReader* reader) {
- ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal());
- RETURN_IF_ERROR(
- ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local));
- return ApplyTernaryOpF<KERNEL>(a_local, b_local, c_local, dst_local);
-}
-
-Status ApplyCopy(BufferView* src_local, absl::Span<const int32_t> src_indices,
- BufferView* dst_local, absl::Span<const int32_t> dst_indices,
- absl::Span<const int32_t> lengths);
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_
diff --git a/iree/hal/interpreter/bytecode_executable.cc b/iree/hal/interpreter/bytecode_executable.cc
deleted file mode 100644
index c528263..0000000
--- a/iree/hal/interpreter/bytecode_executable.cc
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/interpreter/bytecode_executable.h"
-
-#include <iostream>
-
-#include "iree/hal/interpreter/interpreter_module.h"
-#include "iree/rt/policy.h"
-
-namespace iree {
-namespace hal {
-
-// static
-StatusOr<ref_ptr<BytecodeExecutable>> BytecodeExecutable::Load(
- ref_ptr<rt::Instance> instance, hal::Allocator* allocator,
- ExecutableSpec spec, bool allow_aliasing_data) {
- // Allocate the executable now.
- // We do this here so that if we need to clone the data we are passing that
- // to the VM loader instead of the data we may not have access to later.
- auto executable = make_ref<BytecodeExecutable>(std::move(instance), allocator,
- spec, allow_aliasing_data);
-
- // Create the executable module.
- auto module_def =
- ::flatbuffers::GetRoot<ModuleDef>(executable->executable_data().data());
- ASSIGN_OR_RETURN(auto module,
- InterpreterModule::FromDef(allocator, *module_def));
- executable->module_ = add_ref(module);
- RETURN_IF_ERROR(executable->context()->RegisterModule(std::move(module)));
-
- return executable;
-}
-
-BytecodeExecutable::BytecodeExecutable(ref_ptr<rt::Instance> instance,
- hal::Allocator* allocator,
- ExecutableSpec spec,
- bool allow_aliasing_data)
- : spec_(spec),
- context_(
- make_ref<rt::Context>(std::move(instance), make_ref<rt::Policy>())) {
- if (!allow_aliasing_data) {
- // Clone data.
- cloned_executable_data_ = {spec.executable_data.begin(),
- spec.executable_data.end()};
- spec_.executable_data = absl::MakeConstSpan(cloned_executable_data_);
- }
-}
-
-BytecodeExecutable::~BytecodeExecutable() = default;
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/interpreter/bytecode_executable.h b/iree/hal/interpreter/bytecode_executable.h
deleted file mode 100644
index 6cd47de..0000000
--- a/iree/hal/interpreter/bytecode_executable.h
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_
-#define IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_
-
-#include <vector>
-
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/executable.h"
-#include "iree/hal/executable_spec.h"
-#include "iree/rt/context.h"
-#include "iree/rt/instance.h"
-#include "iree/rt/module.h"
-
-namespace iree {
-namespace hal {
-
-class BytecodeExecutable final : public Executable {
- public:
- static StatusOr<ref_ptr<BytecodeExecutable>> Load(
- ref_ptr<rt::Instance> instance, hal::Allocator* allocator,
- ExecutableSpec spec, bool allow_aliasing_data);
-
- BytecodeExecutable(ref_ptr<rt::Instance> instance, hal::Allocator* allocator,
- ExecutableSpec spec, bool allow_aliasing_data);
- ~BytecodeExecutable() override;
-
- bool supports_debugging() const override { return false; }
-
- // Reference to the bytecode blob contents.
- absl::Span<const uint8_t> executable_data() const {
- return spec_.executable_data;
- }
-
- // VM context with the executable registered.
- const ref_ptr<rt::Context>& context() const { return context_; }
-
- // VM module representing the executable.
- // Note that there may be more than one module in the Context and only this
- // module can be used to lookup executable exports.
- const ref_ptr<rt::Module>& module() const { return module_; }
-
- private:
- ExecutableSpec spec_;
- std::vector<uint8_t> cloned_executable_data_;
-
- ref_ptr<rt::Context> context_;
- ref_ptr<rt::Module> module_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_
diff --git a/iree/hal/interpreter/bytecode_kernels.h b/iree/hal/interpreter/bytecode_kernels.h
deleted file mode 100644
index 8485b98..0000000
--- a/iree/hal/interpreter/bytecode_kernels.h
+++ /dev/null
@@ -1,371 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Defines kernel functions and provides their implementation via one (or more)
-// included files.
-//
-// Kernels should do the simplest possible operation. Buffer validation is
-// handled by the dispatch logic and need not be checked. Kernels may optionally
-// accept arguments beyond just the buffers, depending on the required state
-// and attributes.
-//
-// Kernels may optionally have runtime state. This is state that is allocated
-// once for the entire Runtime (and stored on RuntimeState) and shared across
-// all fibers. This enables kernels that may require thread pools or device
-// handles to be shared while kernels that require transient storage to be safe
-// to use from multiple fibers concurrently.
-//
-// All kernels are templated to enable specialization of particular types or
-// type combinations. By default the bytecode_kernels_generic.h will provide C++
-// semantics as reference and platform-specific versions can be implemented
-// as needed.
-
-#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_
-#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_
-
-#include <cstdint>
-
-#include "absl/types/span.h"
-#include "iree/base/shape.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace hal {
-namespace kernels {
-
-struct CompareEQ {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer);
-};
-struct CompareNE {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer);
-};
-struct CompareLT {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer);
-};
-struct CompareLE {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer);
-};
-struct CompareGT {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer);
-};
-struct CompareGE {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer);
-};
-
-struct Copy {
- template <int element_size>
- static Status Execute(absl::Span<const uint8_t> src_buffer,
- const Shape& src_shape,
- absl::Span<const int32_t> src_indices,
- absl::Span<uint8_t> dst_buffer, const Shape& dst_shape,
- absl::Span<const int32_t> dst_indices,
- absl::Span<const int32_t> lengths);
-};
-
-struct Select {
- template <typename T>
- static Status Execute(absl::Span<const uint8_t> cond_buffer,
- absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Transpose {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer, const Shape& src_shape,
- absl::Span<const int32_t> perm);
-};
-
-struct Pad {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> padding_value,
- absl::Span<T> dst_buffer, const Shape& src_shape,
- const Shape& dst_shape,
- absl::Span<const int32_t> edge_padding_low,
- absl::Span<const int32_t> edge_padding_high,
- absl::Span<const int32_t> interior_padding);
-};
-
-struct Reverse {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer, const Shape& src_shape,
- absl::Span<const int32_t> dimensions);
-};
-
-struct Broadcast {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Tile {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer, const Shape& src_shape,
- const Shape& dst_shape);
-};
-
-struct Not {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct And {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Or {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Xor {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct ShiftLeft {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct ShiftRight {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Add {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Sub {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Abs {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Mul {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Div {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-// a + (b * c)
-struct MulAdd {
- template <typename T>
- static Status Execute(absl::Span<const T> a_buffer,
- absl::Span<const T> b_buffer,
- absl::Span<const T> c_buffer, absl::Span<T> dst_buffer);
-};
-
-struct Exp {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Log {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Rsqrt {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Cos {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Sin {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Tanh {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Atan2 {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Min {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Max {
- template <typename T>
- static Status Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Clamp {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> min_buffer,
- absl::Span<const T> max_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Floor {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Ceil {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer);
-};
-
-struct Convert {
- template <typename SRC, typename DST>
- static Status Execute(absl::Span<const SRC> src_buffer,
- absl::Span<DST> dst_buffer);
-};
-
-struct MatMul {
- struct RuntimeState;
-
- static std::unique_ptr<RuntimeState> CreateRuntimeState();
-
- template <typename T, typename ACC>
- struct Buffers {
- Shape lhs_shape;
- absl::Span<const T> lhs_buffer;
- Shape rhs_shape;
- absl::Span<const T> rhs_buffer;
- Shape dst_shape;
- absl::Span<T> dst_buffer;
-
- // Optional bias buffer.
- absl::Span<const ACC> bias_buffer;
-
- // Fixed-point multiplier mantissa/exponent. May be a single value (for
- // uniform quantization) or one element per row of the destination matrix
- // for per-channel.
- absl::Span<const ACC> multiplier_mantissa_buffer;
- absl::Span<const int32_t> multiplier_exponent_buffer;
- };
-
- template <typename T, typename ACC>
- static Status Execute(RuntimeState* runtime_state,
- const Buffers<T, ACC>& buffers);
-};
-
-struct RuntimeState {
- std::unique_ptr<MatMul::RuntimeState> mat_mul_state =
- MatMul::CreateRuntimeState();
-};
-
-struct ReduceSum {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> init_buffer,
- absl::Span<T> dst_buffer, int32_t dimension,
- const Shape& src_shape, const Shape& dst_shape);
-};
-
-struct ReduceMin {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> init_buffer,
- absl::Span<T> dst_buffer, int32_t dimension,
- const Shape& src_shape, const Shape& dst_shape);
-};
-
-struct ReduceMax {
- template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> init_buffer,
- absl::Span<T> dst_buffer, int32_t dimension,
- const Shape& src_shape, const Shape& dst_shape);
-};
-
-} // namespace kernels
-} // namespace hal
-} // namespace iree
-
-#include "iree/hal/interpreter/bytecode_kernels_generic.h" // IWYU pragma: export
-#include "iree/hal/interpreter/bytecode_kernels_ruy.h" // IWYU pragma: export
-
-#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_
diff --git a/iree/hal/interpreter/bytecode_kernels_generic.h b/iree/hal/interpreter/bytecode_kernels_generic.h
deleted file mode 100644
index ddf69c7..0000000
--- a/iree/hal/interpreter/bytecode_kernels_generic.h
+++ /dev/null
@@ -1,708 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_
-#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_
-
-#include "absl/container/flat_hash_set.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace hal {
-namespace kernels {
-
-template <typename T>
-Status CompareEQ::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] == rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status CompareNE::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] != rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status CompareLT::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] < rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status CompareLE::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] <= rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status CompareGT::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] > rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status CompareGE::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<uint8_t> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] >= rhs_buffer[i];
- }
- return OkStatus();
-}
-
-namespace impl {
-inline absl::InlinedVector<size_t, 6> ComputeCopyStrides(const Shape& shape,
- size_t element_size) {
- absl::InlinedVector<size_t, 6> strides(shape.size());
- strides.back() = element_size;
- for (int i = shape.size() - 2; i >= 0; --i) {
- strides[i] = strides[i + 1] * shape[i + 1];
- }
- return strides;
-}
-
-inline void CopyRegion(absl::Span<const uint8_t> src_buffer,
- absl::Span<const size_t> src_strides,
- absl::Span<const int32_t> src_indices,
- absl::Span<uint8_t> dst_buffer,
- absl::Span<const size_t> dst_strides,
- absl::Span<const int32_t> dst_indices,
- absl::Span<const int32_t> lengths) {
- if (lengths.size() > 1) {
- for (int i = 0; i < lengths[0]; ++i) {
- size_t src_offset = src_strides[0] * (src_indices[0] + i);
- size_t dst_offset = dst_strides[0] * (dst_indices[0] + i);
- CopyRegion(src_buffer.subspan(src_offset), src_strides.subspan(1),
- src_indices.subspan(1), dst_buffer.subspan(dst_offset),
- dst_strides.subspan(1), dst_indices.subspan(1),
- lengths.subspan(1));
- }
- } else {
- DCHECK_EQ(dst_strides.size(), 1);
- DCHECK_EQ(src_strides.size(), 1);
- DCHECK_EQ(src_indices.size(), 1);
- DCHECK_EQ(dst_indices.size(), 1);
- DCHECK_EQ(lengths.size(), 1);
- auto src_offset = src_indices[0] * src_strides[0];
- auto dst_offset = dst_indices[0] * dst_strides[0];
- auto length = dst_strides[0] * lengths[0];
- std::memcpy(dst_buffer.data() + dst_offset, src_buffer.data() + src_offset,
- length);
- }
-}
-} // namespace impl
-
-// TODO(benvanik): replace with a real implementation once copy is defined.
-// TODO(gcmn): More consistent/principled handling for scalars.
-template <int element_size>
-Status Copy::Execute(absl::Span<const uint8_t> src_buffer,
- const Shape& src_shape,
- absl::Span<const int32_t> src_indices,
- absl::Span<uint8_t> dst_buffer, const Shape& dst_shape,
- absl::Span<const int32_t> dst_indices,
- absl::Span<const int32_t> lengths) {
- DCHECK_EQ(src_indices.size(), lengths.size());
- DCHECK_EQ(dst_indices.size(), lengths.size());
- DCHECK_EQ(src_shape.size(), lengths.size());
- DCHECK_EQ(dst_shape.size(), lengths.size());
- if (lengths.empty()) {
- std::memcpy(dst_buffer.data(), src_buffer.data(), element_size);
- return OkStatus();
- }
-
- // TODO(gcmn) Maybe we can fast-path earlier if we detect contiguous memory
- // across multiple rows.
- auto src_strides = impl::ComputeCopyStrides(src_shape, element_size);
- auto dst_strides = impl::ComputeCopyStrides(dst_shape, element_size);
- DCHECK_EQ(src_strides.size(), lengths.size());
- DCHECK_EQ(dst_strides.size(), lengths.size());
- impl::CopyRegion(src_buffer, src_strides, src_indices, dst_buffer,
- dst_strides, dst_indices, lengths);
- return OkStatus();
-}
-
-template <typename T>
-Status Select::Execute(absl::Span<const uint8_t> cond_buffer,
- absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = cond_buffer[i] ? lhs_buffer[i] : rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Transpose::Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer, const Shape& src_shape,
- absl::Span<const int32_t> perm) {
- // This implementation is .... not fast.
- int rank = src_shape.size();
- absl::InlinedVector<int, 8> src_strides(rank);
- absl::InlinedVector<int, 8> dst_strides(rank);
- size_t src_stride = 1;
- size_t dst_stride = 1;
- for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
- src_strides[dim_i] = src_stride;
- dst_strides[dim_i] = dst_stride;
- src_stride *= src_shape[dim_i];
- dst_stride *= src_shape[perm[dim_i]];
- }
- for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
- size_t src_i = 0;
- size_t t = dst_i;
- for (int dim_i = 0; dim_i < rank; ++dim_i) {
- size_t ratio = t / dst_strides[dim_i];
- t -= ratio * dst_strides[dim_i];
- src_i += ratio * src_strides[perm[dim_i]];
- }
- dst_buffer[dst_i] = src_buffer[src_i];
- }
- return OkStatus();
-}
-
-namespace impl {
-inline void IncrementShapeIndex(absl::Span<int32_t> indices,
- const Shape& shape) {
- for (int i = indices.size() - 1; i >= 0; --i) {
- if (++indices[i] < shape[i]) return;
- indices[i] = 0;
- }
-}
-
-inline bool IsPadding(absl::Span<const int32_t> indices, const Shape& shape,
- absl::Span<const int32_t> edge_padding_low,
- absl::Span<const int32_t> edge_padding_high,
- absl::Span<const int32_t> interior_padding) {
- for (int i = 0; i < indices.size(); ++i) {
- auto index = indices[i];
- if (index < edge_padding_low[i] ||
- index >= shape[i] - edge_padding_high[i] ||
- (index - edge_padding_low[i]) % (interior_padding[i] + 1) != 0) {
- return true;
- }
- }
-
- return false;
-}
-} // namespace impl
-
-template <typename T>
-Status Pad::Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> padding_value_buffer,
- absl::Span<T> dst_buffer, const Shape& src_shape,
- const Shape& dst_shape,
- absl::Span<const int32_t> edge_padding_low,
- absl::Span<const int32_t> edge_padding_high,
- absl::Span<const int32_t> interior_padding) {
- // This implementation is not at all fast, as it iterates every index in the
- // destination buffer individually. Potential improvements:
- // 1. Fill the dst buffer with padded value initially. Only need to iterate
- // through source buffer and can exit early.
- // 2. Use striding to advance through larger swaths of the buffer with a
- // memcpy from src and filling (or skipping) padded incides. Especially
- // useful when e.g. entire rows are padded.
-
- // TODO(b/140836672) support negative padding
-
- if (padding_value_buffer.size() != 1) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Padding value buffer is larger than one element.";
- }
- auto padding_value = padding_value_buffer.front();
-
- absl::InlinedVector<int, 8> dst_indices(src_shape.size(), 0);
-
- const T* src_ptr = src_buffer.begin();
- T* dst_ptr = dst_buffer.begin();
- while (dst_ptr != dst_buffer.end()) {
- if (impl::IsPadding(dst_indices, dst_shape, edge_padding_low,
- edge_padding_high, interior_padding)) {
- *dst_ptr++ = padding_value;
- } else {
- DCHECK(src_ptr != src_buffer.end());
- *dst_ptr++ = *src_ptr++;
- }
- impl::IncrementShapeIndex(absl::MakeSpan(dst_indices), dst_shape);
- }
-
- return OkStatus();
-}
-
-template <typename T>
-Status Reverse::Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer, const Shape& src_shape,
- absl::Span<const int32_t> dimensions) {
- // This implementation is not fast either
- int rank = src_shape.size();
- absl::InlinedVector<int, 8> strides(rank);
- size_t stride = 1;
- for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
- strides[dim_i] = stride;
- stride *= src_shape[dim_i];
- }
- absl::flat_hash_set<int32_t> dims_set(dimensions.begin(), dimensions.end());
- for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
- size_t src_i = 0;
- size_t t = dst_i;
- for (int dim_i = 0; dim_i < rank; ++dim_i) {
- size_t ratio = t / strides[dim_i];
- t -= ratio * strides[dim_i];
- bool do_reverse = dims_set.contains(dim_i);
- src_i += (do_reverse ? (src_shape[dim_i] - 1 - ratio) : ratio) *
- strides[dim_i];
- }
- dst_buffer[dst_i] = src_buffer[src_i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Broadcast::Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = src_buffer[0];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Tile::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer,
- const Shape& src_shape, const Shape& dst_shape) {
- // This implementation is .... not fast.
- int rank = dst_shape.size();
- absl::InlinedVector<int, 8> src_strides(rank);
- absl::InlinedVector<int, 8> dst_strides(rank);
- size_t src_stride = 1;
- size_t dst_stride = 1;
- for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
- src_strides[dim_i] = src_stride;
- dst_strides[dim_i] = dst_stride;
- src_stride *= src_shape[dim_i];
- dst_stride *= dst_shape[dim_i];
- }
- for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
- size_t src_i = 0;
- size_t t = dst_i;
- for (int dim_i = 0; dim_i < rank; ++dim_i) {
- src_i += t / dst_strides[dim_i] % src_shape[dim_i] * src_strides[dim_i];
- t %= dst_strides[dim_i];
- }
- dst_buffer[dst_i] = src_buffer[src_i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Not::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = ~src_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status And::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] & rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Or::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] | rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Xor::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] ^ rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status ShiftLeft::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] << rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status ShiftRight::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] >> rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Add::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] + rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Sub::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] - rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Abs::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::abs(src_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Mul::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] * rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Div::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = lhs_buffer[i] / rhs_buffer[i];
- }
- return OkStatus();
-}
-
-template <typename T>
-Status MulAdd::Execute(absl::Span<const T> a_buffer,
- absl::Span<const T> b_buffer,
- absl::Span<const T> c_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = a_buffer[i] + (b_buffer[i] * c_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Exp::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::exp(src_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Rsqrt::Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = 1.0 / std::sqrt(src_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Log::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::log(src_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Cos::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::cos(src_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Sin::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::sin(src_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Tanh::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::tanh(src_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Atan2::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer,
- absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::atan2(lhs_buffer[i], rhs_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Min::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::min(lhs_buffer[i], rhs_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Max::Execute(absl::Span<const T> lhs_buffer,
- absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::max(lhs_buffer[i], rhs_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Clamp::Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> min_buffer,
- absl::Span<const T> max_buffer,
- absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- T src = src_buffer[i];
- T min = min_buffer[i];
- T max = max_buffer[i];
- dst_buffer[i] = src <= min ? min : src >= max ? max : src;
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Floor::Execute(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::floor(src_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename T>
-Status Ceil::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = std::ceil(src_buffer[i]);
- }
- return OkStatus();
-}
-
-template <typename SRC, typename DST>
-Status Convert::Execute(absl::Span<const SRC> src_buffer,
- absl::Span<DST> dst_buffer) {
- DCHECK_EQ(src_buffer.size(), dst_buffer.size());
- for (size_t i = 0; i < dst_buffer.size(); ++i) {
- dst_buffer[i] = static_cast<DST>(src_buffer[i]);
- }
- return OkStatus();
-}
-
-namespace impl {
-
-struct SumKernel {
- template <typename T>
- inline void operator()(T* value0, const T value1) {
- *value0 += value1;
- }
-};
-
-struct MinKernel {
- template <typename T>
- inline void operator()(T* value0, const T value1) {
- *value0 = std::min(*value0, value1);
- }
-};
-
-struct MaxKernel {
- template <typename T>
- inline void operator()(T* value0, const T value1) {
- *value0 = std::max(*value0, value1);
- }
-};
-
-template <typename T, typename KernelImpl>
-inline void ReduceDimension(absl::Span<const T> src_buffer,
- absl::Span<T> dst_buffer, const Shape& src_shape,
- absl::Span<const int32_t> reduce_dims,
- absl::Span<const int> dst_strides, int dim,
- absl::Span<int> src_indices, size_t flat_src_i,
- size_t src_stride) {
- if (dim < 0) {
- // Base case of the recursion - figure out which elements should be acted
- // upon and apply the reduction kernel to them.
-
- // Derive destination indices from source indices.
- // For example,
- // reduce_dims: [1, 2]
- // src_indices: [2, 1, 3, 0]
- // ^ ^
- // | |
- // |----- remove these dimensions
- // dst_indices: [2, 0]
- //
- // TODO(scotttodd): Clean this up somehow, share across recursion levels?
- size_t dst_size = src_shape.size() - reduce_dims.size();
- absl::InlinedVector<int, 8> dst_indices;
- for (size_t i = 0; i < src_indices.size(); ++i) {
- if (std::find(std::begin(reduce_dims), std::end(reduce_dims), i) ==
- std::end(reduce_dims)) {
- dst_indices.push_back(src_indices[i]);
- }
- }
- // Compute the flattened index into dst_buffer at [dst_indices].
- size_t dst_i = 0;
- for (size_t i = 0; i < dst_indices.size(); ++i) {
- dst_i += dst_indices[i] * dst_strides[dst_size - 1 - i];
- }
-
- // Flattened src and dst indices have been computed, invoke the kernel.
- KernelImpl()(&dst_buffer[dst_i], src_buffer[flat_src_i]);
- return;
- }
-
- // Iterate through the current dimension in the source shape, recursing
- // down one dimension at a time.
- //
- // This touches each element in the source buffer once, tracking complete
- // dimensions within the shaped source buffer and using them to compute
- // the corresponding indices (shaped and flattened) within the destination
- // buffer. Each element in the destination buffer will be touched multiple
- // times.
- //
- // Note that cache coherency isn't considered here, and some computations
- // are redundant, so this could be optimized substantially.
- for (size_t dim_i = 0; dim_i < src_shape[dim]; ++dim_i) {
- src_indices[dim] = dim_i;
-
- // Recurse down to the next dimension (e.g. 2 -> 1 -> 0 -> base case)
- // * Add the current stride to flat_src_i
- // * Multiply src_stride by this dimension's shape
- ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape,
- reduce_dims, dst_strides, dim - 1,
- src_indices, flat_src_i + dim_i * src_stride,
- src_stride * src_shape[dim]);
- }
-}
-
-template <typename T, typename KernelImpl>
-Status GenericReduce(absl::Span<const T> src_buffer,
- absl::Span<const T> init_buffer, absl::Span<T> dst_buffer,
- int32_t dimension, const Shape& src_shape,
- const Shape& dst_shape) {
- // Initialize using init_buffer, which is expected to be a scalar.
- std::fill_n(dst_buffer.data(), dst_buffer.size(), init_buffer[0]);
-
- // Precompute destination strides.
- int dst_rank = dst_shape.size();
- absl::InlinedVector<int, 8> dst_strides;
- size_t dst_stride = 1;
- for (int dim_i = dst_rank - 1; dim_i >= 0; --dim_i) {
- dst_strides.push_back(dst_stride);
- dst_stride *= dst_shape[dim_i];
- }
-
- // Call the helper (recursive) function, starting with:
- // * source index [0, 0, ..., 0]
- // * the innermost dimension (last in the shape)
- // * flat_src_i of 0 (corresponds to [0, 0, ..., 0] above)
- // * source stride 1
- absl::InlinedVector<int, 8> src_indices(src_shape.size(), 0);
- ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape, {dimension},
- absl::MakeSpan(dst_strides),
- src_shape.size() - 1,
- absl::MakeSpan(src_indices), 0, 1);
-
- return OkStatus();
-}
-
-} // namespace impl
-
-template <typename T>
-Status ReduceSum::Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> init_buffer,
- absl::Span<T> dst_buffer, int32_t dimension,
- const Shape& src_shape, const Shape& dst_shape) {
- return impl::GenericReduce<T, impl::SumKernel>(
- src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
-}
-
-template <typename T>
-Status ReduceMin::Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> init_buffer,
- absl::Span<T> dst_buffer, int32_t dimension,
- const Shape& src_shape, const Shape& dst_shape) {
- return impl::GenericReduce<T, impl::MinKernel>(
- src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
-}
-
-template <typename T>
-Status ReduceMax::Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> init_buffer,
- absl::Span<T> dst_buffer, int32_t dimension,
- const Shape& src_shape, const Shape& dst_shape) {
- return impl::GenericReduce<T, impl::MaxKernel>(
- src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
-}
-
-} // namespace kernels
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_
diff --git a/iree/hal/interpreter/bytecode_kernels_ruy.h b/iree/hal/interpreter/bytecode_kernels_ruy.h
deleted file mode 100644
index 60bcb44..0000000
--- a/iree/hal/interpreter/bytecode_kernels_ruy.h
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_
-#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_
-
-#include "absl/base/thread_annotations.h"
-#include "absl/memory/memory.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "tensorflow/lite/experimental/ruy/context.h"
-#include "tensorflow/lite/experimental/ruy/ruy.h"
-
-namespace iree {
-namespace hal {
-namespace kernels {
-
-// TODO(benvanik): something more clever for making this shareable.
-// Maybe a factory fn based on the impl selected?
-struct MatMul::RuntimeState {
- // TODO(benvanik): share the thread pool but keep context per-fiber?
- ruy::Context context;
-};
-
-inline std::unique_ptr<MatMul::RuntimeState> MatMul::CreateRuntimeState() {
- return absl::make_unique<RuntimeState>();
-}
-
-template <typename T, typename ACC>
-Status MatMul::Execute(RuntimeState* runtime_state,
- const Buffers<T, ACC>& buffers) {
- ruy::Matrix<T> lhs_matrix;
- ruy::MakeSimpleLayout(buffers.lhs_shape[0], buffers.lhs_shape[1],
- ruy::Order::kRowMajor, &lhs_matrix.layout);
- lhs_matrix.data.set(buffers.lhs_buffer.data());
-
- ruy::Matrix<T> rhs_matrix;
- ruy::MakeSimpleLayout(buffers.rhs_shape[0], buffers.rhs_shape[1],
- ruy::Order::kRowMajor, &rhs_matrix.layout);
- rhs_matrix.data.set(buffers.rhs_buffer.data());
-
- ruy::Matrix<T> dst_matrix;
- ruy::MakeSimpleLayout(buffers.dst_shape[0], buffers.dst_shape[1],
- ruy::Order::kRowMajor, &dst_matrix.layout);
- dst_matrix.data.set(buffers.dst_buffer.data());
-
- ruy::BasicSpec<ACC, T> spec;
- spec.bias = buffers.bias_buffer.data();
-
- if (buffers.multiplier_mantissa_buffer.size() == 1) {
- spec.multiplier_fixedpoint = buffers.multiplier_mantissa_buffer[0];
- spec.multiplier_exponent = buffers.multiplier_exponent_buffer[0];
- } else {
- spec.multiplier_fixedpoint_perchannel =
- buffers.multiplier_mantissa_buffer.data();
- spec.multiplier_exponent_perchannel =
- buffers.multiplier_exponent_buffer.data();
- }
-
- ruy::Mul<ruy::kAllPaths>(lhs_matrix, rhs_matrix, spec,
- &runtime_state->context, &dst_matrix);
-
- return OkStatus();
-}
-
-} // namespace kernels
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_
diff --git a/iree/hal/interpreter/bytecode_kernels_test.cc b/iree/hal/interpreter/bytecode_kernels_test.cc
deleted file mode 100644
index db9a63b..0000000
--- a/iree/hal/interpreter/bytecode_kernels_test.cc
+++ /dev/null
@@ -1,407 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/interpreter/bytecode_kernels.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/memory.h"
-#include "iree/base/status_matchers.h"
-
-namespace iree {
-namespace hal {
-namespace kernels {
-
-namespace {
-
-constexpr float kEpsilon = 0.0001f;
-
-template <typename T>
-std::vector<T> MakeIota(int size) {
- std::vector<T> v(size);
- std::iota(v.begin(), v.end(), static_cast<T>(1));
- return v;
-}
-
-TEST(Copy, WholeBuffer) {
- Shape src_shape = {2, 2};
- auto src_buffer = MakeIota<uint8_t>(4);
- std::vector<int32_t> src_indices = {0, 0};
- Shape dst_shape = src_shape;
- std::vector<uint8_t> dst_buffer(dst_shape.element_count());
- std::vector<int32_t> dst_indices = {0, 0};
- std::vector<int32_t> lengths = {2, 2};
- auto expected_dst = src_buffer;
-
- EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
- absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
- lengths));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, FirstRow) {
- Shape src_shape = {3, 4};
- auto src_buffer = MakeIota<uint8_t>(12);
- std::vector<int32_t> src_indices = {0, 0};
- Shape dst_shape = {1, 4};
- std::vector<uint8_t> dst_buffer(dst_shape.element_count());
- std::vector<int32_t> dst_indices = {0, 0};
- std::vector<int32_t> lengths = {1, 4};
- std::vector<uint8_t> expected_dst = {1, 2, 3, 4};
-
- EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
- absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
- lengths));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, RowPart) {
- Shape src_shape = {3, 4};
- auto src_buffer = MakeIota<uint8_t>(12);
- std::vector<int32_t> src_indices = {1, 1};
- Shape dst_shape = {1, 2};
- std::vector<uint8_t> dst_buffer(dst_shape.element_count());
- std::vector<int32_t> dst_indices = {0, 0};
- std::vector<int32_t> lengths = {1, 2};
- std::vector<uint8_t> expected_dst = {6, 7};
-
- EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
- absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
- lengths));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, MultiRow) {
- Shape src_shape = {3, 4};
- auto src_buffer = MakeIota<uint8_t>(12);
- std::vector<int32_t> src_indices = {1, 0};
- Shape dst_shape = {2, 4};
- std::vector<uint8_t> dst_buffer(dst_shape.element_count());
- std::vector<int32_t> dst_indices = {0, 0};
- std::vector<int32_t> lengths = {2, 4};
- std::vector<uint8_t> expected_dst = {5, 6, 7, 8, 9, 10, 11, 12};
-
- EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
- absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
- lengths));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, NonContiguous) {
- Shape src_shape = {3, 4};
- auto src_buffer = MakeIota<uint8_t>(12);
- std::vector<int32_t> src_indices = {1, 1};
- Shape dst_shape = {2, 2};
- std::vector<uint8_t> dst_buffer(dst_shape.element_count());
- std::vector<int32_t> dst_indices = {0, 0};
- std::vector<int32_t> lengths = {2, 2};
- std::vector<uint8_t> expected_dst = {6, 7, 10, 11};
-
- EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
- absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
- lengths));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, MultiByte) {
- Shape src_shape = {3, 4};
- auto src_vals = MakeIota<int32_t>(12);
- auto src_buffer = ReinterpretSpan<uint8_t>(absl::MakeSpan(src_vals));
- std::vector<int32_t> src_indices = {1, 1};
- Shape dst_shape = {2, 2};
- std::vector<uint8_t> dst_buffer(dst_shape.element_count() * sizeof(int32_t));
- std::vector<int32_t> dst_indices = {0, 0};
- std::vector<int32_t> lengths = {2, 2};
- std::vector<int32_t> expected_dst = {6, 7, 10, 11};
-
- EXPECT_OK(Copy::Execute<4>(src_buffer, src_shape, src_indices,
- absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
- lengths));
-
- absl::Span<int32_t> dst_buffer_int32_t =
- ReinterpretSpan<int32_t>(absl::MakeSpan(dst_buffer));
-
- EXPECT_EQ(dst_buffer_int32_t, expected_dst);
-}
-
-TEST(Copy, NotFullDst) {
- Shape src_shape = {3, 4};
- auto src_buffer = MakeIota<uint8_t>(12);
- std::vector<int32_t> src_indices = {0, 0};
- Shape dst_shape = {4, 3};
- std::vector<uint8_t> dst_buffer(12, 42);
- std::vector<int32_t> dst_indices = {1, 1};
- std::vector<int32_t> lengths = {2, 2};
- // clang-format off
- std::vector<uint8_t> expected_dst = {42, 42, 42,
- 42, 1, 2,
- 42, 5, 6,
- 42, 42, 42};
- // clang-format on
-
- EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
- absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
- lengths));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, HighRank) {
- Shape src_shape = {3, 3, 3, 3};
- auto src_buffer = MakeIota<uint8_t>(81);
- std::vector<int32_t> src_indices = {1, 1, 1, 1};
- Shape dst_shape = {2, 2, 2, 2};
- std::vector<uint8_t> dst_buffer(dst_shape.element_count());
- std::vector<int32_t> dst_indices = {0, 0, 0, 0};
- std::vector<int32_t> lengths = {2, 2, 2, 2};
- std::vector<uint8_t> expected_dst = {41, 42, 44, 45, 50, 51, 53, 54,
- 68, 69, 71, 72, 77, 78, 80, 81};
-
- EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
- absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
- lengths));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, Scalar) {
- Shape src_shape = {};
- std::vector<uint8_t> src_buffer = {42};
- std::vector<int32_t> src_indices = {};
- Shape dst_shape = {};
- std::vector<uint8_t> dst_buffer(dst_shape.element_count());
- std::vector<int32_t> dst_indices = {};
- std::vector<int32_t> lengths = {};
- std::vector<uint8_t> expected_dst = {42};
-
- EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices,
- absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
- lengths));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Copy, ScalarMultiByte) {
- Shape src_shape = {};
- std::vector<int32_t> src_vals = {INT32_MAX};
- auto src_buffer = ReinterpretSpan<uint8_t>(absl::MakeSpan(src_vals));
- std::vector<int32_t> src_indices = {};
- Shape dst_shape = {};
- std::vector<uint8_t> dst_buffer(sizeof(int32_t));
- std::vector<int32_t> dst_indices = {};
- std::vector<int32_t> lengths = {};
- std::vector<int32_t> expected_dst = {INT32_MAX};
-
- EXPECT_OK(Copy::Execute<4>(src_buffer, src_shape, src_indices,
- absl::MakeSpan(dst_buffer), dst_shape, dst_indices,
- lengths));
-
- absl::Span<int32_t> dst_buffer_int32_t =
- ReinterpretSpan<int32_t>(absl::MakeSpan(dst_buffer));
-
- EXPECT_EQ(dst_buffer_int32_t, expected_dst);
-}
-
-TEST(Pad, NoPadding) {
- Shape src_shape = {2, 3};
- auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
- std::vector<uint16_t> pad_value_buffer = {0};
- std::vector<int32_t> edge_padding_low = {0, 0};
- std::vector<int32_t> edge_padding_high = {0, 0};
- std::vector<int32_t> interior_padding = {0, 0};
- Shape dst_shape = src_shape;
- std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
- auto expected_dst = src_buffer;
-
- EXPECT_OK(Pad::Execute<uint16_t>(
- src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
- dst_shape, edge_padding_low, edge_padding_high, interior_padding));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, LowHighPadding) {
- Shape src_shape = {2, 3};
- auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
- std::vector<uint16_t> pad_value_buffer = {0};
- std::vector<int32_t> edge_padding_low = {0, 1};
- std::vector<int32_t> edge_padding_high = {1, 2};
- std::vector<int32_t> interior_padding = {0, 0};
- Shape dst_shape = {3, 6};
- std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
- // clang-format off
- std::vector<uint16_t> expected_dst = {0, 1, 2, 3, 0, 0,
- 0, 4, 5, 6, 0, 0,
- 0, 0, 0, 0, 0, 0};
- // clang-format on
-
- EXPECT_OK(Pad::Execute<uint16_t>(
- src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
- dst_shape, edge_padding_low, edge_padding_high, interior_padding));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, OnlyHighPadding) {
- Shape src_shape = {2, 3};
- auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
- std::vector<uint16_t> pad_value_buffer = {0};
- std::vector<int32_t> edge_padding_low = {0, 0};
- std::vector<int32_t> edge_padding_high = {1, 3};
- std::vector<int32_t> interior_padding = {0, 0};
- Shape dst_shape = {3, 6};
- std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
- // clang-format off
- std::vector<uint16_t> expected_dst = {1, 2, 3, 0, 0, 0,
- 4, 5, 6, 0, 0, 0,
- 0, 0, 0, 0, 0, 0};
- // clang-format on
-
- EXPECT_OK(Pad::Execute<uint16_t>(
- src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
- dst_shape, edge_padding_low, edge_padding_high, interior_padding));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, OnlyLowPadding) {
- Shape src_shape = {2, 3};
- auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
- std::vector<uint16_t> pad_value_buffer = {0};
- std::vector<int32_t> edge_padding_low = {1, 3};
- std::vector<int32_t> edge_padding_high = {0, 0};
- std::vector<int32_t> interior_padding = {0, 0};
- Shape dst_shape = {3, 6};
- std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
- // clang-format off
- std::vector<uint16_t> expected_dst = {0, 0, 0, 0, 0, 0,
- 0, 0, 0, 1, 2, 3,
- 0, 0, 0, 4, 5, 6};
- // clang-format on
-
- EXPECT_OK(Pad::Execute<uint16_t>(
- src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
- dst_shape, edge_padding_low, edge_padding_high, interior_padding));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, OnlyInteriorPadding) {
- Shape src_shape = {2, 3};
- auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
- std::vector<uint16_t> pad_value_buffer = {0};
- std::vector<int32_t> edge_padding_low = {0, 0};
- std::vector<int32_t> edge_padding_high = {0, 0};
- std::vector<int32_t> interior_padding = {1, 1};
- Shape dst_shape = {3, 5};
- std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
- // clang-format off
- std::vector<uint16_t> expected_dst = {1, 0, 2, 0, 3,
- 0, 0, 0, 0, 0,
- 4, 0, 5, 0, 6};
- // clang-format on
-
- EXPECT_OK(Pad::Execute<uint16_t>(
- src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
- dst_shape, edge_padding_low, edge_padding_high, interior_padding));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, AllPaddingTypes) {
- Shape src_shape = {2, 3};
- auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
- std::vector<uint16_t> pad_value_buffer = {0};
- std::vector<int32_t> edge_padding_low = {1, 1};
- std::vector<int32_t> edge_padding_high = {1, 2};
- std::vector<int32_t> interior_padding = {1, 1};
- Shape dst_shape = {5, 8};
- std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
- // clang-format off
- std::vector<uint16_t> expected_dst = {0, 0, 0, 0, 0, 0, 0, 0,
- 0, 1, 0, 2, 0, 3, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 4, 0, 5, 0, 6, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0};
- // clang-format on
-
- EXPECT_OK(Pad::Execute<uint16_t>(
- src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
- dst_shape, edge_padding_low, edge_padding_high, interior_padding));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(Pad, HighRank) {
- Shape src_shape = {2, 2, 2, 2};
- auto src_buffer = MakeIota<uint16_t>(src_shape.element_count());
- std::vector<uint16_t> pad_value_buffer = {0};
- std::vector<int32_t> edge_padding_low = {1, 0, 0, 0};
- std::vector<int32_t> edge_padding_high = {0, 1, 0, 0};
- std::vector<int32_t> interior_padding = {0, 0, 1, 0};
- Shape dst_shape = {3, 3, 3, 2};
- std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX);
- // clang-format off
- std::vector<uint16_t> expected_dst = { 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0,
-
- 1, 2, 0, 0, 3, 4,
- 5, 6, 0, 0, 7, 8,
- 0, 0, 0, 0, 0, 0,
-
- 9, 10, 0, 0, 11, 12,
- 13, 14, 0, 0, 15, 16,
- 0, 0, 0, 0, 0, 0};
- // clang-format on
-
- ASSERT_EQ(dst_buffer.size(), expected_dst.size());
-
- EXPECT_OK(Pad::Execute<uint16_t>(
- src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape,
- dst_shape, edge_padding_low, edge_padding_high, interior_padding));
- EXPECT_EQ(dst_buffer, expected_dst);
-}
-
-TEST(ReduceSum, Scalar) {
- Shape src_shape = {5};
- int32_t dimension = 0;
- Shape dst_shape = {1};
- std::vector<float> src_buffer = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
- std::vector<float> init_buffer = {0.0f};
- std::vector<float> dst_buffer(dst_shape.element_count(), 0.0f);
- std::vector<float> expected_dst = {5.0f};
-
- EXPECT_OK(ReduceSum::Execute<float>(src_buffer, init_buffer,
- absl::MakeSpan(dst_buffer), dimension,
- src_shape, dst_shape));
-
- for (int i = 0; i < dst_buffer.size(); ++i) {
- EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
- }
-}
-
-TEST(ReduceMin, TwoDimensionsToOne) {
- Shape src_shape = {3, 3};
- int32_t dimension = 0;
- Shape dst_shape = {3};
- std::vector<float> src_buffer = MakeIota<float>(src_shape.element_count());
- std::vector<float> init_buffer = {std::numeric_limits<float>::max()};
- std::vector<float> dst_buffer(dst_shape.element_count(), 0.0f);
- std::vector<float> expected_dst = {1.0f, 2.0f, 3.0f};
-
- EXPECT_OK(ReduceMin::Execute<float>(src_buffer, init_buffer,
- absl::MakeSpan(dst_buffer), dimension,
- src_shape, dst_shape));
-
- for (int i = 0; i < dst_buffer.size(); ++i) {
- EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
- }
-}
-
-} // namespace
-} // namespace kernels
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/interpreter/interpreter_command_processor.cc b/iree/hal/interpreter/interpreter_command_processor.cc
deleted file mode 100644
index 783809d..0000000
--- a/iree/hal/interpreter/interpreter_command_processor.cc
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/interpreter/interpreter_command_processor.h"
-
-#include "absl/container/inlined_vector.h"
-#include "absl/types/span.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/interpreter/bytecode_executable.h"
-#include "iree/rt/stack.h"
-
-namespace iree {
-namespace hal {
-
-InterpreterCommandProcessor::InterpreterCommandProcessor(
- Allocator* allocator, CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories)
- : HostLocalCommandProcessor(allocator, mode, command_categories) {}
-
-InterpreterCommandProcessor::~InterpreterCommandProcessor() = default;
-
-Status InterpreterCommandProcessor::Dispatch(
- const DispatchRequest& dispatch_request) {
- IREE_TRACE_SCOPE0("InterpreterCommandProcessor::Dispatch");
-
- // Lookup the exported function.
- auto* executable =
- static_cast<BytecodeExecutable*>(dispatch_request.executable);
- const auto& module = executable->module();
- ASSIGN_OR_RETURN(auto entry_function, module->LookupFunctionByOrdinal(
- rt::Function::Linkage::kExport,
- dispatch_request.entry_point));
-
- rt::Stack stack(executable->context().get());
-
- // TODO(benvanik): avoid this by directly referencing the bindings.
- absl::InlinedVector<BufferView, 8> arguments;
- arguments.reserve(dispatch_request.bindings.size());
- for (auto& binding : dispatch_request.bindings) {
- arguments.push_back(BufferView{add_ref(binding.buffer), binding.shape,
- binding.element_size});
- }
- absl::InlinedVector<BufferView, 8> results;
-
- RETURN_IF_ERROR(executable->module()->Execute(
- &stack, entry_function, std::move(arguments), &results));
-
- return OkStatus();
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/interpreter/interpreter_command_processor.h b/iree/hal/interpreter/interpreter_command_processor.h
deleted file mode 100644
index b3e474d..0000000
--- a/iree/hal/interpreter/interpreter_command_processor.h
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_
-#define IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_
-
-#include "iree/hal/host/host_local_command_processor.h"
-
-namespace iree {
-namespace hal {
-
-class InterpreterCommandProcessor final : public HostLocalCommandProcessor {
- public:
- InterpreterCommandProcessor(Allocator* allocator,
- CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories);
- ~InterpreterCommandProcessor() override;
-
- Status Dispatch(const DispatchRequest& dispatch_request) override;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_
diff --git a/iree/hal/interpreter/interpreter_device.cc b/iree/hal/interpreter/interpreter_device.cc
deleted file mode 100644
index 7a61cc3..0000000
--- a/iree/hal/interpreter/interpreter_device.cc
+++ /dev/null
@@ -1,173 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/interpreter/interpreter_device.h"
-
-#include <utility>
-
-#include "absl/memory/memory.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/command_buffer_validation.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/fence.h"
-#include "iree/hal/host/async_command_queue.h"
-#include "iree/hal/host/host_event.h"
-#include "iree/hal/host/host_submission_queue.h"
-#include "iree/hal/host/inproc_command_buffer.h"
-#include "iree/hal/interpreter/bytecode_cache.h"
-#include "iree/hal/interpreter/interpreter_command_processor.h"
-
-namespace iree {
-namespace hal {
-
-namespace {
-
-// A CommandQueue that performs no synchronization (semaphores/fences) and just
-// directly executes command buffers inline.
-//
-// This is meant to be wrapped by SyncCommandQueue or AsyncCommandQueue that
-// themselves perform the synchronization/threading/etc. As such we ignore
-// all semaphores in the provided batches under the assumption that if Submit is
-// being called then all dependencies are valid. The wrapping queue is also
-// responsible for signaling the fence as well as propagating errors in a way
-// that is dependent on how it is performing its synchronization.
-class UnsynchronizedCommandQueue final : public CommandQueue {
- public:
- UnsynchronizedCommandQueue(Allocator* allocator, std::string name,
- CommandCategoryBitfield supported_categories)
- : CommandQueue(std::move(name), supported_categories),
- allocator_(allocator) {}
- ~UnsynchronizedCommandQueue() override = default;
-
- Status Submit(absl::Span<const SubmissionBatch> batches,
- FenceValue fence) override {
- IREE_TRACE_SCOPE0("UnsynchronizedCommandQueue::Submit");
- DCHECK_EQ(nullptr, fence.first)
- << "Fences must be handled by the wrapping queue";
-
- // Process command buffers and propagate errors asynchronously through the
- // fence. This ensures that even if we are running synchronously we still
- // get consistent failure behavior with drivers that are purely async.
- for (auto& batch : batches) {
- DCHECK(batch.wait_semaphores.empty() && batch.signal_semaphores.empty())
- << "Semaphores must be handled by the wrapping queue";
- RETURN_IF_ERROR(ProcessCommandBuffers(batch.command_buffers));
- }
-
- // NOTE: fence is ignored here.
- return OkStatus();
- }
-
- Status WaitIdle(absl::Time deadline) override {
- // No-op.
- return OkStatus();
- }
-
- private:
- // Processes each command buffer in-turn with a fresh processor.
- // This ensures we don't have any state that can carry across buffers.
- Status ProcessCommandBuffers(
- absl::Span<CommandBuffer* const> command_buffers) {
- IREE_TRACE_SCOPE0("UnsynchronizedCommandQueue::ProcessCommandBuffers");
- for (auto* command_buffer : command_buffers) {
- auto* inproc_command_buffer =
- static_cast<InProcCommandBuffer*>(command_buffer->impl());
- InterpreterCommandProcessor command_processor(
- allocator_, command_buffer->mode(), supported_categories());
- RETURN_IF_ERROR(inproc_command_buffer->Process(&command_processor));
- }
- return OkStatus();
- }
-
- Allocator* const allocator_;
-};
-
-} // namespace
-
-InterpreterDevice::InterpreterDevice(DeviceInfo device_info)
- : Device(std::move(device_info)), instance_(make_ref<rt::Instance>()) {
- // We currently only expose a single command queue.
- auto command_queue = absl::make_unique<UnsynchronizedCommandQueue>(
- &allocator_, "cpu0",
- CommandCategory::kTransfer | CommandCategory::kDispatch);
-
- // TODO(benvanik): allow injection of the wrapper type to support
- // SyncCommandQueue without always linking in both.
- auto async_command_queue =
- absl::make_unique<AsyncCommandQueue>(std::move(command_queue));
- command_queues_.push_back(std::move(async_command_queue));
-}
-
-InterpreterDevice::~InterpreterDevice() = default;
-
-std::shared_ptr<ExecutableCache> InterpreterDevice::CreateExecutableCache() {
- return std::make_shared<BytecodeCache>(add_ref(instance_), &allocator_);
-}
-
-StatusOr<ref_ptr<CommandBuffer>> InterpreterDevice::CreateCommandBuffer(
- CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories) {
- // TODO(b/140026716): conditionally enable validation.
- auto impl =
- make_ref<InProcCommandBuffer>(&allocator_, mode, command_categories);
- return WrapCommandBufferWithValidation(std::move(impl));
-}
-
-StatusOr<ref_ptr<Event>> InterpreterDevice::CreateEvent() {
- return make_ref<HostEvent>();
-}
-
-StatusOr<ref_ptr<BinarySemaphore>> InterpreterDevice::CreateBinarySemaphore(
- bool initial_value) {
- IREE_TRACE_SCOPE0("InterpreterDevice::CreateBinarySemaphore");
- return make_ref<HostBinarySemaphore>(initial_value);
-}
-
-StatusOr<ref_ptr<TimelineSemaphore>> InterpreterDevice::CreateTimelineSemaphore(
- uint64_t initial_value) {
- IREE_TRACE_SCOPE0("InterpreterDevice::CreateTimelineSemaphore");
-
- // TODO(b/140141417): implement timeline semaphores.
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Timeline semaphores not yet implemented";
-}
-
-StatusOr<ref_ptr<Fence>> InterpreterDevice::CreateFence(
- uint64_t initial_value) {
- IREE_TRACE_SCOPE0("InterpreterDevice::CreateFence");
- return make_ref<HostFence>(initial_value);
-}
-
-Status InterpreterDevice::WaitAllFences(absl::Span<const FenceValue> fences,
- absl::Time deadline) {
- IREE_TRACE_SCOPE0("InterpreterDevice::WaitAllFences");
- return HostFence::WaitForFences(fences, /*wait_all=*/true, deadline);
-}
-
-StatusOr<int> InterpreterDevice::WaitAnyFence(
- absl::Span<const FenceValue> fences, absl::Time deadline) {
- IREE_TRACE_SCOPE0("InterpreterDevice::WaitAnyFence");
- return HostFence::WaitForFences(fences, /*wait_all=*/false, deadline);
-}
-
-Status InterpreterDevice::WaitIdle(absl::Time deadline) {
- for (auto& command_queue : command_queues_) {
- RETURN_IF_ERROR(command_queue->WaitIdle(deadline));
- }
- return OkStatus();
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/interpreter/interpreter_device.h b/iree/hal/interpreter/interpreter_device.h
deleted file mode 100644
index 5987ab9..0000000
--- a/iree/hal/interpreter/interpreter_device.h
+++ /dev/null
@@ -1,79 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_
-#define IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_
-
-#include "absl/container/inlined_vector.h"
-#include "absl/types/span.h"
-#include "iree/base/memory.h"
-#include "iree/hal/device.h"
-#include "iree/hal/host/host_local_allocator.h"
-#include "iree/hal/interpreter/bytecode_kernels.h"
-#include "iree/rt/instance.h"
-
-namespace iree {
-namespace hal {
-
-class InterpreterDevice final : public Device {
- public:
- explicit InterpreterDevice(DeviceInfo device_info);
- ~InterpreterDevice() override;
-
- kernels::RuntimeState* kernel_runtime_state() {
- return &kernel_runtime_state_;
- }
-
- Allocator* allocator() const override { return &allocator_; }
-
- absl::Span<CommandQueue*> dispatch_queues() const override {
- return RawPtrSpan(absl::MakeSpan(command_queues_));
- }
-
- absl::Span<CommandQueue*> transfer_queues() const override {
- return RawPtrSpan(absl::MakeSpan(command_queues_));
- }
-
- std::shared_ptr<ExecutableCache> CreateExecutableCache() override;
-
- StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer(
- CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories) override;
-
- StatusOr<ref_ptr<Event>> CreateEvent() override;
-
- StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore(
- bool initial_value) override;
- StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore(
- uint64_t initial_value) override;
-
- StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override;
- Status WaitAllFences(absl::Span<const FenceValue> fences,
- absl::Time deadline) override;
- StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences,
- absl::Time deadline) override;
-
- Status WaitIdle(absl::Time deadline) override;
-
- private:
- ref_ptr<rt::Instance> instance_;
- kernels::RuntimeState kernel_runtime_state_;
- mutable HostLocalAllocator allocator_;
- mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 1> command_queues_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_
diff --git a/iree/hal/interpreter/interpreter_driver.cc b/iree/hal/interpreter/interpreter_driver.cc
deleted file mode 100644
index d9074ef..0000000
--- a/iree/hal/interpreter/interpreter_driver.cc
+++ /dev/null
@@ -1,62 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/interpreter/interpreter_driver.h"
-
-#include <memory>
-
-#include "iree/hal/device_info.h"
-#include "iree/hal/interpreter/interpreter_device.h"
-
-namespace iree {
-namespace hal {
-
-namespace {
-
-DeviceInfo GetDefaultDeviceInfo() {
- DeviceFeatureBitfield supported_features = DeviceFeature::kNone;
- // TODO(benvanik): implement debugging/profiling features.
- // supported_features |= DeviceFeature::kDebugging;
- // supported_features |= DeviceFeature::kCoverage;
- // supported_features |= DeviceFeature::kProfiling;
- DeviceInfo device_info("interpreter", supported_features);
- // TODO(benvanik): device info.
- return device_info;
-}
-
-} // namespace
-
-InterpreterDriver::InterpreterDriver() : Driver("interpreter") {}
-
-InterpreterDriver::~InterpreterDriver() = default;
-
-StatusOr<std::vector<DeviceInfo>>
-InterpreterDriver::EnumerateAvailableDevices() {
- std::vector<DeviceInfo> device_infos;
- device_infos.push_back(GetDefaultDeviceInfo());
- return device_infos;
-}
-
-StatusOr<std::shared_ptr<Device>> InterpreterDriver::CreateDefaultDevice() {
- return CreateDevice(GetDefaultDeviceInfo());
-}
-
-StatusOr<std::shared_ptr<Device>> InterpreterDriver::CreateDevice(
- const DeviceInfo& device_info) {
- auto device = std::make_shared<InterpreterDevice>(device_info);
- return device;
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/interpreter/interpreter_driver.h b/iree/hal/interpreter/interpreter_driver.h
deleted file mode 100644
index b217996..0000000
--- a/iree/hal/interpreter/interpreter_driver.h
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_
-#define IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_
-
-#include "iree/hal/driver.h"
-
-namespace iree {
-namespace hal {
-
-class InterpreterDriver final : public Driver {
- public:
- InterpreterDriver();
- ~InterpreterDriver() override;
-
- StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override;
-
- StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override;
-
- StatusOr<std::shared_ptr<Device>> CreateDevice(
- const DeviceInfo& device_info) override;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_
diff --git a/iree/hal/interpreter/interpreter_driver_module.cc b/iree/hal/interpreter/interpreter_driver_module.cc
deleted file mode 100644
index e9e78e7..0000000
--- a/iree/hal/interpreter/interpreter_driver_module.cc
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <memory>
-
-#include "iree/base/init.h"
-#include "iree/base/status.h"
-#include "iree/hal/driver_registry.h"
-#include "iree/hal/interpreter/interpreter_driver.h"
-
-namespace iree {
-namespace hal {
-namespace {
-
-StatusOr<std::shared_ptr<Driver>> CreateInterpreterDriver() {
- return std::make_shared<InterpreterDriver>();
-}
-
-} // namespace
-} // namespace hal
-} // namespace iree
-
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_interpreter_driver, {
- QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "interpreter", ::iree::hal::CreateInterpreterDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal,
- iree_hal_interpreter_driver);
diff --git a/iree/hal/interpreter/interpreter_module.cc b/iree/hal/interpreter/interpreter_module.cc
deleted file mode 100644
index 84e962c..0000000
--- a/iree/hal/interpreter/interpreter_module.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/interpreter/interpreter_module.h"
-
-#include "iree/base/flatbuffer_util.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/interpreter/bytecode_dispatch.h"
-#include "iree/vm/bytecode_tables_interpreter.h"
-
-namespace iree {
-namespace hal {
-
-// static
-StatusOr<ref_ptr<rt::Module>> InterpreterModule::FromDef(
- hal::Allocator* allocator, const ModuleDef& module_def) {
- ASSIGN_OR_RETURN(auto module_file,
- vm::ModuleFile::Create(&module_def, []() {}));
- if (module_file->root() == nullptr) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No root ModuleDef present";
- }
-
- auto module =
- assign_ref(new InterpreterModule(allocator, std::move(module_file)));
-
- // TODO(benvanik): validate internals here? or make explicit?
-
- return {std::move(module)};
-}
-
-InterpreterModule::InterpreterModule(
- hal::Allocator* allocator, std::unique_ptr<vm::ModuleFile> module_file)
- : vm::BytecodeModule(std::move(module_file),
- vm::interpreter_opcode_table()),
- allocator_(allocator) {}
-
-Status InterpreterModule::Execute(
- rt::Stack* stack, const rt::Function function,
- absl::InlinedVector<hal::BufferView, 8> arguments,
- absl::InlinedVector<hal::BufferView, 8>* results) const {
- IREE_TRACE_SCOPE0("InterperterModule::Execute");
-
- // Push stack frame for the function we are calling.
- ASSIGN_OR_RETURN(auto* callee_stack_frame, stack->PushFrame(function));
-
- // TODO(benvanik): rework register storage interface.
- ASSIGN_OR_RETURN(const auto* function_def,
- GetFunctionDef(function.linkage(), function.ordinal()));
- auto* registers = callee_stack_frame->mutable_registers();
- registers->buffer_views.resize(function_def->bytecode()->local_count());
-
- // Marshal input arguments.
- for (int i = 0; i < arguments.size(); ++i) {
- registers->buffer_views[i] = std::move(arguments[i]);
- }
-
- // Run main dispatch loop until it exits (or errors).
- RETURN_IF_ERROR(Dispatch(allocator_, &kernel_runtime_state_, stack,
- callee_stack_frame, absl::MakeSpan(*results)));
-
- // Pop the callee frame to balance out the stack.
- RETURN_IF_ERROR(stack->PopFrame());
-
- return OkStatus();
-}
-
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/interpreter/interpreter_module.h b/iree/hal/interpreter/interpreter_module.h
deleted file mode 100644
index d7d98e3..0000000
--- a/iree/hal/interpreter/interpreter_module.h
+++ /dev/null
@@ -1,55 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_
-#define IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_
-
-#include <memory>
-
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/interpreter/bytecode_kernels.h"
-#include "iree/rt/function.h"
-#include "iree/rt/module.h"
-#include "iree/rt/stack.h"
-#include "iree/vm/bytecode_module.h"
-#include "iree/vm/bytecode_tables_interpreter.h"
-
-namespace iree {
-namespace hal {
-
-class InterpreterModule final : public vm::BytecodeModule {
- public:
- static StatusOr<ref_ptr<rt::Module>> FromDef(hal::Allocator* allocator,
- const ModuleDef& module_def);
-
- Status Execute(
- rt::Stack* stack, const rt::Function function,
- absl::InlinedVector<hal::BufferView, 8> arguments,
- absl::InlinedVector<hal::BufferView, 8>* results) const override;
-
- private:
- InterpreterModule(hal::Allocator* allocator,
- std::unique_ptr<vm::ModuleFile> module_file);
-
- hal::Allocator* allocator_;
- mutable kernels::RuntimeState kernel_runtime_state_;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_
diff --git a/iree/hal/resource.h b/iree/hal/resource.h
deleted file mode 100644
index 311a7ee..0000000
--- a/iree/hal/resource.h
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_RESOURCE_H_
-#define IREE_HAL_RESOURCE_H_
-
-#include "iree/base/ref_ptr.h"
-
-namespace iree {
-namespace hal {
-
-// Abstract resource type whose lifetime is managed by a ResourceSet.
-// Used mostly just to get a virtual dtor, though we could add nicer logging.
-class Resource : public RefObject<Resource> {
- public:
- virtual ~Resource() = default;
-};
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_RESOURCE_H_
diff --git a/iree/hal/semaphore.h b/iree/hal/semaphore.h
deleted file mode 100644
index 74665d3..0000000
--- a/iree/hal/semaphore.h
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_SEMAPHORE_H_
-#define IREE_HAL_SEMAPHORE_H_
-
-#include "absl/types/variant.h"
-#include "iree/hal/resource.h"
-
-namespace iree {
-namespace hal {
-
-// A synchronization primitive used to indicate submission dependencies.
-// Semaphores are either of type binary (signaled or unsignaled) or timeline
-// (uint64 payload with >= semantics).
-class Semaphore : public Resource {
- public:
-};
-
-// Binary semaphores have strict ordering requirements and must be carefully
-// balanced. Each binary semaphore must only be waited on after a signal
-// operation has been issued and each wait requires exactly one signal. They
-// are commonly used only when interacting with external handles that may
-// cross device or process boundaries.
-class BinarySemaphore : public Semaphore {
- public:
-};
-
-// Timeline semaphores act as a fence along a per-semaphore timeline where
-// signaling is done by setting the payload to a monotonically increasing
-// 64-bit integer and waiting is done by blocking until the payload is set
-// greater-than or equal-to the specified value. Timeline semaphores may be
-// waited on or signaled in any order and can be significantly more
-// efficient due to system-level coalescing.
-class TimelineSemaphore : public Semaphore {
- public:
- // TODO(benvanik): add value query support.
- // TODO(benvanik): add host-side signal/wait.
-};
-
-// A reference to a strongly-typed semaphore and associated information.
-// For TimelineSemaphores the provided payload is used to specify either the
-// payload to wait for or new payload value.
-using SemaphoreValue =
- absl::variant<BinarySemaphore*, std::pair<TimelineSemaphore*, uint64_t>>;
-
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_SEMAPHORE_H_
diff --git a/iree/hal/testing/BUILD b/iree/hal/testing/BUILD
deleted file mode 100644
index ac80497..0000000
--- a/iree/hal/testing/BUILD
+++ /dev/null
@@ -1,36 +0,0 @@
-# Test utilities for HAL-specific code.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "mock_allocator",
- testonly = True,
- hdrs = ["mock_allocator.h"],
- deps = [
- "//iree/hal:allocator",
- "@com_google_googletest//:gtest",
- ],
-)
-
-cc_library(
- name = "mock_command_buffer",
- testonly = True,
- hdrs = ["mock_command_buffer.h"],
- deps = [
- "//iree/hal:command_buffer",
- "@com_google_googletest//:gtest",
- ],
-)
-
-cc_library(
- name = "mock_command_queue",
- testonly = True,
- hdrs = ["mock_command_queue.h"],
- deps = [
- "//iree/hal:command_queue",
- "@com_google_googletest//:gtest",
- ],
-)
diff --git a/iree/hal/testing/mock_allocator.h b/iree/hal/testing/mock_allocator.h
deleted file mode 100644
index cc5548e..0000000
--- a/iree/hal/testing/mock_allocator.h
+++ /dev/null
@@ -1,55 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_TESTING_MOCK_ALLOCATOR_H_
-#define IREE_HAL_TESTING_MOCK_ALLOCATOR_H_
-
-#include "gmock/gmock.h"
-#include "iree/hal/allocator.h"
-
-namespace iree {
-namespace hal {
-namespace testing {
-
-class MockAllocator : public ::testing::StrictMock<Allocator> {
- public:
- MockAllocator() : ::testing::StrictMock<Allocator>() {}
-
- MOCK_CONST_METHOD4(CanUseBufferLike,
- bool(Allocator* source_allocator,
- MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- BufferUsageBitfield intended_usage));
-
- MOCK_CONST_METHOD3(CanAllocate, bool(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size));
-
- MOCK_METHOD3(Allocate,
- StatusOr<ref_ptr<Buffer>>(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size));
-
- MOCK_METHOD5(WrapMutable,
- StatusOr<ref_ptr<Buffer>>(MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield buffer_usage,
- void* data, size_t data_length));
-};
-
-} // namespace testing
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_TESTING_MOCK_ALLOCATOR_H_
diff --git a/iree/hal/testing/mock_command_buffer.h b/iree/hal/testing/mock_command_buffer.h
deleted file mode 100644
index 6c55499..0000000
--- a/iree/hal/testing/mock_command_buffer.h
+++ /dev/null
@@ -1,80 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_
-#define IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_
-
-#include "gmock/gmock.h"
-#include "iree/hal/command_buffer.h"
-
-namespace iree {
-namespace hal {
-namespace testing {
-
-class MockCommandBuffer : public ::testing::StrictMock<CommandBuffer> {
- public:
- MockCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories)
- : ::testing::StrictMock<CommandBuffer>(allocator, mode,
- command_categories) {}
-
- bool is_recording() const override { return false; }
-
- MOCK_METHOD0(Begin, Status());
- MOCK_METHOD0(End, Status());
-
- MOCK_METHOD4(ExecutionBarrier,
- Status(ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers));
-
- MOCK_METHOD2(SignalEvent,
- Status(Event* event, ExecutionStageBitfield source_stage_mask));
-
- MOCK_METHOD2(ResetEvent,
- Status(Event* event, ExecutionStageBitfield source_stage_mask));
-
- MOCK_METHOD5(WaitEvents,
- Status(absl::Span<Event*> events,
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers));
-
- MOCK_METHOD5(FillBuffer,
- Status(Buffer* target_buffer, device_size_t target_offset,
- device_size_t length, const void* pattern,
- size_t pattern_length));
-
- MOCK_METHOD1(DiscardBuffer, Status(Buffer* buffer));
-
- MOCK_METHOD5(UpdateBuffer,
- Status(const void* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length));
-
- MOCK_METHOD5(CopyBuffer,
- Status(Buffer* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length));
-
- MOCK_METHOD1(Dispatch, Status(const DispatchRequest& dispatch_request));
-};
-
-} // namespace testing
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_
diff --git a/iree/hal/testing/mock_command_queue.h b/iree/hal/testing/mock_command_queue.h
deleted file mode 100644
index 1f5fe8b..0000000
--- a/iree/hal/testing/mock_command_queue.h
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_
-#define IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_
-
-#include "gmock/gmock.h"
-#include "iree/hal/command_queue.h"
-
-namespace iree {
-namespace hal {
-namespace testing {
-
-class MockCommandQueue : public ::testing::StrictMock<CommandQueue> {
- public:
- MockCommandQueue(std::string name,
- CommandCategoryBitfield supported_categories)
- : ::testing::StrictMock<CommandQueue>(std::move(name),
- supported_categories) {}
-
- MOCK_METHOD2(Submit, Status(absl::Span<const SubmissionBatch> batches,
- FenceValue fence));
-
- MOCK_METHOD1(WaitIdle, Status(absl::Time deadline));
-};
-
-} // namespace testing
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_
diff --git a/iree/hal/vulkan/BUILD b/iree/hal/vulkan/BUILD
deleted file mode 100644
index 56bcb98..0000000
--- a/iree/hal/vulkan/BUILD
+++ /dev/null
@@ -1,380 +0,0 @@
-# HAL implementation using Vulkan and (likely) SPIR-V executables.
-
-load("//iree:build_defs.bzl", "PLATFORM_VULKAN_LOADER_COPTS", "PLATFORM_VULKAN_TEST_DEPS")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-# --define=IREE_VK=native to use the native Vulkan drivers (and real hardware).
-config_setting(
- name = "native_vk",
- values = {
- "define": "IREE_VK=native",
- },
-)
-
-# --define=IREE_VK=swiftshader to use SwiftShader.
-config_setting(
- name = "swiftshader_vk",
- values = {
- "define": "IREE_VK=swiftshader",
- },
-)
-
-cc_library(
- name = "debug_reporter",
- srcs = ["debug_reporter.cc"],
- hdrs = ["debug_reporter.h"],
- deps = [
- ":dynamic_symbols",
- ":status_util",
- "//iree/base:status",
- "//iree/base:tracing",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "descriptor_pool_cache",
- srcs = ["descriptor_pool_cache.cc"],
- hdrs = ["descriptor_pool_cache.h"],
- deps = [
- ":dynamic_symbols",
- ":handle_util",
- ":status_util",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "//iree/base:tracing",
- "@com_google_absl//absl/container:inlined_vector",
- ],
-)
-
-cc_library(
- name = "descriptor_set_arena",
- srcs = ["descriptor_set_arena.cc"],
- hdrs = ["descriptor_set_arena.h"],
- deps = [
- ":descriptor_pool_cache",
- ":pipeline_executable",
- ":status_util",
- ":vma_allocator",
- "//iree/base:arena",
- "//iree/base:math",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer",
- ],
-)
-
-cc_library(
- name = "direct_command_buffer",
- srcs = ["direct_command_buffer.cc"],
- hdrs = ["direct_command_buffer.h"],
- deps = [
- ":descriptor_pool_cache",
- ":descriptor_set_arena",
- ":dynamic_symbols",
- ":handle_util",
- ":native_event",
- ":pipeline_executable",
- ":status_util",
- ":vma_allocator",
- "//iree/base:math",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "direct_command_queue",
- srcs = ["direct_command_queue.cc"],
- hdrs = ["direct_command_queue.h"],
- deps = [
- ":direct_command_buffer",
- ":dynamic_symbols",
- ":handle_util",
- ":legacy_fence",
- ":native_binary_semaphore",
- ":status_util",
- "//iree/base:arena",
- "//iree/base:memory",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_queue",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "dynamic_symbols",
- srcs = ["dynamic_symbols.cc"],
- hdrs = [
- "dynamic_symbol_tables.h",
- "dynamic_symbols.h",
- ],
- copts = PLATFORM_VULKAN_LOADER_COPTS,
- linkopts = [
- "-ldl",
- ],
- deps = [
- "//iree/base:file_path",
- "//iree/base:ref_ptr",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:target_platform",
- "//iree/base:tracing",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/memory",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_test(
- name = "dynamic_symbols_test",
- srcs = ["dynamic_symbols_test.cc"],
- deps = [
- ":status_util",
- ":dynamic_symbols",
- "//iree/base:status_matchers",
- ] + PLATFORM_VULKAN_TEST_DEPS,
-)
-
-cc_library(
- name = "extensibility_util",
- srcs = ["extensibility_util.cc"],
- hdrs = ["extensibility_util.h"],
- deps = [
- ":dynamic_symbols",
- ":status_util",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "@com_google_absl//absl/types:span",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "handle_util",
- hdrs = ["handle_util.h"],
- deps = [
- ":dynamic_symbols",
- ":extensibility_util",
- "//iree/base:ref_ptr",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/utility",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "legacy_fence",
- srcs = ["legacy_fence.cc"],
- hdrs = ["legacy_fence.h"],
- deps = [
- ":handle_util",
- ":status_util",
- "//iree/base:intrusive_list",
- "//iree/base:ref_ptr",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:fence",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "native_binary_semaphore",
- srcs = ["native_binary_semaphore.cc"],
- hdrs = ["native_binary_semaphore.h"],
- deps = [
- ":handle_util",
- "//iree/hal:semaphore",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "native_event",
- srcs = ["native_event.cc"],
- hdrs = ["native_event.h"],
- deps = [
- ":handle_util",
- "//iree/hal:event",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "pipeline_cache",
- srcs = ["pipeline_cache.cc"],
- hdrs = ["pipeline_cache.h"],
- deps = [
- ":handle_util",
- ":pipeline_executable",
- ":status_util",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_format",
- "//iree/schemas:spirv_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "pipeline_executable",
- srcs = ["pipeline_executable.cc"],
- hdrs = ["pipeline_executable.h"],
- deps = [
- ":handle_util",
- ":status_util",
- "//iree/base:memory",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_spec",
- "//iree/schemas:spirv_executable_def_cc_fbs",
- "@com_google_absl//absl/container:inlined_vector",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "status_util",
- srcs = ["status_util.cc"],
- hdrs = ["status_util.h"],
- deps = [
- "//iree/base:status",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "vma_allocator",
- srcs = [
- "internal_vk_mem_alloc.cc",
- "internal_vk_mem_alloc.h",
- "vma_allocator.cc",
- "vma_buffer.cc",
- ],
- hdrs = [
- "vma_allocator.h",
- "vma_buffer.h",
- ],
- copts = [
- # Only needed in the implementation cc and not by external users.
- "-DVMA_STATIC_VULKAN_FUNCTIONS=0",
- ],
- deps = [
- ":dynamic_symbols",
- ":handle_util",
- ":status_util",
- "//iree/base:logging",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:buffer",
- "@com_google_absl//absl/container:flat_hash_map",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/synchronization",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- "@vulkan_memory_allocator//:impl_header_only",
- ],
-)
-
-cc_library(
- name = "vulkan_device",
- srcs = ["vulkan_device.cc"],
- hdrs = ["vulkan_device.h"],
- deps = [
- ":descriptor_pool_cache",
- ":direct_command_buffer",
- ":direct_command_queue",
- ":dynamic_symbols",
- ":extensibility_util",
- ":handle_util",
- ":legacy_fence",
- ":native_binary_semaphore",
- ":native_event",
- ":pipeline_cache",
- ":status_util",
- ":vma_allocator",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:command_buffer_validation",
- "//iree/hal:command_queue",
- "//iree/hal:device",
- "//iree/hal:fence",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/types:span",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "vulkan_driver",
- srcs = ["vulkan_driver.cc"],
- hdrs = ["vulkan_driver.h"],
- deps = [
- ":debug_reporter",
- ":dynamic_symbols",
- ":extensibility_util",
- ":status_util",
- ":vulkan_device",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- "@com_google_absl//absl/container:inlined_vector",
- "@vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "vulkan_driver_module",
- srcs = ["vulkan_driver_module.cc"],
- deps = [
- ":dynamic_symbols",
- ":vulkan_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:driver_registry",
- "@com_google_absl//absl/flags:flag",
- ],
- alwayslink = 1,
-)
diff --git a/iree/hal/vulkan/debug_reporter.cc b/iree/hal/vulkan/debug_reporter.cc
deleted file mode 100644
index 46ec68d..0000000
--- a/iree/hal/vulkan/debug_reporter.cc
+++ /dev/null
@@ -1,160 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/debug_reporter.h"
-
-#include "iree/base/tracing.h"
-#include "iree/hal/vulkan/status_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-namespace {
-
-// NOTE: |user_data| may be nullptr if we are being called during instance
-// creation. Otherwise it is a pointer to the DebugReporter instance.
-
-// NOTE: this callback must be thread safe and must be careful not to reach too
-// far outside of the call - it is called in-context from arbitrary threads with
-// some amount of Vulkan state on the stack. Assume that creating or deleting
-// Vulkan objects, issuing most Vulkan commands, etc are off-limits.
-
-VKAPI_ATTR VkBool32 VKAPI_CALL DebugUtilsMessageCallback(
- VkDebugUtilsMessageSeverityFlagBitsEXT message_severity,
- VkDebugUtilsMessageTypeFlagsEXT message_type,
- const VkDebugUtilsMessengerCallbackDataEXT* callback_data,
- void* user_data) {
- // TODO(benvanik): better logging once we have switched logging APIs.
- LOG(ERROR) << callback_data->pMessage;
-
- return VK_FALSE; // VK_TRUE is reserved for future use.
-}
-
-VKAPI_ATTR VkBool32 VKAPI_CALL DebugReportCallback(
- VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT object_type,
- uint64_t object, size_t location, int32_t message_code,
- const char* layer_prefix, const char* message, void* user_data) {
- // TODO(benvanik): better logging once we have switched logging APIs.
- LOG(ERROR) << message;
-
- return VK_FALSE; // VK_TRUE is reserved for future use.
-}
-
-} // namespace
-
-// static
-void DebugReporter::PopulateStaticCreateInfo(
- VkDebugUtilsMessengerCreateInfoEXT* create_info) {
- create_info->sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
- create_info->pNext = nullptr;
- create_info->flags = 0;
-
- // TODO(benvanik): only enable the severities that logging has enabled.
- create_info->messageSeverity =
- VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT |
- VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT |
- VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT |
- VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
-
- // TODO(benvanik): allow filtering by category as a flag.
- create_info->messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT |
- VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT |
- VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
-
- create_info->pfnUserCallback = DebugUtilsMessageCallback;
- create_info->pUserData = nullptr;
-}
-
-// static
-void DebugReporter::PopulateStaticCreateInfo(
- VkDebugReportCallbackCreateInfoEXT* create_info) {
- create_info->sType = VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT;
- create_info->pNext = nullptr;
- create_info->flags = 0;
-
- // TODO(benvanik): only enable the severities that logging has enabled.
- create_info->flags |=
- VK_DEBUG_REPORT_INFORMATION_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT |
- VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT |
- VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_DEBUG_BIT_EXT;
-
- create_info->pfnCallback = DebugReportCallback;
- create_info->pUserData = nullptr;
-}
-
-// static
-StatusOr<std::unique_ptr<DebugReporter>>
-DebugReporter::CreateDebugUtilsMessenger(
- VkInstance instance, const ref_ptr<DynamicSymbols>& syms,
- const VkAllocationCallbacks* allocation_callbacks) {
- IREE_TRACE_SCOPE0("DebugReporter::CreateDebugUtilsMessenger");
-
- auto debug_reporter =
- absl::WrapUnique(new DebugReporter(instance, syms, allocation_callbacks));
-
- VkDebugUtilsMessengerCreateInfoEXT create_info;
- PopulateStaticCreateInfo(&create_info);
- create_info.pUserData = debug_reporter.get();
-
- VK_RETURN_IF_ERROR(syms->vkCreateDebugUtilsMessengerEXT(
- instance, &create_info, allocation_callbacks,
- &debug_reporter->messenger_));
-
- return debug_reporter;
-}
-
-// static
-StatusOr<std::unique_ptr<DebugReporter>>
-DebugReporter::CreateDebugReportCallback(
- VkInstance instance, const ref_ptr<DynamicSymbols>& syms,
- const VkAllocationCallbacks* allocation_callbacks) {
- IREE_TRACE_SCOPE0("DebugReporter::CreateDebugReportCallback");
-
- auto debug_reporter =
- absl::WrapUnique(new DebugReporter(instance, syms, allocation_callbacks));
-
- VkDebugReportCallbackCreateInfoEXT create_info;
- PopulateStaticCreateInfo(&create_info);
- create_info.pUserData = debug_reporter.get();
-
- VK_RETURN_IF_ERROR(syms->vkCreateDebugReportCallbackEXT(
- instance, &create_info, allocation_callbacks,
- &debug_reporter->callback_));
-
- return debug_reporter;
-}
-
-DebugReporter::DebugReporter(VkInstance instance,
- const ref_ptr<DynamicSymbols>& syms,
- const VkAllocationCallbacks* allocation_callbacks)
- : instance_(instance),
- syms_(add_ref(syms)),
- allocation_callbacks_(allocation_callbacks) {}
-
-DebugReporter::~DebugReporter() {
- IREE_TRACE_SCOPE0("DebugReporter::dtor");
- if (messenger_ != VK_NULL_HANDLE) {
- syms_->vkDestroyDebugUtilsMessengerEXT(instance_, messenger_,
- allocation_callbacks_);
- }
- if (callback_ != VK_NULL_HANDLE) {
- syms_->vkDestroyDebugReportCallbackEXT(instance_, callback_,
- allocation_callbacks_);
- }
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/debug_reporter.h b/iree/hal/vulkan/debug_reporter.h
deleted file mode 100644
index ffc594b..0000000
--- a/iree/hal/vulkan/debug_reporter.h
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_DEBUG_REPORTER_H_
-#define IREE_HAL_VULKAN_DEBUG_REPORTER_H_
-
-#include <vulkan/vulkan.h>
-
-#include "iree/base/status.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// A debug reporter that works with the VK_EXT_debug_utils extension.
-// One reporter should be created per VkInstance to receive callbacks from the
-// API and route them to our logging systems. In general VK_EXT_debug_utils
-// should be preferred if available as it provides a much cleaner interface and
-// more plug-points than VK_EXT_debug_report.
-//
-// Since creating a reporter requires a VkInstance it's not possible to report
-// on messages during instance creation. To work around this it's possible to
-// pass a *CreateInfo struct to vkCreateInstance as part of the
-// VkInstanceCreateInfo::pNext chain. The callback will only be used this way
-// during the creation call after which users can create the real
-// instance-specific reporter.
-class DebugReporter final {
- public:
- // Populates |create_info| with an instance-agnostic callback.
- // This can be used during instance creation by chaining the |create_info| to
- // VkInstanceCreateInfo::pNext.
- //
- // Only use if VK_EXT_debug_utils is present.
- static void PopulateStaticCreateInfo(
- VkDebugUtilsMessengerCreateInfoEXT* create_info);
-
- // Populates |create_info| with an instance-agnostic callback.
- // This can be used during instance creation by chaining the |create_info| to
- // VkInstanceCreateInfo::pNext.
- //
- // Only use if VK_EXT_debug_report is present.
- static void PopulateStaticCreateInfo(
- VkDebugReportCallbackCreateInfoEXT* create_info);
-
- // Creates a debug messenger for the given Vulkan |instance| with
- // VK_EXT_debug_utils enabled.
- static StatusOr<std::unique_ptr<DebugReporter>> CreateDebugUtilsMessenger(
- VkInstance instance, const ref_ptr<DynamicSymbols>& syms,
- const VkAllocationCallbacks* allocation_callbacks);
-
- // Creates a debug report callback for the given Vulkan |instance| with
- // VK_EXT_debug_report enabled.
- static StatusOr<std::unique_ptr<DebugReporter>> CreateDebugReportCallback(
- VkInstance instance, const ref_ptr<DynamicSymbols>& syms,
- const VkAllocationCallbacks* allocation_callbacks);
-
- ~DebugReporter();
-
- private:
- DebugReporter(VkInstance instance, const ref_ptr<DynamicSymbols>& syms,
- const VkAllocationCallbacks* allocation_callbacks);
-
- VkInstance instance_ = VK_NULL_HANDLE;
- ref_ptr<DynamicSymbols> syms_;
- const VkAllocationCallbacks* allocation_callbacks_ = nullptr;
-
- VkDebugUtilsMessengerEXT messenger_ = VK_NULL_HANDLE;
- VkDebugReportCallbackEXT callback_ = VK_NULL_HANDLE;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_DEBUG_REPORTER_H_
diff --git a/iree/hal/vulkan/descriptor_pool_cache.cc b/iree/hal/vulkan/descriptor_pool_cache.cc
deleted file mode 100644
index 3e1e4ef..0000000
--- a/iree/hal/vulkan/descriptor_pool_cache.cc
+++ /dev/null
@@ -1,102 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/descriptor_pool_cache.h"
-
-#include <array>
-
-#include "iree/base/tracing.h"
-#include "iree/hal/vulkan/status_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-namespace {
-
-// TODO(benvanik): be more conservative with descriptor set count or allow
-// chaining in the command buffer when pools run out.
-static constexpr int kMaxDescriptorSets = 4096;
-
-} // namespace
-
-DescriptorSetGroup::~DescriptorSetGroup() {
- CHECK(descriptor_pools_.empty())
- << "DescriptorSetGroup must be reset explicitly";
-}
-
-Status DescriptorSetGroup::Reset() {
- IREE_TRACE_SCOPE0("DescriptorSetGroup::Reset");
-
- RETURN_IF_ERROR(descriptor_pool_cache_->ReleaseDescriptorPools(
- absl::MakeSpan(descriptor_pools_)));
- descriptor_pools_.clear();
-
- return OkStatus();
-}
-
-DescriptorPoolCache::DescriptorPoolCache(ref_ptr<VkDeviceHandle> logical_device)
- : logical_device_(std::move(logical_device)) {}
-
-StatusOr<DescriptorPool> DescriptorPoolCache::AcquireDescriptorPool(
- VkDescriptorType descriptor_type, int max_descriptor_count) {
- IREE_TRACE_SCOPE0("DescriptorPoolCache::AcquireDescriptorPool");
-
- // TODO(benvanik): lookup in cache.
-
- VkDescriptorPoolCreateInfo create_info;
- create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
- create_info.pNext = nullptr;
- create_info.flags = 0;
- create_info.maxSets = kMaxDescriptorSets;
- std::array<VkDescriptorPoolSize, 1> pool_sizes;
- pool_sizes[0].type = descriptor_type;
- pool_sizes[0].descriptorCount = max_descriptor_count;
- create_info.poolSizeCount = pool_sizes.size();
- create_info.pPoolSizes = pool_sizes.data();
-
- DescriptorPool descriptor_pool;
- descriptor_pool.descriptor_type = descriptor_type;
- descriptor_pool.max_descriptor_count = max_descriptor_count;
- descriptor_pool.handle = VK_NULL_HANDLE;
-
- VK_RETURN_IF_ERROR(syms().vkCreateDescriptorPool(
- *logical_device_, &create_info, logical_device_->allocator(),
- &descriptor_pool.handle));
-
- return descriptor_pool;
-}
-
-Status DescriptorPoolCache::ReleaseDescriptorPools(
- absl::Span<DescriptorPool> descriptor_pools) {
- IREE_TRACE_SCOPE0("DescriptorPoolCache::ReleaseDescriptorPools");
-
- for (const auto& descriptor_pool : descriptor_pools) {
- // Always reset immediately. We could do this on allocation instead however
- // this leads to better errors when using the validation layers as we'll
- // throw if there are in-flight command buffers using the sets in the pool.
- VK_RETURN_IF_ERROR(syms().vkResetDescriptorPool(*logical_device_,
- descriptor_pool.handle, 0));
-
- // TODO(benvanik): release to cache.
- syms().vkDestroyDescriptorPool(*logical_device_, descriptor_pool.handle,
- logical_device_->allocator());
- }
-
- return OkStatus();
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/descriptor_pool_cache.h b/iree/hal/vulkan/descriptor_pool_cache.h
deleted file mode 100644
index bb4fa33..0000000
--- a/iree/hal/vulkan/descriptor_pool_cache.h
+++ /dev/null
@@ -1,103 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_
-#define IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_
-
-#include "absl/container/inlined_vector.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-#include "iree/hal/vulkan/handle_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-class DescriptorPoolCache;
-
-// A descriptor pool with a single descriptor type of some number.
-// We only support a single descriptor type for now as we only generate SPIR-V
-// that uses a single type.
-struct DescriptorPool {
- // Type of the descriptor in the set.
- VkDescriptorType descriptor_type = VK_DESCRIPTOR_TYPE_MAX_ENUM;
- // Maximum number of descriptors of the given type per allocation.
- int max_descriptor_count = 0;
- // Pool handle.
- VkDescriptorPool handle = VK_NULL_HANDLE;
-};
-
-// A group of descriptor sets allocated and released together.
-// The group must be explicitly reset with Reset() prior to disposing.
-class DescriptorSetGroup final {
- public:
- DescriptorSetGroup() = default;
- DescriptorSetGroup(ref_ptr<DescriptorPoolCache> descriptor_pool_cache,
- absl::InlinedVector<DescriptorPool, 8> descriptor_pools)
- : descriptor_pool_cache_(std::move(descriptor_pool_cache)),
- descriptor_pools_(std::move(descriptor_pools)) {}
- DescriptorSetGroup(const DescriptorSetGroup&) = delete;
- DescriptorSetGroup& operator=(const DescriptorSetGroup&) = delete;
- DescriptorSetGroup(DescriptorSetGroup&& other) noexcept
- : descriptor_pool_cache_(std::move(other.descriptor_pool_cache_)),
- descriptor_pools_(std::move(other.descriptor_pools_)) {}
- DescriptorSetGroup& operator=(DescriptorSetGroup&& other) {
- std::swap(descriptor_pool_cache_, other.descriptor_pool_cache_);
- std::swap(descriptor_pools_, other.descriptor_pools_);
- return *this;
- }
- ~DescriptorSetGroup();
-
- Status Reset();
-
- private:
- ref_ptr<DescriptorPoolCache> descriptor_pool_cache_;
- absl::InlinedVector<DescriptorPool, 8> descriptor_pools_;
-};
-
-// A "cache" (or really, pool) of descriptor pools. These pools are allocated
-// as needed to satisfy different descriptor size requirements and are given
-// to command buffers during recording to write descriptor updates and bind
-// resources. After the descriptors in the pool are no longer used (all
-// command buffers using descriptor sets allocated from the pool have retired)
-// the pool is returned here to be reused in the future.
-class DescriptorPoolCache final : public RefObject<DescriptorPoolCache> {
- public:
- explicit DescriptorPoolCache(ref_ptr<VkDeviceHandle> logical_device);
-
- const ref_ptr<VkDeviceHandle>& logical_device() const {
- return logical_device_;
- }
- const DynamicSymbols& syms() const { return *logical_device_->syms(); }
-
- // Acquires a new descriptor pool for use by the caller.
- // The pool will have been reset and have all descriptor sets available.
- // When all sets allocated from the pool are no longer in use it must be
- // returned to the cache with ReleaseDescriptorPool.
- StatusOr<DescriptorPool> AcquireDescriptorPool(
- VkDescriptorType descriptor_type, int max_descriptor_count);
-
- // Releases descriptor pools back to the cache. The pools will be reset
- // immediately and must no longer be in use by any in-flight command.
- Status ReleaseDescriptorPools(absl::Span<DescriptorPool> descriptor_pools);
-
- private:
- ref_ptr<VkDeviceHandle> logical_device_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_
diff --git a/iree/hal/vulkan/descriptor_set_arena.cc b/iree/hal/vulkan/descriptor_set_arena.cc
deleted file mode 100644
index 636cfcd..0000000
--- a/iree/hal/vulkan/descriptor_set_arena.cc
+++ /dev/null
@@ -1,204 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/descriptor_set_arena.h"
-
-#include "iree/base/math.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/vulkan/status_util.h"
-#include "iree/hal/vulkan/vma_buffer.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-namespace {
-
-StatusOr<VmaBuffer*> CastBuffer(Buffer* buffer) {
- // TODO(benvanik): assert that the buffer is from the right allocator and
- // that it is compatible with our target queue family.
- return static_cast<VmaBuffer*>(buffer->allocated_buffer());
-}
-
-StatusOr<absl::Span<VkWriteDescriptorSet>> PopulateDescriptorSetWriteInfos(
- const PipelineDescriptorSets& pipeline_descriptor_sets,
- absl::Span<const BufferBinding> bindings, VkDescriptorSet dst_set,
- Arena* arena) {
- int required_descriptor_count =
- pipeline_descriptor_sets.buffer_binding_set_map.size();
-
- arena->Reset();
- auto buffer_infos =
- arena->AllocateSpan<VkDescriptorBufferInfo>(required_descriptor_count);
- auto write_infos = arena->AllocateSpan<VkWriteDescriptorSet>(bindings.size());
-
- for (int i = 0; i < bindings.size(); ++i) {
- const auto& binding = bindings[i];
-
- auto& buffer_info = buffer_infos[i];
- ASSIGN_OR_RETURN(auto buffer, CastBuffer(binding.buffer));
- buffer_info.buffer = buffer->handle();
- // TODO(benvanik): properly subrange (add to BufferBinding).
- buffer_info.offset = binding.buffer->byte_offset();
- buffer_info.range = binding.buffer->byte_length();
-
- auto& write_info = write_infos[i];
- write_info.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
- write_info.pNext = nullptr;
- write_info.dstSet = dst_set;
- write_info.dstBinding = pipeline_descriptor_sets.buffer_binding_set_map[i];
- write_info.dstArrayElement = 0;
- write_info.descriptorCount = 1;
- write_info.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
- write_info.pImageInfo = nullptr;
- write_info.pBufferInfo = &buffer_info;
- write_info.pTexelBufferView = nullptr;
- }
-
- return write_infos;
-}
-
-} // namespace
-
-DescriptorSetArena::DescriptorSetArena(
- ref_ptr<DescriptorPoolCache> descriptor_pool_cache)
- : logical_device_(add_ref(descriptor_pool_cache->logical_device())),
- descriptor_pool_cache_(std::move(descriptor_pool_cache)) {}
-
-DescriptorSetArena::~DescriptorSetArena() {
- if (!used_descriptor_pools_.empty()) {
- descriptor_pool_cache_
- ->ReleaseDescriptorPools(absl::MakeSpan(used_descriptor_pools_))
- .IgnoreError();
- used_descriptor_pools_.clear();
- }
-}
-
-Status DescriptorSetArena::BindDescriptorSet(
- VkCommandBuffer command_buffer, PipelineExecutable* executable,
- absl::Span<const BufferBinding> bindings) {
- // Always prefer using push descriptors when available as we can avoid the
- // additional API overhead of updating/resetting pools.
- if (logical_device_->enabled_extensions().push_descriptors) {
- return PushDescriptorSet(command_buffer, executable, bindings);
- }
-
- IREE_TRACE_SCOPE0("DescriptorSetArena::BindDescriptorSet");
-
- // Pick a bucket based on the number of descriptors required.
- // NOTE: right now we are 1:1 with bindings.
- int required_descriptor_count = bindings.size() * 1;
- int max_descriptor_count =
- std::max(8, RoundUpToNearestPow2(required_descriptor_count));
- int bucket = TrailingZeros(max_descriptor_count >> 3);
- if (bucket >= descriptor_pool_buckets_.size()) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Too many descriptors required: " << required_descriptor_count
- << " (max=" << (1 << (descriptor_pool_buckets_.size() + 3)) << ")";
- }
- if (descriptor_pool_buckets_[bucket].handle == VK_NULL_HANDLE) {
- // Acquire a pool for this max_descriptor_count bucket.
- ASSIGN_OR_RETURN(
- descriptor_pool_buckets_[bucket],
- descriptor_pool_cache_->AcquireDescriptorPool(
- VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, max_descriptor_count));
- used_descriptor_pools_.push_back(descriptor_pool_buckets_[bucket]);
- }
- auto& descriptor_pool = descriptor_pool_buckets_[bucket];
-
- const auto& pipeline_descriptor_sets = executable->descriptor_sets();
-
- VkDescriptorSetAllocateInfo allocate_info;
- allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
- allocate_info.pNext = nullptr;
- allocate_info.descriptorPool = descriptor_pool.handle;
- allocate_info.descriptorSetCount = 1;
- allocate_info.pSetLayouts =
- &pipeline_descriptor_sets.buffer_binding_set_layout;
- VkDescriptorSet descriptor_set = VK_NULL_HANDLE;
- VkResult result = syms().vkAllocateDescriptorSets(
- *logical_device_, &allocate_info, &descriptor_set);
- if (result == VK_ERROR_OUT_OF_POOL_MEMORY) {
- // Allocation failed because the pool is either out of descriptors or too
- // fragmented. We'll just allocate another pool.
- ASSIGN_OR_RETURN(
- descriptor_pool_buckets_[bucket],
- descriptor_pool_cache_->AcquireDescriptorPool(
- VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, max_descriptor_count));
- used_descriptor_pools_.push_back(descriptor_pool_buckets_[bucket]);
- }
-
- // Get a list of VkWriteDescriptorSet structs with all bound buffers.
- ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos(
- pipeline_descriptor_sets, bindings,
- descriptor_set, &scratch_arena_));
-
- // This is the reason why push descriptor sets are good.
- // We can't batch these effectively as we don't know prior to recording what
- // descriptor sets we will need and what buffers they will point to (without
- // doing just as much work as actually recording the buffer to try to find
- // out).
- syms().vkUpdateDescriptorSets(*logical_device_, write_infos.size(),
- write_infos.data(), 0, nullptr);
-
- // Bind the descriptor set.
- syms().vkCmdBindDescriptorSets(command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE,
- executable->pipeline_layout(),
- pipeline_descriptor_sets.buffer_binding_set, 1,
- &descriptor_set, 0, nullptr);
-
- return OkStatus();
-}
-
-Status DescriptorSetArena::PushDescriptorSet(
- VkCommandBuffer command_buffer, PipelineExecutable* executable,
- absl::Span<const BufferBinding> bindings) {
- IREE_TRACE_SCOPE0("DescriptorSetArena::PushDescriptorSet");
-
- const auto& pipeline_descriptor_sets = executable->descriptor_sets();
-
- // Get a list of VkWriteDescriptorSet structs with all bound buffers.
- ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos(
- pipeline_descriptor_sets, bindings,
- VK_NULL_HANDLE, &scratch_arena_));
-
- // Fast path using push descriptors. These are pooled internally by the
- // command buffer and prevent the need for our own pooling mechanisms.
- syms().vkCmdPushDescriptorSetKHR(command_buffer,
- VK_PIPELINE_BIND_POINT_COMPUTE,
- executable->pipeline_layout(),
- pipeline_descriptor_sets.buffer_binding_set,
- write_infos.size(), write_infos.data());
-
- return OkStatus();
-}
-
-StatusOr<DescriptorSetGroup> DescriptorSetArena::Flush() {
- IREE_TRACE_SCOPE0("DescriptorSetArena::Flush");
-
- if (used_descriptor_pools_.empty()) {
- // No resources to free.
- return DescriptorSetGroup{};
- }
-
- for (auto& bucket : descriptor_pool_buckets_) {
- bucket = {};
- }
- return DescriptorSetGroup(add_ref(descriptor_pool_cache_),
- std::move(used_descriptor_pools_));
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/descriptor_set_arena.h b/iree/hal/vulkan/descriptor_set_arena.h
deleted file mode 100644
index bbbe5bc..0000000
--- a/iree/hal/vulkan/descriptor_set_arena.h
+++ /dev/null
@@ -1,76 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_
-#define IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_
-
-#include <array>
-#include <vector>
-
-#include "iree/base/arena.h"
-#include "iree/base/status.h"
-#include "iree/hal/command_buffer.h"
-#include "iree/hal/vulkan/descriptor_pool_cache.h"
-#include "iree/hal/vulkan/pipeline_executable.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// A reusable arena for allocating descriptor sets and batching updates.
-class DescriptorSetArena final {
- public:
- explicit DescriptorSetArena(
- ref_ptr<DescriptorPoolCache> descriptor_pool_cache);
- ~DescriptorSetArena();
-
- // Allocates and binds a descriptor set from the arena.
- // The command buffer will have the descriptor set containing |bindings| bound
- // to it.
- Status BindDescriptorSet(VkCommandBuffer command_buffer,
- PipelineExecutable* executable,
- absl::Span<const BufferBinding> bindings);
-
- // Flushes all pending writes to descriptor sets allocated from the arena and
- // returns a group that - when dropped - will release the descriptor sets
- // back to the pools they were allocated from.
- StatusOr<DescriptorSetGroup> Flush();
-
- private:
- const DynamicSymbols& syms() const { return *logical_device_->syms(); }
-
- // Pushes the descriptor set to the command buffer, if supported.
- Status PushDescriptorSet(VkCommandBuffer command_buffer,
- PipelineExecutable* executable,
- absl::Span<const BufferBinding> bindings);
-
- ref_ptr<VkDeviceHandle> logical_device_;
- ref_ptr<DescriptorPoolCache> descriptor_pool_cache_;
-
- // Arena used for temporary binding information used during allocation.
- Arena scratch_arena_;
-
- // A list of pools acquired on demand as different descriptor counts are
- // needed. Allocation granularity is max_descriptor_count=[8, 16, 32, 64].
- std::array<DescriptorPool, 4> descriptor_pool_buckets_;
-
- // All pools that have been used during allocation.
- absl::InlinedVector<DescriptorPool, 8> used_descriptor_pools_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_
diff --git a/iree/hal/vulkan/direct_command_buffer.cc b/iree/hal/vulkan/direct_command_buffer.cc
deleted file mode 100644
index 7ced0a2..0000000
--- a/iree/hal/vulkan/direct_command_buffer.cc
+++ /dev/null
@@ -1,403 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/direct_command_buffer.h"
-
-#include "absl/base/attributes.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/base/math.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/vulkan/status_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-namespace {
-
-VkPipelineStageFlags ConvertPipelineStageFlags(
- ExecutionStageBitfield stage_mask) {
- VkPipelineStageFlags flags = 0;
- flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandIssue)
- ? VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT
- : 0;
- flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandProcess)
- ? VK_PIPELINE_STAGE_DRAW_INDIRECT_BIT
- : 0;
- flags |= AnyBitSet(stage_mask & ExecutionStage::kDispatch)
- ? VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT
- : 0;
- flags |= AnyBitSet(stage_mask & ExecutionStage::kTransfer)
- ? VK_PIPELINE_STAGE_TRANSFER_BIT
- : 0;
- flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandRetire)
- ? VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT
- : 0;
- flags |= AnyBitSet(stage_mask & ExecutionStage::kHost)
- ? VK_PIPELINE_STAGE_HOST_BIT
- : 0;
- return flags;
-}
-
-VkAccessFlags ConvertAccessMask(AccessScopeBitfield access_mask) {
- VkAccessFlags flags = 0;
- flags |= AnyBitSet(access_mask & AccessScope::kIndirectCommandRead)
- ? VK_ACCESS_INDIRECT_COMMAND_READ_BIT
- : 0;
- flags |= AnyBitSet(access_mask & AccessScope::kConstantRead)
- ? VK_ACCESS_UNIFORM_READ_BIT
- : 0;
- flags |= AnyBitSet(access_mask & AccessScope::kDispatchRead)
- ? VK_ACCESS_SHADER_READ_BIT
- : 0;
- flags |= AnyBitSet(access_mask & AccessScope::kDispatchWrite)
- ? VK_ACCESS_SHADER_WRITE_BIT
- : 0;
- flags |= AnyBitSet(access_mask & AccessScope::kTransferRead)
- ? VK_ACCESS_TRANSFER_READ_BIT
- : 0;
- flags |= AnyBitSet(access_mask & AccessScope::kTransferWrite)
- ? VK_ACCESS_TRANSFER_WRITE_BIT
- : 0;
- flags |= AnyBitSet(access_mask & AccessScope::kHostRead)
- ? VK_ACCESS_HOST_READ_BIT
- : 0;
- flags |= AnyBitSet(access_mask & AccessScope::kHostWrite)
- ? VK_ACCESS_HOST_WRITE_BIT
- : 0;
- flags |= AnyBitSet(access_mask & AccessScope::kMemoryRead)
- ? VK_ACCESS_MEMORY_READ_BIT
- : 0;
- flags |= AnyBitSet(access_mask & AccessScope::kMemoryWrite)
- ? VK_ACCESS_MEMORY_WRITE_BIT
- : 0;
- return flags;
-}
-
-// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value.
-uint32_t SplatPattern(const void* pattern, size_t pattern_length) {
- switch (pattern_length) {
- case 1: {
- uint32_t pattern_value = *static_cast<const uint8_t*>(pattern);
- return (pattern_value << 24) | (pattern_value << 16) |
- (pattern_value << 8) | pattern_value;
- }
- case 2: {
- uint32_t pattern_value = *static_cast<const uint16_t*>(pattern);
- return (pattern_value << 16) | pattern_value;
- }
- case 4: {
- uint32_t pattern_value = *static_cast<const uint32_t*>(pattern);
- return pattern_value;
- }
- default:
- return 0; // Already verified that this should not be possible.
- }
-}
-
-} // namespace
-
-DirectCommandBuffer::DirectCommandBuffer(
- Allocator* allocator, CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories,
- ref_ptr<DescriptorPoolCache> descriptor_pool_cache,
- ref_ptr<VkCommandPoolHandle> command_pool, VkCommandBuffer command_buffer)
- : CommandBuffer(allocator, mode, command_categories),
- command_pool_(std::move(command_pool)),
- command_buffer_(command_buffer),
- descriptor_set_arena_(std::move(descriptor_pool_cache)) {}
-
-DirectCommandBuffer::~DirectCommandBuffer() {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::dtor");
- descriptor_set_group_.Reset().IgnoreError();
- absl::MutexLock lock(command_pool_->mutex());
- syms()->vkFreeCommandBuffers(*command_pool_->logical_device(), *command_pool_,
- 1, &command_buffer_);
-}
-
-StatusOr<NativeEvent*> DirectCommandBuffer::CastEvent(Event* event) const {
- // TODO(benvanik): assert the event is valid.
- return static_cast<NativeEvent*>(event);
-}
-
-StatusOr<VmaBuffer*> DirectCommandBuffer::CastBuffer(Buffer* buffer) const {
- // TODO(benvanik): assert that the buffer is from the right allocator and
- // that it is compatible with our target queue family.
- return static_cast<VmaBuffer*>(buffer->allocated_buffer());
-}
-
-Status DirectCommandBuffer::Begin() {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::Begin");
-
- is_recording_ = true;
-
- // NOTE: we require that command buffers not be recorded while they are
- // in-flight so this is safe.
- RETURN_IF_ERROR(descriptor_set_group_.Reset());
-
- VkCommandBufferBeginInfo begin_info;
- begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
- begin_info.pNext = nullptr;
- begin_info.flags = AllBitsSet(mode(), CommandBufferMode::kOneShot)
- ? VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT
- : 0;
- begin_info.pInheritanceInfo = nullptr;
- VK_RETURN_IF_ERROR(
- syms()->vkBeginCommandBuffer(command_buffer_, &begin_info));
-
- return OkStatus();
-}
-
-Status DirectCommandBuffer::End() {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::End");
-
- VK_RETURN_IF_ERROR(syms()->vkEndCommandBuffer(command_buffer_));
-
- // Flush all pending descriptor set writes (if any).
- ASSIGN_OR_RETURN(descriptor_set_group_, descriptor_set_arena_.Flush());
-
- is_recording_ = false;
-
- return OkStatus();
-}
-
-Status DirectCommandBuffer::ExecutionBarrier(
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::ExecutionBarrier");
-
- absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos(
- memory_barriers.size());
- for (int i = 0; i < memory_barriers.size(); ++i) {
- const auto& memory_barrier = memory_barriers[i];
- auto& info = memory_barrier_infos[i];
- info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
- info.pNext = nullptr;
- info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope);
- info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope);
- }
-
- absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos(
- buffer_barriers.size());
- for (int i = 0; i < buffer_barriers.size(); ++i) {
- const auto& buffer_barrier = buffer_barriers[i];
- auto& info = buffer_barrier_infos[i];
- info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER;
- info.pNext = nullptr;
- info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope);
- info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope);
- info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
- info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
- ASSIGN_OR_RETURN(auto* device_buffer, CastBuffer(buffer_barrier.buffer));
- info.buffer = device_buffer->handle();
- info.offset = buffer_barrier.offset;
- info.size = buffer_barrier.length;
- }
-
- syms()->vkCmdPipelineBarrier(
- command_buffer_, ConvertPipelineStageFlags(source_stage_mask),
- ConvertPipelineStageFlags(target_stage_mask), /*dependencyFlags=*/0,
- memory_barrier_infos.size(), memory_barrier_infos.data(),
- buffer_barrier_infos.size(), buffer_barrier_infos.data(), 0, nullptr);
-
- return OkStatus();
-}
-
-Status DirectCommandBuffer::SignalEvent(
- Event* event, ExecutionStageBitfield source_stage_mask) {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::SignalEvent");
- ASSIGN_OR_RETURN(auto* device_event, CastEvent(event));
- syms()->vkCmdSetEvent(command_buffer_, device_event->handle(),
- ConvertPipelineStageFlags(source_stage_mask));
- return OkStatus();
-}
-
-Status DirectCommandBuffer::ResetEvent(
- Event* event, ExecutionStageBitfield source_stage_mask) {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::ResetEvent");
- ASSIGN_OR_RETURN(auto* device_event, CastEvent(event));
- syms()->vkCmdResetEvent(command_buffer_, device_event->handle(),
- ConvertPipelineStageFlags(source_stage_mask));
- return OkStatus();
-}
-
-Status DirectCommandBuffer::WaitEvents(
- absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::WaitEvents");
-
- absl::InlinedVector<VkEvent, 4> event_handles(events.size());
- for (int i = 0; i < events.size(); ++i) {
- ASSIGN_OR_RETURN(auto* device_event, CastEvent(events[i]));
- event_handles[i] = device_event->handle();
- }
-
- absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos(
- memory_barriers.size());
- for (int i = 0; i < memory_barriers.size(); ++i) {
- const auto& memory_barrier = memory_barriers[i];
- auto& info = memory_barrier_infos[i];
- info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
- info.pNext = nullptr;
- info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope);
- info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope);
- }
-
- absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos(
- buffer_barriers.size());
- for (int i = 0; i < buffer_barriers.size(); ++i) {
- const auto& buffer_barrier = buffer_barriers[i];
- auto& info = buffer_barrier_infos[i];
- info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER;
- info.pNext = nullptr;
- info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope);
- info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope);
- info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
- info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
- ASSIGN_OR_RETURN(auto* device_buffer, CastBuffer(buffer_barrier.buffer));
- info.buffer = device_buffer->handle();
- info.offset = buffer_barrier.offset;
- info.size = buffer_barrier.length;
- }
-
- syms()->vkCmdWaitEvents(
- command_buffer_, event_handles.size(), event_handles.data(),
- ConvertPipelineStageFlags(source_stage_mask),
- ConvertPipelineStageFlags(target_stage_mask), memory_barrier_infos.size(),
- memory_barrier_infos.data(), buffer_barrier_infos.size(),
- buffer_barrier_infos.data(), 0, nullptr);
- return OkStatus();
-}
-
-Status DirectCommandBuffer::FillBuffer(Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length,
- const void* pattern,
- size_t pattern_length) {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::FillBuffer");
- ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer));
-
- // Note that fill only accepts 4-byte aligned values so we need to splat out
- // our variable-length pattern.
- target_offset += target_buffer->byte_offset();
- uint32_t dword_pattern = SplatPattern(pattern, pattern_length);
- syms()->vkCmdFillBuffer(command_buffer_, target_device_buffer->handle(),
- target_offset, length, dword_pattern);
-
- return OkStatus();
-}
-
-Status DirectCommandBuffer::DiscardBuffer(Buffer* buffer) {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::DiscardBuffer");
- // NOTE: we could use this to prevent queue family transitions.
- return OkStatus();
-}
-
-Status DirectCommandBuffer::UpdateBuffer(const void* source_buffer,
- device_size_t source_offset,
- Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length) {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::UpdateBuffer");
- ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer));
-
- // Vulkan only allows updates of <= 65536 because you really, really, really
- // shouldn't do large updates like this (as it wastes command buffer space and
- // may be slower than just using write-through mapped memory). The
- // recommendation in the spec for larger updates is to split the single update
- // into multiple updates over the entire desired range.
- const auto* source_buffer_ptr = static_cast<const uint8_t*>(source_buffer);
- target_offset += target_buffer->byte_offset();
- while (length > 0) {
- device_size_t chunk_length =
- std::min(static_cast<device_size_t>(65536u), length);
- syms()->vkCmdUpdateBuffer(command_buffer_, target_device_buffer->handle(),
- target_offset, chunk_length, source_buffer_ptr);
- source_buffer_ptr += chunk_length;
- target_offset += chunk_length;
- length -= chunk_length;
- }
-
- return OkStatus();
-}
-
-Status DirectCommandBuffer::CopyBuffer(Buffer* source_buffer,
- device_size_t source_offset,
- Buffer* target_buffer,
- device_size_t target_offset,
- device_size_t length) {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::CopyBuffer");
- ASSIGN_OR_RETURN(auto* source_device_buffer, CastBuffer(source_buffer));
- ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer));
-
- VkBufferCopy region;
- region.srcOffset = source_buffer->byte_offset() + source_offset;
- region.dstOffset = target_buffer->byte_offset() + target_offset;
- region.size = length;
- syms()->vkCmdCopyBuffer(command_buffer_, source_device_buffer->handle(),
- target_device_buffer->handle(), 1, ®ion);
-
- return OkStatus();
-}
-
-Status DirectCommandBuffer::Dispatch(const DispatchRequest& dispatch_request) {
- IREE_TRACE_SCOPE0("DirectCommandBuffer::Dispatch");
-
- // Get the compiled and linked pipeline for the specified entry point and
- // bind it to the command buffer.
- auto* executable =
- static_cast<PipelineExecutable*>(dispatch_request.executable);
- ASSIGN_OR_RETURN(VkPipeline pipeline, executable->GetPipelineForEntryPoint(
- dispatch_request.entry_point));
- syms()->vkCmdBindPipeline(command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
- pipeline);
-
- // Either allocate, update, and bind a descriptor set or use push descriptor
- // sets to use the command buffer pool when supported.
- RETURN_IF_ERROR(descriptor_set_arena_.BindDescriptorSet(
- command_buffer_, executable, dispatch_request.bindings));
-
- // TODO(benvanik): divide workload by caps and issue multiple dispatches.
- // TODO(benvanik): track local workgroup/subgroup size and divide into groups.
- if (dispatch_request.workload_buffer) {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Dynamic dispatches not yet implemented";
- }
- uint32_t group_count_x = dispatch_request.workload[0];
- uint32_t group_count_y = dispatch_request.workload[1];
- uint32_t group_count_z = dispatch_request.workload[2];
-
- // TODO(GH-67): pre-divide workload by tile size.
- if (executable->is_matmul()) {
- group_count_x = (group_count_x + 16 - 1) / 16;
- group_count_y = (group_count_y + 16 - 1) / 16;
- group_count_z = 1;
- }
-
- syms()->vkCmdDispatch(command_buffer_, group_count_x, group_count_y,
- group_count_z);
-
- return OkStatus();
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/direct_command_buffer.h b/iree/hal/vulkan/direct_command_buffer.h
deleted file mode 100644
index 989ea77..0000000
--- a/iree/hal/vulkan/direct_command_buffer.h
+++ /dev/null
@@ -1,103 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
-#define IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
-
-#include <vulkan/vulkan.h>
-
-#include "iree/hal/command_buffer.h"
-#include "iree/hal/vulkan/descriptor_pool_cache.h"
-#include "iree/hal/vulkan/descriptor_set_arena.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-#include "iree/hal/vulkan/handle_util.h"
-#include "iree/hal/vulkan/native_event.h"
-#include "iree/hal/vulkan/pipeline_executable.h"
-#include "iree/hal/vulkan/vma_buffer.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// Command buffer implementation that directly maps to VkCommandBuffer.
-// This records the commands on the calling thread without additional threading
-// indirection.
-class DirectCommandBuffer final : public CommandBuffer {
- public:
- DirectCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories,
- ref_ptr<DescriptorPoolCache> descriptor_pool_cache,
- ref_ptr<VkCommandPoolHandle> command_pool,
- VkCommandBuffer command_buffer);
- ~DirectCommandBuffer() override;
-
- VkCommandBuffer handle() const { return command_buffer_; }
-
- bool is_recording() const override { return is_recording_; }
-
- Status Begin() override;
- Status End() override;
-
- Status ExecutionBarrier(
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) override;
- Status SignalEvent(Event* event,
- ExecutionStageBitfield source_stage_mask) override;
- Status ResetEvent(Event* event,
- ExecutionStageBitfield source_stage_mask) override;
- Status WaitEvents(absl::Span<Event*> events,
- ExecutionStageBitfield source_stage_mask,
- ExecutionStageBitfield target_stage_mask,
- absl::Span<const MemoryBarrier> memory_barriers,
- absl::Span<const BufferBarrier> buffer_barriers) override;
-
- Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
- device_size_t length, const void* pattern,
- size_t pattern_length) override;
- Status DiscardBuffer(Buffer* buffer) override;
- Status UpdateBuffer(const void* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length) override;
- Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
- Buffer* target_buffer, device_size_t target_offset,
- device_size_t length) override;
-
- Status Dispatch(const DispatchRequest& dispatch_request) override;
-
- private:
- const ref_ptr<DynamicSymbols>& syms() const { return command_pool_->syms(); }
-
- StatusOr<NativeEvent*> CastEvent(Event* event) const;
- StatusOr<VmaBuffer*> CastBuffer(Buffer* buffer) const;
-
- bool is_recording_ = false;
- ref_ptr<VkCommandPoolHandle> command_pool_;
- VkCommandBuffer command_buffer_;
-
- // TODO(b/140026716): may grow large - should try to reclaim or reuse.
- DescriptorSetArena descriptor_set_arena_;
-
- // The current descriptor set group in use by the command buffer, if any.
- // This must remain valid until all in-flight submissions of the command
- // buffer complete.
- DescriptorSetGroup descriptor_set_group_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
diff --git a/iree/hal/vulkan/direct_command_queue.cc b/iree/hal/vulkan/direct_command_queue.cc
deleted file mode 100644
index fb5cd59..0000000
--- a/iree/hal/vulkan/direct_command_queue.cc
+++ /dev/null
@@ -1,201 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/direct_command_queue.h"
-
-#include <cstdint>
-
-#include "absl/time/clock.h"
-#include "absl/time/time.h"
-#include "iree/base/memory.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/vulkan/direct_command_buffer.h"
-#include "iree/hal/vulkan/legacy_fence.h"
-#include "iree/hal/vulkan/native_binary_semaphore.h"
-#include "iree/hal/vulkan/status_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-DirectCommandQueue::DirectCommandQueue(
- std::string name, CommandCategoryBitfield supported_categories,
- const ref_ptr<VkDeviceHandle>& logical_device, VkQueue queue)
- : CommandQueue(std::move(name), supported_categories),
- logical_device_(add_ref(logical_device)),
- queue_(queue) {}
-
-DirectCommandQueue::~DirectCommandQueue() {
- IREE_TRACE_SCOPE0("DirectCommandQueue::dtor");
- absl::MutexLock lock(&queue_mutex_);
- syms()->vkQueueWaitIdle(queue_);
-}
-
-Status DirectCommandQueue::TranslateBatchInfo(const SubmissionBatch& batch,
- VkSubmitInfo* submit_info,
- Arena* arena) {
- // TODO(benvanik): see if we can go to finer-grained stages.
- // For example, if this was just queue ownership transfers then we can use
- // the pseudo-stage of VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT.
- VkPipelineStageFlags dst_stage_mask =
- VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT;
-
- auto wait_semaphore_handles =
- arena->AllocateSpan<VkSemaphore>(batch.wait_semaphores.size());
- auto wait_dst_stage_masks =
- arena->AllocateSpan<VkPipelineStageFlags>(batch.wait_semaphores.size());
- for (int i = 0; i < batch.wait_semaphores.size(); ++i) {
- const auto& semaphore_value = batch.wait_semaphores[i];
- if (semaphore_value.index() == 0) {
- const auto& binary_semaphore =
- static_cast<NativeBinarySemaphore*>(absl::get<0>(semaphore_value));
- wait_semaphore_handles[i] = binary_semaphore->handle();
- } else {
- // TODO(b/140141417): implement timeline semaphores.
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Timeline semaphores not yet implemented";
- }
- wait_dst_stage_masks[i] = dst_stage_mask;
- }
-
- auto signal_semaphore_handles =
- arena->AllocateSpan<VkSemaphore>(batch.signal_semaphores.size());
- for (int i = 0; i < batch.signal_semaphores.size(); ++i) {
- const auto& semaphore_value = batch.signal_semaphores[i];
- if (semaphore_value.index() == 0) {
- const auto& binary_semaphore =
- static_cast<NativeBinarySemaphore*>(absl::get<0>(semaphore_value));
- signal_semaphore_handles[i] = binary_semaphore->handle();
- } else {
- // TODO(b/140141417): implement timeline semaphores.
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Timeline semaphores not yet implemented";
- }
- }
-
- auto command_buffer_handles =
- arena->AllocateSpan<VkCommandBuffer>(batch.command_buffers.size());
- for (int i = 0; i < batch.command_buffers.size(); ++i) {
- const auto& command_buffer = batch.command_buffers[i];
- auto* direct_command_buffer =
- static_cast<DirectCommandBuffer*>(command_buffer->impl());
- command_buffer_handles[i] = direct_command_buffer->handle();
- }
-
- submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
- submit_info->pNext = nullptr;
- submit_info->waitSemaphoreCount = wait_semaphore_handles.size();
- submit_info->pWaitSemaphores = wait_semaphore_handles.data();
- submit_info->pWaitDstStageMask = wait_dst_stage_masks.data();
- submit_info->commandBufferCount = command_buffer_handles.size();
- submit_info->pCommandBuffers = command_buffer_handles.data();
- submit_info->signalSemaphoreCount = signal_semaphore_handles.size();
- submit_info->pSignalSemaphores = signal_semaphore_handles.data();
-
- return OkStatus();
-}
-
-Status DirectCommandQueue::Submit(absl::Span<const SubmissionBatch> batches,
- FenceValue fence) {
- IREE_TRACE_SCOPE0("DirectCommandQueue::Submit");
-
- // Map the submission batches to VkSubmitInfos.
- // Note that we must keep all arrays referenced alive until submission
- // completes and since there are a bunch of them we use an arena.
- Arena arena(4 * 1024);
- auto submit_infos = arena.AllocateSpan<VkSubmitInfo>(batches.size());
- for (int i = 0; i < batches.size(); ++i) {
- RETURN_IF_ERROR(TranslateBatchInfo(batches[i], &submit_infos[i], &arena));
- }
-
- // TODO(b/140141417): implement timeline semaphore fences and switch here.
- auto legacy_fence = reinterpret_cast<LegacyFence*>(fence.first);
- ASSIGN_OR_RETURN(VkFence fence_handle,
- legacy_fence->AcquireSignalFence(fence.second));
-
- {
- absl::MutexLock lock(&queue_mutex_);
- VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(
- queue_, submit_infos.size(), submit_infos.data(), fence_handle));
- }
-
- return OkStatus();
-}
-
-Status DirectCommandQueue::WaitIdle(absl::Time deadline) {
- if (deadline == absl::InfiniteFuture()) {
- // Fast path for using vkQueueWaitIdle, which is usually cheaper (as it
- // requires fewer calls into the driver).
- IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#vkQueueWaitIdle");
- absl::MutexLock lock(&queue_mutex_);
- VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_));
- return OkStatus();
- }
-
- IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#Fence");
-
- // Create a new fence just for this wait. This keeps us thread-safe as the
- // behavior of wait+reset is racey.
- VkFenceCreateInfo create_info;
- create_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
- create_info.pNext = nullptr;
- create_info.flags = 0;
- VkFence fence = VK_NULL_HANDLE;
- VK_RETURN_IF_ERROR(syms()->vkCreateFence(
- *logical_device_, &create_info, logical_device_->allocator(), &fence));
- auto fence_cleanup = MakeCleanup([this, fence]() {
- syms()->vkDestroyFence(*logical_device_, fence,
- logical_device_->allocator());
- });
-
- uint64_t timeout;
- if (deadline == absl::InfinitePast()) {
- // Do not wait.
- timeout = 0;
- } else if (deadline == absl::InfiniteFuture()) {
- // Wait forever.
- timeout = UINT64_MAX;
- } else {
- // Convert to relative time in nanoseconds.
- // The implementation may not wait with this granularity (like, by 10000x).
- absl::Time now = absl::Now();
- if (deadline < now) {
- return DeadlineExceededErrorBuilder(IREE_LOC) << "Deadline in the past";
- }
- timeout = static_cast<uint64_t>(absl::ToInt64Nanoseconds(deadline - now));
- }
-
- {
- absl::MutexLock lock(&queue_mutex_);
- VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(queue_, 0, nullptr, fence));
- }
-
- VkResult result =
- syms()->vkWaitForFences(*logical_device_, 1, &fence, VK_TRUE, timeout);
- switch (result) {
- case VK_SUCCESS:
- return OkStatus();
- case VK_TIMEOUT:
- return DeadlineExceededErrorBuilder(IREE_LOC)
- << "Deadline exceeded waiting for idle";
- default:
- return VkResultToStatus(result);
- }
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/direct_command_queue.h b/iree/hal/vulkan/direct_command_queue.h
deleted file mode 100644
index 99fd0dd..0000000
--- a/iree/hal/vulkan/direct_command_queue.h
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
-#define IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
-
-#include <vulkan/vulkan.h>
-
-#include <cstdint>
-#include <string>
-
-#include "absl/base/thread_annotations.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/time/time.h"
-#include "iree/base/arena.h"
-#include "iree/base/status.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-#include "iree/hal/vulkan/handle_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// Command queue implementation directly maps to VkQueue.
-class DirectCommandQueue final : public CommandQueue {
- public:
- DirectCommandQueue(std::string name,
- CommandCategoryBitfield supported_categories,
- const ref_ptr<VkDeviceHandle>& logical_device,
- VkQueue queue);
- ~DirectCommandQueue() override;
-
- const ref_ptr<DynamicSymbols>& syms() const {
- return logical_device_->syms();
- }
-
- Status Submit(absl::Span<const SubmissionBatch> batches,
- FenceValue fence) override;
-
- Status WaitIdle(absl::Time deadline) override;
-
- private:
- Status TranslateBatchInfo(const SubmissionBatch& batch,
- VkSubmitInfo* submit_info, Arena* arena);
-
- ref_ptr<VkDeviceHandle> logical_device_;
-
- // VkQueue needs to be externally synchronized.
- mutable absl::Mutex queue_mutex_;
- VkQueue queue_ ABSL_GUARDED_BY(queue_mutex_);
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
diff --git a/iree/hal/vulkan/dynamic_symbols.cc b/iree/hal/vulkan/dynamic_symbols.cc
deleted file mode 100644
index 00380ec..0000000
--- a/iree/hal/vulkan/dynamic_symbols.cc
+++ /dev/null
@@ -1,238 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/dynamic_symbols.h"
-
-#include <cstddef>
-#include <cstdlib>
-
-#include "absl/base/attributes.h"
-#include "absl/base/macros.h"
-#include "absl/memory/memory.h"
-#include "iree/base/file_path.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/target_platform.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/vulkan/dynamic_symbol_tables.h"
-
-#if defined(IREE_PLATFORM_WINDOWS)
-#include <windows.h>
-#else
-#include <dlfcn.h>
-#endif // IREE_PLATFORM_WINDOWS
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// Read-only table of function pointer information designed to be in .rdata.
-// To reduce binary size this structure is packed (knowing that we won't have
-// gigabytes of function pointers :).
-struct FunctionPtrInfo {
- // Name of the function (like 'vkSomeFunction').
- const char* function_name;
- // 1 if the function pointer can be resolved via vkGetDeviceProcAddr.
- uint32_t is_device : 1;
- // 1 if the function is required and the loader should bail if not found.
- uint32_t is_required : 1;
- // TODO(benvanik): remove from table by manually walking sizeof(uintptr_t).
- // An offset in bytes from the base of &syms to where the PFN_vkSomeFunction
- // member is located.
- uint32_t member_offset : 30;
-} ABSL_ATTRIBUTE_PACKED;
-
-namespace {
-
-#define REQUIRED_PFN_FUNCTION_PTR(function_name, is_device) \
- {#function_name, is_device, 1, offsetof(DynamicSymbols, function_name)},
-#define OPTIONAL_PFN_FUNCTION_PTR(function_name, is_device) \
- {#function_name, is_device, 0, offsetof(DynamicSymbols, function_name)},
-#define EXCLUDED_PFN_FUNCTION_PTR(function_name, is_device)
-#define INS_PFN_FUNCTION_PTR(requirement, function_name) \
- requirement##_PFN_FUNCTION_PTR(function_name, 0)
-#define DEV_PFN_FUNCTION_PTR(requirement, function_name) \
- requirement##_PFN_FUNCTION_PTR(function_name, 1)
-
-// Defines the table of mandatory FunctionPtrInfos resolved prior to instance
-// creation. These are safe to call with no instance parameter and should be
-// exported by all loaders/ICDs.
-static constexpr const FunctionPtrInfo kInstancelessFunctionPtrInfos[] = {
- REQUIRED_PFN_FUNCTION_PTR(vkCreateInstance, false) //
- REQUIRED_PFN_FUNCTION_PTR(vkEnumerateInstanceLayerProperties, false) //
- REQUIRED_PFN_FUNCTION_PTR(vkEnumerateInstanceExtensionProperties, false) //
-};
-
-// Defines the table of FunctionPtrInfos for dynamic loading that must wait
-// until an instance has been created to be resolved.
-static constexpr const FunctionPtrInfo kDynamicFunctionPtrInfos[] = {
- IREE_VULKAN_DYNAMIC_SYMBOL_TABLES(INS_PFN_FUNCTION_PTR,
- DEV_PFN_FUNCTION_PTR)};
-
-} // namespace
-
-// static
-StatusOr<ref_ptr<DynamicSymbols>> DynamicSymbols::Create(
- const GetProcAddrFn& get_proc_addr) {
- IREE_TRACE_SCOPE0("DynamicSymbols::Create");
-
- auto syms = make_ref<DynamicSymbols>();
-
- // Resolve the method the shared object uses to resolve other functions.
- // Some libraries will export all symbols while others will only export this
- // single function.
- syms->vkGetInstanceProcAddr = reinterpret_cast<PFN_vkGetInstanceProcAddr>(
- get_proc_addr("vkGetInstanceProcAddr"));
- if (!syms->vkGetInstanceProcAddr) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Required method vkGetInstanceProcAddr not "
- "found in provided Vulkan library (did you pick the wrong file?)";
- }
-
- // Resolve the mandatory functions that we need to create instances.
- // If the provided |get_proc_addr| cannot resolve these then it's not a loader
- // or ICD we want to use, anyway.
- for (int i = 0; i < ABSL_ARRAYSIZE(kInstancelessFunctionPtrInfos); ++i) {
- const auto& function_ptr = kInstancelessFunctionPtrInfos[i];
- auto* member_ptr = reinterpret_cast<PFN_vkVoidFunction*>(
- reinterpret_cast<uint8_t*>(syms.get()) + function_ptr.member_offset);
- *member_ptr =
- syms->vkGetInstanceProcAddr(VK_NULL_HANDLE, function_ptr.function_name);
- if (*member_ptr == nullptr) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Mandatory Vulkan function " << function_ptr.function_name
- << " not available; invalid loader/ICD?";
- }
- }
-
- return syms;
-}
-
-// static
-StatusOr<ref_ptr<DynamicSymbols>> DynamicSymbols::CreateFromSystemLoader() {
- IREE_TRACE_SCOPE0("DynamicSymbols::CreateFromSystemLoader");
-
-#if defined(IREE_VK_ICD_FILENAMES)
-#define IREE_STRINGIFY_(x) #x
-#define IREE_STRING_(x) IREE_STRINGIFY_(x)
- std::string vk_icd_filenames = IREE_STRING_(IREE_VK_ICD_FILENAMES);
-#undef IREE_STRINGIFY_
-#undef IREE_STRING_
-#if defined(IREE_PLATFORM_WINDOWS)
- // TODO(b/138220713): Set VK_ICD_FILENAMES on Windows
-#else
- ::setenv("VK_ICD_FILENAMES", vk_icd_filenames.c_str(), 0);
-#endif // IREE_PLATFORM_WINDOWS
-#else
- // Leave VK_ICD_FILENAMES unchanged and rely on the system Vulkan loader to
- // discover ICDs.
-#endif // IREE_VK_ICD_FILENAMES
-
-// NOTE: we could factor this out into base, but this is the only place we use
-// it right now so it's fine.
-#if defined(IREE_PLATFORM_WINDOWS)
- HMODULE library = ::LoadLibraryA("vulkan-1.dll");
- if (!library) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Unable to open vulkan-1.dll; driver not installed/on PATH";
- }
- ASSIGN_OR_RETURN(auto syms, Create([library](const char* function_name) {
- return reinterpret_cast<PFN_vkVoidFunction>(
- ::GetProcAddress(library, function_name));
- }));
- syms->close_fn_ = [library]() {
- // TODO(benvanik): disable if we want to get profiling results. Sometimes
- // closing the library can prevent proper symbolization on crashes or
- // in sampling profilers.
- ::FreeLibrary(library);
- };
- return syms;
-#else
- void* library = ::dlopen("libvulkan.so.1", RTLD_LAZY | RTLD_LOCAL);
- if (!library) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Unable to open libvulkan.so; driver not installed/on "
- "LD_LIBRARY_PATH";
- }
- ASSIGN_OR_RETURN(auto syms, Create([library](const char* function_name) {
- return reinterpret_cast<PFN_vkVoidFunction>(
- ::dlsym(library, function_name));
- }));
- syms->close_fn_ = [library]() {
- // TODO(benvanik): disable if we want to get profiling results. Sometimes
- // closing the library can prevent proper symbolization on crashes or
- // in sampling profilers.
- ::dlclose(library);
- };
- return syms;
-#endif // IREE_PLATFORM_WINDOWS
-}
-
-Status DynamicSymbols::LoadFromInstance(VkInstance instance) {
- IREE_TRACE_SCOPE0("DynamicSymbols::LoadFromInstance");
- return LoadFromDevice(instance, VK_NULL_HANDLE);
-}
-
-Status DynamicSymbols::LoadFromDevice(VkInstance instance, VkDevice device) {
- IREE_TRACE_SCOPE0("DynamicSymbols::LoadFromDevice");
-
- if (!instance) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Instance must have been created and a default instance proc "
- "lookup function is required";
- }
-
- // Setup the lookup methods first. The rest of the syms uses these to
- // resolve function pointers.
- this->vkGetDeviceProcAddr = reinterpret_cast<PFN_vkGetDeviceProcAddr>(
- this->vkGetInstanceProcAddr(instance, "vkGetDeviceProcAddr"));
- if (!this->vkGetDeviceProcAddr) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Required Vulkan function vkGetDeviceProcAddr not available; "
- "invalid driver handle?";
- }
-
- // Load the rest of the functions.
- for (int i = 0; i < ABSL_ARRAYSIZE(kDynamicFunctionPtrInfos); ++i) {
- const auto& function_ptr = kDynamicFunctionPtrInfos[i];
- auto* member_ptr = reinterpret_cast<PFN_vkVoidFunction*>(
- reinterpret_cast<uint8_t*>(this) + function_ptr.member_offset);
- if (function_ptr.is_device && device) {
- *member_ptr =
- this->vkGetDeviceProcAddr(device, function_ptr.function_name);
- } else {
- *member_ptr =
- this->vkGetInstanceProcAddr(instance, function_ptr.function_name);
- }
- if (*member_ptr == nullptr && function_ptr.is_required) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Required Vulkan function " << function_ptr.function_name
- << " not available";
- }
- }
-
- return OkStatus();
-}
-
-DynamicSymbols::DynamicSymbols() = default;
-
-DynamicSymbols::~DynamicSymbols() {
- if (close_fn_) {
- close_fn_();
- }
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/dynamic_symbols.h b/iree/hal/vulkan/dynamic_symbols.h
deleted file mode 100644
index 429c7f6..0000000
--- a/iree/hal/vulkan/dynamic_symbols.h
+++ /dev/null
@@ -1,129 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
-#define IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
-
-#include <vulkan/vulkan.h>
-
-#include <cstdint>
-#include <functional>
-#include <memory>
-
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-#include "iree/hal/vulkan/dynamic_symbol_tables.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-struct FunctionPtrInfo;
-
-// Dynamic Vulkan function loader for use with vulkan.hpp.
-// This loader is a subset of the DispatchLoaderDynamic implementation that only
-// loads functions we are interested in (a compute-specific subset) and avoids
-// extensions we will never use.
-//
-// This exposes all Vulkan methods as function pointer members. Optional
-// methods will be nullptr if not present. Excluded methods will be omitted.
-//
-// DynamicSymbols instances are designed to be passed to vulkan.hpp methods as
-// the last argument, though they may also be called directly.
-// **Always make sure to pass the loader to vulkan.hpp methods!**
-//
-// Loading is performed by walking a table of required and optional functions
-// (defined in dynamic_symbol_tables.h) and populating the member function
-// pointers exposed on this struct when available. For example, if the
-// vkSomeFunction method is marked in the table as OPTIONAL the loader will
-// attempt to lookup the function and if successful set the
-// DynamicSymbols::vkSomeFunction pointer to the resolved address. If the
-// function is not found then it will be set to nullptr so users can check for
-// function availability.
-//
-// Documentation:
-// https://github.com/KhronosGroup/Vulkan-Hpp#extensions--per-device-function-pointers
-//
-// Usage:
-// ASSIGN_OR_RETURN(auto syms, DynamicSymbols::CreateFromSystemLoader());
-// VkInstance instance = VK_NULL_HANDLE;
-// syms->vkCreateInstance(..., &instance);
-// RETURN_IF_ERROR(syms->LoadFromInstance(instance));
-struct DynamicSymbols : public RefObject<DynamicSymbols> {
- using GetProcAddrFn =
- std::function<PFN_vkVoidFunction(const char* function_name)>;
-
- DynamicSymbols();
- ~DynamicSymbols();
-
- // Creates the dynamic symbol table using the given |get_proc_addr| to resolve
- // the vkCreateInstance function.
- //
- // After the instance is created the caller must use LoadFromInstance (or
- // LoadFromDevice) to load the remaining symbols.
- static StatusOr<ref_ptr<DynamicSymbols>> Create(
- const GetProcAddrFn& get_proc_addr);
-
- // Loads all required and optional Vulkan functions from the Vulkan loader.
- // This will look for a Vulkan loader on the system (like libvulkan.so) and
- // dlsym the functions from that.
- //
- // The loaded function pointers will point to thunks in the ICD. This may
- // enable additional debug checking and more readable stack traces (as
- // errors come from within the ICD, where we have symbols).
- static StatusOr<ref_ptr<DynamicSymbols>> CreateFromSystemLoader();
-
- // Loads all required and optional Vulkan functions from the given instance.
- //
- // The loaded function pointers will point to thunks in the ICD. This may
- // enable additional debug checking and more readable stack traces (as
- // errors come from within the ICD, where we have symbols).
- Status LoadFromInstance(VkInstance instance);
-
- // Loads all required and optional Vulkan functions from the given device,
- // falling back to the instance when required.
- //
- // This attempts to directly query the methods from the device, bypassing any
- // ICD or shim layers. These methods will generally have less overhead at
- // runtime as they need not jump through the various trampolines.
- Status LoadFromDevice(VkInstance instance, VkDevice device);
-
- // Define members for each function pointer.
- // See dynamic_symbol_tables.h for the full list of methods.
- //
- // Each required and optional function in the loader tables will expand to
- // the following member, such as for example 'vkSomeFunction':
- // PFN_vkSomeFunction vkSomeFunction;
-#define REQUIRED_PFN(function_name) PFN_##function_name function_name
-#define OPTIONAL_PFN(function_name) PFN_##function_name function_name
-#define EXCLUDED_PFN(function_name)
-#define PFN_MEMBER(requirement, function_name) requirement##_PFN(function_name);
- REQUIRED_PFN(vkGetInstanceProcAddr);
- REQUIRED_PFN(vkGetDeviceProcAddr);
- IREE_VULKAN_DYNAMIC_SYMBOL_TABLES(PFN_MEMBER, PFN_MEMBER);
-#undef REQUIRED_PFN
-#undef OPTIONAL_PFN
-#undef EXCLUDED_PFN
-#undef PFN_MEMBER
-
- private:
- // Optional callback on loader destruction.
- std::function<void()> close_fn_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
diff --git a/iree/hal/vulkan/dynamic_symbols_test.cc b/iree/hal/vulkan/dynamic_symbols_test.cc
deleted file mode 100644
index 2576a44..0000000
--- a/iree/hal/vulkan/dynamic_symbols_test.cc
+++ /dev/null
@@ -1,73 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/dynamic_symbols.h"
-
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/status_matchers.h"
-#include "iree/hal/vulkan/status_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-namespace {
-
-VkApplicationInfo GetApplicationInfo() {
- VkApplicationInfo app_info;
- app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
- app_info.pNext = nullptr;
- app_info.pApplicationName = "IREE-ML-TEST";
- app_info.applicationVersion = 0;
- app_info.pEngineName = "IREE";
- app_info.engineVersion = 0;
- app_info.apiVersion = VK_API_VERSION_1_0;
- return app_info;
-}
-
-VkInstanceCreateInfo GetInstanceCreateInfo(VkApplicationInfo* app_info) {
- VkInstanceCreateInfo create_info;
- create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
- create_info.pNext = nullptr;
- create_info.flags = 0;
- create_info.pApplicationInfo = app_info;
- create_info.enabledLayerCount = 0;
- create_info.ppEnabledLayerNames = nullptr;
- create_info.enabledExtensionCount = 0;
- create_info.ppEnabledExtensionNames = nullptr;
- return create_info;
-}
-
-TEST(DynamicSymbolsTest, CreateFromSystemLoader) {
- auto status_or_syms = DynamicSymbols::CreateFromSystemLoader();
- ASSERT_OK(status_or_syms);
- ref_ptr<DynamicSymbols> syms = std::move(status_or_syms.ValueOrDie());
-
- // Create and destroy a VkInstance using the symbols. This is mainly testing
- // that the symbols were loaded successfully and are actually able to be used.
- VkApplicationInfo app_info = GetApplicationInfo();
- VkInstanceCreateInfo create_info = GetInstanceCreateInfo(&app_info);
- VkInstance instance = VK_NULL_HANDLE;
- VK_CHECK_OK(
- syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance));
-
- ASSERT_OK(syms->LoadFromInstance(instance));
-
- syms->vkDestroyInstance(instance, /*pAllocator=*/nullptr);
-}
-
-} // namespace
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/extensibility_util.cc b/iree/hal/vulkan/extensibility_util.cc
deleted file mode 100644
index ca1d2a7..0000000
--- a/iree/hal/vulkan/extensibility_util.cc
+++ /dev/null
@@ -1,221 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/extensibility_util.h"
-
-#include "iree/base/memory.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/vulkan/status_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-namespace {
-
-StatusOr<std::vector<const char*>> MatchAvailableLayers(
- absl::Span<const char* const> required_layers,
- absl::Span<const char* const> optional_layers,
- absl::Span<const VkLayerProperties> properties) {
- IREE_TRACE_SCOPE0("MatchAvailableLayers");
-
- std::vector<const char*> enabled_layers;
- enabled_layers.reserve(required_layers.size() + optional_layers.size());
-
- for (const char* layer_name : required_layers) {
- bool found = false;
- for (const auto& layer_properties : properties) {
- if (std::strcmp(layer_name, layer_properties.layerName) == 0) {
- VLOG(1) << "Enabling required layer: " << layer_name;
- found = true;
- enabled_layers.push_back(layer_name);
- break;
- }
- }
- if (!found) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Required layer " << layer_name << " not available";
- }
- }
-
- for (const char* layer_name : optional_layers) {
- bool found = false;
- for (const auto& layer_properties : properties) {
- if (std::strcmp(layer_name, layer_properties.layerName) == 0) {
- VLOG(1) << "Enabling optional layer: " << layer_name;
- found = true;
- enabled_layers.push_back(layer_name);
- break;
- }
- }
- if (!found) {
- VLOG(1) << "Optional layer " << layer_name << " not available";
- }
- }
-
- return enabled_layers;
-}
-
-StatusOr<std::vector<const char*>> MatchAvailableExtensions(
- absl::Span<const char* const> required_extensions,
- absl::Span<const char* const> optional_extensions,
- absl::Span<const VkExtensionProperties> properties) {
- IREE_TRACE_SCOPE0("MatchAvailableExtensions");
-
- std::vector<const char*> enabled_extensions;
- enabled_extensions.reserve(required_extensions.size() +
- optional_extensions.size());
-
- for (const char* extension_name : required_extensions) {
- bool found = false;
- for (const auto& extension_properties : properties) {
- if (std::strcmp(extension_name, extension_properties.extensionName) ==
- 0) {
- VLOG(1) << "Enabling required extension: " << extension_name;
- found = true;
- enabled_extensions.push_back(extension_name);
- break;
- }
- }
- if (!found) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Required extension " << extension_name << " not available";
- }
- }
-
- for (const char* extension_name : optional_extensions) {
- bool found = false;
- for (const auto& extension_properties : properties) {
- if (std::strcmp(extension_name, extension_properties.extensionName) ==
- 0) {
- VLOG(1) << "Enabling optional extension: " << extension_name;
- found = true;
- enabled_extensions.push_back(extension_name);
- break;
- }
- }
- if (!found) {
- VLOG(1) << "Optional extension " << extension_name << " not available";
- }
- }
-
- return enabled_extensions;
-}
-
-} // namespace
-
-StatusOr<std::vector<const char*>> MatchAvailableInstanceLayers(
- const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) {
- uint32_t layer_property_count = 0;
- VK_RETURN_IF_ERROR(
- syms.vkEnumerateInstanceLayerProperties(&layer_property_count, nullptr));
- std::vector<VkLayerProperties> layer_properties(layer_property_count);
- VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceLayerProperties(
- &layer_property_count, layer_properties.data()));
- ASSIGN_OR_RETURN(auto enabled_layers,
- MatchAvailableLayers(extensibility_spec.required_layers,
- extensibility_spec.optional_layers,
- layer_properties),
- _ << "Unable to find all required instance layers");
- return enabled_layers;
-}
-
-StatusOr<std::vector<const char*>> MatchAvailableInstanceExtensions(
- const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) {
- uint32_t extension_property_count = 0;
- // Warning: leak checks remain disabled if an error is returned.
- IREE_DISABLE_LEAK_CHECKS();
- VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties(
- nullptr, &extension_property_count, nullptr));
- std::vector<VkExtensionProperties> extension_properties(
- extension_property_count);
- VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties(
- nullptr, &extension_property_count, extension_properties.data()));
- ASSIGN_OR_RETURN(
- auto enabled_extensions,
- MatchAvailableExtensions(extensibility_spec.required_extensions,
- extensibility_spec.optional_extensions,
- extension_properties),
- _ << "Unable to find all required instance extensions");
- IREE_ENABLE_LEAK_CHECKS();
- return enabled_extensions;
-}
-
-StatusOr<std::vector<const char*>> MatchAvailableDeviceLayers(
- VkPhysicalDevice physical_device,
- const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) {
- uint32_t layer_property_count = 0;
- VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceLayerProperties(
- physical_device, &layer_property_count, nullptr));
- std::vector<VkLayerProperties> layer_properties(layer_property_count);
- VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceLayerProperties(
- physical_device, &layer_property_count, layer_properties.data()));
- ASSIGN_OR_RETURN(auto enabled_layers,
- MatchAvailableLayers(extensibility_spec.required_layers,
- extensibility_spec.optional_layers,
- layer_properties),
- _ << "Unable to find all required device layers");
- return enabled_layers;
-}
-
-StatusOr<std::vector<const char*>> MatchAvailableDeviceExtensions(
- VkPhysicalDevice physical_device,
- const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) {
- uint32_t extension_property_count = 0;
- VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties(
- physical_device, nullptr, &extension_property_count, nullptr));
- std::vector<VkExtensionProperties> extension_properties(
- extension_property_count);
- VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties(
- physical_device, nullptr, &extension_property_count,
- extension_properties.data()));
- ASSIGN_OR_RETURN(
- auto enabled_extensions,
- MatchAvailableExtensions(extensibility_spec.required_extensions,
- extensibility_spec.optional_extensions,
- extension_properties),
- _ << "Unable to find all required device extensions");
- return enabled_extensions;
-}
-
-InstanceExtensions PopulateEnabledInstanceExtensions(
- absl::Span<const char* const> extension_names) {
- InstanceExtensions extensions = {0};
- for (const char* extension_name : extension_names) {
- if (std::strcmp(extension_name, VK_EXT_DEBUG_REPORT_EXTENSION_NAME) == 0) {
- extensions.debug_report = true;
- } else if (std::strcmp(extension_name, VK_EXT_DEBUG_UTILS_EXTENSION_NAME) ==
- 0) {
- extensions.debug_utils = true;
- }
- }
- return extensions;
-}
-
-DeviceExtensions PopulateEnabledDeviceExtensions(
- absl::Span<const char* const> extension_names) {
- DeviceExtensions extensions = {0};
- for (const char* extension_name : extension_names) {
- if (std::strcmp(extension_name, VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME) ==
- 0) {
- extensions.push_descriptors = true;
- }
- }
- return extensions;
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/extensibility_util.h b/iree/hal/vulkan/extensibility_util.h
deleted file mode 100644
index 8f2b784..0000000
--- a/iree/hal/vulkan/extensibility_util.h
+++ /dev/null
@@ -1,100 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Utilities for working with layers and extensions.
-
-#ifndef IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
-#define IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
-
-#include <vulkan/vulkan.h>
-
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// Describes required and optional extensibility points.
-struct ExtensibilitySpec {
- // A list of required and optional layers.
- std::vector<const char*> required_layers;
- std::vector<const char*> optional_layers;
-
- // A list of required and optional extensions.
- // Prefer using the _EXTENSION_NAME macros to make tracking easier (such as
- // 'VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME').
- std::vector<const char*> required_extensions;
- std::vector<const char*> optional_extensions;
-};
-
-// Returns a list of layer names available for instances.
-// Fails if any required_layers are unavailable.
-StatusOr<std::vector<const char*>> MatchAvailableInstanceLayers(
- const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms);
-
-// Returns a list of extension names available for instances.
-// Fails if any required_extensions are unavailable.
-StatusOr<std::vector<const char*>> MatchAvailableInstanceExtensions(
- const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms);
-
-// Returns a list of layer names available for the given |physical_device|.
-// Fails if any required_layers are unavailable.
-StatusOr<std::vector<const char*>> MatchAvailableDeviceLayers(
- VkPhysicalDevice physical_device,
- const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms);
-
-// Returns a list of extension names available for the given |physical_device|.
-// Fails if any required_extensions are unavailable.
-StatusOr<std::vector<const char*>> MatchAvailableDeviceExtensions(
- VkPhysicalDevice physical_device,
- const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms);
-
-// Bits for enabled instance extensions.
-// We must use this to query support instead of just detecting symbol names as
-// ICDs will resolve the functions sometimes even if they don't support the
-// extension (or we didn't ask for it to be enabled).
-struct InstanceExtensions {
- // VK_EXT_debug_report is enabled and a callback is regsitered.
- // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_report
- bool debug_report : 1;
-
- // VK_EXT_debug_utils is enabled and a debug messenger is registered.
- // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_utils
- bool debug_utils : 1;
-};
-
-// Returns a bitfield with all of the provided extension names.
-InstanceExtensions PopulateEnabledInstanceExtensions(
- absl::Span<const char* const> extension_names);
-
-// Bits for enabled device extensions.
-// We must use this to query support instead of just detecting symbol names as
-// ICDs will resolve the functions sometimes even if they don't support the
-// extension (or we didn't ask for it to be enabled).
-struct DeviceExtensions {
- // VK_KHR_push_descriptor is enabled and vkCmdPushDescriptorSetKHR is valid.
- bool push_descriptors : 1;
-};
-
-// Returns a bitfield with all of the provided extension names.
-DeviceExtensions PopulateEnabledDeviceExtensions(
- absl::Span<const char* const> extension_names);
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
diff --git a/iree/hal/vulkan/handle_util.h b/iree/hal/vulkan/handle_util.h
deleted file mode 100644
index 3173de4..0000000
--- a/iree/hal/vulkan/handle_util.h
+++ /dev/null
@@ -1,136 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Helpers for wrapping Vulkan handles that don't require us to wrap every type.
-// This keeps our compilation time reasonable (as the vulkancpp library is
-// insane) while giving us nice safety around cleanup and ensuring we use
-// dynamic symbols and consistent allocators.
-//
-// Do not add functionality beyond handle management to these types. Keep our
-// Vulkan usage mostly functional and C-like to ensure minimal code size and
-// readability.
-
-#ifndef IREE_HAL_VULKAN_HANDLE_UTIL_H_
-#define IREE_HAL_VULKAN_HANDLE_UTIL_H_
-
-#include <vulkan/vulkan.h>
-
-#include "absl/synchronization/mutex.h"
-#include "absl/utility/utility.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-#include "iree/hal/vulkan/extensibility_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-class VkDeviceHandle : public RefObject<VkDeviceHandle> {
- public:
- VkDeviceHandle(const ref_ptr<DynamicSymbols>& syms,
- DeviceExtensions enabled_extensions,
- const VkAllocationCallbacks* allocator = nullptr)
- : syms_(add_ref(syms)),
- enabled_extensions_(enabled_extensions),
- allocator_(allocator) {}
- ~VkDeviceHandle() { reset(); }
-
- VkDeviceHandle(const VkDeviceHandle&) = delete;
- VkDeviceHandle& operator=(const VkDeviceHandle&) = delete;
- VkDeviceHandle(VkDeviceHandle&& other) noexcept
- : value_(absl::exchange(other.value_,
- static_cast<VkDevice>(VK_NULL_HANDLE))),
- syms_(std::move(other.syms_)),
- enabled_extensions_(other.enabled_extensions_),
- allocator_(other.allocator_) {}
-
- void reset() {
- if (value_ == VK_NULL_HANDLE) return;
- syms_->vkDestroyDevice(value_, allocator_);
- value_ = VK_NULL_HANDLE;
- }
-
- VkDevice value() const noexcept { return value_; }
- VkDevice* mutable_value() noexcept { return &value_; }
- operator VkDevice() const noexcept { return value_; }
-
- const ref_ptr<DynamicSymbols>& syms() const noexcept { return syms_; }
- const VkAllocationCallbacks* allocator() const noexcept { return allocator_; }
-
- const DeviceExtensions& enabled_extensions() const {
- return enabled_extensions_;
- }
-
- private:
- VkDevice value_ = VK_NULL_HANDLE;
- ref_ptr<DynamicSymbols> syms_;
- DeviceExtensions enabled_extensions_;
- const VkAllocationCallbacks* allocator_ = nullptr;
-};
-
-class VkCommandPoolHandle : public RefObject<VkCommandPoolHandle> {
- public:
- explicit VkCommandPoolHandle(const ref_ptr<VkDeviceHandle>& logical_device)
- : logical_device_(add_ref(logical_device)) {}
- ~VkCommandPoolHandle() { reset(); }
-
- VkCommandPoolHandle(const VkCommandPoolHandle&) = delete;
- VkCommandPoolHandle& operator=(const VkCommandPoolHandle&) = delete;
- VkCommandPoolHandle(VkCommandPoolHandle&& other) noexcept
- : logical_device_(std::move(other.logical_device_)),
- value_(absl::exchange(other.value_,
- static_cast<VkCommandPool>(VK_NULL_HANDLE))) {}
- VkCommandPoolHandle& operator=(VkCommandPoolHandle&& other) {
- std::swap(logical_device_, other.logical_device_);
- std::swap(value_, other.value_);
- return *this;
- }
-
- void reset() {
- if (value_ == VK_NULL_HANDLE) return;
- syms()->vkDestroyCommandPool(*logical_device_, value_, allocator());
- value_ = VK_NULL_HANDLE;
- }
-
- VkCommandPool value() const noexcept { return value_; }
- VkCommandPool* mutable_value() noexcept { return &value_; }
- operator VkCommandPool() const noexcept { return value_; }
-
- const ref_ptr<VkDeviceHandle>& logical_device() const noexcept {
- return logical_device_;
- }
- const ref_ptr<DynamicSymbols>& syms() const noexcept {
- return logical_device_->syms();
- }
- const VkAllocationCallbacks* allocator() const noexcept {
- return logical_device_->allocator();
- }
-
- absl::Mutex* mutex() const { return &mutex_; }
-
- private:
- ref_ptr<VkDeviceHandle> logical_device_;
- VkCommandPool value_ = VK_NULL_HANDLE;
-
- // Vulkan command pools are not thread safe and require external
- // synchronization. Since we allow arbitrary threads to allocate and
- // deallocate the HAL command buffers we need to externally synchronize.
- mutable absl::Mutex mutex_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_HANDLE_UTIL_H_
diff --git a/iree/hal/vulkan/internal_vk_mem_alloc.cc b/iree/hal/vulkan/internal_vk_mem_alloc.cc
deleted file mode 100644
index 0884a34..0000000
--- a/iree/hal/vulkan/internal_vk_mem_alloc.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// This file configures VMA to use common Google/Abseil types in an effort to
-// better integrate with applications compiled using other Google code. By using
-// the same types that dependers are likely using we can often reduce binary
-// size and ease debugging (such as by using absl::Mutex to get better tsan
-// warnings).
-
-// Only compile if an external implementation has not been otherwise linked.
-#if !defined(VULKAN_MEMORY_ALLOCATOR_EXTERNAL_IMPL)
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/base/logging.h"
-
-// Use std::vector instead of the VMA version.
-#define VMA_USE_STL_VECTOR 1
-
-// TODO(benvanik): figure out why std::list cannot be used.
-// #define VMA_USE_STL_LIST 1
-
-// Use absl::flat_hash_map instead of std::unordered_map.
-#define VmaPair std::pair
-#define VMA_MAP_TYPE(KeyT, ValueT) \
- absl::flat_hash_map<KeyT, ValueT, std::hash<KeyT>, std::equal_to<KeyT>, \
- VmaStlAllocator<std::pair<KeyT, ValueT> > >
-
-// Use CHECK for assertions.
-#define VMA_ASSERT CHECK
-#define VMA_HEAVY_ASSERT DCHECK
-
-// Use LOG for logging.
-#ifndef NDEBUG
-#define VMA_DEBUG_LOG(...) ABSL_RAW_LOG(INFO, __VA_ARGS__)
-#else
-#define VMA_DEBUG_LOG(...)
-#endif // !NDEBUG
-
-// Use absl::Mutex for VMA_MUTEX.
-#define VMA_MUTEX absl::Mutex
-class AbslVmaRWMutex {
- public:
- void LockRead() ABSL_SHARED_LOCK_FUNCTION() { mutex_.ReaderLock(); }
- void UnlockRead() ABSL_UNLOCK_FUNCTION() { mutex_.ReaderUnlock(); }
- void LockWrite() ABSL_EXCLUSIVE_LOCK_FUNCTION() { mutex_.WriterLock(); }
- void UnlockWrite() ABSL_UNLOCK_FUNCTION() { mutex_.WriterUnlock(); }
-
- private:
- absl::Mutex mutex_;
-};
-#define VMA_RW_MUTEX AbslVmaRWMutex
-
-#define VMA_IMPLEMENTATION
-#include "vk_mem_alloc.h"
-
-#endif
diff --git a/iree/hal/vulkan/legacy_fence.cc b/iree/hal/vulkan/legacy_fence.cc
deleted file mode 100644
index 481e855..0000000
--- a/iree/hal/vulkan/legacy_fence.cc
+++ /dev/null
@@ -1,396 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/legacy_fence.h"
-
-#include <cstdint>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/time/time.h"
-#include "iree/base/intrusive_list.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/vulkan/status_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-namespace {
-
-// Inserts the given |fence_signal| into |list| in ascending order.
-void InsertOutstandingFenceSignal(OutstandingFenceSignal* fence_signal,
- IntrusiveList<OutstandingFenceSignal>* list) {
- for (auto existing_signal : *list) {
- if (existing_signal->value > fence_signal->value) {
- list->insert(existing_signal, fence_signal);
- return;
- }
- }
- list->push_back(fence_signal);
-}
-
-} // namespace
-
-// static
-StatusOr<ref_ptr<LegacyFencePool>> LegacyFencePool::Create(
- ref_ptr<VkDeviceHandle> logical_device) {
- IREE_TRACE_SCOPE0("LegacyFencePool::Create");
- ref_ptr<LegacyFencePool> fence_pool(
- new LegacyFencePool(std::move(logical_device)));
- RETURN_IF_ERROR(fence_pool->PreallocateFences());
- return fence_pool;
-}
-
-LegacyFencePool::LegacyFencePool(ref_ptr<VkDeviceHandle> logical_device)
- : logical_device_(std::move(logical_device)) {}
-
-LegacyFencePool::~LegacyFencePool() {
- IREE_TRACE_SCOPE0("LegacyFencePool::dtor");
-
- absl::MutexLock lock(&mutex_);
- for (auto& fence_signal : storage_) {
- syms()->vkDestroyFence(*logical_device_, fence_signal.fence,
- logical_device_->allocator());
- }
- unused_fences_.clear();
- unresolved_fences_.clear();
-}
-
-Status LegacyFencePool::PreallocateFences() {
- IREE_TRACE_SCOPE0("LegacyFencePool::PreallocateFences");
-
- VkFenceCreateInfo create_info;
- create_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
- create_info.pNext = nullptr;
- create_info.flags = 0;
-
- absl::MutexLock lock(&mutex_);
- for (int i = 0; i < kMaxInFlightFenceCount; ++i) {
- auto* fence_signal = &storage_[i];
- VK_RETURN_IF_ERROR(syms()->vkCreateFence(*logical_device_, &create_info,
- logical_device_->allocator(),
- &fence_signal->fence));
- unused_fences_.push_back(fence_signal);
- }
-
- return OkStatus();
-}
-
-StatusOr<OutstandingFenceSignal*> LegacyFencePool::Acquire() {
- IREE_TRACE_SCOPE0("LegacyFencePool::Acquire");
-
- absl::MutexLock lock(&mutex_);
- if (unused_fences_.empty()) {
- return ResourceExhaustedErrorBuilder(IREE_LOC)
- << "Fence pool out of unused fences";
- }
-
- auto* fence_signal = unused_fences_.front();
- unused_fences_.pop_front();
- return fence_signal;
-}
-
-void LegacyFencePool::ReleaseResolved(
- IntrusiveList<OutstandingFenceSignal>* fence_signals) {
- IREE_TRACE_SCOPE0("LegacyFencePool::ReleaseResolved");
-
- // Get a list of fences we need to reset. Note that not all fences may have
- // been signaled and we can avoid resetting them.
- absl::InlinedVector<VkFence, 8> handles;
- handles.reserve(fence_signals->size());
- for (auto* fence_signal : *fence_signals) {
- if (fence_signal->is_pending) {
- handles.push_back(fence_signal->fence);
- }
- }
- if (!handles.empty()) {
- syms()->vkResetFences(*logical_device_, handles.size(), handles.data());
- }
-
- absl::MutexLock lock(&mutex_);
- unused_fences_.merge_from(fence_signals);
-}
-
-void LegacyFencePool::ReleaseUnresolved(
- IntrusiveList<OutstandingFenceSignal>* fence_signals) {
- IREE_TRACE_SCOPE0("LegacyFencePool::ReleaseUnresolved");
-
- absl::MutexLock lock(&mutex_);
- while (!fence_signals->empty()) {
- auto* fence_signal = fence_signals->front();
- fence_signals->pop_front();
- if (fence_signal->is_pending) {
- // Fence was submitted and may still have a pending signal on it. We can't
- // reuse it until it has resolved.
- // TODO(benvanik): fix these fences by reallocating? We aren't leaking
- // here (technically) but we will exhaust the pool pretty quickly.
- unresolved_fences_.push_back(fence_signal);
- } else {
- // Fence was never actually submitted so we can reuse it no problem.
- unused_fences_.push_back(fence_signal);
- }
- }
-}
-
-// static
-Status LegacyFence::WaitForFences(VkDeviceHandle* logical_device,
- absl::Span<const FenceValue> fences,
- bool wait_all, absl::Time deadline) {
- IREE_TRACE_SCOPE0("LegacyFence::WaitForFences");
-
- // NOTE: we could pool this state too (probably right on the LegacyFencePool)
- // or be smarter about using stack-allocated storage. The best idea is to use
- // real timeline semaphores, though, so not much effort has been spent on
- // optimizing this.
- absl::InlinedVector<VkFence, 4> handles;
- handles.reserve(fences.size());
-
- // Loop over the fences and wait for any/all to signal. In wait_all mode we
- // perform the bookkeeping to remove fences that have already been signaled so
- // that we only wait on ones we need to (and possibly avoid making the vk call
- // entirely!).
- while (true) {
- // Grab handles and acquire fences for all fences not yet at the requested
- // timeline value.
- for (const auto& fence_value : fences) {
- auto* fence = reinterpret_cast<LegacyFence*>(fence_value.first);
- // NOTE: this will return the sticky fence error if the fence has failed.
- ASSIGN_OR_RETURN(VkFence handle,
- fence->AcquireWaitFence(fence_value.second));
- if (handle != VK_NULL_HANDLE) {
- // Fence is unresolved and we need to really wait for it.
- handles.push_back(handle);
- }
- }
- if (handles.empty()) {
- // All fences resolved.
- return OkStatus();
- }
-
- uint64_t timeout_nanos;
- if (deadline == absl::InfiniteFuture()) {
- timeout_nanos = UINT64_MAX;
- } else if (deadline == absl::InfinitePast()) {
- timeout_nanos = 0;
- } else {
- auto relative_nanos = absl::ToInt64Nanoseconds(deadline - absl::Now());
- timeout_nanos = relative_nanos < 0 ? 0 : relative_nanos;
- }
-
- // Wait on the fences we still need.
- // Note that waking does not actually indicate all fences were hit! We need
- // to do another pass above on the next iteration to make sure that we don't
- // need to wait again on another fence.
- VK_RETURN_IF_ERROR(logical_device->syms()->vkWaitForFences(
- *logical_device, handles.size(), handles.data(), wait_all,
- timeout_nanos));
- handles.clear();
- }
-
- return OkStatus();
-}
-
-LegacyFence::LegacyFence(ref_ptr<LegacyFencePool> fence_pool,
- uint64_t initial_value)
- : fence_pool_(std::move(fence_pool)), value_(initial_value) {}
-
-LegacyFence::~LegacyFence() {
- IREE_TRACE_SCOPE0("LegacyFence::dtor");
- CHECK_OK(TryResolveOutstandingFences(UINT64_MAX));
- absl::MutexLock lock(&mutex_);
- CHECK(outstanding_signals_.empty())
- << "Destroying a fence without first waiting on outstanding signals";
-}
-
-Status LegacyFence::status() const {
- if (value_.load() != UINT64_MAX) {
- return OkStatus();
- }
- absl::MutexLock lock(&mutex_);
- return status_;
-}
-
-StatusOr<uint64_t> LegacyFence::QueryValue() {
- RETURN_IF_ERROR(TryResolveOutstandingFences(UINT64_MAX));
- return value_.load();
-}
-
-StatusOr<VkFence> LegacyFence::AcquireSignalFence(uint64_t value) {
- absl::MutexLock lock(&mutex_);
-
- // It's an error to signal out of order (as that requires a lot more
- // tracking and magic to get right).
- if (value_.load() >= value) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Attempting to signal a timeline fence out of order; value="
- << value_ << ", new_value=" << value;
- }
-
- // Scan to see if there's waiters for this value (or values before it).
- // We may be able to reuse a previously allocated fence in the case that a
- // user is waiting prior to actually submitting the signal operation.
- OutstandingFenceSignal* signal_state = nullptr;
- for (auto* fence_signal : outstanding_signals_) {
- if (fence_signal->value == value) {
- // Fence is going to be signaled at exactly the required value.
- if (fence_signal->is_pending) {
- // Already have signaled to this value - that's a paddlin'.
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Duplicate signal of timeline fence for value=" << value;
- }
- signal_state = fence_signal;
- break;
- }
- }
- if (!signal_state) {
- // Allocate a signal state entry and a VkFence to submit with.
- // TODO(benvanik): check for RESOURCE_EXHAUSTED and force a flush.
- ASSIGN_OR_RETURN(signal_state, fence_pool_->Acquire());
- signal_state->value = value;
- InsertOutstandingFenceSignal(signal_state, &outstanding_signals_);
- }
-
- signal_state->is_pending = true;
- return signal_state->fence;
-}
-
-StatusOr<VkFence> LegacyFence::AcquireWaitFence(uint64_t value) {
- // If we've already resolved then we want to avoid doing any kind of wait.
- // Since the value is monotonically increasing we can do a lock-free peek
- // here to see if we need to bother taking a full lock.
- if (value_.load() >= value) {
- return VK_NULL_HANDLE;
- }
-
- absl::MutexLock lock(&mutex_);
-
- // Try to resolve any outstanding fence signals.
- RETURN_IF_ERROR(TryResolveOutstandingFencesLocked(value));
- if (value_.load() >= value) {
- return VK_NULL_HANDLE;
- }
-
- // Try to find an existing fence we can reuse based on the required value.
- OutstandingFenceSignal* signal_state = nullptr;
- for (auto* fence_signal : outstanding_signals_) {
- if (fence_signal->value >= value) {
- // Fence is going to be signaled at or above the required value.
- signal_state = fence_signal;
- break; // |outstanding_signals_| is in sorted order.
- }
- }
- if (!signal_state) {
- // Allocate a signal state entry and a VkFence that we will need to signal
- // in the future. We can't yet insert it into the queue but it will go in
- // when the user tries to signal a value >= the required value.
- // TODO(benvanik): check for RESOURCE_EXHAUSTED and force a flush.
- ASSIGN_OR_RETURN(signal_state, fence_pool_->Acquire());
- signal_state->value = value;
- InsertOutstandingFenceSignal(signal_state, &outstanding_signals_);
- }
-
- return signal_state->fence;
-}
-
-Status LegacyFence::TryResolveOutstandingFences(uint64_t upper_value) {
- absl::MutexLock lock(&mutex_);
- return TryResolveOutstandingFencesLocked(upper_value);
-}
-
-Status LegacyFence::TryResolveOutstandingFencesLocked(uint64_t upper_value) {
- // Fast-path for when we have no outstanding fences.
- // NOTE: we hold the lock during the entire resolve process so that any waiter
- // will only be woken once we have resolved to the furthest possible value.
- if (outstanding_signals_.empty() || value_ > upper_value) {
- return OkStatus();
- }
-
- IREE_TRACE_SCOPE0("LegacyFence::TryResolveOutstandingFences");
-
- IntrusiveList<OutstandingFenceSignal> resolved_fences;
- IntrusiveList<OutstandingFenceSignal> unresolved_fences;
- VkDevice device = *fence_pool_->logical_device();
- const auto& syms = fence_pool_->syms();
- bool keep_resolving = true;
- while (keep_resolving && !outstanding_signals_.empty()) {
- auto* fence_signal = outstanding_signals_.front();
- if (fence_signal->value > upper_value) {
- // Signal is for a value beyond our upper limit - early exit so that we
- // don't spend time dealing with signals we don't yet care about. This can
- // prevent live lock where one thread is signaling fences as fast/faster
- // than another thread can consume them.
- keep_resolving = false;
- break;
- }
- VkResult fence_status = syms->vkGetFenceStatus(device, fence_signal->fence);
- switch (fence_status) {
- case VK_SUCCESS: {
- // Fence has signaled meaning that we have reached this point in the
- // timeline and can advance the value.
- value_.store(fence_signal->value);
- outstanding_signals_.erase(fence_signal);
- resolved_fences.push_back(fence_signal);
-
- // Run backwards and resolve any non-pending fences as they will never
- // be used.
- for (auto* it = fence_signal; it != nullptr;) {
- auto* prev_fence_signal = it;
- it = outstanding_signals_.previous(it);
- if (!prev_fence_signal->is_pending) {
- outstanding_signals_.erase(prev_fence_signal);
- unresolved_fences.push_back(prev_fence_signal);
- }
- }
- break;
- }
- case VK_NOT_READY:
- if (fence_signal->is_pending) {
- // Fence has not yet been signaled. We stop here and wait for future
- // attempts at resolution.
- keep_resolving = false;
- }
- // Fence is not even pending yet - we may have skipped it. Keep
- // resolving to see if there's a higher value we can use.
- break;
- default:
- // Fence indicates an error (device lost, out of memory, etc).
- // Propagate this back to our status (and thus any waiters).
- // Since we only take the first error we find we skip all remaining
- // fences.
- status_ = VkResultToStatus(fence_status);
- value_.store(UINT64_MAX);
- outstanding_signals_.erase(fence_signal);
- resolved_fences.push_back(fence_signal);
- break;
- }
- }
-
- // Release resolved fences back to the pool. Note that we can only do this
- // to fences we know have actually completed: unresolved fences after an error
- // may still be in-flight and we don't want to reuse them.
- fence_pool_->ReleaseResolved(&resolved_fences);
- fence_pool_->ReleaseUnresolved(&unresolved_fences);
- if (!status_.ok()) {
- fence_pool_->ReleaseUnresolved(&outstanding_signals_);
- }
-
- return status_;
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/legacy_fence.h b/iree/hal/vulkan/legacy_fence.h
deleted file mode 100644
index 42f0362..0000000
--- a/iree/hal/vulkan/legacy_fence.h
+++ /dev/null
@@ -1,200 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// TODO(b/140141417): share the pool (and possibly most of the fence impl) with
-// the timeline semaphores fallback.
-
-#ifndef IREE_HAL_VULKAN_LEGACY_FENCE_H_
-#define IREE_HAL_VULKAN_LEGACY_FENCE_H_
-
-#include <vulkan/vulkan.h>
-
-#include <array>
-#include <atomic>
-
-#include "absl/base/thread_annotations.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/base/intrusive_list.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-#include "iree/hal/fence.h"
-#include "iree/hal/vulkan/handle_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// An outstanding legacy fence signal for a particular timeline value.
-// Each signal to a new value gets a new VkFence and these are stored in a
-// LegacyFence to quickly scan and process signaled fences.
-//
-// Must be externally synchronized via the LegacyFence mutex.
-struct OutstandingFenceSignal : public IntrusiveLinkBase<void> {
- // Allocated fence that is passed to vkQueueSubmit/vkWaitForFences.
- // Represents a point in the timeline of value.
- VkFence fence = VK_NULL_HANDLE;
-
- // Value that the fence payload should be when the fence is signaled.
- // Note that since fences may resolve out of order we still need to check that
- // we are only ever advancing the timeline and not just setting this value.
- uint64_t value = UINT64_MAX;
-
- // True when the fence has been submitted and is pending on the device.
- bool is_pending = false;
-};
-
-// A pool of VkFences that can be used by LegacyFence to simulate individual
-// payload value signaling. Note that we prefer a pool instead of a ringbuffer
-// as we want to allow out-of-order completion.
-class LegacyFencePool final : public RefObject<LegacyFencePool> {
- public:
- static constexpr int kMaxInFlightFenceCount = 64;
-
- // Allocates a new fence pool and all fences.
- static StatusOr<ref_ptr<LegacyFencePool>> Create(
- ref_ptr<VkDeviceHandle> logical_device);
-
- ~LegacyFencePool();
-
- const ref_ptr<VkDeviceHandle>& logical_device() const {
- return logical_device_;
- }
- const ref_ptr<DynamicSymbols>& syms() const {
- return logical_device_->syms();
- }
-
- // Acquires a fence from the pool for use by the caller.
- // The fence is guaranteed to not be in-flight and will have been reset to an
- // unsignaled state.
- //
- // Returns RESOURCE_EXHAUSTED if the pool has no more available fences.
- // Callers are expected to handle this by waiting on previous fences or for
- // complete device idle. Yes, that's as bad as it sounds, and if we start
- // seeing that we should bump up the max count.
- StatusOr<OutstandingFenceSignal*> Acquire();
-
- // Releases one or more fences back to the pool.
- // The fences must either be signaled or not be in-flight.
- void ReleaseResolved(IntrusiveList<OutstandingFenceSignal>* fence_signals);
-
- // Releases one or more unresolved fences back to the pool.
- // These may be in any state and will be assumed as untouchable.
- void ReleaseUnresolved(IntrusiveList<OutstandingFenceSignal>* fence_signals);
-
- private:
- explicit LegacyFencePool(ref_ptr<VkDeviceHandle> logical_device);
-
- Status PreallocateFences() ABSL_LOCKS_EXCLUDED(mutex_);
-
- ref_ptr<VkDeviceHandle> logical_device_;
-
- absl::Mutex mutex_;
- std::array<OutstandingFenceSignal, kMaxInFlightFenceCount> storage_
- ABSL_GUARDED_BY(mutex_);
- IntrusiveList<OutstandingFenceSignal> unused_fences_ ABSL_GUARDED_BY(mutex_);
- IntrusiveList<OutstandingFenceSignal> unresolved_fences_
- ABSL_GUARDED_BY(mutex_);
-};
-
-// A fence implemented using a pool of native VkFences.
-// This is supported unconditionally on all versions of Vulkan. When timeline
-// semaphores are available we prefer using those instead and this is only
-// present as a fallback. We keep this implementation separate so that it can be
-// compiled out when the target is known to have the extension.
-//
-// Simulation of timeline semaphore-based fences is done via a pool of native
-// VkFences that each represent a single signaled value. This means that worst
-// case we are using one fence per submit however that's no different than if
-// we did anything else. Though we can't cancel previously-queued fences when
-// increasing values are signaled we can be clever when querying and releasing
-// by always walking in reverse relying on the monotonically increasing values.
-//
-// Valid usage patterns we need to handle:
-// 1. fence signaled and waited on (common case)
-// 2. fence waited on before beginning signaling
-// 3. fence signaled and never waited on
-//
-// Case 1 is fairly straightforward: we acquire a VkFence, pass that to the
-// queue submit, and then vkWaitForFences/query it for completion.
-//
-// Case 2 requires that we reserve a fence during the wait so that we can pass
-// it to vkWaitForFences and track it such that we can reuse it during a future
-// signal operation. Since we don't know during signaling if the specific value
-// we waited on will ever have its own dedicated signal operation we need to be
-// conservative and try to coalesce for correctness. This means that if a wait
-// for a value of 1 is performed and we get a signal for a value of 2 we need to
-// combine the two. If a signal for a value of 1 is later performed it then
-// becomes a no-op. This could lead to some additional latency however that's a
-// risk (or benefit!) of using timelines. Rule of thumb: don't do out of order
-// signaling.
-//
-// Case 3 is like case 2 where we need to reserve a fence to wait on, however
-// since we don't know if it will ever be signaled we need to take care to
-// properly release the VkFence back to the pool for reuse: we don't want to
-// return it while there are still waiters for its original event. For this
-// reason we track the waiters on a given fence during their wait operation and
-// if a fence is released with waiters active we put them in a special
-// unresolved until the waiters continue on.
-class LegacyFence final : public Fence {
- public:
- // Waits for one or more (or all) fences to reach or exceed the given values.
- static Status WaitForFences(VkDeviceHandle* logical_device,
- absl::Span<const FenceValue> fences,
- bool wait_all, absl::Time deadline);
-
- LegacyFence(ref_ptr<LegacyFencePool> fence_pool, uint64_t initial_value);
- ~LegacyFence() override;
-
- Status status() const override;
-
- StatusOr<uint64_t> QueryValue() override;
-
- // Acquires a new fence for signaling a specific value.
- StatusOr<VkFence> AcquireSignalFence(uint64_t value);
-
- private:
- // Acquires a new fence for waiting on a specific value.
- // Returns VK_NULL_HANDLE if the fence already resolved and the sticky error
- // if the fence is in an error state.
- StatusOr<VkFence> AcquireWaitFence(uint64_t value);
-
- // Runs down the outstanding fences list and resolves to the latest signaled
- // value. Will early exit if the value moves beyond |upper_value|.
- Status TryResolveOutstandingFences(uint64_t upper_value)
- ABSL_LOCKS_EXCLUDED(mutex_);
- Status TryResolveOutstandingFencesLocked(uint64_t upper_value)
- ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
- ref_ptr<LegacyFencePool> fence_pool_;
-
- // The current highest value of the fence as verified during a wait or query.
- // Kept outside of |mutex_| so that queries do not require a lock.
- std::atomic<uint64_t> value_;
-
- mutable absl::Mutex mutex_;
-
- // Sticky status failure value set on first failure.
- Status status_ ABSL_GUARDED_BY(mutex_);
-
- // Outstanding VkFences representing signal values.
- // Expected to be sorted in ascending order by value.
- IntrusiveList<OutstandingFenceSignal> outstanding_signals_
- ABSL_GUARDED_BY(mutex_);
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_LEGACY_FENCE_H_
diff --git a/iree/hal/vulkan/native_binary_semaphore.cc b/iree/hal/vulkan/native_binary_semaphore.cc
deleted file mode 100644
index 5439a67..0000000
--- a/iree/hal/vulkan/native_binary_semaphore.cc
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/native_binary_semaphore.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-NativeBinarySemaphore::NativeBinarySemaphore(
- ref_ptr<VkDeviceHandle> logical_device, VkSemaphore handle)
- : logical_device_(std::move(logical_device)), handle_(handle) {}
-
-NativeBinarySemaphore::~NativeBinarySemaphore() {
- logical_device_->syms()->vkDestroySemaphore(*logical_device_, handle_,
- logical_device_->allocator());
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/native_binary_semaphore.h b/iree/hal/vulkan/native_binary_semaphore.h
deleted file mode 100644
index fc55ebb..0000000
--- a/iree/hal/vulkan/native_binary_semaphore.h
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_
-#define IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_
-
-#include <vulkan/vulkan.h>
-
-#include "iree/hal/semaphore.h"
-#include "iree/hal/vulkan/handle_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// A binary semaphore implemented using the native VkSemaphore type.
-// This is supported unconditionally on all versions of Vulkan.
-class NativeBinarySemaphore final : public BinarySemaphore {
- public:
- NativeBinarySemaphore(ref_ptr<VkDeviceHandle> logical_device,
- VkSemaphore handle);
- ~NativeBinarySemaphore() override;
-
- VkSemaphore handle() const { return handle_; }
-
- private:
- ref_ptr<VkDeviceHandle> logical_device_;
- VkSemaphore handle_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_
diff --git a/iree/hal/vulkan/native_event.cc b/iree/hal/vulkan/native_event.cc
deleted file mode 100644
index 28dbc56..0000000
--- a/iree/hal/vulkan/native_event.cc
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/native_event.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-NativeEvent::NativeEvent(ref_ptr<VkDeviceHandle> logical_device, VkEvent handle)
- : logical_device_(std::move(logical_device)), handle_(handle) {}
-
-NativeEvent::~NativeEvent() {
- logical_device_->syms()->vkDestroyEvent(*logical_device_, handle_,
- logical_device_->allocator());
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/native_event.h b/iree/hal/vulkan/native_event.h
deleted file mode 100644
index 691ef6d..0000000
--- a/iree/hal/vulkan/native_event.h
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_NATIVE_EVENT_H_
-#define IREE_HAL_VULKAN_NATIVE_EVENT_H_
-
-#include <vulkan/vulkan.h>
-
-#include "iree/hal/event.h"
-#include "iree/hal/vulkan/handle_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// An event implemented with the native VkEvent type.
-class NativeEvent final : public Event {
- public:
- NativeEvent(ref_ptr<VkDeviceHandle> logical_device, VkEvent handle);
- ~NativeEvent() override;
-
- VkEvent handle() const { return handle_; }
-
- private:
- ref_ptr<VkDeviceHandle> logical_device_;
- VkEvent handle_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_NATIVE_EVENT_H_
diff --git a/iree/hal/vulkan/pipeline_cache.cc b/iree/hal/vulkan/pipeline_cache.cc
deleted file mode 100644
index aba1917..0000000
--- a/iree/hal/vulkan/pipeline_cache.cc
+++ /dev/null
@@ -1,235 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/pipeline_cache.h"
-
-#include "absl/synchronization/mutex.h"
-#include "flatbuffers/flatbuffers.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/executable_format.h"
-#include "iree/hal/vulkan/status_util.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-PipelineCache::PipelineCache(const ref_ptr<VkDeviceHandle>& logical_device)
- : logical_device_(add_ref(logical_device)) {}
-
-PipelineCache::~PipelineCache() {
- IREE_TRACE_SCOPE0("PipelineCache::dtor");
- ClearLayoutCaches();
-}
-
-bool PipelineCache::CanPrepareFormat(ExecutableFormat format) const {
- return format == kExecutableFormatSpirV;
-}
-
-StatusOr<ref_ptr<Executable>> PipelineCache::PrepareExecutable(
- ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) {
- IREE_TRACE_SCOPE0("PipelineCache::PrepareExecutable");
- if (!CanPrepareFormat(spec.format)) {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unsupported 4CC format: 0x" << std::hex << spec.format;
- }
- if (spec.executable_data.size() <= 4 ||
- !SpirVExecutableDefBufferHasIdentifier(spec.executable_data.data())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Supplied executable data does not contain a SpirVExecutableDef";
- }
-
- // Get the SPIR-V executable def flatbuffer.
- const auto& spirv_executable_def =
- *::flatbuffers::GetRoot<SpirVExecutableDef>(spec.executable_data.data());
-
- // Create (or reuse) a pipeline layout.
- if (!spirv_executable_def.pipeline_layout()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Missing pipeline layout def";
- }
- ASSIGN_OR_RETURN(
- auto pipeline_layout_entry,
- LookupOrInsertPipelineLayout(*spirv_executable_def.pipeline_layout()));
-
- // Create the executable (which may itself own many pipelines).
- ASSIGN_OR_RETURN(auto executable, PipelineExecutable::Create(
- logical_device_,
- /*pipeline_cache=*/VK_NULL_HANDLE,
- pipeline_layout_entry->pipeline_layout,
- pipeline_layout_entry->descriptor_sets,
- mode, spirv_executable_def));
- return executable;
-}
-
-StatusOr<const PipelineCache::CachedPipelineLayout*>
-PipelineCache::LookupOrInsertPipelineLayout(
- const VkPipelineLayoutDef& pipeline_layout_def) {
- IREE_TRACE_SCOPE0("PipelineCache::LookupOrInsertPipelineLayout");
- absl::MutexLock lock(&mutex_);
-
- // Build a list of the required descriptor set layouts and push constants.
- // If we were being fast about this we would just hash the def and directly
- // look up the pipeline layout.
- PipelineDescriptorSets descriptor_sets;
- descriptor_sets.buffer_binding_set = pipeline_layout_def.buffer_binding_set();
- descriptor_sets.buffer_binding_set_layout = VK_NULL_HANDLE;
- absl::InlinedVector<VkDescriptorSetLayout, 4> descriptor_set_layouts;
- if (pipeline_layout_def.descriptor_set_layouts()) {
- const auto& layout_defs = *pipeline_layout_def.descriptor_set_layouts();
- descriptor_set_layouts.resize(layout_defs.size());
- for (int i = 0; i < descriptor_set_layouts.size(); ++i) {
- if (!layout_defs[i]) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "Missing layout def";
- }
- ASSIGN_OR_RETURN(descriptor_set_layouts[i],
- LookupOrInsertDescriptorSetLayout(*layout_defs[i]));
- if (i == pipeline_layout_def.buffer_binding_set()) {
- descriptor_sets.buffer_binding_set_layout = descriptor_set_layouts[i];
- descriptor_sets.buffer_binding_set_map.resize(
- layout_defs[i]->bindings()->size());
- for (int j = 0; j < layout_defs[i]->bindings()->size(); ++j) {
- descriptor_sets.buffer_binding_set_map[j] =
- layout_defs[i]->bindings()->Get(j)->binding();
- }
- }
- }
- }
-
- absl::InlinedVector<VkPushConstantRange, 1> push_constant_ranges;
- if (pipeline_layout_def.push_constant_ranges()) {
- const auto& range_defs = *pipeline_layout_def.push_constant_ranges();
- push_constant_ranges.resize(range_defs.size());
- for (int i = 0; i < push_constant_ranges.size(); ++i) {
- if (!range_defs[i]) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Missing push constant range def";
- }
- push_constant_ranges[i].stageFlags = range_defs[i]->stage_flags();
- push_constant_ranges[i].offset = range_defs[i]->offset();
- push_constant_ranges[i].size = range_defs[i]->size();
- }
- }
-
- // Scan for an existing pipeline layout that matches the descriptor sets.
- for (auto& entry : pipeline_layout_cache_) {
- if (entry.descriptor_set_layouts.size() != descriptor_set_layouts.size() ||
- entry.push_constant_ranges.size() != push_constant_ranges.size()) {
- continue;
- }
- if (std::memcmp(
- descriptor_set_layouts.data(), entry.descriptor_set_layouts.data(),
- descriptor_set_layouts.size() * sizeof(VkDescriptorSetLayout)) ==
- 0 &&
- std::memcmp(
- push_constant_ranges.data(), entry.push_constant_ranges.data(),
- push_constant_ranges.size() * sizeof(VkPushConstantRange)) == 0) {
- return &entry;
- }
- }
-
- VkPipelineLayoutCreateInfo create_info;
- create_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
- create_info.pNext = nullptr;
- create_info.flags = 0;
- create_info.setLayoutCount = descriptor_set_layouts.size();
- create_info.pSetLayouts = descriptor_set_layouts.data();
- create_info.pushConstantRangeCount = push_constant_ranges.size();
- create_info.pPushConstantRanges = push_constant_ranges.data();
-
- // Create and insert into the cache.
- VkPipelineLayout pipeline_layout = VK_NULL_HANDLE;
- VK_RETURN_IF_ERROR(syms()->vkCreatePipelineLayout(
- *logical_device_, &create_info, logical_device_->allocator(),
- &pipeline_layout));
- pipeline_layout_cache_.push_back({std::move(descriptor_set_layouts),
- std::move(push_constant_ranges),
- pipeline_layout, descriptor_sets});
- return &pipeline_layout_cache_.back();
-}
-
-StatusOr<VkDescriptorSetLayout>
-PipelineCache::LookupOrInsertDescriptorSetLayout(
- const VkDescriptorSetLayoutDef& descriptor_set_layout_def) {
- // Build a list of bindings in the set.
- // If we were being fast we would hash the bindings and directly lookup
- // without doing this allocation.
- absl::InlinedVector<VkDescriptorSetLayoutBinding, 4> bindings;
- if (descriptor_set_layout_def.bindings()) {
- const auto& binding_defs = *descriptor_set_layout_def.bindings();
- bindings.resize(binding_defs.size());
- for (int i = 0; i < binding_defs.size(); ++i) {
- bindings[i].binding = binding_defs[i]->binding();
- bindings[i].descriptorType =
- static_cast<VkDescriptorType>(binding_defs[i]->descriptor_type());
- bindings[i].descriptorCount = binding_defs[i]->descriptor_count();
- bindings[i].stageFlags = binding_defs[i]->stage_flags();
- bindings[i].pImmutableSamplers = nullptr;
- }
- }
-
- // Scan for an existing descriptor set layout that matches the bindings.
- for (auto& entry : descriptor_set_layout_cache_) {
- if (entry.bindings.size() != bindings.size()) continue;
- if (std::memcmp(bindings.data(), entry.bindings.data(),
- bindings.size() * sizeof(VkDescriptorSetLayoutBinding)) ==
- 0) {
- return entry.descriptor_set_layout;
- }
- }
-
- VkDescriptorSetLayoutCreateInfo create_info;
- create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
- create_info.pNext = nullptr;
- create_info.flags = 0;
- if (logical_device_->enabled_extensions().push_descriptors) {
- // Note that we can *only* use push descriptor sets if we set this create
- // flag. That's fine, though, as the command buffer recording logic always
- // prefers the extension if available.
- create_info.flags |=
- VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
- }
- create_info.bindingCount = bindings.size();
- create_info.pBindings = bindings.data();
-
- // Create and insert into the cache.
- VkDescriptorSetLayout descriptor_set_layout = VK_NULL_HANDLE;
- VK_RETURN_IF_ERROR(syms()->vkCreateDescriptorSetLayout(
- *logical_device_, &create_info, logical_device_->allocator(),
- &descriptor_set_layout));
- descriptor_set_layout_cache_.push_back(
- {std::move(bindings), descriptor_set_layout});
- return descriptor_set_layout;
-}
-
-void PipelineCache::ClearLayoutCaches() {
- absl::MutexLock lock(&mutex_);
- for (auto& entry : pipeline_layout_cache_) {
- syms()->vkDestroyPipelineLayout(*logical_device_, entry.pipeline_layout,
- logical_device_->allocator());
- }
- pipeline_layout_cache_.clear();
- for (auto& entry : descriptor_set_layout_cache_) {
- syms()->vkDestroyDescriptorSetLayout(*logical_device_,
- entry.descriptor_set_layout,
- logical_device_->allocator());
- }
- descriptor_set_layout_cache_.clear();
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/pipeline_cache.h b/iree/hal/vulkan/pipeline_cache.h
deleted file mode 100644
index 1847f36..0000000
--- a/iree/hal/vulkan/pipeline_cache.h
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_PIPELINE_CACHE_H_
-#define IREE_HAL_VULKAN_PIPELINE_CACHE_H_
-
-#include <vulkan/vulkan.h>
-
-#include "absl/base/thread_annotations.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/hal/executable.h"
-#include "iree/hal/executable_cache.h"
-#include "iree/hal/vulkan/handle_util.h"
-#include "iree/hal/vulkan/pipeline_executable.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-class PipelineCache final : public ExecutableCache {
- public:
- explicit PipelineCache(const ref_ptr<VkDeviceHandle>& logical_device);
- ~PipelineCache() override;
-
- const ref_ptr<DynamicSymbols>& syms() const {
- return logical_device_->syms();
- }
-
- bool CanPrepareFormat(ExecutableFormat format) const override;
-
- StatusOr<ref_ptr<Executable>> PrepareExecutable(
- ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override;
-
- private:
- struct CachedDescriptorSetLayout {
- absl::InlinedVector<VkDescriptorSetLayoutBinding, 4> bindings;
- VkDescriptorSetLayout descriptor_set_layout;
- };
- struct CachedPipelineLayout {
- absl::InlinedVector<VkDescriptorSetLayout, 4> descriptor_set_layouts;
- absl::InlinedVector<VkPushConstantRange, 1> push_constant_ranges;
- VkPipelineLayout pipeline_layout;
- PipelineDescriptorSets descriptor_sets;
- };
-
- StatusOr<const CachedPipelineLayout*> LookupOrInsertPipelineLayout(
- const VkPipelineLayoutDef& pipeline_layout_def)
- ABSL_LOCKS_EXCLUDED(mutex_);
- StatusOr<VkDescriptorSetLayout> LookupOrInsertDescriptorSetLayout(
- const VkDescriptorSetLayoutDef& descriptor_set_layout_def)
- ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
- void ClearLayoutCaches() ABSL_LOCKS_EXCLUDED(mutex_);
-
- ref_ptr<VkDeviceHandle> logical_device_;
-
- // A "cache" of descriptor set and pipeline layouts for various values.
- // We never evict and just do a simple linear scan on lookup. This is fine for
- // now as we only support a single descriptor type and really we only need to
- // check for binding count. As we go toward more general usage of descriptors
- // (images/etc) we will likely want to change this to a real cache.
- absl::Mutex mutex_;
- std::vector<CachedDescriptorSetLayout> descriptor_set_layout_cache_
- ABSL_GUARDED_BY(mutex_);
- std::vector<CachedPipelineLayout> pipeline_layout_cache_
- ABSL_GUARDED_BY(mutex_);
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_PIPELINE_CACHE_H_
diff --git a/iree/hal/vulkan/pipeline_executable.cc b/iree/hal/vulkan/pipeline_executable.cc
deleted file mode 100644
index 9c1514e..0000000
--- a/iree/hal/vulkan/pipeline_executable.cc
+++ /dev/null
@@ -1,191 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/pipeline_executable.h"
-
-#include "absl/container/inlined_vector.h"
-#include "iree/base/memory.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/vulkan/status_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-namespace {
-
-// Generates the baked specialization constant data based on the flatbuffer.
-// We only support uint32_t right now so this is easy.
-// Note that the returned vectors are referenced by pointers in |out_info| and
-// must remain valid until the info is no longer in use.
-std::pair<std::vector<VkSpecializationMapEntry>, std::vector<uint8_t>>
-PopulateSpecializationInfo(const VkSpecializationInfoDef* info_def) {
- int entry_count =
- info_def && info_def->map_entries() ? info_def->map_entries()->size() : 0;
- if (!entry_count) {
- return {};
- }
-
- std::vector<VkSpecializationMapEntry> entries;
- entries.reserve(entry_count);
- std::vector<uint8_t> data;
- data.resize(entry_count * sizeof(uint32_t));
-
- uint32_t offset = 0;
- for (const auto* entry_def : *info_def->map_entries()) {
- if (!entry_def) continue;
- entries.push_back({});
- auto& entry = entries.back();
- entry.constantID = entry_def->constant_id();
- entry.offset = offset;
- entry.size = sizeof(uint32_t);
- uint32_t value = entry_def->uint32_value();
- std::memcpy(data.data() + offset, &value, sizeof(value));
- offset += entry.size;
- }
-
- return {std::move(entries), std::move(data)};
-}
-
-} // namespace
-
-// static
-StatusOr<ref_ptr<PipelineExecutable>> PipelineExecutable::Create(
- const ref_ptr<VkDeviceHandle>& logical_device,
- VkPipelineCache pipeline_cache, VkPipelineLayout pipeline_layout,
- PipelineDescriptorSets descriptor_sets, ExecutableCachingModeBitfield mode,
- const SpirVExecutableDef& spirv_executable_def) {
- IREE_TRACE_SCOPE0("PipelineExecutable::Create");
- const auto& syms = logical_device->syms();
- if (!spirv_executable_def.entry_points() ||
- spirv_executable_def.entry_points()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined";
- }
- if (!spirv_executable_def.code()) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No SPIR-V code present";
- }
- const auto& code = *spirv_executable_def.code();
-
- // Create the shader module.
- VkShaderModuleCreateInfo shader_module_create_info;
- shader_module_create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
- shader_module_create_info.pNext = nullptr;
- shader_module_create_info.flags = 0;
- shader_module_create_info.codeSize = code.size() * sizeof(uint32_t);
- shader_module_create_info.pCode = code.data();
- VkShaderModule shader_module = VK_NULL_HANDLE;
- VK_RETURN_IF_ERROR(
- syms->vkCreateShaderModule(*logical_device, &shader_module_create_info,
- logical_device->allocator(), &shader_module));
-
- // We only need to keep this around during pipeline creation so ensure we
- // always clean it up when we exit this function.
- auto shader_module_cleanup = MakeCleanup([&logical_device, shader_module]() {
- logical_device->syms()->vkDestroyShaderModule(
- *logical_device, shader_module, logical_device->allocator());
- });
-
- // Specialization info is currently constant against all entry points.
- std::vector<VkSpecializationMapEntry> spec_entries;
- std::vector<uint8_t> spec_data;
- std::tie(spec_entries, spec_data) =
- PopulateSpecializationInfo(spirv_executable_def.specialization_info());
- VkSpecializationInfo specialization_info;
- specialization_info.mapEntryCount = spec_entries.size();
- specialization_info.pMapEntries = spec_entries.data();
- specialization_info.dataSize = spec_data.size();
- specialization_info.pData = spec_data.data();
-
- // Create pipelines for each entry point.
- const auto& entry_points = *spirv_executable_def.entry_points();
- absl::InlinedVector<VkComputePipelineCreateInfo, 1> pipeline_create_infos;
- pipeline_create_infos.resize(entry_points.size());
- for (int entry_ordinal = 0; entry_ordinal < entry_points.size();
- ++entry_ordinal) {
- auto& pipeline_create_info = pipeline_create_infos[entry_ordinal];
- pipeline_create_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
- pipeline_create_info.pNext = nullptr;
- pipeline_create_info.flags = 0;
- if (!AllBitsSet(mode, ExecutableCachingMode::kAllowOptimization)) {
- pipeline_create_info.flags |= VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT;
- }
- if (entry_ordinal == 0) {
- pipeline_create_info.flags |= VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT;
- } else {
- pipeline_create_info.flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT;
- }
- pipeline_create_info.layout = pipeline_layout;
- pipeline_create_info.basePipelineHandle = VK_NULL_HANDLE;
- pipeline_create_info.basePipelineIndex = 0;
- auto& stage_create_info = pipeline_create_info.stage;
- stage_create_info.sType =
- VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
- stage_create_info.pNext = nullptr;
- stage_create_info.flags = 0;
- stage_create_info.stage = VK_SHADER_STAGE_COMPUTE_BIT;
- stage_create_info.module = shader_module;
- stage_create_info.pName = entry_points[entry_ordinal]->c_str();
- stage_create_info.pSpecializationInfo = &specialization_info;
- }
- absl::InlinedVector<VkPipeline, 1> pipelines;
- pipelines.resize(entry_points.size());
-
- // Some ICDs appear to leak in here, out of our control.
- // Warning: leak checks remain disabled if an error is returned.
- IREE_DISABLE_LEAK_CHECKS();
- VK_RETURN_IF_ERROR(syms->vkCreateComputePipelines(
- *logical_device, pipeline_cache, pipeline_create_infos.size(),
- pipeline_create_infos.data(), logical_device->allocator(),
- pipelines.data()));
- IREE_ENABLE_LEAK_CHECKS();
-
- auto executable =
- make_ref<PipelineExecutable>(CtorKey{}, logical_device, pipeline_layout,
- descriptor_sets, std::move(pipelines));
- executable->tag_ =
- spirv_executable_def.tag() ? spirv_executable_def.tag()->str() : "";
- return executable;
-}
-
-PipelineExecutable::PipelineExecutable(
- CtorKey ctor_key, const ref_ptr<VkDeviceHandle>& logical_device,
- VkPipelineLayout pipeline_layout, PipelineDescriptorSets descriptor_sets,
- absl::InlinedVector<VkPipeline, 1> pipelines)
- : logical_device_(add_ref(logical_device)),
- pipeline_layout_(pipeline_layout),
- descriptor_sets_(descriptor_sets),
- pipelines_(std::move(pipelines)) {}
-
-PipelineExecutable::~PipelineExecutable() {
- IREE_TRACE_SCOPE0("PipelineExecutable::dtor");
- for (auto pipeline : pipelines_) {
- syms()->vkDestroyPipeline(*logical_device_, pipeline,
- logical_device_->allocator());
- }
- pipelines_.clear();
-}
-
-StatusOr<VkPipeline> PipelineExecutable::GetPipelineForEntryPoint(
- int entry_ordinal) const {
- if (entry_ordinal < 0 || entry_ordinal >= pipelines_.size()) {
- return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal";
- }
- return pipelines_[entry_ordinal];
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/pipeline_executable.h b/iree/hal/vulkan/pipeline_executable.h
deleted file mode 100644
index 1ca21aa..0000000
--- a/iree/hal/vulkan/pipeline_executable.h
+++ /dev/null
@@ -1,91 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
-#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
-
-#include <vulkan/vulkan.h>
-
-#include <vector>
-
-#include "absl/container/inlined_vector.h"
-#include "iree/base/status.h"
-#include "iree/hal/executable.h"
-#include "iree/hal/executable_cache.h"
-#include "iree/hal/executable_spec.h"
-#include "iree/hal/vulkan/handle_util.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-struct PipelineDescriptorSets {
- uint32_t buffer_binding_set;
- VkDescriptorSetLayout buffer_binding_set_layout;
- absl::InlinedVector<uint32_t, 8> buffer_binding_set_map;
-};
-
-class PipelineExecutable final : public Executable {
- public:
- static StatusOr<ref_ptr<PipelineExecutable>> Create(
- const ref_ptr<VkDeviceHandle>& logical_device,
- VkPipelineCache pipeline_cache, VkPipelineLayout pipeline_layout,
- PipelineDescriptorSets descriptor_sets,
- ExecutableCachingModeBitfield mode,
- const SpirVExecutableDef& spirv_executable_def);
-
- // Private constructor.
- struct CtorKey {
- private:
- friend class PipelineExecutable;
- CtorKey() = default;
- };
- PipelineExecutable(CtorKey ctor_key,
- const ref_ptr<VkDeviceHandle>& logical_device,
- VkPipelineLayout pipeline_layout,
- PipelineDescriptorSets descriptor_sets,
- absl::InlinedVector<VkPipeline, 1> pipelines);
- ~PipelineExecutable() override;
-
- const ref_ptr<DynamicSymbols>& syms() const {
- return logical_device_->syms();
- }
-
- bool supports_debugging() const override { return false; }
-
- VkPipelineLayout pipeline_layout() const { return pipeline_layout_; }
- const PipelineDescriptorSets& descriptor_sets() const {
- return descriptor_sets_;
- }
-
- bool is_matmul() const { return tag_ == "__matmul__"; }
-
- StatusOr<VkPipeline> GetPipelineForEntryPoint(int entry_ordinal) const;
-
- private:
- ref_ptr<VkDeviceHandle> logical_device_;
- VkPipelineLayout pipeline_layout_;
- PipelineDescriptorSets descriptor_sets_;
- std::string tag_;
-
- // One pipeline per entry point.
- absl::InlinedVector<VkPipeline, 1> pipelines_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
diff --git a/iree/hal/vulkan/status_util.cc b/iree/hal/vulkan/status_util.cc
deleted file mode 100644
index 4b08e21..0000000
--- a/iree/hal/vulkan/status_util.cc
+++ /dev/null
@@ -1,231 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/status_util.h"
-
-#include "iree/base/status.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-Status VkResultToStatus(VkResult result) {
- switch (result) {
- // Success codes.
- case VK_SUCCESS:
- // Command successfully completed.
- return OkStatus();
- case VK_NOT_READY:
- // A fence or query has not yet completed.
- return OkStatus();
- case VK_TIMEOUT:
- // A wait operation has not completed in the specified time.
- return OkStatus();
- case VK_EVENT_SET:
- // An event is signaled.
- return OkStatus();
- case VK_EVENT_RESET:
- // An event is unsignaled.
- return OkStatus();
- case VK_INCOMPLETE:
- // A return array was too small for the result.
- return OkStatus();
- case VK_SUBOPTIMAL_KHR:
- // A swapchain no longer matches the surface properties exactly, but can
- // still be used to present to the surface successfully.
- return OkStatus();
-
- // Error codes.
- case VK_ERROR_OUT_OF_HOST_MEMORY:
- // A host memory allocation has failed.
- return ResourceExhaustedError("VK_ERROR_OUT_OF_HOST_MEMORY");
- case VK_ERROR_OUT_OF_DEVICE_MEMORY:
- // A device memory allocation has failed.
- return ResourceExhaustedError("VK_ERROR_OUT_OF_DEVICE_MEMORY");
- case VK_ERROR_INITIALIZATION_FAILED:
- // Initialization of an object could not be completed for
- // implementation-specific reasons.
- return InternalError("VK_ERROR_INITIALIZATION_FAILED");
- case VK_ERROR_DEVICE_LOST:
- // The logical or physical device has been lost.
- //
- // A logical device may become lost for a number of
- // implementation-specific reasons, indicating that pending and future
- // command execution may fail and cause resources and backing memory to
- // become undefined.
- //
- // Typical reasons for device loss will include things like execution
- // timing out (to prevent denial of service), power management events,
- // platform resource management, or implementation errors.
- //
- // When this happens, certain commands will return
- // VK_ERROR_DEVICE_LOST (see Error Codes for a list of such
- // commands). After any such event, the logical device is considered lost.
- // It is not possible to reset the logical device to a non-lost state,
- // however the lost state is specific to a logical device (VkDevice), and
- // the corresponding physical device (VkPhysicalDevice) may be otherwise
- // unaffected.
- //
- // In some cases, the physical device may also be lost, and attempting to
- // create a new logical device will fail, returning VK_ERROR_DEVICE_LOST.
- // This is usually indicative of a problem with the underlying
- // implementation, or its connection to the host. If the physical device
- // has not been lost, and a new logical device is successfully created
- // from that physical device, it must be in the non-lost state.
- //
- // Whilst logical device loss may be recoverable, in the case of physical
- // device loss, it is unlikely that an application will be able to recover
- // unless additional, unaffected physical devices exist on the system. The
- // error is largely informational and intended only to inform the user
- // that a platform issue has occurred, and should be investigated further.
- // For example, underlying hardware may have developed a fault or become
- // physically disconnected from the rest of the system. In many cases,
- // physical device loss may cause other more serious issues such as the
- // operating system crashing; in which case it may not be reported via the
- // Vulkan API.
- //
- // Undefined behavior caused by an application error may cause a device to
- // become lost. However, such undefined behavior may also cause
- // unrecoverable damage to the process, and it is then not guaranteed that
- // the API objects, including the VkPhysicalDevice or the VkInstance are
- // still valid or that the error is recoverable.
- //
- // When a device is lost, its child objects are not implicitly destroyed
- // and their handles are still valid. Those objects must still be
- // destroyed before their parents or the device can be destroyed (see the
- // Object Lifetime section). The host address space corresponding to
- // device memory mapped using vkMapMemory is still valid, and host memory
- // accesses to these mapped regions are still valid, but the contents are
- // undefined. It is still legal to call any API command on the device and
- // child objects.
- //
- // Once a device is lost, command execution may fail, and commands that
- // return a VkResult may return VK_ERROR_DEVICE_LOST.
- // Commands that do not allow run-time errors must still operate correctly
- // for valid usage and, if applicable, return valid data.
- //
- // Commands that wait indefinitely for device execution (namely
- // vkDeviceWaitIdle, vkQueueWaitIdle, vkWaitForFences with a maximum
- // timeout, and vkGetQueryPoolResults with the VK_QUERY_RESULT_WAIT_BIT
- // bit set in flags) must return in finite time even in the case
- // of a lost device, and return either VK_SUCCESS or
- // VK_ERROR_DEVICE_LOST. For any command that may return
- // VK_ERROR_DEVICE_LOST, for the purpose of determining whether a
- // command buffer is in the pending state, or whether resources are
- // considered in-use by the device, a return value of
- // VK_ERROR_DEVICE_LOST is equivalent to VK_SUCCESS.
- return InternalError("VK_ERROR_DEVICE_LOST");
- case VK_ERROR_MEMORY_MAP_FAILED:
- // Mapping of a memory object has failed.
- return InternalError("VK_ERROR_MEMORY_MAP_FAILED");
- case VK_ERROR_LAYER_NOT_PRESENT:
- // A requested layer is not present or could not be loaded.
- return UnimplementedError("VK_ERROR_LAYER_NOT_PRESENT");
- case VK_ERROR_EXTENSION_NOT_PRESENT:
- // A requested extension is not supported.
- return UnimplementedError("VK_ERROR_EXTENSION_NOT_PRESENT");
- case VK_ERROR_FEATURE_NOT_PRESENT:
- // A requested feature is not supported.
- return UnimplementedError("VK_ERROR_FEATURE_NOT_PRESENT");
- case VK_ERROR_INCOMPATIBLE_DRIVER:
- // The requested version of Vulkan is not supported by the driver or is
- // otherwise incompatible for implementation-specific reasons.
- return FailedPreconditionError("VK_ERROR_INCOMPATIBLE_DRIVER");
- case VK_ERROR_TOO_MANY_OBJECTS:
- // Too many objects of the type have already been created.
- return ResourceExhaustedError("VK_ERROR_TOO_MANY_OBJECTS");
- case VK_ERROR_FORMAT_NOT_SUPPORTED:
- // A requested format is not supported on this device.
- return UnimplementedError("VK_ERROR_FORMAT_NOT_SUPPORTED");
- case VK_ERROR_FRAGMENTED_POOL:
- // A pool allocation has failed due to fragmentation of the pool’s memory.
- // This must only be returned if no attempt to allocate host or device
- // memory was made to accommodate the new allocation.
- return ResourceExhaustedError("VK_ERROR_FRAGMENTED_POOL");
- case VK_ERROR_OUT_OF_POOL_MEMORY:
- // A pool memory allocation has failed. This must only be returned if no
- // attempt to allocate host or device memory was made to accommodate the
- // new allocation. If the failure was definitely due to fragmentation of
- // the pool, VK_ERROR_FRAGMENTED_POOL should be returned instead.
- return ResourceExhaustedError("VK_ERROR_OUT_OF_POOL_MEMORY");
- case VK_ERROR_INVALID_EXTERNAL_HANDLE:
- // An external handle is not a valid handle of the specified type.
- return InvalidArgumentError("VK_ERROR_INVALID_EXTERNAL_HANDLE");
- case VK_ERROR_SURFACE_LOST_KHR:
- // A surface is no longer available.
- return UnavailableError("VK_ERROR_SURFACE_LOST_KHR");
- case VK_ERROR_NATIVE_WINDOW_IN_USE_KHR:
- // The requested window is already in use by Vulkan or another API in a
- // manner which prevents it from being used again.
- return InvalidArgumentError("VK_ERROR_NATIVE_WINDOW_IN_USE_KHR");
- case VK_ERROR_OUT_OF_DATE_KHR:
- // A surface has changed in such a way that it is no longer compatible
- // with the swapchain, and further presentation requests using the
- // swapchain will fail. Applications must query the new surface properties
- // and recreate their swapchain if they wish to continue presenting to the
- // surface.
- return FailedPreconditionError("VK_ERROR_OUT_OF_DATE_KHR");
- case VK_ERROR_INCOMPATIBLE_DISPLAY_KHR:
- // The display used by a swapchain does not use the same presentable image
- // layout, or is incompatible in a way that prevents sharing an image.
- return InvalidArgumentError("VK_ERROR_INCOMPATIBLE_DISPLAY_KHR");
- case VK_ERROR_VALIDATION_FAILED_EXT:
- // Validation layer testing failed. It is not expected that an
- // application would see this this error code during normal use of the
- // validation layers.
- return InvalidArgumentError("VK_ERROR_VALIDATION_FAILED_EXT");
- case VK_ERROR_INVALID_SHADER_NV:
- // One or more shaders failed to compile or link. More details are
- // reported back to the application when the validation layer is enabled
- // using the extension VK_EXT_debug_report.
- return InvalidArgumentError("VK_ERROR_INVALID_SHADER_NV");
- case VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT:
- // When creating an image with
- // VkImageDrmFormatModifierExplicitCreateInfoEXT, it is the application’s
- // responsibility to satisfy all Valid Usage requirements. However, the
- // implementation must validate that the provided pPlaneLayouts, when
- // combined with the provided drmFormatModifier and other creation
- // parameters in VkImageCreateInfo and its pNext chain, produce a valid
- // image. (This validation is necessarily implementation-dependent and
- // outside the scope of Vulkan, and therefore not described by Valid Usage
- // requirements). If this validation fails, then vkCreateImage returns
- // VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT.
- return InvalidArgumentError(
- "VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT");
- case VK_ERROR_FRAGMENTATION_EXT:
- // A descriptor pool creation has failed due to fragmentation.
- return ResourceExhaustedError("VK_ERROR_FRAGMENTATION_EXT");
- case VK_ERROR_NOT_PERMITTED_EXT:
- // When creating a queue, the caller does not have sufficient privileges
- // to request to acquire a priority above the default priority
- // (VK_QUEUE_GLOBAL_PRIORITY_MEDIUM_EXT).
- return PermissionDeniedError("VK_ERROR_NOT_PERMITTED_EXT");
- case VK_ERROR_INVALID_DEVICE_ADDRESS_EXT:
- // A buffer creation failed because the requested address is not
- // available.
- return OutOfRangeError("VK_ERROR_INVALID_DEVICE_ADDRESS_EXT");
- case VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT:
- // An operation on a swapchain created with
- // VK_FULL_SCREEN_EXCLUSIVE_APPLICATION_CONTROLLED_EXT failed as it did
- // not have exlusive full-screen access. This may occur due to
- // implementation-dependent reasons, outside of the application’s control.
- return UnavailableError("VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT");
- default:
- return UnknownError(std::to_string(result));
- }
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/status_util.h b/iree/hal/vulkan/status_util.h
deleted file mode 100644
index 85f9def..0000000
--- a/iree/hal/vulkan/status_util.h
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_STATUS_UTIL_H_
-#define IREE_HAL_VULKAN_STATUS_UTIL_H_
-
-#include <vulkan/vulkan.h>
-
-#include "iree/base/status.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// RETURN_IF_ERROR but implicitly converts the VkResult return value to
-// a Status.
-//
-// Usage:
-// VK_RETURN_IF_ERROR(vkDoThing(...));
-#define VK_RETURN_IF_ERROR(expr) \
- RETURN_IF_ERROR(::iree::hal::vulkan::VkResultToStatus(expr))
-
-// CHECK_OK but implicitly converts the VkResults return value to a
-// Status and checks that it is OkStatus.
-//
-// Usage:
-// VK_CHECK_OK(vkDoThing(...));
-#define VK_CHECK_OK(expr) CHECK_OK(::iree::hal::vulkan::VkResultToStatus(expr))
-
-// Converts a VkResult to a Status object.
-//
-// Vulkan considers the following as "success codes" and users should ensure
-// they first check the result prior to converting:
-//
-// - VK_SUCCESS -> OkStatus()
-// - VK_NOT_READY -> OkStatus()
-// - VK_TIMEOUT -> OkStatus()
-// - VK_EVENT_SET -> OkStatus()
-// - VK_EVENT_RESET -> OkStatus()
-// - VK_INCOMPLETE -> OkStatus()
-// - VK_SUBOPTIMAL_KHR -> OkStatus()
-//
-// The rest are considered as "error codes":
-//
-// - VK_ERROR_OUT_OF_HOST_MEMORY -> ResourceExhaustedError("VK...")
-// - VK_ERROR_OUT_OF_DEVICE_MEMORY -> ResourceExhaustedError("VK...")
-// - VK_ERROR_INITIALIZATION_FAILED -> InternalError("VK...")
-// - VK_ERROR_DEVICE_LOST -> InternalError("VK...")
-// - VK_ERROR_MEMORY_MAP_FAILED -> InternalError("VK...")
-// - VK_ERROR_LAYER_NOT_PRESENT -> NotFoundError("VK...")
-// - VK_ERROR_EXTENSION_NOT_PRESENT -> NotFoundError("VK...")
-// - VK_ERROR_FEATURE_NOT_PRESENT -> NotFoundError("VK...")
-// - VK_ERROR_INCOMPATIBLE_DRIVER -> FailedPreconditionError("VK...")
-// - VK_ERROR_TOO_MANY_OBJECTS -> ResourceExhaustedError("VK...")
-// - VK_ERROR_FORMAT_NOT_SUPPORTED -> UnimplementedError("VK...")
-// - VK_ERROR_FRAGMENTED_POOL -> ResourceExhaustedError("VK...")
-// - VK_ERROR_OUT_OF_POOL_MEMORY -> ResourceExhaustedError("VK...")
-// - VK_ERROR_INVALID_EXTERNAL_HANDLE -> InvalidArgumentError("VK...")
-// - VK_ERROR_SURFACE_LOST_KHR -> InternalError("VK...")
-// - VK_ERROR_NATIVE_WINDOW_IN_USE_KHR -> InternalError("VK...")
-// - VK_ERROR_OUT_OF_DATE_KHR -> InternalError("VK...")
-// - VK_ERROR_INCOMPATIBLE_DISPLAY_KHR -> InternalError("VK...")
-// - VK_ERROR_VALIDATION_FAILED_EXT -> InternalError("VK...")
-// - VK_ERROR_INVALID_SHADER_NV -> InternalError("VK...")
-// - VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT -> InternalError
-// - VK_ERROR_FRAGMENTATION_EXT -> ResourceExhaustedError("VK...")
-// - VK_ERROR_NOT_PERMITTED_EXT -> PermissionDeniedError("VK...")
-// - VK_ERROR_INVALID_DEVICE_ADDRESS_EXT -> OutOfRangeError("VK...")
-// - VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT -> InternalError("VK...")
-Status VkResultToStatus(VkResult result);
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_STATUS_UTIL_H_
diff --git a/iree/hal/vulkan/vma_allocator.cc b/iree/hal/vulkan/vma_allocator.cc
deleted file mode 100644
index 181d6e1..0000000
--- a/iree/hal/vulkan/vma_allocator.cc
+++ /dev/null
@@ -1,248 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/vma_allocator.h"
-
-#include "absl/flags/flag.h"
-#include "absl/memory/memory.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/buffer.h"
-#include "iree/hal/vulkan/status_util.h"
-#include "iree/hal/vulkan/vma_buffer.h"
-
-#if VMA_RECORDING_ENABLED
-ABSL_FLAG(std::string, vma_recording_file, "",
- "File path to write a CSV containing the VMA recording.");
-ABSL_FLAG(bool, vma_recording_flush_after_call, false,
- "Flush the VMA recording file after every call (useful if "
- "crashing/not exiting cleanly).");
-#endif // VMA_RECORDING_ENABLED
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-// static
-StatusOr<std::unique_ptr<VmaAllocator>> VmaAllocator::Create(
- VkPhysicalDevice physical_device,
- const ref_ptr<VkDeviceHandle>& logical_device) {
- IREE_TRACE_SCOPE0("VmaAllocator::Create");
-
- const auto& syms = logical_device->syms();
- VmaVulkanFunctions vulkan_fns;
- vulkan_fns.vkGetPhysicalDeviceProperties =
- syms->vkGetPhysicalDeviceProperties;
- vulkan_fns.vkGetPhysicalDeviceMemoryProperties =
- syms->vkGetPhysicalDeviceMemoryProperties;
- vulkan_fns.vkAllocateMemory = syms->vkAllocateMemory;
- vulkan_fns.vkFreeMemory = syms->vkFreeMemory;
- vulkan_fns.vkMapMemory = syms->vkMapMemory;
- vulkan_fns.vkUnmapMemory = syms->vkUnmapMemory;
- vulkan_fns.vkFlushMappedMemoryRanges = syms->vkFlushMappedMemoryRanges;
- vulkan_fns.vkInvalidateMappedMemoryRanges =
- syms->vkInvalidateMappedMemoryRanges;
- vulkan_fns.vkBindBufferMemory = syms->vkBindBufferMemory;
- vulkan_fns.vkBindImageMemory = syms->vkBindImageMemory;
- vulkan_fns.vkGetBufferMemoryRequirements =
- syms->vkGetBufferMemoryRequirements;
- vulkan_fns.vkGetImageMemoryRequirements = syms->vkGetImageMemoryRequirements;
- vulkan_fns.vkCreateBuffer = syms->vkCreateBuffer;
- vulkan_fns.vkDestroyBuffer = syms->vkDestroyBuffer;
- vulkan_fns.vkCreateImage = syms->vkCreateImage;
- vulkan_fns.vkDestroyImage = syms->vkDestroyImage;
- vulkan_fns.vkCmdCopyBuffer = syms->vkCmdCopyBuffer;
-
- VmaRecordSettings record_settings;
-#if VMA_RECORDING_ENABLED
- record_settings.flags = absl::GetFlag(FLAGS_vma_recording_flush_after_call)
- ? VMA_RECORD_FLUSH_AFTER_CALL_BIT
- : 0;
- record_settings.pFilePath = absl::GetFlag(FLAGS_vma_recording_file).c_str();
-#else
- record_settings.flags = 0;
- record_settings.pFilePath = nullptr;
-#endif // VMA_RECORDING_ENABLED
-
- VmaAllocatorCreateInfo create_info;
- create_info.flags = 0;
- create_info.physicalDevice = physical_device;
- create_info.device = *logical_device;
- create_info.preferredLargeHeapBlockSize = 64 * 1024 * 1024;
- create_info.pAllocationCallbacks = logical_device->allocator();
- create_info.pDeviceMemoryCallbacks = nullptr;
- create_info.frameInUseCount = 0;
- create_info.pHeapSizeLimit = nullptr;
- create_info.pVulkanFunctions = &vulkan_fns;
- create_info.pRecordSettings = &record_settings;
- ::VmaAllocator vma = VK_NULL_HANDLE;
- VK_RETURN_IF_ERROR(vmaCreateAllocator(&create_info, &vma));
-
- auto allocator =
- absl::WrapUnique(new VmaAllocator(physical_device, logical_device, vma));
- // TODO(benvanik): query memory properties/types.
- return allocator;
-}
-
-VmaAllocator::VmaAllocator(VkPhysicalDevice physical_device,
- const ref_ptr<VkDeviceHandle>& logical_device,
- ::VmaAllocator vma)
- : physical_device_(physical_device),
- logical_device_(add_ref(logical_device)),
- vma_(vma) {}
-
-VmaAllocator::~VmaAllocator() {
- IREE_TRACE_SCOPE0("VmaAllocator::dtor");
- vmaDestroyAllocator(vma_);
-}
-
-bool VmaAllocator::CanUseBufferLike(Allocator* source_allocator,
- MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- BufferUsageBitfield intended_usage) const {
- // TODO(benvanik): ensure there is a memory type that can satisfy the request.
- return source_allocator == this;
-}
-
-bool VmaAllocator::CanAllocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) const {
- // TODO(benvnik): ensure there is a memory type that can satisfy the request.
- return true;
-}
-
-Status VmaAllocator::MakeCompatible(MemoryTypeBitfield* memory_type,
- BufferUsageBitfield* buffer_usage) const {
- // TODO(benvanik): mutate to match supported memory types.
- return OkStatus();
-}
-
-StatusOr<ref_ptr<VmaBuffer>> VmaAllocator::AllocateInternal(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- MemoryAccessBitfield allowed_access, size_t allocation_size,
- VmaAllocationCreateFlags flags) {
- IREE_TRACE_SCOPE0("VmaAllocator::AllocateInternal");
-
- VkBufferCreateInfo buffer_create_info;
- buffer_create_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
- buffer_create_info.pNext = nullptr;
- buffer_create_info.flags = 0;
- buffer_create_info.size = allocation_size;
- buffer_create_info.usage = 0;
- if (AllBitsSet(buffer_usage, BufferUsage::kTransfer)) {
- buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
- buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_DST_BIT;
- }
- if (AllBitsSet(buffer_usage, BufferUsage::kDispatch)) {
- buffer_create_info.usage |= VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
- buffer_create_info.usage |= VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
- buffer_create_info.usage |= VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT;
- }
- buffer_create_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
- buffer_create_info.queueFamilyIndexCount = 0;
- buffer_create_info.pQueueFamilyIndices = nullptr;
-
- VmaAllocationCreateInfo allocation_create_info;
- allocation_create_info.flags = flags;
- allocation_create_info.usage = VMA_MEMORY_USAGE_UNKNOWN;
- allocation_create_info.requiredFlags = 0;
- allocation_create_info.preferredFlags = 0;
- allocation_create_info.memoryTypeBits = 0; // Automatic selection.
- allocation_create_info.pool = VK_NULL_HANDLE;
- allocation_create_info.pUserData = nullptr;
- if (AllBitsSet(memory_type, MemoryType::kDeviceLocal)) {
- if (AllBitsSet(memory_type, MemoryType::kHostVisible)) {
- // Device-local, host-visible.
- allocation_create_info.usage = VMA_MEMORY_USAGE_CPU_TO_GPU;
- allocation_create_info.preferredFlags |=
- VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT;
- } else {
- // Device-local only.
- allocation_create_info.usage = VMA_MEMORY_USAGE_GPU_ONLY;
- allocation_create_info.requiredFlags |=
- VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT;
- }
- } else {
- if (AllBitsSet(memory_type, MemoryType::kDeviceVisible)) {
- // Host-local, device-visible.
- allocation_create_info.usage = VMA_MEMORY_USAGE_GPU_TO_CPU;
- } else {
- // Host-local only.
- allocation_create_info.usage = VMA_MEMORY_USAGE_CPU_ONLY;
- }
- }
- if (AllBitsSet(memory_type, MemoryType::kHostCached)) {
- allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
- }
- if (AllBitsSet(memory_type, MemoryType::kHostCoherent)) {
- allocation_create_info.requiredFlags |=
- VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
- }
- if (AllBitsSet(memory_type, MemoryType::kTransient)) {
- allocation_create_info.preferredFlags |=
- VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT;
- }
- if (AllBitsSet(buffer_usage, BufferUsage::kMapping)) {
- allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT;
- }
-
- VkBuffer buffer = VK_NULL_HANDLE;
- VmaAllocation allocation = VK_NULL_HANDLE;
- VmaAllocationInfo allocation_info;
- VK_RETURN_IF_ERROR(vmaCreateBuffer(vma_, &buffer_create_info,
- &allocation_create_info, &buffer,
- &allocation, &allocation_info));
-
- return make_ref<VmaBuffer>(this, memory_type, allowed_access, buffer_usage,
- allocation_size, 0, allocation_size, buffer,
- allocation, allocation_info);
-}
-
-StatusOr<ref_ptr<Buffer>> VmaAllocator::Allocate(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- size_t allocation_size) {
- IREE_TRACE_SCOPE0("VmaAllocator::Allocate");
- return AllocateInternal(memory_type, buffer_usage, MemoryAccess::kAll,
- allocation_size, /*flags=*/0);
-}
-
-StatusOr<ref_ptr<Buffer>> VmaAllocator::AllocateConstant(
- BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) {
- IREE_TRACE_SCOPE0("VmaAllocator::AllocateConstant");
- // TODO(benvanik): import memory to avoid the copy.
- ASSIGN_OR_RETURN(
- auto buffer,
- AllocateInternal(MemoryType::kDeviceLocal | MemoryType::kHostVisible,
- buffer_usage,
- MemoryAccess::kRead | MemoryAccess::kDiscardWrite,
- source_buffer->byte_length(),
- /*flags=*/0));
- RETURN_IF_ERROR(buffer->CopyData(0, source_buffer.get(), 0, kWholeBuffer));
- buffer->set_allowed_access(MemoryAccess::kRead);
- return buffer;
-}
-
-StatusOr<ref_ptr<Buffer>> VmaAllocator::WrapMutable(
- MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access,
- BufferUsageBitfield buffer_usage, void* data, size_t data_length) {
- IREE_TRACE_SCOPE0("VmaAllocator::WrapMutable");
- // TODO(benvanik): import memory.
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Wrapping host memory is not yet implemented";
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/vma_allocator.h b/iree/hal/vulkan/vma_allocator.h
deleted file mode 100644
index 2c0054c..0000000
--- a/iree/hal/vulkan/vma_allocator.h
+++ /dev/null
@@ -1,110 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
-#define IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
-
-#include <vulkan/vulkan.h>
-
-#include <memory>
-
-#include "iree/base/status.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-#include "iree/hal/vulkan/handle_util.h"
-#include "iree/hal/vulkan/internal_vk_mem_alloc.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-class VmaBuffer;
-
-// A HAL allocator using the Vulkan Memory Allocator (VMA) to manage memory.
-// VMA (//third_party/vulkan_memory_allocator) provides dlmalloc-like behavior
-// with suballocations made with various policies (best fit, first fit, etc).
-// This reduces the number of allocations we need from the Vulkan implementation
-// (which can sometimes be limited to as little as 4096 total allowed) and
-// manages higher level allocation semantics like slab allocation and
-// defragmentation.
-//
-// VMA is internally synchronized and the functionality exposed on the HAL
-// interface is thread-safe.
-//
-// More information:
-// https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator
-// https://gpuopen-librariesandsdks.github.io/VulkanMemoryAllocator/html/
-class VmaAllocator final : public Allocator {
- public:
- static StatusOr<std::unique_ptr<VmaAllocator>> Create(
- VkPhysicalDevice physical_device,
- const ref_ptr<VkDeviceHandle>& logical_device);
-
- ~VmaAllocator() override;
-
- const ref_ptr<DynamicSymbols>& syms() const {
- return logical_device_->syms();
- }
-
- ::VmaAllocator vma() const { return vma_; }
-
- bool CanUseBufferLike(Allocator* source_allocator,
- MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- BufferUsageBitfield intended_usage) const override;
-
- bool CanAllocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) const override;
-
- Status MakeCompatible(MemoryTypeBitfield* memory_type,
- BufferUsageBitfield* buffer_usage) const override;
-
- StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type,
- BufferUsageBitfield buffer_usage,
- size_t allocation_size) override;
-
- StatusOr<ref_ptr<Buffer>> AllocateConstant(
- BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) override;
-
- StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield buffer_usage,
- void* data,
- size_t data_length) override;
-
- private:
- VmaAllocator(VkPhysicalDevice physical_device,
- const ref_ptr<VkDeviceHandle>& logical_device,
- ::VmaAllocator vma);
-
- StatusOr<ref_ptr<VmaBuffer>> AllocateInternal(
- MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage,
- MemoryAccessBitfield allowed_access, size_t allocation_size,
- VmaAllocationCreateFlags flags);
-
- VkPhysicalDevice physical_device_;
- ref_ptr<VkDeviceHandle> logical_device_;
-
- // Internally synchronized. We could externally synchronize if we thought it
- // was worth it, however I'm not sure we'd be able to do much better with the
- // current Allocator API.
- ::VmaAllocator vma_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
diff --git a/iree/hal/vulkan/vma_buffer.cc b/iree/hal/vulkan/vma_buffer.cc
deleted file mode 100644
index f7d3ed0..0000000
--- a/iree/hal/vulkan/vma_buffer.cc
+++ /dev/null
@@ -1,163 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/vma_buffer.h"
-
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/vulkan/status_util.h"
-#include "iree/hal/vulkan/vma_allocator.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-VmaBuffer::VmaBuffer(VmaAllocator* allocator, MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access,
- BufferUsageBitfield usage, device_size_t allocation_size,
- device_size_t byte_offset, device_size_t byte_length,
- VkBuffer buffer, VmaAllocation allocation,
- VmaAllocationInfo allocation_info)
- : Buffer(allocator, memory_type, allowed_access, usage, allocation_size,
- byte_offset, byte_length),
- vma_(allocator->vma()),
- buffer_(buffer),
- allocation_(allocation),
- allocation_info_(allocation_info) {
- // TODO(benvanik): set debug name instead and use the
- // VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT flag.
- vmaSetAllocationUserData(vma_, allocation_, this);
-}
-
-VmaBuffer::~VmaBuffer() {
- IREE_TRACE_SCOPE0("VmaBuffer::dtor");
- vmaDestroyBuffer(vma_, buffer_, allocation_);
-}
-
-Status VmaBuffer::FillImpl(device_size_t byte_offset, device_size_t byte_length,
- const void* pattern, device_size_t pattern_length) {
- ASSIGN_OR_RETURN(auto mapping, MapMemory<uint8_t>(MemoryAccess::kDiscardWrite,
- byte_offset, byte_length));
- void* data_ptr = static_cast<void*>(mapping.mutable_data());
- switch (pattern_length) {
- case 1: {
- uint8_t* data = static_cast<uint8_t*>(data_ptr);
- uint8_t value_bits = *static_cast<const uint8_t*>(pattern);
- std::fill_n(data + byte_offset, byte_length, value_bits);
- break;
- }
- case 2: {
- uint16_t* data = static_cast<uint16_t*>(data_ptr);
- uint16_t value_bits = *static_cast<const uint16_t*>(pattern);
- std::fill_n(data + byte_offset / sizeof(uint16_t),
- byte_length / sizeof(uint16_t), value_bits);
- break;
- }
- case 4: {
- uint32_t* data = static_cast<uint32_t*>(data_ptr);
- uint32_t value_bits = *static_cast<const uint32_t*>(pattern);
- std::fill_n(data + byte_offset / sizeof(uint32_t),
- byte_length / sizeof(uint32_t), value_bits);
- break;
- }
- default:
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Unsupported scalar data size: " << pattern_length;
- }
- return OkStatus();
-}
-
-Status VmaBuffer::ReadDataImpl(device_size_t source_offset, void* data,
- device_size_t data_length) {
- ASSIGN_OR_RETURN(
- auto mapping,
- MapMemory<uint8_t>(MemoryAccess::kRead, source_offset, data_length));
- std::memcpy(data, mapping.data(), mapping.byte_length());
- return OkStatus();
-}
-
-Status VmaBuffer::WriteDataImpl(device_size_t target_offset, const void* data,
- device_size_t data_length) {
- ASSIGN_OR_RETURN(auto mapping,
- MapMemory<uint8_t>(MemoryAccess::kDiscardWrite,
- target_offset, data_length));
- std::memcpy(mapping.mutable_data(), data, mapping.byte_length());
- return OkStatus();
-}
-
-Status VmaBuffer::CopyDataImpl(device_size_t target_offset,
- Buffer* source_buffer,
- device_size_t source_offset,
- device_size_t data_length) {
- // This is pretty terrible. Let's not do this.
- // TODO(benvanik): a way for allocators to indicate transfer compat.
- ASSIGN_OR_RETURN(auto source_mapping,
- source_buffer->MapMemory<uint8_t>(
- MemoryAccess::kRead, source_offset, data_length));
- CHECK_EQ(data_length, source_mapping.size());
- ASSIGN_OR_RETURN(auto target_mapping,
- MapMemory<uint8_t>(MemoryAccess::kDiscardWrite,
- target_offset, data_length));
- CHECK_EQ(data_length, target_mapping.size());
- std::memcpy(target_mapping.mutable_data() + target_offset,
- source_mapping.data(), data_length);
- return OkStatus();
-}
-
-Status VmaBuffer::MapMemoryImpl(MappingMode mapping_mode,
- MemoryAccessBitfield memory_access,
- device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void** out_data) {
- uint8_t* data_ptr = nullptr;
- VK_RETURN_IF_ERROR(
- vmaMapMemory(vma_, allocation_, reinterpret_cast<void**>(&data_ptr)));
- *out_data = data_ptr + local_byte_offset;
-
- // If we mapped for discard scribble over the bytes. This is not a mandated
- // behavior but it will make debugging issues easier. Alternatively for
- // heap buffers we could reallocate them such that ASAN yells, but that
- // would only work if the entire buffer was discarded.
-#ifndef NDEBUG
- if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) {
- std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length);
- }
-#endif // !NDEBUG
-
- return OkStatus();
-}
-
-Status VmaBuffer::UnmapMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length, void* data) {
- vmaUnmapMemory(vma_, allocation_);
- return OkStatus();
-}
-
-Status VmaBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) {
- vmaInvalidateAllocation(vma_, allocation_, local_byte_offset,
- local_byte_length);
- return OkStatus();
-}
-
-Status VmaBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) {
- vmaFlushAllocation(vma_, allocation_, local_byte_offset, local_byte_length);
- return OkStatus();
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/vma_buffer.h b/iree/hal/vulkan/vma_buffer.h
deleted file mode 100644
index b768f71..0000000
--- a/iree/hal/vulkan/vma_buffer.h
+++ /dev/null
@@ -1,79 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_VMA_BUFFER_H_
-#define IREE_HAL_VULKAN_VMA_BUFFER_H_
-
-#include <vulkan/vulkan.h>
-
-#include "iree/hal/buffer.h"
-#include "vk_mem_alloc.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-class VmaAllocator;
-
-// A buffer implementation representing an allocation made from within a pool of
-// a Vulkan Memory Allocator instance. See VmaAllocator for more information.
-class VmaBuffer final : public Buffer {
- public:
- VmaBuffer(VmaAllocator* allocator, MemoryTypeBitfield memory_type,
- MemoryAccessBitfield allowed_access, BufferUsageBitfield usage,
- device_size_t allocation_size, device_size_t byte_offset,
- device_size_t byte_length, VkBuffer buffer,
- VmaAllocation allocation, VmaAllocationInfo allocation_info);
- ~VmaBuffer() override;
-
- VkBuffer handle() const { return buffer_; }
- VmaAllocation allocation() const { return allocation_; }
- const VmaAllocationInfo& allocation_info() const { return allocation_info_; }
-
- // Exposed so that VmaAllocator can reset access after initial mapping.
- using Buffer::set_allowed_access;
-
- private:
- Status FillImpl(device_size_t byte_offset, device_size_t byte_length,
- const void* pattern, device_size_t pattern_length) override;
- Status ReadDataImpl(device_size_t source_offset, void* data,
- device_size_t data_length) override;
- Status WriteDataImpl(device_size_t target_offset, const void* data,
- device_size_t data_length) override;
- Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer,
- device_size_t source_offset,
- device_size_t data_length) override;
- Status MapMemoryImpl(MappingMode mapping_mode,
- MemoryAccessBitfield memory_access,
- device_size_t local_byte_offset,
- device_size_t local_byte_length,
- void** out_data) override;
- Status UnmapMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length, void* data) override;
- Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) override;
- Status FlushMappedMemoryImpl(device_size_t local_byte_offset,
- device_size_t local_byte_length) override;
-
- ::VmaAllocator vma_;
- VkBuffer buffer_;
- VmaAllocation allocation_;
- VmaAllocationInfo allocation_info_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_VMA_BUFFER_H_
diff --git a/iree/hal/vulkan/vulkan_device.cc b/iree/hal/vulkan/vulkan_device.cc
deleted file mode 100644
index 56f3483..0000000
--- a/iree/hal/vulkan/vulkan_device.cc
+++ /dev/null
@@ -1,500 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/vulkan_device.h"
-
-#include <functional>
-#include <utility>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/command_buffer_validation.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/fence.h"
-#include "iree/hal/vulkan/direct_command_buffer.h"
-#include "iree/hal/vulkan/direct_command_queue.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-#include "iree/hal/vulkan/extensibility_util.h"
-#include "iree/hal/vulkan/legacy_fence.h"
-#include "iree/hal/vulkan/native_binary_semaphore.h"
-#include "iree/hal/vulkan/native_event.h"
-#include "iree/hal/vulkan/pipeline_cache.h"
-#include "iree/hal/vulkan/status_util.h"
-#include "iree/hal/vulkan/vma_allocator.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-namespace {
-
-constexpr uint32_t kInvalidQueueFamilyIndex = -1;
-
-struct QueueFamilyInfo {
- uint32_t dispatch_index = kInvalidQueueFamilyIndex;
- uint32_t dispatch_queue_count = 0;
- uint32_t transfer_index = kInvalidQueueFamilyIndex;
- uint32_t transfer_queue_count = 0;
-};
-
-// Finds the first queue in the listing (which is usually the driver-preferred)
-// that has all of the |required_queue_flags| and none of the
-// |excluded_queue_flags|.
-// Returns kInvalidQueueFamilyIndex if no matching queue is found.
-uint32_t FindFirstQueueFamilyWithFlags(
- absl::Span<const VkQueueFamilyProperties> queue_family_properties,
- uint32_t required_queue_flags, uint32_t excluded_queue_flags) {
- for (int queue_family_index = 0;
- queue_family_index < queue_family_properties.size();
- ++queue_family_index) {
- const auto& properties = queue_family_properties[queue_family_index];
- if ((properties.queueFlags & required_queue_flags) ==
- required_queue_flags &&
- (properties.queueFlags & excluded_queue_flags) == 0) {
- return queue_family_index;
- }
- }
- return kInvalidQueueFamilyIndex;
-}
-
-// Selects queue family indices for compute and transfer queues.
-// Note that both queue families may be the same if there is only one family
-// available.
-StatusOr<QueueFamilyInfo> SelectQueueFamilies(
- VkPhysicalDevice physical_device, const ref_ptr<DynamicSymbols>& syms) {
- // Enumerate queue families available on the device.
- uint32_t queue_family_count = 0;
- syms->vkGetPhysicalDeviceQueueFamilyProperties(physical_device,
- &queue_family_count, nullptr);
- absl::InlinedVector<VkQueueFamilyProperties, 4> queue_family_properties(
- queue_family_count);
- syms->vkGetPhysicalDeviceQueueFamilyProperties(
- physical_device, &queue_family_count, queue_family_properties.data());
-
- QueueFamilyInfo queue_family_info;
-
- // Try to find a dedicated compute queue (no graphics caps).
- // Some may support both transfer and compute. If that fails then fallback to
- // any queue that supports compute.
- queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags(
- queue_family_properties, VK_QUEUE_COMPUTE_BIT, VK_QUEUE_GRAPHICS_BIT);
- if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) {
- queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags(
- queue_family_properties, VK_QUEUE_COMPUTE_BIT, 0);
- }
- if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) {
- return NotFoundErrorBuilder(IREE_LOC)
- << "Unable to find any queue family support compute operations";
- }
- queue_family_info.dispatch_queue_count =
- queue_family_properties[queue_family_info.dispatch_index].queueCount;
-
- // Try to find a dedicated transfer queue (no compute or graphics caps).
- // Not all devices have one, and some have only a queue family for everything
- // and possibly a queue family just for compute/etc. If that fails then
- // fallback to any queue that supports transfer. Finally, if /that/ fails then
- // we just won't create a transfer queue and instead use the compute queue for
- // all operations.
- queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags(
- queue_family_properties, VK_QUEUE_TRANSFER_BIT,
- VK_QUEUE_COMPUTE_BIT | VK_QUEUE_GRAPHICS_BIT);
- if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) {
- queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags(
- queue_family_properties, VK_QUEUE_TRANSFER_BIT, VK_QUEUE_GRAPHICS_BIT);
- }
- if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) {
- queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags(
- queue_family_properties, VK_QUEUE_TRANSFER_BIT, 0);
- }
- if (queue_family_info.transfer_index != kInvalidQueueFamilyIndex) {
- queue_family_info.transfer_queue_count =
- queue_family_properties[queue_family_info.transfer_index].queueCount;
- }
-
- // Ensure that we don't share the dispatch queues with transfer queues if that
- // would put us over the queue count.
- if (queue_family_info.dispatch_index == queue_family_info.transfer_index) {
- queue_family_info.transfer_queue_count = std::min(
- queue_family_properties[queue_family_info.dispatch_index].queueCount -
- queue_family_info.dispatch_queue_count,
- queue_family_info.transfer_queue_count);
- }
-
- return queue_family_info;
-}
-
-// Creates a transient command pool for the given queue family.
-// Command buffers allocated from the pool must only be issued on queues
-// belonging to the specified family.
-StatusOr<ref_ptr<VkCommandPoolHandle>> CreateTransientCommandPool(
- const ref_ptr<VkDeviceHandle>& logical_device,
- uint32_t queue_family_index) {
- VkCommandPoolCreateInfo create_info;
- create_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
- create_info.pNext = nullptr;
- create_info.flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT;
- create_info.queueFamilyIndex = queue_family_index;
-
- auto command_pool = make_ref<VkCommandPoolHandle>(logical_device);
- VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateCommandPool(
- *logical_device, &create_info, logical_device->allocator(),
- command_pool->mutable_value()));
- return command_pool;
-}
-
-} // namespace
-
-// static
-StatusOr<std::shared_ptr<VulkanDevice>> VulkanDevice::Create(
- const DeviceInfo& device_info, VkPhysicalDevice physical_device,
- const ExtensibilitySpec& extensibility_spec,
- const ref_ptr<DynamicSymbols>& syms) {
- IREE_TRACE_SCOPE0("VulkanDevice::Create");
-
- // Find the layers and extensions we need (or want) that are also available
- // on the device. This will fail when required ones are not present.
- ASSIGN_OR_RETURN(
- auto enabled_layer_names,
- MatchAvailableDeviceLayers(physical_device, extensibility_spec, *syms));
- ASSIGN_OR_RETURN(auto enabled_extension_names,
- MatchAvailableDeviceExtensions(physical_device,
- extensibility_spec, *syms));
- auto enabled_device_extensions =
- PopulateEnabledDeviceExtensions(enabled_extension_names);
-
- // Find queue families we will expose as HAL queues.
- ASSIGN_OR_RETURN(auto queue_family_info,
- SelectQueueFamilies(physical_device, syms));
-
- // Limit the number of queues we create (for now).
- // We may want to allow this to grow, but each queue adds overhead and we need
- // to measure to make sure we can effectively use them all.
- queue_family_info.dispatch_queue_count =
- std::min(2u, queue_family_info.dispatch_queue_count);
- queue_family_info.transfer_queue_count =
- std::min(1u, queue_family_info.transfer_queue_count);
- bool has_dedicated_transfer_queues =
- queue_family_info.transfer_queue_count > 0;
-
- // Setup the queue info we'll be using.
- // Each queue here (created from within a family) will map to a HAL queue.
- //
- // Note that we need to handle the case where we have transfer queues that are
- // of the same queue family as the dispatch queues: Vulkan requires that all
- // queues created from the same family are done in the same
- // VkDeviceQueueCreateInfo struct.
- DVLOG(1) << "Creating " << queue_family_info.dispatch_queue_count
- << " dispatch queue(s) in queue family "
- << queue_family_info.dispatch_index;
- absl::InlinedVector<VkDeviceQueueCreateInfo, 2> queue_create_info;
- absl::InlinedVector<float, 4> dispatch_queue_priorities;
- absl::InlinedVector<float, 4> transfer_queue_priorities;
- queue_create_info.push_back({});
- auto& dispatch_queue_info = queue_create_info.back();
- dispatch_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
- dispatch_queue_info.pNext = nullptr;
- dispatch_queue_info.flags = 0;
- dispatch_queue_info.queueFamilyIndex = queue_family_info.dispatch_index;
- dispatch_queue_info.queueCount = queue_family_info.dispatch_queue_count;
- if (has_dedicated_transfer_queues) {
- if (queue_family_info.dispatch_index == queue_family_info.transfer_index) {
- DVLOG(1) << "Creating " << queue_family_info.transfer_queue_count
- << " dedicated transfer queue(s) in shared queue family "
- << queue_family_info.transfer_index;
- dispatch_queue_info.queueCount += queue_family_info.transfer_queue_count;
- } else {
- DVLOG(1) << "Creating " << queue_family_info.transfer_queue_count
- << " dedicated transfer queue(s) in independent queue family "
- << queue_family_info.transfer_index;
- queue_create_info.push_back({});
- auto& transfer_queue_info = queue_create_info.back();
- transfer_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
- transfer_queue_info.pNext = nullptr;
- transfer_queue_info.queueFamilyIndex = queue_family_info.transfer_index;
- transfer_queue_info.queueCount = queue_family_info.transfer_queue_count;
- transfer_queue_info.flags = 0;
- transfer_queue_priorities.resize(transfer_queue_info.queueCount);
- transfer_queue_info.pQueuePriorities = transfer_queue_priorities.data();
- }
- }
- dispatch_queue_priorities.resize(dispatch_queue_info.queueCount);
- dispatch_queue_info.pQueuePriorities = dispatch_queue_priorities.data();
-
- // TODO(benvanik): specify features with VkPhysicalDeviceFeatures.
-
- // Create device and its queues.
- VkDeviceCreateInfo device_create_info = {};
- device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
- device_create_info.pNext = nullptr;
- device_create_info.enabledLayerCount = enabled_layer_names.size();
- device_create_info.ppEnabledLayerNames = enabled_layer_names.data();
- device_create_info.enabledExtensionCount = enabled_extension_names.size();
- device_create_info.ppEnabledExtensionNames = enabled_extension_names.data();
- device_create_info.queueCreateInfoCount = queue_create_info.size();
- device_create_info.pQueueCreateInfos = queue_create_info.data();
- device_create_info.pEnabledFeatures = nullptr;
- auto logical_device = make_ref<VkDeviceHandle>(
- syms, enabled_device_extensions, /*allocator=*/nullptr);
- VK_RETURN_IF_ERROR(syms->vkCreateDevice(physical_device, &device_create_info,
- logical_device->allocator(),
- logical_device->mutable_value()));
-
- // Create the device memory allocator.
- // TODO(benvanik): allow other types to be plugged in.
- ASSIGN_OR_RETURN(auto allocator,
- VmaAllocator::Create(physical_device, logical_device));
-
- // Create command pools for each queue family. If we don't have a transfer
- // queue then we'll ignore that one and just use the dispatch pool.
- // If we wanted to expose the pools through the HAL to allow the VM to more
- // effectively manage them (pool per fiber, etc) we could, however I doubt the
- // overhead of locking the pool will be even a blip.
- ASSIGN_OR_RETURN(auto dispatch_command_pool,
- CreateTransientCommandPool(
- logical_device, queue_family_info.dispatch_index));
- ref_ptr<VkCommandPoolHandle> transfer_command_pool;
- if (has_dedicated_transfer_queues) {
- ASSIGN_OR_RETURN(transfer_command_pool,
- CreateTransientCommandPool(
- logical_device, queue_family_info.transfer_index));
- }
-
- // Get the queues and create the HAL wrappers.
- absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues;
- for (uint32_t i = 0; i < queue_family_info.dispatch_queue_count; ++i) {
- VkQueue queue = VK_NULL_HANDLE;
- syms->vkGetDeviceQueue(*logical_device, queue_family_info.dispatch_index, i,
- &queue);
- std::string queue_name = absl::StrCat(device_info.name(), ":d", i);
- command_queues.push_back(absl::make_unique<DirectCommandQueue>(
- std::move(queue_name),
- CommandCategory::kDispatch | CommandCategory::kTransfer, logical_device,
- queue));
- }
- if (has_dedicated_transfer_queues) {
- uint32_t base_queue_index = 0;
- if (queue_family_info.dispatch_index == queue_family_info.transfer_index) {
- // Sharing a family, so transfer queues follow compute queues.
- base_queue_index = queue_family_info.dispatch_index;
- }
- for (uint32_t i = 0; i < queue_family_info.transfer_queue_count; ++i) {
- VkQueue queue = VK_NULL_HANDLE;
- syms->vkGetDeviceQueue(*logical_device, queue_family_info.transfer_index,
- base_queue_index + i, &queue);
- std::string queue_name = absl::StrCat(device_info.name(), ":t", i);
- command_queues.push_back(absl::make_unique<DirectCommandQueue>(
- std::move(queue_name), CommandCategory::kTransfer, logical_device,
- queue));
- }
- }
-
- // TODO(b/140141417): implement timeline semaphore fences and switch here.
- ASSIGN_OR_RETURN(auto legacy_fence_pool,
- LegacyFencePool::Create(add_ref(logical_device)));
-
- return std::make_shared<VulkanDevice>(
- CtorKey{}, device_info, physical_device, std::move(logical_device),
- std::move(allocator), std::move(command_queues),
- std::move(dispatch_command_pool), std::move(transfer_command_pool),
- std::move(legacy_fence_pool));
-}
-
-VulkanDevice::VulkanDevice(
- CtorKey ctor_key, const DeviceInfo& device_info,
- VkPhysicalDevice physical_device, ref_ptr<VkDeviceHandle> logical_device,
- std::unique_ptr<Allocator> allocator,
- absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues,
- ref_ptr<VkCommandPoolHandle> dispatch_command_pool,
- ref_ptr<VkCommandPoolHandle> transfer_command_pool,
- ref_ptr<LegacyFencePool> legacy_fence_pool)
- : Device(device_info),
- physical_device_(physical_device),
- logical_device_(std::move(logical_device)),
- allocator_(std::move(allocator)),
- command_queues_(std::move(command_queues)),
- descriptor_pool_cache_(
- make_ref<DescriptorPoolCache>(add_ref(logical_device_))),
- dispatch_command_pool_(std::move(dispatch_command_pool)),
- transfer_command_pool_(std::move(transfer_command_pool)),
- legacy_fence_pool_(std::move(legacy_fence_pool)) {
- // Populate the queue lists based on queue capabilities.
- for (auto& command_queue : command_queues_) {
- if (command_queue->can_dispatch()) {
- dispatch_queues_.push_back(command_queue.get());
- if (transfer_command_pool_ == VK_NULL_HANDLE) {
- transfer_queues_.push_back(command_queue.get());
- }
- } else {
- transfer_queues_.push_back(command_queue.get());
- }
- }
-}
-
-VulkanDevice::~VulkanDevice() {
- IREE_TRACE_SCOPE0("VulkanDevice::dtor");
-
- // Drop all command queues. These may wait until idle.
- command_queues_.clear();
- dispatch_queues_.clear();
- transfer_queues_.clear();
-
- // Drop command pools now that we know there are no more outstanding command
- // buffers.
- dispatch_command_pool_.reset();
- transfer_command_pool_.reset();
-
- // Now that no commands are outstanding we can release all descriptor sets.
- descriptor_pool_cache_.reset();
-
- // Finally, destroy the device.
- logical_device_.reset();
-}
-
-std::shared_ptr<ExecutableCache> VulkanDevice::CreateExecutableCache() {
- IREE_TRACE_SCOPE0("VulkanDevice::CreateExecutableCache");
- return std::make_shared<PipelineCache>(logical_device_);
-}
-
-StatusOr<ref_ptr<CommandBuffer>> VulkanDevice::CreateCommandBuffer(
- CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories) {
- IREE_TRACE_SCOPE0("VulkanDevice::CreateCommandBuffer");
-
- // Select the command pool to used based on the types of commands used.
- // Note that we may not have a dedicated transfer command pool if there are no
- // dedicated transfer queues.
- ref_ptr<VkCommandPoolHandle> command_pool;
- if (transfer_command_pool_ &&
- !AllBitsSet(command_categories, CommandCategory::kDispatch)) {
- command_pool = add_ref(transfer_command_pool_);
- } else {
- command_pool = add_ref(dispatch_command_pool_);
- }
-
- VkCommandBufferAllocateInfo allocate_info;
- allocate_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
- allocate_info.pNext = nullptr;
- allocate_info.commandPool = *command_pool;
- allocate_info.commandBufferCount = 1;
- allocate_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
-
- VkCommandBuffer command_buffer = VK_NULL_HANDLE;
- {
- absl::MutexLock lock(command_pool->mutex());
- VK_RETURN_IF_ERROR(syms()->vkAllocateCommandBuffers(
- *logical_device_, &allocate_info, &command_buffer));
- }
-
- // TODO(b/140026716): conditionally enable validation.
- auto impl = make_ref<DirectCommandBuffer>(
- allocator(), mode, command_categories, add_ref(descriptor_pool_cache_),
- add_ref(command_pool), command_buffer);
- return WrapCommandBufferWithValidation(std::move(impl));
-}
-
-StatusOr<ref_ptr<Event>> VulkanDevice::CreateEvent() {
- IREE_TRACE_SCOPE0("VulkanDevice::CreateEvent");
-
- // TODO(b/138729892): pool events.
- VkEventCreateInfo create_info;
- create_info.sType = VK_STRUCTURE_TYPE_EVENT_CREATE_INFO;
- create_info.pNext = nullptr;
- create_info.flags = 0;
- VkEvent event_handle = VK_NULL_HANDLE;
- VK_RETURN_IF_ERROR(syms()->vkCreateEvent(*logical_device_, &create_info,
- logical_device_->allocator(),
- &event_handle));
-
- return make_ref<NativeEvent>(add_ref(logical_device_), event_handle);
-}
-
-StatusOr<ref_ptr<BinarySemaphore>> VulkanDevice::CreateBinarySemaphore(
- bool initial_value) {
- IREE_TRACE_SCOPE0("VulkanDevice::CreateBinarySemaphore");
-
- VkSemaphoreCreateInfo create_info;
- create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO;
- create_info.pNext = nullptr;
- create_info.flags = initial_value ? VK_FENCE_CREATE_SIGNALED_BIT : 0;
- VkSemaphore semaphore_handle = VK_NULL_HANDLE;
- VK_RETURN_IF_ERROR(syms()->vkCreateSemaphore(*logical_device_, &create_info,
- logical_device_->allocator(),
- &semaphore_handle));
-
- return make_ref<NativeBinarySemaphore>(add_ref(logical_device_),
- semaphore_handle);
-}
-
-StatusOr<ref_ptr<TimelineSemaphore>> VulkanDevice::CreateTimelineSemaphore(
- uint64_t initial_value) {
- IREE_TRACE_SCOPE0("VulkanDevice::CreateTimelineSemaphore");
-
- // TODO(b/140141417): implement timeline semaphores.
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Timeline semaphores not yet implemented";
-}
-
-StatusOr<ref_ptr<Fence>> VulkanDevice::CreateFence(uint64_t initial_value) {
- IREE_TRACE_SCOPE0("VulkanDevice::CreateFence");
-
- // TODO(b/140141417): implement timeline semaphore fences and switch here.
- // NOTE: we'll want some magic factory so that we can cleanly compile out the
- // legacy implementation and pool.
-
- return make_ref<LegacyFence>(add_ref(legacy_fence_pool_), initial_value);
-}
-
-Status VulkanDevice::WaitAllFences(absl::Span<const FenceValue> fences,
- absl::Time deadline) {
- IREE_TRACE_SCOPE0("VulkanDevice::WaitAllFences");
-
- // TODO(b/140141417): implement timeline semaphore fences and switch here.
-
- return LegacyFence::WaitForFences(logical_device_.get(), fences,
- /*wait_all=*/true, deadline);
-}
-
-StatusOr<int> VulkanDevice::WaitAnyFence(absl::Span<const FenceValue> fences,
- absl::Time deadline) {
- IREE_TRACE_SCOPE0("VulkanDevice::WaitAnyFence");
-
- // TODO(b/140141417): implement timeline semaphore fences and switch here.
-
- return LegacyFence::WaitForFences(logical_device_.get(), fences,
- /*wait_all=*/false, deadline);
-}
-
-Status VulkanDevice::WaitIdle(absl::Time deadline) {
- if (deadline == absl::InfiniteFuture()) {
- // Fast path for using vkDeviceWaitIdle, which is usually cheaper (as it
- // requires fewer calls into the driver).
- IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#vkDeviceWaitIdle");
- VK_RETURN_IF_ERROR(syms()->vkDeviceWaitIdle(*logical_device_));
- return OkStatus();
- }
-
- IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#Fences");
- for (auto& command_queue : command_queues_) {
- RETURN_IF_ERROR(command_queue->WaitIdle(deadline));
- }
- return OkStatus();
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/vulkan_device.h b/iree/hal/vulkan/vulkan_device.h
deleted file mode 100644
index f4557c1..0000000
--- a/iree/hal/vulkan/vulkan_device.h
+++ /dev/null
@@ -1,120 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_VULKAN_DEVICE_H_
-#define IREE_HAL_VULKAN_VULKAN_DEVICE_H_
-
-#include <vulkan/vulkan.h>
-
-#include <functional>
-#include <memory>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/types/span.h"
-#include "iree/base/memory.h"
-#include "iree/hal/allocator.h"
-#include "iree/hal/device.h"
-#include "iree/hal/vulkan/descriptor_pool_cache.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-#include "iree/hal/vulkan/extensibility_util.h"
-#include "iree/hal/vulkan/handle_util.h"
-#include "iree/hal/vulkan/legacy_fence.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-class VulkanDevice final : public Device {
- public:
- static StatusOr<std::shared_ptr<VulkanDevice>> Create(
- const DeviceInfo& device_info, VkPhysicalDevice physical_device,
- const ExtensibilitySpec& extensibility_spec,
- const ref_ptr<DynamicSymbols>& syms);
-
- // Private constructor.
- struct CtorKey {
- private:
- friend class VulkanDevice;
- CtorKey() = default;
- };
- VulkanDevice(
- CtorKey ctor_key, const DeviceInfo& device_info,
- VkPhysicalDevice physical_device, ref_ptr<VkDeviceHandle> logical_device,
- std::unique_ptr<Allocator> allocator,
- absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues,
- ref_ptr<VkCommandPoolHandle> dispatch_command_pool,
- ref_ptr<VkCommandPoolHandle> transfer_command_pool,
- ref_ptr<LegacyFencePool> legacy_fence_pool);
- ~VulkanDevice() override;
-
- const ref_ptr<DynamicSymbols>& syms() const {
- return logical_device_->syms();
- }
-
- Allocator* allocator() const override { return allocator_.get(); }
-
- absl::Span<CommandQueue*> dispatch_queues() const override {
- return absl::MakeSpan(dispatch_queues_);
- }
-
- absl::Span<CommandQueue*> transfer_queues() const override {
- return absl::MakeSpan(transfer_queues_);
- }
-
- std::shared_ptr<ExecutableCache> CreateExecutableCache() override;
-
- StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer(
- CommandBufferModeBitfield mode,
- CommandCategoryBitfield command_categories) override;
-
- StatusOr<ref_ptr<Event>> CreateEvent() override;
-
- StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore(
- bool initial_value) override;
- StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore(
- uint64_t initial_value) override;
-
- StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override;
- Status WaitAllFences(absl::Span<const FenceValue> fences,
- absl::Time deadline) override;
- StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences,
- absl::Time deadline) override;
-
- Status WaitIdle(absl::Time deadline) override;
-
- private:
- VkPhysicalDevice physical_device_;
- ref_ptr<VkDeviceHandle> logical_device_;
-
- std::unique_ptr<Allocator> allocator_;
-
- mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues_;
- mutable absl::InlinedVector<CommandQueue*, 4> dispatch_queues_;
- mutable absl::InlinedVector<CommandQueue*, 4> transfer_queues_;
-
- ref_ptr<DescriptorPoolCache> descriptor_pool_cache_;
-
- ref_ptr<VkCommandPoolHandle> dispatch_command_pool_;
- ref_ptr<VkCommandPoolHandle> transfer_command_pool_;
-
- // TODO(b/140141417): implement timeline semaphore fences and conditionally
- // compile the legacy fence pool out.
- ref_ptr<LegacyFencePool> legacy_fence_pool_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_VULKAN_DEVICE_H_
diff --git a/iree/hal/vulkan/vulkan_driver.cc b/iree/hal/vulkan/vulkan_driver.cc
deleted file mode 100644
index a5238d0..0000000
--- a/iree/hal/vulkan/vulkan_driver.cc
+++ /dev/null
@@ -1,229 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/hal/vulkan/vulkan_driver.h"
-
-#include <memory>
-
-#include "absl/container/inlined_vector.h"
-#include "iree/base/memory.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/device_info.h"
-#include "iree/hal/vulkan/extensibility_util.h"
-#include "iree/hal/vulkan/status_util.h"
-#include "iree/hal/vulkan/vulkan_device.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-namespace {
-
-// Returns a VkApplicationInfo struct populated with the default app info.
-// We may allow hosting applications to override this via weak-linkage if it's
-// useful, otherwise this is enough to create the application.
-VkApplicationInfo GetDefaultApplicationInfo() {
- VkApplicationInfo info;
- info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
- info.pNext = nullptr;
- info.pApplicationName = "IREE-ML";
- info.applicationVersion = 0;
- info.pEngineName = "IREE";
- info.engineVersion = 0;
- info.apiVersion = VK_API_VERSION_1_0;
- return info;
-}
-
-// Populates device information from the given Vulkan physical device handle.
-StatusOr<DeviceInfo> PopulateDeviceInfo(VkPhysicalDevice physical_device,
- const ref_ptr<DynamicSymbols>& syms) {
- VkPhysicalDeviceFeatures physical_device_features;
- syms->vkGetPhysicalDeviceFeatures(physical_device, &physical_device_features);
- // TODO(benvanik): check and optionally require these features:
- // - physical_device_features.robustBufferAccess
- // - physical_device_features.shaderInt16
- // - physical_device_features.shaderInt64
- // - physical_device_features.shaderFloat64
-
- VkPhysicalDeviceProperties physical_device_properties;
- syms->vkGetPhysicalDeviceProperties(physical_device,
- &physical_device_properties);
- // TODO(benvanik): check and optionally require reasonable limits.
-
- // TODO(benvanik): more clever/sanitized device naming.
- std::string name = std::string(physical_device_properties.deviceName);
-
- DeviceFeatureBitfield supported_features = DeviceFeature::kNone;
- // TODO(benvanik): implement debugging/profiling features.
- // TODO(benvanik): use props to determine if we have timing info.
- // supported_features |= DeviceFeature::kDebugging;
- // supported_features |= DeviceFeature::kCoverage;
- // supported_features |= DeviceFeature::kProfiling;
- return DeviceInfo(std::move(name), supported_features, physical_device);
-}
-
-} // namespace
-
-// static
-StatusOr<std::shared_ptr<VulkanDriver>> VulkanDriver::Create(
- Options options, ref_ptr<DynamicSymbols> syms) {
- IREE_TRACE_SCOPE0("VulkanDriver::Create");
-
- // Find the layers and extensions we need (or want) that are also available
- // on the instance. This will fail when required ones are not present.
- ASSIGN_OR_RETURN(
- auto enabled_layer_names,
- MatchAvailableInstanceLayers(options.instance_extensibility, *syms));
- ASSIGN_OR_RETURN(
- auto enabled_extension_names,
- MatchAvailableInstanceExtensions(options.instance_extensibility, *syms));
- auto instance_extensions =
- PopulateEnabledInstanceExtensions(enabled_extension_names);
-
- // Create the instance this driver will use for all requests.
- VkApplicationInfo app_info = GetDefaultApplicationInfo();
- app_info.apiVersion = options.api_version;
- VkInstanceCreateInfo create_info;
- create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
- create_info.pNext = nullptr;
- create_info.flags = 0;
- create_info.pApplicationInfo = &app_info;
- create_info.enabledLayerCount = enabled_layer_names.size();
- create_info.ppEnabledLayerNames = enabled_layer_names.data();
- create_info.enabledExtensionCount = enabled_extension_names.size();
- create_info.ppEnabledExtensionNames = enabled_extension_names.data();
-
- // If we have the debug_utils extension then we can chain a one-shot messenger
- // callback that we can use to log out the instance creation errors. Once we
- // have the real instance we can then register a real messenger.
- union {
- VkDebugUtilsMessengerCreateInfoEXT debug_utils_create_info;
- VkDebugReportCallbackCreateInfoEXT debug_report_create_info;
- };
- if (instance_extensions.debug_utils) {
- create_info.pNext = &debug_utils_create_info;
- DebugReporter::PopulateStaticCreateInfo(&debug_utils_create_info);
- } else if (instance_extensions.debug_report) {
- create_info.pNext = &debug_report_create_info;
- DebugReporter::PopulateStaticCreateInfo(&debug_report_create_info);
- }
-
- // Some ICDs appear to leak in here, out of our control.
- // Warning: leak checks remain disabled if an error is returned.
- IREE_DISABLE_LEAK_CHECKS();
- VkInstance instance = VK_NULL_HANDLE;
- VK_RETURN_IF_ERROR(
- syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance))
- << "Unable to create Vulkan instance";
- IREE_ENABLE_LEAK_CHECKS();
-
- // TODO(benvanik): enable validation layers if needed.
-
- // Now that the instance has been created we can fetch all of the instance
- // symbols.
- RETURN_IF_ERROR(syms->LoadFromInstance(instance));
-
- // The real debug messenger (not just the static one used above) can now be
- // created as we've loaded all the required symbols.
- // TODO(benvanik): strip in release builds.
- std::unique_ptr<DebugReporter> debug_reporter;
- if (instance_extensions.debug_utils) {
- ASSIGN_OR_RETURN(debug_reporter, DebugReporter::CreateDebugUtilsMessenger(
- instance, syms,
- /*allocation_callbacks=*/nullptr));
- } else if (instance_extensions.debug_report) {
- ASSIGN_OR_RETURN(debug_reporter,
- DebugReporter::CreateDebugReportCallback(
- instance, syms, /*allocation_callbacks=*/nullptr));
- }
-
- return std::make_shared<VulkanDriver>(
- CtorKey{}, std::move(syms), instance, std::move(debug_reporter),
- std::move(options.device_extensibility));
-}
-
-VulkanDriver::VulkanDriver(CtorKey ctor_key, ref_ptr<DynamicSymbols> syms,
- VkInstance instance,
- std::unique_ptr<DebugReporter> debug_reporter,
- ExtensibilitySpec device_extensibility_spec)
- : Driver("vulkan"),
- syms_(std::move(syms)),
- instance_(instance),
- debug_reporter_(std::move(debug_reporter)),
- device_extensibility_spec_(std::move(device_extensibility_spec)) {}
-
-VulkanDriver::~VulkanDriver() {
- IREE_TRACE_SCOPE0("VulkanDriver::dtor");
- debug_reporter_.reset();
- syms()->vkDestroyInstance(instance_, /*pAllocator=*/nullptr);
-}
-
-StatusOr<std::vector<DeviceInfo>> VulkanDriver::EnumerateAvailableDevices() {
- IREE_TRACE_SCOPE0("VulkanDriver::EnumerateAvailableDevices");
-
- // Query all available devices (at this moment, note that this may change!).
- uint32_t physical_device_count = 0;
- VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices(
- instance_, &physical_device_count, nullptr));
- absl::InlinedVector<VkPhysicalDevice, 2> physical_devices(
- physical_device_count);
- VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices(
- instance_, &physical_device_count, physical_devices.data()));
-
- // Convert to our HAL structure.
- std::vector<DeviceInfo> device_infos;
- device_infos.reserve(physical_device_count);
- for (auto physical_device : physical_devices) {
- // TODO(benvanik): if we fail should we just ignore the device in the list?
- ASSIGN_OR_RETURN(auto device_info,
- PopulateDeviceInfo(physical_device, syms()));
- device_infos.push_back(std::move(device_info));
- }
- return device_infos;
-}
-
-StatusOr<std::shared_ptr<Device>> VulkanDriver::CreateDefaultDevice() {
- IREE_TRACE_SCOPE0("VulkanDriver::CreateDefaultDevice");
-
- // Query available devices.
- ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices());
- if (available_devices.empty()) {
- return NotFoundErrorBuilder(IREE_LOC) << "No devices are available";
- }
-
- // Just create the first one we find.
- return CreateDevice(available_devices.front());
-}
-
-StatusOr<std::shared_ptr<Device>> VulkanDriver::CreateDevice(
- const DeviceInfo& device_info) {
- IREE_TRACE_SCOPE0("VulkanDriver::CreateDevice");
-
- auto physical_device =
- static_cast<VkPhysicalDevice>(device_info.driver_handle());
-
- // Attempt to create the device.
- // This may fail if the device was enumerated but is in exclusive use,
- // disabled by the system, or permission is denied.
- ASSIGN_OR_RETURN(auto device,
- VulkanDevice::Create(device_info, physical_device,
- device_extensibility_spec_, syms()));
-
- return device;
-}
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
diff --git a/iree/hal/vulkan/vulkan_driver.h b/iree/hal/vulkan/vulkan_driver.h
deleted file mode 100644
index d2383db..0000000
--- a/iree/hal/vulkan/vulkan_driver.h
+++ /dev/null
@@ -1,84 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_HAL_VULKAN_VULKAN_DRIVER_H_
-#define IREE_HAL_VULKAN_VULKAN_DRIVER_H_
-
-#include <vulkan/vulkan.h>
-
-#include <memory>
-#include <vector>
-
-#include "iree/hal/driver.h"
-#include "iree/hal/vulkan/debug_reporter.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-#include "iree/hal/vulkan/extensibility_util.h"
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-
-class VulkanDriver final : public Driver {
- public:
- struct Options {
- // Vulkan version that will be requested.
- // Driver creation will fail if the required version is not available.
- uint32_t api_version = VK_API_VERSION_1_0;
-
- // Extensibility descriptions for instances and devices.
- // Device descriptions will be used for all devices created by the driver.
- ExtensibilitySpec instance_extensibility;
- ExtensibilitySpec device_extensibility;
- };
-
- static StatusOr<std::shared_ptr<VulkanDriver>> Create(
- Options options, ref_ptr<DynamicSymbols> syms);
-
- // TODO(benvanik): method to wrap an existing instance/device (interop).
-
- // Private constructor.
- struct CtorKey {
- private:
- friend class VulkanDriver;
- CtorKey() = default;
- };
- VulkanDriver(CtorKey ctor_key, ref_ptr<DynamicSymbols> syms,
- VkInstance instance,
- std::unique_ptr<DebugReporter> debug_reporter,
- ExtensibilitySpec device_extensibility_spec);
- ~VulkanDriver() override;
-
- const ref_ptr<DynamicSymbols>& syms() const { return syms_; }
-
- VkInstance instance() const { return instance_; }
-
- StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override;
-
- StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override;
-
- StatusOr<std::shared_ptr<Device>> CreateDevice(
- const DeviceInfo& device_info) override;
-
- private:
- ref_ptr<DynamicSymbols> syms_;
- VkInstance instance_;
- std::unique_ptr<DebugReporter> debug_reporter_;
- ExtensibilitySpec device_extensibility_spec_;
-};
-
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-#endif // IREE_HAL_VULKAN_VULKAN_DRIVER_H_
diff --git a/iree/hal/vulkan/vulkan_driver_module.cc b/iree/hal/vulkan/vulkan_driver_module.cc
deleted file mode 100644
index 4d3d745..0000000
--- a/iree/hal/vulkan/vulkan_driver_module.cc
+++ /dev/null
@@ -1,102 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <memory>
-
-#include "absl/flags/flag.h"
-#include "iree/base/init.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/driver_registry.h"
-#include "iree/hal/vulkan/dynamic_symbols.h"
-#include "iree/hal/vulkan/vulkan_driver.h"
-
-ABSL_FLAG(bool, vulkan_validation_layers, true,
- "Enables standard Vulkan validation layers.");
-ABSL_FLAG(bool, vulkan_debug_utils, true,
- "Enables VK_EXT_debug_utils, records markers, and logs errors.");
-ABSL_FLAG(bool, vulkan_debug_report, false,
- "Enables VK_EXT_debug_report and logs errors.");
-ABSL_FLAG(bool, vulkan_push_descriptors, true,
- "Enables use of vkCmdPushDescriptorSetKHR, if available.");
-
-namespace iree {
-namespace hal {
-namespace vulkan {
-namespace {
-
-StatusOr<std::shared_ptr<Driver>> CreateVulkanDriver() {
- IREE_TRACE_SCOPE0("CreateVulkanDriver");
-
- // Load the Vulkan library. This will fail if the library cannot be found or
- // does not have the expected functions.
- ASSIGN_OR_RETURN(auto syms, DynamicSymbols::CreateFromSystemLoader());
-
- // Setup driver options from flags. We do this here as we want to enable other
- // consumers that may not be using modules/command line flags to be able to
- // set their options however they want.
- VulkanDriver::Options options;
-
- // TODO: validation layers have bugs when using VK_EXT_debug_report, so if the
- // user requested that we force them off with a warning. Prefer using
- // VK_EXT_debug_utils when available.
- if (absl::GetFlag(FLAGS_vulkan_debug_report) &&
- absl::GetFlag(FLAGS_vulkan_validation_layers)) {
- LOG(WARNING) << "VK_EXT_debug_report has issues with modern validation "
- "layers; disabling validation";
- absl::SetFlag(&FLAGS_vulkan_validation_layers, false);
- }
-
- // REQUIRED: these are required extensions that must be present for IREE to
- // work (such as those relied upon by SPIR-V kernels, etc).
- options.device_extensibility.required_extensions.push_back(
- VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME);
-
- if (absl::GetFlag(FLAGS_vulkan_validation_layers)) {
- options.instance_extensibility.optional_layers.push_back(
- "VK_LAYER_LUNARG_standard_validation");
- }
-
- if (absl::GetFlag(FLAGS_vulkan_debug_report)) {
- options.instance_extensibility.optional_extensions.push_back(
- VK_EXT_DEBUG_REPORT_EXTENSION_NAME);
- }
- if (absl::GetFlag(FLAGS_vulkan_debug_utils)) {
- options.instance_extensibility.optional_extensions.push_back(
- VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
- }
-
- if (absl::GetFlag(FLAGS_vulkan_push_descriptors)) {
- options.instance_extensibility.optional_extensions.push_back(
- VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME);
- options.device_extensibility.optional_extensions.push_back(
- VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME);
- }
-
- // Create the driver and VkInstance.
- ASSIGN_OR_RETURN(auto driver, VulkanDriver::Create(options, std::move(syms)));
-
- return driver;
-}
-
-} // namespace
-} // namespace vulkan
-} // namespace hal
-} // namespace iree
-
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_vulkan_driver, {
- QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "vulkan", ::iree::hal::vulkan::CreateVulkanDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_vulkan_driver);
diff --git a/iree/rt/BUILD b/iree/rt/BUILD
deleted file mode 100644
index 33aea24..0000000
--- a/iree/rt/BUILD
+++ /dev/null
@@ -1,83 +0,0 @@
-# Runtime API for interacting with IREE modules and invoking functions.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "api",
- srcs = ["api.cc"],
- hdrs = ["api.h"],
- visibility = ["//visibility:public"],
- deps = [
- ":api_hdrs",
- ":rt",
- "//iree/base:api",
- "//iree/base:api_util",
- "//iree/base:tracing",
- "//iree/hal:api",
- "//iree/hal:buffer_view",
- "//iree/hal:driver_registry",
- "//iree/rt/debug:debug_server_interface",
- "@com_google_absl//absl/time",
- ],
-)
-
-cc_library(
- name = "api_hdrs",
- hdrs = ["api.h"],
- deps = [
- "//iree/base:api_hdrs",
- "//iree/hal:api_hdrs",
- ],
-)
-
-cc_library(
- name = "rt",
- srcs = [
- "context.cc",
- "function.cc",
- "instance.cc",
- "invocation.cc",
- "module_printer.cc",
- "source_location.cc",
- "stack.cc",
- "stack_frame.cc",
- "stack_trace.cc",
- ],
- hdrs = [
- "context.h",
- "disassembler.h",
- "function.h",
- "function_signature.h",
- "instance.h",
- "invocation.h",
- "module.h",
- "module_printer.h",
- "module_signature.h",
- "policy.h",
- "source_location.h",
- "source_resolver.h",
- "stack.h",
- "stack_frame.h",
- "stack_trace.h",
- ],
- deps = [
- "//iree/base:bitfield",
- "//iree/base:intrusive_list",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:buffer_view",
- "//iree/hal:device_manager",
- "//iree/rt/debug:debug_server_interface",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_absl//absl/types:optional",
- "@com_google_absl//absl/types:span",
- ],
-)
diff --git a/iree/rt/api.cc b/iree/rt/api.cc
deleted file mode 100644
index 04cbca4..0000000
--- a/iree/rt/api.cc
+++ /dev/null
@@ -1,788 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/api.h"
-
-#include "absl/time/time.h"
-#include "iree/base/api.h"
-#include "iree/base/api_util.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/api.h"
-#include "iree/hal/api_detail.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/driver_registry.h"
-#include "iree/rt/context.h"
-#include "iree/rt/debug/debug_server.h"
-#include "iree/rt/function.h"
-#include "iree/rt/instance.h"
-#include "iree/rt/invocation.h"
-#include "iree/rt/module.h"
-#include "iree/rt/policy.h"
-
-namespace iree {
-namespace rt {
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Instance
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create(
- iree_allocator_t allocator, iree_rt_instance_t** out_instance) {
- IREE_TRACE_SCOPE0("iree_rt_instance_create");
-
- if (!out_instance) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_instance = nullptr;
-
- auto instance = make_ref<Instance>();
- *out_instance = reinterpret_cast<iree_rt_instance_t*>(instance.release());
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_instance_retain(iree_rt_instance_t* instance) {
- IREE_TRACE_SCOPE0("iree_rt_instance_retain");
- auto* handle = reinterpret_cast<Instance*>(instance);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->AddReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_instance_release(iree_rt_instance_t* instance) {
- IREE_TRACE_SCOPE0("iree_rt_instance_release");
- auto* handle = reinterpret_cast<Instance*>(instance);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->ReleaseReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_register_driver_ex(
- iree_rt_instance_t* instance, iree_string_view_t driver_name) {
- IREE_TRACE_SCOPE0("iree_rt_instance_register_driver_ex");
- auto* handle = reinterpret_cast<Instance*>(instance);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- IREE_API_ASSIGN_OR_RETURN(
- auto driver, hal::DriverRegistry::shared_registry()->Create(
- absl::string_view{driver_name.data, driver_name.size}));
- IREE_API_ASSIGN_OR_RETURN(auto available_devices,
- driver->EnumerateAvailableDevices());
- for (const auto& device_info : available_devices) {
- LOG(INFO) << " Device: " << device_info.name();
- }
- LOG(INFO) << "Creating default device...";
- IREE_API_ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
- IREE_API_RETURN_IF_ERROR(handle->device_manager()->RegisterDevice(device));
- LOG(INFO) << "Successfully created device '" << device->info().name() << "'";
-
- return IREE_STATUS_OK;
-}
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Module
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-class ExternalModule final : public Module {
- public:
- ExternalModule(iree_rt_external_module_t impl, iree_allocator_t allocator)
- : impl_(impl), allocator_(allocator) {
- IREE_TRACE_SCOPE0("ExternalModule::ctor");
- }
-
- ~ExternalModule() override {
- IREE_TRACE_SCOPE0("ExternalModule::dtor");
- impl_.destroy(impl_.self);
- std::memset(&impl_, 0, sizeof(impl_));
- }
-
- absl::string_view name() const override {
- auto result = impl_.name(impl_.self);
- return absl::string_view{result.data, result.size};
- }
-
- const ModuleSignature signature() const override {
- auto signature = impl_.signature(impl_.self);
- return ModuleSignature{
- signature.import_function_count,
- signature.export_function_count,
- signature.internal_function_count,
- signature.state_slot_count,
- };
- }
-
- SourceResolver* source_resolver() const override { return nullptr; }
-
- Disassembler* disassembler() const override { return nullptr; }
-
- std::string DebugStringShort() const override { return std::string(name()); }
-
- StatusOr<const Function> LookupFunctionByOrdinal(
- Function::Linkage linkage, int32_t ordinal) const override {
- IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByOrdinal");
- iree_rt_function_t function;
- auto status = impl_.lookup_function_by_ordinal(
- impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
- &function);
- if (status != IREE_STATUS_OK) {
- return FromApiStatus(status, IREE_LOC);
- }
- return Function{reinterpret_cast<Module*>(function.module),
- static_cast<Function::Linkage>(function.linkage),
- function.ordinal};
- }
-
- StatusOr<const Function> LookupFunctionByName(
- Function::Linkage linkage, absl::string_view name) const override {
- IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByName");
- iree_rt_function_t function;
- auto status = impl_.lookup_function_by_name(
- impl_.self, static_cast<iree_rt_function_linkage_t>(linkage),
- iree_string_view_t{name.data(), name.size()}, &function);
- if (status != IREE_STATUS_OK) {
- return FromApiStatus(status, IREE_LOC);
- }
- return Function{reinterpret_cast<Module*>(function.module),
- static_cast<Function::Linkage>(function.linkage),
- function.ordinal};
- }
-
- StatusOr<absl::string_view> GetFunctionName(Function::Linkage linkage,
- int32_t ordinal) const override {
- IREE_TRACE_SCOPE0("ExternalModule::GetFunctionName");
- iree_string_view_t name;
- auto status = impl_.get_function_name(
- impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
- &name);
- RETURN_IF_ERROR(FromApiStatus(status, IREE_LOC));
- return absl::string_view{name.data, name.size};
- }
-
- StatusOr<const FunctionSignature> GetFunctionSignature(
- Function::Linkage linkage, int32_t ordinal) const override {
- IREE_TRACE_SCOPE0("ExternalModule::GetFunctionSignature");
- iree_rt_function_signature_t signature;
- auto status = impl_.get_function_signature(
- impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
- &signature);
- if (status != IREE_STATUS_OK) {
- return FromApiStatus(status, IREE_LOC);
- }
- return FunctionSignature{signature.argument_count, signature.result_count};
- }
-
- Status Execute(
- Stack* stack, const Function function,
- absl::InlinedVector<hal::BufferView, 8> arguments,
- absl::InlinedVector<hal::BufferView, 8>* results) const override {
- // TODO(benvanik): fn ptr callback to external code. Waiting on fibers.
- return UnimplementedErrorBuilder(IREE_LOC)
- << "External calls not yet implemented";
- }
-
- private:
- iree_rt_external_module_t impl_;
- iree_allocator_t allocator_;
-};
-
-} // namespace
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external(
- iree_rt_external_module_t impl, iree_allocator_t allocator,
- iree_rt_module_t** out_module) {
- IREE_TRACE_SCOPE0("iree_rt_module_create_external");
-
- if (!out_module) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_module = nullptr;
-
- auto module = make_ref<ExternalModule>(impl, allocator);
- *out_module = reinterpret_cast<iree_rt_module_t*>(module.release());
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_module_retain(iree_rt_module_t* module) {
- IREE_TRACE_SCOPE0("iree_rt_module_retain");
- auto* handle = reinterpret_cast<Module*>(module);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->AddReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_module_release(iree_rt_module_t* module) {
- IREE_TRACE_SCOPE0("iree_rt_module_release");
- auto* handle = reinterpret_cast<Module*>(module);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->ReleaseReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_string_view_t IREE_API_CALL
-iree_rt_module_name(const iree_rt_module_t* module) {
- IREE_TRACE_SCOPE0("iree_rt_module_name");
- const auto* handle = reinterpret_cast<const Module*>(module);
- CHECK(handle) << "NULL module handle";
- return iree_string_view_t{handle->name().data(), handle->name().size()};
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module,
- iree_rt_function_linkage_t linkage,
- int32_t ordinal,
- iree_rt_function_t* out_function) {
- IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_ordinal");
-
- if (!out_function) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- std::memset(out_function, 0, sizeof(*out_function));
-
- auto* handle = reinterpret_cast<Module*>(module);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- auto function_or = handle->LookupFunctionByOrdinal(
- static_cast<Function::Linkage>(linkage), ordinal);
- if (!function_or.ok()) {
- // Map this invalid argument to not found, per the API spec.
- if (IsInvalidArgument(function_or.status())) {
- return IREE_STATUS_NOT_FOUND;
- }
- return ToApiStatus(std::move(function_or).status());
- }
- auto function = *function_or;
-
- out_function->module = module;
- out_function->linkage =
- static_cast<iree_rt_function_linkage_t>(function.linkage());
- out_function->ordinal = function.ordinal();
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_module_lookup_function_by_name(iree_rt_module_t* module,
- iree_rt_function_linkage_t linkage,
- iree_string_view_t name,
- iree_rt_function_t* out_function) {
- IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_name");
-
- if (!out_function) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- std::memset(out_function, 0, sizeof(*out_function));
-
- auto* handle = reinterpret_cast<Module*>(module);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- IREE_API_ASSIGN_OR_RETURN(
- auto function,
- handle->LookupFunctionByName(static_cast<Function::Linkage>(linkage),
- absl::string_view{name.data, name.size}));
-
- out_function->linkage =
- static_cast<iree_rt_function_linkage_t>(function.linkage());
- out_function->module = module;
- out_function->linkage = linkage;
- out_function->ordinal = function.ordinal();
-
- return IREE_STATUS_OK;
-}
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Function
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_string_view_t IREE_API_CALL
-iree_rt_function_name(const iree_rt_function_t* function) {
- IREE_TRACE_SCOPE0("iree_rt_function_name");
- CHECK(function && function->module) << "NULL function handle";
- auto* module = reinterpret_cast<Module*>(function->module);
- auto name_or = module->GetFunctionName(
- static_cast<Function::Linkage>(function->linkage), function->ordinal);
- if (!name_or.ok()) return {};
- auto name = name_or.ValueOrDie();
- return iree_string_view_t{name.data(), name.size()};
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_function_signature(const iree_rt_function_t* function,
- iree_rt_function_signature_t* out_signature) {
- IREE_TRACE_SCOPE0("iree_rt_function_signature");
-
- if (!out_signature) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- std::memset(out_signature, 0, sizeof(*out_signature));
-
- if (!function || !function->module) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- auto* module = reinterpret_cast<Module*>(function->module);
- IREE_API_ASSIGN_OR_RETURN(
- auto signature, module->GetFunctionSignature(
- static_cast<Function::Linkage>(function->linkage),
- function->ordinal));
- out_signature->argument_count = signature.argument_count();
- out_signature->result_count = signature.result_count();
- return IREE_STATUS_OK;
-}
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Policy
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t iree_rt_policy_create(
- iree_allocator_t allocator, iree_rt_policy_t** out_policy) {
- IREE_TRACE_SCOPE0("iree_rt_policy_create");
-
- if (!out_policy) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_policy = nullptr;
-
- auto policy = make_ref<Policy>();
-
- *out_policy = reinterpret_cast<iree_rt_policy_t*>(policy.release());
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_policy_retain(iree_rt_policy_t* policy) {
- IREE_TRACE_SCOPE0("iree_rt_policy_retain");
- auto* handle = reinterpret_cast<Policy*>(policy);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->AddReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_policy_release(iree_rt_policy_t* policy) {
- IREE_TRACE_SCOPE0("iree_rt_policy_release");
- auto* handle = reinterpret_cast<Policy*>(policy);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->ReleaseReference();
- return IREE_STATUS_OK;
-}
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Context
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create(
- iree_rt_instance_t* instance, iree_rt_policy_t* policy,
- iree_allocator_t allocator, iree_rt_context_t** out_context) {
- IREE_TRACE_SCOPE0("iree_rt_context_create");
-
- if (!out_context) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_context = nullptr;
-
- if (!instance || !policy) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- auto context =
- make_ref<Context>(add_ref(reinterpret_cast<Instance*>(instance)),
- add_ref(reinterpret_cast<Policy*>(policy)));
-
- *out_context = reinterpret_cast<iree_rt_context_t*>(context.release());
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_context_retain(iree_rt_context_t* context) {
- IREE_TRACE_SCOPE0("iree_rt_context_retain");
- auto* handle = reinterpret_cast<Context*>(context);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->AddReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_context_release(iree_rt_context_t* context) {
- IREE_TRACE_SCOPE0("iree_rt_context_release");
- auto* handle = reinterpret_cast<Context*>(context);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->ReleaseReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT int32_t IREE_API_CALL
-iree_rt_context_id(const iree_rt_context_t* context) {
- IREE_TRACE_SCOPE0("iree_rt_context_id");
- const auto* handle = reinterpret_cast<const Context*>(context);
- CHECK(handle) << "NULL context handle";
- return handle->id();
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_register_modules(
- iree_rt_context_t* context, iree_rt_module_t** modules,
- iree_host_size_t module_count) {
- IREE_TRACE_SCOPE0("iree_rt_context_register_modules");
- auto* handle = reinterpret_cast<Context*>(context);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- if (module_count && !modules) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- for (size_t i = 0; i < module_count; ++i) {
- auto* module = reinterpret_cast<Module*>(modules[i]);
- if (!module) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- IREE_API_RETURN_IF_ERROR(handle->RegisterModule(add_ref(module)));
- }
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL
-iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context,
- iree_string_view_t module_name) {
- IREE_TRACE_SCOPE0("iree_rt_context_lookup_module_by_name");
- const auto* handle = reinterpret_cast<const Context*>(context);
- CHECK(handle) << "NULL context handle";
- auto module_or = handle->LookupModuleByName(
- absl::string_view{module_name.data, module_name.size});
- if (!module_or.ok()) {
- return nullptr;
- }
- return reinterpret_cast<iree_rt_module_t*>(module_or.ValueOrDie());
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function(
- const iree_rt_context_t* context, iree_string_view_t full_name,
- iree_rt_function_t* out_function) {
- IREE_TRACE_SCOPE0("iree_rt_context_resolve_function");
-
- if (!out_function) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- std::memset(out_function, 0, sizeof(*out_function));
-
- const auto* handle = reinterpret_cast<const Context*>(context);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- auto full_name_view = absl::string_view{full_name.data, full_name.size};
- size_t last_dot = full_name_view.rfind('.');
- if (last_dot == absl::string_view::npos) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- auto module_name = full_name_view.substr(0, last_dot);
- auto function_name = full_name_view.substr(last_dot + 1);
-
- iree_rt_module_t* module = iree_rt_context_lookup_module_by_name(
- context, iree_string_view_t{module_name.data(), module_name.size()});
- if (!module) {
- return IREE_STATUS_NOT_FOUND;
- }
-
- return iree_rt_module_lookup_function_by_name(
- module, IREE_RT_FUNCTION_LINKAGE_EXPORT,
- iree_string_view_t{function_name.data(), function_name.size()},
- out_function);
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_context_allocate_device_visible_buffer(
- iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage,
- iree_host_size_t allocation_size, iree_allocator_t allocator,
- iree_hal_buffer_t** out_buffer) {
- IREE_TRACE_SCOPE0("iree_rt_context_allocate_device_visible_buffer");
-
- if (!out_buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- std::memset(out_buffer, 0, sizeof(*out_buffer));
-
- const auto* handle = reinterpret_cast<const Context*>(context);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- } else if (!allocation_size) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- // TODO(benvanik): reroute to context based on current policy.
- auto* device_manager = handle->instance()->device_manager();
- IREE_API_ASSIGN_OR_RETURN(auto device_placement,
- device_manager->ResolvePlacement({}));
- IREE_API_ASSIGN_OR_RETURN(auto buffer,
- device_manager->AllocateDeviceVisibleBuffer(
- static_cast<hal::BufferUsage>(buffer_usage),
- allocation_size, {device_placement}));
-
- *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release());
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_context_allocate_device_local_buffer(
- iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage,
- iree_host_size_t allocation_size, iree_allocator_t allocator,
- iree_hal_buffer_t** out_buffer) {
- IREE_TRACE_SCOPE0("iree_rt_context_allocate_device_local_buffer");
-
- if (!out_buffer) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- std::memset(out_buffer, 0, sizeof(*out_buffer));
-
- const auto* handle = reinterpret_cast<const Context*>(context);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- } else if (!allocation_size) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- // TODO(benvanik): reroute to context based on current policy.
- auto* device_manager = handle->instance()->device_manager();
- IREE_API_ASSIGN_OR_RETURN(auto device_placement,
- device_manager->ResolvePlacement({}));
- IREE_API_ASSIGN_OR_RETURN(auto buffer,
- device_manager->AllocateDeviceLocalBuffer(
- static_cast<hal::BufferUsage>(buffer_usage),
- allocation_size, {device_placement}));
-
- *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release());
-
- return IREE_STATUS_OK;
-}
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Invocation
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create(
- iree_rt_context_t* context, iree_rt_function_t* function,
- iree_rt_policy_t* policy,
- const iree_rt_invocation_dependencies_t* dependencies,
- iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count,
- iree_hal_buffer_view_t** results, iree_host_size_t result_count,
- iree_allocator_t allocator, iree_rt_invocation_t** out_invocation) {
- IREE_TRACE_SCOPE0("iree_rt_invocation_create");
-
- if (!out_invocation) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_invocation = nullptr;
-
- if (!context || !function || !function->module) {
- return IREE_STATUS_INVALID_ARGUMENT;
- } else if (dependencies &&
- (dependencies->invocation_count && !dependencies->invocations)) {
- return IREE_STATUS_INVALID_ARGUMENT;
- } else if ((argument_count && !arguments) || (result_count && !results)) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- // TODO(benvanik): unwrap without needing to retain here.
- absl::InlinedVector<ref_ptr<Invocation>, 4> dependent_invocations;
- if (dependencies) {
- dependent_invocations.resize(dependencies->invocation_count);
- for (int i = 0; i < dependencies->invocation_count; ++i) {
- dependent_invocations[i] =
- add_ref(reinterpret_cast<Invocation*>(dependencies->invocations[i]));
- }
- }
-
- // TODO(benvanik): unwrap without needing to retain here.
- absl::InlinedVector<hal::BufferView, 8> argument_views(argument_count);
- for (int i = 0; i < argument_count; ++i) {
- const auto* api_buffer_view =
- reinterpret_cast<const hal::iree_hal_buffer_view*>(arguments[i]);
- if (!api_buffer_view) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- argument_views[i] = hal::BufferView{add_ref(api_buffer_view->impl.buffer),
- api_buffer_view->impl.shape,
- api_buffer_view->impl.element_size};
- }
-
- // TODO(benvanik): unwrap without needing to retain here.
- absl::InlinedVector<hal::BufferView, 8> result_views(result_count);
- for (int i = 0; i < result_count; ++i) {
- const auto* api_buffer_view =
- reinterpret_cast<const hal::iree_hal_buffer_view*>(results[i]);
- if (api_buffer_view) {
- result_views[i] = hal::BufferView{add_ref(api_buffer_view->impl.buffer),
- api_buffer_view->impl.shape,
- api_buffer_view->impl.element_size};
- }
- }
-
- IREE_API_ASSIGN_OR_RETURN(
- auto invocation,
- Invocation::Create(
- add_ref(reinterpret_cast<Context*>(context)),
- Function{reinterpret_cast<Module*>(function->module),
- static_cast<Function::Linkage>(function->linkage),
- function->ordinal},
- add_ref(reinterpret_cast<Policy*>(policy)),
- std::move(dependent_invocations), std::move(argument_views),
- result_views.empty()
- ? absl::optional<absl::InlinedVector<hal::BufferView, 8>>(
- absl::nullopt)
- : std::move(result_views)));
-
- *out_invocation =
- reinterpret_cast<iree_rt_invocation_t*>(invocation.release());
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_invocation_retain(iree_rt_invocation_t* invocation) {
- IREE_TRACE_SCOPE0("iree_rt_invocation_retain");
- auto* handle = reinterpret_cast<Invocation*>(invocation);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->AddReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_invocation_release(iree_rt_invocation_t* invocation) {
- IREE_TRACE_SCOPE0("iree_rt_invocation_release");
- auto* handle = reinterpret_cast<Invocation*>(invocation);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- handle->ReleaseReference();
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_invocation_query_status(iree_rt_invocation_t* invocation) {
- IREE_TRACE_SCOPE0("iree_rt_invocation_query_status");
- auto* handle = reinterpret_cast<Invocation*>(invocation);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- IREE_API_RETURN_IF_ERROR(handle->QueryStatus());
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results(
- iree_rt_invocation_t* invocation, iree_host_size_t result_capacity,
- iree_allocator_t allocator, iree_hal_buffer_view_t** out_results,
- iree_host_size_t* out_result_count) {
- IREE_TRACE_SCOPE0("iree_rt_invocation_consume_results");
-
- if (!out_result_count) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_result_count = 0;
- if (!out_results) {
- std::memset(out_results, 0,
- sizeof(iree_hal_buffer_view_t*) * result_capacity);
- }
-
- auto* handle = reinterpret_cast<Invocation*>(invocation);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- const auto& function = handle->function();
- int32_t result_count = function.signature().result_count();
- *out_result_count = result_count;
- if (!out_results) {
- return IREE_STATUS_OK;
- } else if (result_capacity < result_count) {
- return IREE_STATUS_OUT_OF_RANGE;
- }
-
- IREE_API_ASSIGN_OR_RETURN(auto results, handle->ConsumeResults());
- iree_status_t status = IREE_STATUS_OK;
- int i = 0;
- for (i = 0; i < results.size(); ++i) {
- iree_shape_t shape;
- status = ToApiShape(results[i].shape, &shape);
- if (status != IREE_STATUS_OK) break;
- status = iree_hal_buffer_view_create(
- reinterpret_cast<iree_hal_buffer_t*>(results[i].buffer.get()), shape,
- results[i].element_size, allocator, &out_results[i]);
- if (status != IREE_STATUS_OK) break;
- }
- if (status != IREE_STATUS_OK) {
- // Release already-retained buffer views on failure.
- for (int j = 0; j < i; ++j) {
- iree_hal_buffer_view_release(out_results[j]);
- out_results[j] = nullptr;
- }
- }
- return status;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await(
- iree_rt_invocation_t* invocation, iree_time_t deadline) {
- IREE_TRACE_SCOPE0("iree_rt_invocation_await");
- auto* handle = reinterpret_cast<Invocation*>(invocation);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- IREE_API_RETURN_IF_ERROR(handle->Await(ToAbslTime(deadline)));
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_invocation_abort(iree_rt_invocation_t* invocation) {
- IREE_TRACE_SCOPE0("iree_rt_invocation_abort");
- auto* handle = reinterpret_cast<Invocation*>(invocation);
- if (!handle) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- IREE_API_RETURN_IF_ERROR(handle->Abort());
- return IREE_STATUS_OK;
-}
-
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/api.h b/iree/rt/api.h
deleted file mode 100644
index 37e66c3..0000000
--- a/iree/rt/api.h
+++ /dev/null
@@ -1,400 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// See iree/base/api.h for documentation on the API conventions used.
-
-#ifndef IREE_RT_API_H_
-#define IREE_RT_API_H_
-
-#include <stdint.h>
-
-#include "iree/base/api.h"
-#include "iree/hal/api.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-//===----------------------------------------------------------------------===//
-// Types and Enums
-//===----------------------------------------------------------------------===//
-
-typedef struct iree_rt_instance iree_rt_instance_t;
-typedef struct iree_rt_context iree_rt_context_t;
-typedef struct iree_rt_policy iree_rt_policy_t;
-typedef struct iree_rt_module iree_rt_module_t;
-typedef struct iree_rt_invocation iree_rt_invocation_t;
-
-// Describes the type of a function reference.
-typedef enum {
- // Function is internal to the module and may not be reflectable.
- IREE_RT_FUNCTION_LINKAGE_INTERNAL = 0,
- // Function is an import from another module.
- IREE_RT_FUNCTION_LINKAGE_IMPORT = 1,
- // Function is an export from the module.
- IREE_RT_FUNCTION_LINKAGE_EXPORT = 2,
-} iree_rt_function_linkage_t;
-
-// A function reference that can be used with the iree_rt_function_* methods.
-// These should be treated as opaque and the accessor functions should be used
-// instead.
-typedef struct {
- // Module the function is contained within.
- iree_rt_module_t* module;
- // Linkage of the function. Note that IREE_RT_FUNCTION_LINKAGE_INTERNAL
- // functions may be missing reflection information.
- iree_rt_function_linkage_t linkage;
- // Ordinal within the module in the linkage scope.
- int32_t ordinal;
-} iree_rt_function_t;
-
-// Describes the expected calling convention and arguments/results of a
-// function.
-typedef struct {
- // Total number of arguments to the function.
- int32_t argument_count;
- // Total number of results from the function.
- int32_t result_count;
-} iree_rt_function_signature_t;
-
-// Describes the imports, exports, and capabilities of a module.
-typedef struct {
- // Total number of imported functions.
- int32_t import_function_count;
- // Total number of exported functions.
- int32_t export_function_count;
- // Total number of internal functions, if debugging info is present and they
- // can be queried.
- int32_t internal_function_count;
- // Total number of state block resource slots consumed.
- int32_t state_slot_count;
-} iree_rt_module_signature_t;
-
-// Dependency information used to order invocations.
-typedef struct {
- // Prior invocations that must complete before the new invocation begins.
- iree_rt_invocation_t** invocations;
- iree_host_size_t invocation_count;
-
- // TODO(benvanik): wait semaphores/importing.
-} iree_rt_invocation_dependencies_t;
-
-// Defines an external module that can be used to reflect and execute functions.
-// Modules must be thread-safe as lookups and executions may occur in any order
-// from any thread.
-//
-// Modules will have their resolve_imports function called upon registration
-// with a context and may use the provided resolver to find imported functions.
-typedef struct {
- // User-defined pointer passed to all functions.
- void* self;
- // Destroys |self| when all references to the module have been released.
- iree_status_t(IREE_API_PTR* destroy)(void* self);
- // Returns the name of the module (used during resolution).
- iree_string_view_t(IREE_API_PTR* name)(void* self);
- // Sets |out_module_signature| to the reflected signature of the module.
- iree_rt_module_signature_t(IREE_API_PTR* signature)(void* self);
- // Sets |out_function| to a resolved function by ordinal, if found.
- iree_status_t(IREE_API_PTR* lookup_function_by_ordinal)(
- void* self, iree_rt_function_linkage_t linkage, int32_t ordinal,
- iree_rt_function_t* out_function);
- // Sets |out_function| to a resolved function by name, if found.
- iree_status_t(IREE_API_PTR* lookup_function_by_name)(
- void* self, iree_rt_function_linkage_t linkage, iree_string_view_t name,
- iree_rt_function_t* out_function);
- // Sets |out_name| to the name of the function with the given ordinal, if
- // found.
- iree_status_t(IREE_API_PTR* get_function_name)(
- void* self, iree_rt_function_linkage_t linkage, int32_t ordinal,
- iree_string_view_t* out_name);
- // Sets |out_signature| to the reflected signature of the given
- // function, if found.
- iree_status_t(IREE_API_PTR* get_function_signature)(
- void* self, iree_rt_function_linkage_t linkage, int32_t ordinal,
- iree_rt_function_signature_t* out_signature);
-} iree_rt_external_module_t;
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Instance
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Creates a new instance. This should be shared with all contexts in an
-// application to ensure that resources are tracked properly and threads are
-// managed correctly.
-// |out_instance| must be released by the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create(
- iree_allocator_t allocator, iree_rt_instance_t** out_instance);
-
-// Retains the given |instance| for the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_instance_retain(iree_rt_instance_t* instance);
-
-// Releases the given |instance| from the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_instance_release(iree_rt_instance_t* instance);
-
-// TEMPORARY: until policies and placement are performed this can be used to
-// explicitly create and register drivers by name.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_register_driver_ex(
- iree_rt_instance_t* instance, iree_string_view_t driver_name);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Module
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Creates a module with an external backing implementation.
-// The provided |external_module| definition will be used to query the module
-// state as needed. No caching occurs within the implementation to allow calls
-// to return different values per-invocation.
-//
-// |out_module| must be released by the caller.
-// iree_rt_external_module_t::destroy is called when the last reference to the
-// iree_rt_module_t is released.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external(
- iree_rt_external_module_t impl, iree_allocator_t allocator,
- iree_rt_module_t** out_module);
-
-// Retains the given |module| for the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_module_retain(iree_rt_module_t* module);
-
-// Releases the given |module| from the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_module_release(iree_rt_module_t* module);
-
-// Returns the name of the module.
-IREE_API_EXPORT iree_string_view_t IREE_API_CALL
-iree_rt_module_name(const iree_rt_module_t* module);
-
-// Sets |out_function| to a function with |ordinal| in the given linkage or
-// returns IREE_STATUS_NOT_FOUND. The function reference is valid for the
-// lifetime of |module|.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module,
- iree_rt_function_linkage_t linkage,
- int32_t ordinal,
- iree_rt_function_t* out_function);
-
-// Sets |out_function| to a function with |name| in the given linkage or returns
-// IREE_STATUS_NOT_FOUND. The function reference is valid for the lifetime of
-// |module|.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_module_lookup_function_by_name(iree_rt_module_t* module,
- iree_rt_function_linkage_t linkage,
- iree_string_view_t name,
- iree_rt_function_t* out_function);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Function
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Returns the name of the function as exported from the module.
-IREE_API_EXPORT iree_string_view_t IREE_API_CALL
-iree_rt_function_name(const iree_rt_function_t* function);
-
-// Sets |out_function_signature| to the reflected signature of the function.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_function_signature(const iree_rt_function_t* function,
- iree_rt_function_signature_t* out_signature);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Policy
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// TODO(benvanik): define policies. For now they are no-ops.
-IREE_API_EXPORT iree_status_t iree_rt_policy_create(
- iree_allocator_t allocator, iree_rt_policy_t** out_policy);
-
-// Retains the given |policy| for the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_policy_retain(iree_rt_policy_t* policy);
-
-// Releases the given |policy| from the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_policy_release(iree_rt_policy_t* policy);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Context
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Creates a new context that uses the given |instance| for device management.
-// |out_context| must be released by the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create(
- iree_rt_instance_t* instance, iree_rt_policy_t* policy,
- iree_allocator_t allocator, iree_rt_context_t** out_context);
-
-// Retains the given |context| for the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_context_retain(iree_rt_context_t* context);
-
-// Releases the given |context| from the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_context_release(iree_rt_context_t* context);
-
-// Returns a process-unique ID for the |context|.
-IREE_API_EXPORT int32_t IREE_API_CALL
-iree_rt_context_id(const iree_rt_context_t* context);
-
-// Registers a list of modules with the context and resolves imports.
-// The modules will be retained by the context until destruction.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_register_modules(
- iree_rt_context_t* context, iree_rt_module_t** modules,
- iree_host_size_t module_count);
-
-// Returns a reference to the module registered with the given name or nullptr
-// if not found. The caller must retain the returned module if they want to
-// continue using it.
-IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL
-iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context,
- iree_string_view_t module_name);
-
-// Sets |out_function| to to an exported function with the fully-qualified name
-// of |full_name| or returns IREE_STATUS_NOT_FOUND. The function reference is
-// valid for the lifetime of |context|.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function(
- const iree_rt_context_t* context, iree_string_view_t full_name,
- iree_rt_function_t* out_function);
-
-// Allocates a host-local buffer that is optimal for use on the host but is
-// usable by the given |device_placements| (at a possible performance
-// penalty). The buffer can be used for staging uploads to device-local
-// buffers and is useful for times when the buffer will be used more on the
-// host than the device. If a buffer never needs to be used with a device
-// prefer instead HeapBuffer::Allocate.
-//
-// Fails if it is not possible to allocate and satisfy all placements for the
-// requested |buffer_usage|.
-// |out_buffer| must be released by the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_context_allocate_device_visible_buffer(
- iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage,
- iree_host_size_t allocation_size, iree_allocator_t allocator,
- iree_hal_buffer_t** out_buffer);
-
-// Allocates a device-local buffer that is optimal for use with the given
-// |device_placements|. The buffer will not be host-visible and can only be
-// used from compatible device queues.
-//
-// Fails if it is not possible to allocate and satisfy all placements for the
-// requested |buffer_usage|.
-// |out_buffer| must be released by the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_context_allocate_device_local_buffer(
- iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage,
- iree_host_size_t allocation_size, iree_allocator_t allocator,
- iree_hal_buffer_t** out_buffer);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-//===----------------------------------------------------------------------===//
-// iree::rt::Invocation
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Creates a new invocation tracking object for invoking the given |function|
-// from |context|. |arguments| will be retained until the invocation is made.
-// If |dependencies| are provided then the invocation will wait until they are
-// resolved before executing. If a |policy| is provided it will override the
-// context-level policy.
-//
-// Optionally |results| may be provided with preallocated buffers that will
-// receive the outputs of the invocation. Invocation will fail if they do not
-// match expected sizes.
-//
-// Note that it's possible for the invocation to complete prior to the return of
-// this function. Any errors that occur will be set on the invocation and
-// callers should query its state prior to assuming it is in-flight.
-//
-// |out_invocation| must be released by the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create(
- iree_rt_context_t* context, iree_rt_function_t* function,
- iree_rt_policy_t* policy,
- const iree_rt_invocation_dependencies_t* dependencies,
- iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count,
- iree_hal_buffer_view_t** results, iree_host_size_t result_count,
- iree_allocator_t allocator, iree_rt_invocation_t** out_invocation);
-
-// Retains the given |invocation| for the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_invocation_retain(iree_rt_invocation_t* invocation);
-
-// Releases the given |invocation| from the caller.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_invocation_release(iree_rt_invocation_t* invocation);
-
-// Queries the completion status of the invocation.
-// Returns one of the following:
-// IREE_STATUS_OK: the invocation completed successfully.
-// IREE_STATUS_UNAVAILABLE: the invocation has not yet completed.
-// IREE_STATUS_CANCELLED: the invocation was cancelled internally.
-// IREE_STATUS_ABORTED: the invocation was aborted.
-// IREE_STATUS_*: an error occurred during invocation.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_invocation_query_status(iree_rt_invocation_t* invocation);
-
-// Populates |out_results| to the values of the results.
-// |result_capacity| defines the number of elements available in |out_results|
-// and |out_result_count| will be set with the actual number of results
-// available. If |result_capacity| is too small IREE_STATUS_OUT_OF_RANGE will be
-// returned wtih the required capacity in |out_result_count|. To only query the
-// required capacity |out_results| may be passed as nullptr.
-//
-// Ownership of returned results will be transferred to the caller and they must
-// be released if no longer needed.
-//
-// Returns errors as with iree_rt_invocation_query_status, for example in the
-// case of not-yet-completed or aborted invocations.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results(
- iree_rt_invocation_t* invocation, iree_host_size_t result_capacity,
- iree_allocator_t allocator, iree_hal_buffer_view_t** out_results,
- iree_host_size_t* out_result_count);
-
-// Blocks the caller until the invocation completes (successfully or otherwise).
-//
-// Returns IREE_STATUS_DEADLINE_EXCEEDED if |deadline| elapses before the
-// invocation completes and otherwise returns iree_rt_invocation_query_status.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await(
- iree_rt_invocation_t* invocation, iree_time_t deadline);
-
-// Attempts to abort the invocation if it is in-flight.
-// A no-op if the invocation has already completed.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_rt_invocation_abort(iree_rt_invocation_t* invocation);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-
-#endif // IREE_RT_API_H_
diff --git a/iree/rt/context.cc b/iree/rt/context.cc
deleted file mode 100644
index 6ae7a0a..0000000
--- a/iree/rt/context.cc
+++ /dev/null
@@ -1,167 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/context.h"
-
-#include <atomic>
-
-#include "absl/strings/str_cat.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/rt/debug/debug_server.h"
-#include "iree/rt/instance.h"
-#include "iree/rt/invocation.h"
-
-namespace iree {
-namespace rt {
-
-namespace {
-
-int32_t NextUniqueContextId() {
- static std::atomic<int32_t> next_id = {0};
- return ++next_id;
-}
-
-} // namespace
-
-Context::Context(ref_ptr<Instance> instance, ref_ptr<Policy> policy)
- : id_(NextUniqueContextId()),
- instance_(std::move(instance)),
- policy_(std::move(policy)) {
- IREE_TRACE_SCOPE("Context::ctor", int32_t)(id_);
- instance_->RegisterContext(this);
-}
-
-Context::~Context() {
- IREE_TRACE_SCOPE("Context::dtor", int32_t)(id_);
- instance_->UnregisterContext(this);
-}
-
-std::string Context::DebugStringShort() const {
- return absl::StrCat("context_", id_);
-}
-
-Status Context::RegisterModule(ref_ptr<Module> module) {
- IREE_TRACE_SCOPE0("Context::RegisterModule");
-
- // Ensure no conflicts in naming - we don't support shadowing.
- for (const auto& existing_module : modules_) {
- if (existing_module->name() == module->name()) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Module '" << module->name()
- << "' has already been registered in the context";
- }
- }
-
- // Try resolving prior to actually registering; if we can't resolve an import
- // then we want to fail the entire registration.
- ASSIGN_OR_RETURN(auto import_table, ResolveImports(module.get()));
-
- auto* debug_server = instance_->debug_server();
- if (debug_server) {
- CHECK_OK(debug_server->RegisterContextModule(this, module.get()));
- }
-
- modules_.push_back(std::move(module));
- module_import_tables_.push_back(std::move(import_table));
- return OkStatus();
-}
-
-StatusOr<ModuleImportTable> Context::ResolveImports(Module* module) {
- IREE_TRACE_SCOPE0("Context::ResolveImports");
-
- int32_t import_count = module->signature().import_function_count();
- ModuleImportTable import_table;
- import_table.first = module;
- import_table.second.resize(import_count);
-
- for (int32_t i = 0; i < import_count; ++i) {
- ASSIGN_OR_RETURN(auto import_function_name,
- module->GetFunctionName(Function::Linkage::kImport, i));
- ASSIGN_OR_RETURN(import_table.second[i],
- ResolveFunction(import_function_name));
- }
-
- return import_table;
-}
-
-StatusOr<Module*> Context::LookupModuleByName(
- absl::string_view module_name) const {
- for (const auto& module : modules_) {
- if (module->name() == module_name) {
- return module.get();
- }
- }
- return NotFoundErrorBuilder(IREE_LOC)
- << "No module with the name '" << module_name
- << "' has been registered";
-}
-
-StatusOr<const Function> Context::ResolveFunction(
- absl::string_view full_name) const {
- size_t last_dot = full_name.rfind('.');
- if (last_dot == absl::string_view::npos) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "'" << full_name
- << "' is not fully qualified (expected 'module.function')";
- }
- auto module_name = full_name.substr(0, last_dot);
- auto function_name = full_name.substr(last_dot + 1);
- ASSIGN_OR_RETURN(auto* module, LookupModuleByName(module_name));
- return module->LookupFunctionByName(Function::Linkage::kExport,
- function_name);
-}
-
-StatusOr<const Function> Context::ResolveImport(const Module* module,
- int32_t ordinal) const {
- for (const auto& import_table_ref : module_import_tables_) {
- if (import_table_ref.first == module) {
- const auto& import_table = import_table_ref.second;
- if (ordinal >= import_table.size()) {
- return NotFoundErrorBuilder(IREE_LOC)
- << "Import ordinal " << ordinal
- << " out of bounds of import table (" << import_table.size()
- << ")";
- }
- return import_table[ordinal];
- }
- }
- return NotFoundErrorBuilder(IREE_LOC)
- << "Import ordinal " << ordinal << " not found";
-}
-
-void Context::RegisterInvocation(Invocation* invocation) {
- {
- absl::MutexLock lock(&invocations_mutex_);
- invocations_.push_back(invocation);
- }
- auto* debug_server = instance_->debug_server();
- if (debug_server) {
- CHECK_OK(debug_server->RegisterInvocation(invocation));
- }
-}
-
-void Context::UnregisterInvocation(Invocation* invocation) {
- auto* debug_server = instance_->debug_server();
- if (debug_server) {
- CHECK_OK(debug_server->UnregisterInvocation(invocation));
- }
- {
- absl::MutexLock lock(&invocations_mutex_);
- invocations_.erase(invocation);
- }
-}
-
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/context.h b/iree/rt/context.h
deleted file mode 100644
index d67b79a..0000000
--- a/iree/rt/context.h
+++ /dev/null
@@ -1,119 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_CONTEXT_H_
-#define IREE_RT_CONTEXT_H_
-
-#include <ostream>
-
-#include "absl/base/thread_annotations.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "absl/types/optional.h"
-#include "iree/base/intrusive_list.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/rt/invocation.h"
-#include "iree/rt/module.h"
-#include "iree/rt/policy.h"
-
-namespace iree {
-namespace rt {
-
-class Instance;
-
-using ModuleImportTable = std::pair<Module*, std::vector<Function>>;
-
-// An isolated execution context.
-// Effectively a sandbox where modules can be loaded and run with restricted
-// visibility and where they can maintain state.
-//
-// Modules have imports resolved automatically when registered by searching
-// existing modules registered within the context and load order is used for
-// resolution. For example, target-specific modules should be loaded prior to
-// generic modules that may import functions defined there and if a function is
-// not available in the target-specific modules the fallback provided by the
-// generic module will be used.
-//
-// Thread-compatible and must be externally synchronized.
-class Context final : public RefObject<Context> {
- public:
- Context(ref_ptr<Instance> instance, ref_ptr<Policy> policy);
- ~Context();
-
- // A process-unique ID for the context.
- int32_t id() const { return id_; }
-
- // Instance this context uses for shared resources.
- const ref_ptr<Instance>& instance() const { return instance_; }
-
- // A short human-readable name for the context.
- std::string DebugStringShort() const;
-
- // A list of modules registered with the context.
- absl::Span<const ref_ptr<Module>> modules() const {
- return absl::MakeConstSpan(modules_);
- }
-
- // Registers a new module with the context.
- // Imports from the module will be resolved using the existing modules in the
- // context. The module will be retained by the context until destruction.
- Status RegisterModule(ref_ptr<Module> module);
-
- // Looks up a module by name.
- StatusOr<Module*> LookupModuleByName(absl::string_view module_name) const;
-
- // Resolves an exported function by fully-qualified name. The function
- // reference is valid for the lifetime of the context.
- StatusOr<const Function> ResolveFunction(absl::string_view full_name) const;
-
- // Resolves an imported function by import ordinal. The function reference is
- // valid for the lifetime of the context.
- StatusOr<const Function> ResolveImport(const Module* module,
- int32_t ordinal) const;
-
- private:
- // Resolves imports for the given module.
- StatusOr<ModuleImportTable> ResolveImports(Module* module);
-
- friend class Invocation;
- void RegisterInvocation(Invocation* invocation);
- void UnregisterInvocation(Invocation* invocation);
-
- int32_t id_;
- ref_ptr<Instance> instance_;
- ref_ptr<Policy> policy_;
-
- absl::InlinedVector<ref_ptr<Module>, 4> modules_;
- absl::InlinedVector<ModuleImportTable, 4> module_import_tables_;
-
- absl::Mutex invocations_mutex_;
- IntrusiveList<Invocation, offsetof(Invocation, context_list_link_)>
- invocations_ ABSL_GUARDED_BY(invocations_mutex_);
-
- friend class Instance;
- IntrusiveListLink instance_list_link_;
-};
-
-inline std::ostream& operator<<(std::ostream& stream, const Context& context) {
- stream << context.DebugStringShort();
- return stream;
-}
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_CONTEXT_H_
diff --git a/iree/rt/debug/BUILD b/iree/rt/debug/BUILD
deleted file mode 100644
index 8c4a939..0000000
--- a/iree/rt/debug/BUILD
+++ /dev/null
@@ -1,192 +0,0 @@
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-# TODO(benvanik): re-enable debugger after refactoring.
-# cc_library(
-# name = "debug_client",
-# srcs = ["debug_client.cc"],
-# hdrs = ["debug_client.h"],
-# deps = [
-# ":debug_client_interface",
-# ":debug_client_tcp", # build-cleaner: keep
-# "@com_google_absl//absl/container:flat_hash_map",
-# "@com_google_absl//absl/strings",
-# "@com_google_absl//absl/types:optional",
-# "@com_google_absl//absl/types:span",
-# "//iree/base:source_location",
-# "//iree/base:status",
-# "//iree/schemas",
-# ],
-# )
-#
-# cc_library(
-# name = "debug_client_interface",
-# hdrs = ["debug_client.h"],
-# deps = [
-# "@com_google_absl//absl/container:flat_hash_map",
-# "@com_google_absl//absl/strings",
-# "@com_google_absl//absl/types:optional",
-# "@com_google_absl//absl/types:span",
-# "//iree/base:status",
-# "//iree/schemas",
-# ],
-# )
-#
-# cc_library(
-# name = "debug_client_tcp",
-# srcs = ["debug_client_tcp.cc"],
-# deps = [
-# ":debug_client_interface",
-# ":debug_tcp_util",
-# "@com_google_absl//absl/container:flat_hash_map",
-# "@com_google_absl//absl/memory",
-# "@com_google_absl//absl/strings",
-# "@com_google_absl//absl/types:span",
-# "@com_github_google_flatbuffers//:flatbuffers",
-# "//iree/base:flatbuffer_util",
-# "//iree/base:status",
-# "//iree/rt",
-# "//iree/schemas",
-# ],
-# )
-#
-# cc_library(
-# name = "debug_server",
-# hdrs = ["debug_server.h"],
-# deps = [
-# ":debug_server_interface",
-# "//third_party/flatbuffers:flatbuffers",
-# "//iree/schemas",
-# "//iree/base:status",
-# ] + select({
-# "//iree:debug": [":debug_server_tcp"],
-# "//conditions:default": [":debug_server_disabled"],
-# }),
-# )
-
-cc_library(
- name = "debug_server",
- hdrs = ["debug_server.h"],
- deps = [
- ":debug_server_disabled",
- ":debug_server_interface",
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "debug_server_interface",
- hdrs = ["debug_server.h"],
- deps = ["//iree/base:status"],
-)
-
-cc_library(
- name = "debug_server_disabled",
- srcs = ["debug_server_disabled.cc"],
- deps = [
- ":debug_server_interface",
- "@com_google_absl//absl/memory",
- ],
-)
-
-# TODO(benvanik): re-enable debugger after refactoring.
-# cc_library(
-# name = "debug_server_tcp",
-# srcs = ["debug_server_tcp.cc"],
-# deps = [
-# ":debug_server_interface",
-# ":debug_service",
-# ":debug_tcp_util",
-# "@com_google_absl//absl/base:core_headers",
-# "@com_google_absl//absl/memory",
-# "@com_google_absl//absl/synchronization",
-# "@com_github_google_flatbuffers//:flatbuffers",
-# "//iree/base:status",
-# "//iree/schemas",
-# ],
-# )
-
-cc_library(
- name = "debug_server_flags",
- srcs = ["debug_server_flags.cc"],
- hdrs = ["debug_server_flags.h"],
- deps = [
- ":debug_server",
- "//iree/base:memory",
- "//iree/base:status",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- ],
-)
-
-# TODO(benvanik): re-enable debugger after refactoring.
-# cc_library(
-# name = "debug_server_flags",
-# srcs = ["debug_server_flags.cc"],
-# hdrs = ["debug_server_flags.h"],
-# copts = select({
-# "//iree:debug": [
-# "-DIREE_DEBUG_EMBEDDED_APP_PRESENT=1",
-# ],
-# "//conditions:default": [],
-# }),
-# deps = [
-# ":debug_server",
-# "@com_google_absl//absl/flags:flag",
-# "@com_google_absl//absl/strings",
-# "//iree/base:memory",
-# "//iree/base:status",
-# ] + select({
-# "//iree:debug": [
-# "//iree/tools/debugger:debug_app_embedded",
-# "//third_party/GL/native:EGL", # build-cleaner: keep
-# "//third_party/GL/native:GLESv2", # build-cleaner: keep
-# ],
-# "//conditions:default": [],
-# }),
-# )
-#
-# cc_library(
-# name = "debug_service",
-# srcs = ["debug_service.cc"],
-# hdrs = ["debug_service.h"],
-# deps = [
-# ":debug_session",
-# "@com_google_absl//absl/base:core_headers",
-# "@com_google_absl//absl/strings",
-# "@com_google_absl//absl/synchronization",
-# "@com_github_google_flatbuffers//:flatbuffers",
-# "//iree/base:flatbuffer_util",
-# "//iree/base:source_location",
-# "//iree/base:status",
-# "//iree/rt",
-# "//iree/schemas",
-# "//iree/schemas:reflection_data",
-# ],
-# )
-#
-# cc_library(
-# name = "debug_session",
-# srcs = ["debug_session.cc"],
-# hdrs = ["debug_session.h"],
-# deps = [
-# "@com_google_absl//absl/base:core_headers",
-# "@com_google_absl//absl/synchronization",
-# "//iree/base:source_location",
-# "//iree/base:status",
-# "//iree/rt",
-# "//iree/schemas",
-# ],
-# )
-#
-# cc_library(
-# name = "debug_tcp_util",
-# hdrs = ["debug_tcp_util.h"],
-# deps = [
-# "@com_github_google_flatbuffers//:flatbuffers",
-# "//iree/base:status",
-# "//iree/schemas",
-# ],
-# )
diff --git a/iree/rt/debug/debug_adapter.h b/iree/rt/debug/debug_adapter.h
deleted file mode 100644
index ab1d495..0000000
--- a/iree/rt/debug/debug_adapter.h
+++ /dev/null
@@ -1,106 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_DEBUG_ADAPTER_H_
-#define IREE_RT_DEBUG_ADAPTER_H_
-
-#include <functional>
-
-#include "iree/base/status.h"
-#include "iree/rt/invocation.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-struct StepTarget {
- // TODO(benvanik): step target info (matching RPC message).
- // module / function / offset
- // relative to current: once, out, return, etc
-};
-
-// TODO(benvanik): move to fiber base.
-// Interface for debugging invocations.
-// This is only accessible in debug builds where such features are compiled in.
-class DebugAdapter {
- public:
- // Called when an invocation completes suspending (in response to a Suspend or
- // Step request). The |suspend_status| will indicate if the suspension was
- // successful.
- using SuspendCallback = std::function<void(Status suspend_status)>;
-
- // Returns true if the invocation is suspended.
- // This only returns true if the invocation has been requested to suspend with
- // Suspend and the runtime has acked the suspend. Once suspended (and until
- // resumed) invocation state will not change and may be observed from any
- // thread.
- //
- // Safe to call from any thread.
- bool IsSuspended(Invocation* invocation);
-
- // Suspends the invocation at the next possible chance.
- //
- // Fibers have a suspension depth and each call to Suspend must be matched
- // with a call to Resume. Fibers will only resume excution when all prior
- // Suspend calls have their matching Resume called.
- //
- // Optionally callers may provide a |suspend_callback| that will be called
- // from a random thread when the invocation is suspended (or fails to
- // suspend).
- //
- // Safe to call from any thread.
- // Returns StatusCode::kUnavailable if debugging is not supported.
- Status Suspend(ref_ptr<Invocation> invocation,
- SuspendCallback suspend_callback = nullptr);
-
- // Resumes the invocation if it is suspended (or cancels a pending suspend).
- // This may wake threads if they are currently waiting on the invocation to
- // execute.
- //
- // Safe to call from any thread.
- // Returns StatusCode::kUnavailable if debugging is not supported.
- Status Resume(Invocation* invocation);
-
- // Steps invocation execution.
- // This will attempt to resume the invocation and will complete
- // asynchronously. Upon returning the invocation should be assumed resumed and
- // callers must query is_suspended to wait until the invocation suspends
- // again. Optionally callers may provide a |suspend_callback| that will be
- // called from a random thread when the invocation is suspended (or fails to
- // suspend).
- //
- // Safe to call from any thread while the invocation is suspended.
- // Returns StatusCode::kUnavailable if debugging is not supported and
- // StatusCode::kFailedPrecondition if the invocation is not suspended.
- Status Step(ref_ptr<Invocation> invocation, StepTarget step_target,
- SuspendCallback suspend_callback = nullptr);
-
- // Returns a call stack that can be used to query and manipulate the
- // invocation state. The behaviors supported depend on the stack frames and
- // the backend support and may be conditionally enabled via compile-time or
- // run-time flags.
- //
- // Safe to call from any thread while the invocation is suspended.
- // Returns StatusCode::kUnavailable if debugging is not supported and
- // StatusCode::kFailedPrecondition if the invocation is not suspended.
- // The returned stack will be invalidated when the invocation is stepped or
- // resumed.
- StatusOr<Stack> CaptureStack(Invocation* invocation);
-};
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_DEBUG_ADAPTER_H_
diff --git a/iree/rt/debug/debug_client.cc b/iree/rt/debug/debug_client.cc
deleted file mode 100644
index 66d42a7..0000000
--- a/iree/rt/debug/debug_client.cc
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/debug/debug_client.h"
-
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-Status DebugClient::GetFunction(
- std::string module_name, std::string function_name,
- std::function<void(StatusOr<RemoteFunction*> function)> callback) {
- return ResolveFunction(
- module_name, function_name,
- [this, module_name, callback](StatusOr<int> function_ordinal) {
- if (!function_ordinal.ok()) {
- callback(function_ordinal.status());
- return;
- }
- auto status =
- GetFunction(module_name, function_ordinal.ValueOrDie(), callback);
- if (!status.ok()) {
- callback(std::move(status));
- }
- });
-}
-
-Status DebugClient::StepInvocationOver(const RemoteInvocation& invocation,
- std::function<void()> callback) {
- // TODO(benvanik): implement bytecode stepping search.
- // int bytecode_offset = 0;
- // return StepInvocationToOffset(invocation, bytecode_offset,
- // std::move(callback));
- return UnimplementedErrorBuilder(IREE_LOC)
- << "StepInvocationOver not yet implemented";
-}
-
-Status DebugClient::StepInvocationOut(const RemoteInvocation& invocation,
- std::function<void()> callback) {
- // TODO(benvanik): implement bytecode stepping search.
- // int bytecode_offset = 0;
- // return StepInvocationToOffset(invocation, bytecode_offset,
- // std::move(callback));
- return UnimplementedErrorBuilder(IREE_LOC)
- << "StepInvocationOut not yet implemented";
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/debug/debug_client.h b/iree/rt/debug/debug_client.h
deleted file mode 100644
index 4031f5e..0000000
--- a/iree/rt/debug/debug_client.h
+++ /dev/null
@@ -1,286 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_DEBUG_DEBUG_CLIENT_H_
-#define IREE_RT_DEBUG_DEBUG_CLIENT_H_
-
-#include <functional>
-#include <memory>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/schemas/debug_service_generated.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-// Remote breakpoint currently active on the server.
-class RemoteBreakpoint {
- public:
- enum class Type {
- kBytecodeFunction = 0,
- kNativeFunction = 1,
- };
-
- virtual ~RemoteBreakpoint() = default;
-
- int id() const { return id_; }
- Type type() const { return type_; }
-
- virtual const std::string& module_name() const = 0;
- virtual const std::string& function_name() const = 0;
- virtual int function_ordinal() const = 0;
- virtual int bytecode_offset() const = 0;
-
- protected:
- explicit RemoteBreakpoint(int id, Type type) : id_(id), type_(type) {}
-
- private:
- int id_;
- Type type_;
-};
-
-class RemoteModule;
-
-class RemoteFunction {
- public:
- virtual ~RemoteFunction() = default;
-
- RemoteModule* module() const { return module_; }
- int ordinal() const { return function_ordinal_; }
- virtual const std::string& name() const = 0;
-
- virtual const FunctionDef& def() = 0;
-
- virtual bool is_loaded() const = 0;
- virtual bool CheckLoadedOrRequest() = 0;
-
- using LoadCallback = std::function<void(StatusOr<RemoteFunction*>)>;
- virtual void WhenLoaded(LoadCallback callback) = 0;
-
- virtual const BytecodeDef* bytecode() = 0;
-
- protected:
- RemoteFunction(RemoteModule* module, int function_ordinal)
- : module_(module), function_ordinal_(function_ordinal) {}
-
- RemoteModule* module_;
- int function_ordinal_;
-};
-
-class RemoteModule {
- public:
- virtual ~RemoteModule() = default;
-
- int context_id() const { return context_id_; }
- const std::string& name() const { return name_; }
-
- virtual const ModuleDef& def() = 0;
-
- virtual bool is_loaded() const = 0;
- virtual bool CheckLoadedOrRequest() = 0;
-
- using LoadCallback = std::function<void(StatusOr<RemoteModule*>)>;
- virtual void WhenLoaded(LoadCallback callback) = 0;
-
- virtual absl::Span<RemoteFunction*> functions() = 0;
-
- protected:
- RemoteModule(int context_id, std::string name)
- : context_id_(context_id), name_(std::move(name)) {}
-
- private:
- int context_id_;
- std::string name_;
-};
-
-class RemoteContext {
- public:
- virtual ~RemoteContext() = default;
-
- int id() const { return id_; }
-
- virtual absl::Span<RemoteModule* const> modules() const = 0;
-
- protected:
- explicit RemoteContext(int id) : id_(id) {}
-
- private:
- int id_;
-};
-
-class RemoteInvocation {
- public:
- virtual ~RemoteInvocation() = default;
-
- int id() const { return id_; }
- const std::string& name() const { return name_; }
-
- virtual const rpc::InvocationDefT& def() const = 0;
-
- protected:
- explicit RemoteInvocation(int id)
- : id_(id), name_(absl::StrCat("Invocation ", id)) {}
-
- private:
- int id_;
- std::string name_;
-};
-
-// Debugger RPC server client.
-// Statefully tracks a DebugServer to provide common client operations and
-// memoized queries.
-//
-// Thread-compatible. Do not use the client from multiple threads concurrently.
-// All remote updates of local state are performed by the Poll function. See
-// Poll for more details.
-class DebugClient {
- public:
- // Debug event listener interface.
- // Event methods will be called from within Poll calls (so on that thread).
- //
- // When the server posts an event it will mark the client as unready and
- // suspend execution of all invocations until MakeReady is used to indicate
- // that the client is ready for the server to resume. Each event needs a
- // matching MakeReady ack.
- //
- // Listeners can defer acking if they need to perform additional queries or
- // state changes to the server or wait for user interaction. Multiple events
- // may come in while unready if there was a series of events pending on the
- // server.
- class Listener {
- public:
- virtual ~Listener() = default;
-
- // Signals that a context has been registered on the server.
- virtual Status OnContextRegistered(const RemoteContext& context) = 0;
- virtual Status OnContextUnregistered(const RemoteContext& context) = 0;
-
- // Signals that a module has been loaded into a context on the server.
- virtual Status OnModuleLoaded(const RemoteContext& context,
- const RemoteModule& module) = 0;
-
- // Signals that a invocation has been registered on the server.
- virtual Status OnInvocationRegistered(
- const RemoteInvocation& invocation) = 0;
- virtual Status OnInvocationUnregistered(
- const RemoteInvocation& invocation) = 0;
-
- // Signals that a breakpoint has been hit by a invocation on the server.
- virtual Status OnBreakpointHit(const RemoteBreakpoint& breakpoint,
- const RemoteInvocation& invocation) = 0;
- };
-
- // Connects to a remote debug service at the provided IP:port.
- // The specified |listener| will receive async event notifications.
- static StatusOr<std::unique_ptr<DebugClient>> Connect(
- absl::string_view service_address, Listener* listener);
-
- virtual ~DebugClient() = default;
-
- // Returns true if the client is connected to a service.
- // virtual bool is_connected() const = 0;
-
- // A list of all contexts registered with the server.
- virtual absl::Span<RemoteContext* const> contexts() const = 0;
-
- // A list of all invocations registered with the server.
- virtual absl::Span<RemoteInvocation* const> invocations() const = 0;
-
- // A list of all breakpoints registered with the server.
- virtual absl::Span<RemoteBreakpoint* const> breakpoints() const = 0;
-
- // Resolves a function to a module ordinal.
- // This will occur asynchronously and the |callback| will be issued on the
- // polling thread.
- virtual Status ResolveFunction(
- std::string module_name, std::string function_name,
- std::function<void(StatusOr<int> function_ordinal)> callback) = 0;
-
- // Gets a function body instance.
- // The provided |callback| will be issued on the polling thread when the
- // function is available.
- virtual Status GetFunction(
- std::string module_name, int function_ordinal,
- std::function<void(StatusOr<RemoteFunction*> function)> callback) = 0;
- Status GetFunction(
- std::string module_name, std::string function_name,
- std::function<void(StatusOr<RemoteFunction*> function)> callback);
-
- // Adds a breakpoint for the given module:function:offset.
- // The breakpoint will apply to all contexts with the module loaded.
- virtual Status AddFunctionBreakpoint(
- std::string module_name, std::string function_name, int offset,
- std::function<void(const RemoteBreakpoint& breakpoint)> callback =
- nullptr) = 0;
-
- // Removes a breakpoint from the server.
- virtual Status RemoveBreakpoint(const RemoteBreakpoint& breakpoint) = 0;
-
- // Notifies the server that the debug session is ready to continue.
- // This must be called once on connection to and in acknowledgement to any
- // events posted by the server (read: any call to the Listener::On* methods).
- virtual Status MakeReady() = 0;
-
- // Suspends all invocations running on the server.
- virtual Status SuspendAllInvocations() = 0;
-
- // Resumes all invocations running on the server.
- virtual Status ResumeAllInvocations() = 0;
-
- // Suspends a list of invocations running on the server. Invocations not in
- // the provided list will not be suspended, such as new invocations created
- // while the request is pending.
- virtual Status SuspendInvocations(
- absl::Span<RemoteInvocation*> invocations) = 0;
-
- // Resumes a list of invocations running on the server.
- virtual Status ResumeInvocations(
- absl::Span<RemoteInvocation*> invocations) = 0;
-
- // Steps a invocation one bytecode operation.
- virtual Status StepInvocation(const RemoteInvocation& invocation,
- std::function<void()> callback) = 0;
- // Steps a invocation over one bytecode operation, not stopping until it
- // completes.
- Status StepInvocationOver(const RemoteInvocation& invocation,
- std::function<void()> callback);
- // Steps a invocation out of the current block.
- Status StepInvocationOut(const RemoteInvocation& invocation,
- std::function<void()> callback);
- // Steps a invocation to a specific bytecode offset within the current
- // function.
- virtual Status StepInvocationToOffset(const RemoteInvocation& invocation,
- int bytecode_offset,
- std::function<void()> callback) = 0;
-
- // TODO(benvanik): profiling modes.
-
- // Polls for the current state of the debug service and processes incoming
- // responses. Must be called as frequently as the UI is desired to update.
- // Returns CancelledError when the service is being shutdown/disconnected.
- //
- // Events on the Listener will be called from within this method.
- virtual Status Poll() = 0;
-};
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_DEBUG_DEBUG_CLIENT_H_
diff --git a/iree/rt/debug/debug_client_tcp.cc b/iree/rt/debug/debug_client_tcp.cc
deleted file mode 100644
index 7b00410..0000000
--- a/iree/rt/debug/debug_client_tcp.cc
+++ /dev/null
@@ -1,1127 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <netdb.h>
-#include <netinet/in.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <unistd.h>
-
-#include <algorithm>
-#include <cstring>
-#include <queue>
-
-#include "absl/container/flat_hash_map.h"
-#include "absl/memory/memory.h"
-#include "absl/strings/ascii.h"
-#include "absl/strings/numbers.h"
-#include "absl/strings/str_split.h"
-#include "absl/types/span.h"
-#include "flatbuffers/base.h"
-#include "flatbuffers/flatbuffers.h"
-#include "iree/base/flatbuffer_util.h"
-#include "iree/base/status.h"
-#include "iree/rt/debug/debug_client.h"
-#include "iree/rt/debug/debug_tcp_util.h"
-#include "iree/rt/module.h"
-#include "iree/schemas/debug_service_generated.h"
-#include "iree/schemas/module_def_generated.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-namespace {
-
-using ::flatbuffers::FlatBufferBuilder;
-
-// Parses a host:port address, with support for the RFC 3986 IPv6 [host]:port
-// format. Returns a pair of (hostname, port), with port being 0 if none was
-// specified.
-//
-// Parses:
-// foo (port 0) / foo:123
-// 1.2.3.4 (port 0) / 1.2.3.4:123
-// [foo] (port 0) / [foo]:123
-// [::1] (port 0) / [::1]:123
-StatusOr<std::pair<std::string, int>> ParseAddress(absl::string_view address) {
- address = absl::StripAsciiWhitespace(address);
- absl::string_view hostname;
- absl::string_view port_str;
- size_t bracket_loc = address.find_last_of(']');
- if (bracket_loc != std::string::npos) {
- // Has at least a ]. Let's assume it's mostly right.
- if (address.find('[') != 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Mismatched brackets in address: " << address;
- }
- hostname = address.substr(1, bracket_loc - 1);
- port_str = address.substr(bracket_loc + 1);
- if (port_str.find(':') == 0) {
- port_str.remove_prefix(1);
- }
- } else {
- size_t colon_loc = address.find_last_of(':');
- if (colon_loc != std::string::npos) {
- hostname = address.substr(0, colon_loc);
- port_str = address.substr(colon_loc + 1);
- } else {
- hostname = address;
- port_str = "";
- }
- }
- int port = 0;
- if (!port_str.empty() && !absl::SimpleAtoi(port_str, &port)) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Unable to parse port '" << port_str << "' from " << address;
- }
- return std::make_pair(std::string(hostname), port);
-}
-
-class TcpDebugClient final : public DebugClient {
- public:
- class TcpRemoteBreakpoint : public RemoteBreakpoint {
- public:
- TcpRemoteBreakpoint(int id, Type type, TcpDebugClient* client)
- : RemoteBreakpoint(id, type) {}
-
- const std::string& module_name() const override { return def_.module_name; }
- const std::string& function_name() const override {
- return def_.function_name;
- }
- int function_ordinal() const override { return def_.function_ordinal; }
- int bytecode_offset() const override { return def_.bytecode_offset; }
-
- Status MergeFrom(const rpc::BreakpointDef& breakpoint_def) {
- breakpoint_def.UnPackTo(&def_);
- return OkStatus();
- }
-
- private:
- rpc::BreakpointDefT def_;
- };
-
- class TcpRemoteFunction final : public RemoteFunction {
- public:
- TcpRemoteFunction(RemoteModule* module, int function_ordinal,
- const FunctionDef* function_def, TcpDebugClient* client)
- : RemoteFunction(module, function_ordinal),
- def_(function_def),
- client_(client) {
- name_ = def_->name() ? std::string(WrapString(def_->name())) : "";
- }
-
- const std::string& name() const override { return name_; }
-
- const FunctionDef& def() override { return *def_; }
-
- bool is_loaded() const override {
- return contents_.flatbuffers_buffer.size() > 0;
- }
-
- bool CheckLoadedOrRequest() override {
- if (!is_loaded()) {
- DemandContents();
- }
- return is_loaded();
- }
-
- void WhenLoaded(LoadCallback callback) override {
- if (is_loaded()) {
- callback(this);
- return;
- }
- load_callbacks_.push_back(std::move(callback));
- }
-
- const BytecodeDef* bytecode() override {
- CHECK(is_loaded());
- return contents_.bytecode_def;
- }
-
- private:
- void DemandContents() {
- if (!has_requested_contents_) {
- VLOG(2) << "Client " << client_->fd() << ": GetFunction("
- << module()->context_id() << ", " << module()->name() << ", "
- << ordinal() << ")";
- FlatBufferBuilder fbb;
- rpc::GetFunctionRequestT request;
- request.session_id = client_->session_id();
- request.context_id = module()->context_id();
- request.module_name = module()->name();
- request.function_ordinal = ordinal();
- auto status =
- client_->IssueRequest<rpc::GetFunctionRequest,
- rpc::ResponseUnion::GetFunctionResponse>(
- rpc::GetFunctionRequest::Pack(fbb, &request), std::move(fbb),
- [this](Status status,
- const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- const auto& response =
- *response_union.message_as_GetFunctionResponse();
- VLOG(2) << "Client " << client_->fd() << ": GetFunction("
- << module()->context_id() << ", " << module()->name()
- << ", " << ordinal() << ") = ...";
- RETURN_IF_ERROR(MergeFrom(response));
- for (auto& callback : load_callbacks_) {
- callback(this);
- }
- load_callbacks_.clear();
- return OkStatus();
- });
- if (!status.ok()) {
- LOG(ERROR) << "Failed to request module: " << status;
- return;
- }
- has_requested_contents_ = true;
- }
- }
-
- Status MergeFrom(const rpc::GetFunctionResponse& response) {
- // Clone and retain the contents.
- // TODO(benvanik): find a way to steal to avoid the reserialization.
- BytecodeDefT bytecode_def_storage;
- response.bytecode()->UnPackTo(&bytecode_def_storage);
- ::flatbuffers::FlatBufferBuilder fbb;
- fbb.Finish(response.bytecode()->Pack(fbb, &bytecode_def_storage));
- contents_.flatbuffers_buffer = fbb.Release();
- contents_.bytecode_def = ::flatbuffers::GetRoot<BytecodeDef>(
- contents_.flatbuffers_buffer.data());
- return OkStatus();
- }
-
- const FunctionDef* def_;
- TcpDebugClient* client_;
- std::string name_;
- bool has_requested_contents_ = false;
- std::vector<LoadCallback> load_callbacks_;
- struct {
- ::flatbuffers::DetachedBuffer flatbuffers_buffer;
- const BytecodeDef* bytecode_def = nullptr;
- } contents_;
- };
-
- class TcpRemoteModule final : public RemoteModule {
- public:
- TcpRemoteModule(int context_id, std::string module_name,
- TcpDebugClient* client)
- : RemoteModule(context_id, std::move(module_name)), client_(client) {}
-
- const ModuleDef& def() override {
- CHECK(is_loaded());
- return *module_file_->root();
- }
-
- bool is_loaded() const override { return module_file_ != nullptr; }
-
- bool CheckLoadedOrRequest() override {
- if (!is_loaded()) {
- DemandModuleDef();
- }
- return is_loaded();
- }
-
- void WhenLoaded(LoadCallback callback) override {
- if (is_loaded()) {
- callback(this);
- return;
- }
- load_callbacks_.push_back(std::move(callback));
- }
-
- absl::Span<RemoteFunction*> functions() override {
- auto* module_def = DemandModuleDef();
- if (!module_def) return {};
- return {reinterpret_cast<RemoteFunction**>(functions_.data()),
- functions_.size()};
- }
-
- private:
- const ModuleDef* DemandModuleDef() {
- if (module_file_) {
- return module_file_->root();
- }
- if (!has_requested_module_def_) {
- VLOG(2) << "Client " << client_->fd() << ": GetModule(" << context_id()
- << ", " << name() << ")";
- FlatBufferBuilder fbb;
- rpc::GetModuleRequestT request;
- request.session_id = client_->session_id();
- request.context_id = context_id();
- request.module_name = name();
- auto status =
- client_->IssueRequest<rpc::GetModuleRequest,
- rpc::ResponseUnion::GetModuleResponse>(
- rpc::GetModuleRequest::Pack(fbb, &request), std::move(fbb),
- [this](Status status,
- const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- const auto& response =
- *response_union.message_as_GetModuleResponse();
- VLOG(2) << "Client " << client_->fd() << ": GetModule("
- << context_id() << ", " << name() << ") = ...";
- RETURN_IF_ERROR(MergeFrom(response));
- for (auto& callback : load_callbacks_) {
- callback(this);
- }
- load_callbacks_.clear();
- return OkStatus();
- });
- if (!status.ok()) {
- LOG(ERROR) << "Failed to request module: " << status;
- return nullptr;
- }
- has_requested_module_def_ = true;
- }
- return nullptr;
- }
-
- Status MergeFrom(const rpc::GetModuleResponse& response) {
- // Clone and retain the module.
- // TODO(benvanik): find a way to steal to avoid the reserialization.
- ModuleDefT module_def_storage;
- response.module_()->UnPackTo(&module_def_storage);
- FlatBufferBuilder fbb;
- auto module_offs = response.module_()->Pack(fbb, &module_def_storage);
- FinishModuleDefBuffer(fbb, module_offs);
- ASSIGN_OR_RETURN(auto module_file,
- ModuleFile::CreateWithBackingBuffer(fbb.Release()));
-
- const auto& module_def = module_file->root();
- const auto& function_table = *module_def->function_table();
- functions_.reserve(function_table.functions()->size());
- for (int i = 0; i < function_table.functions()->size(); ++i) {
- const auto* function_def = function_table.functions()->Get(i);
- functions_.push_back(absl::make_unique<TcpRemoteFunction>(
- this, i, function_def, client_));
- }
-
- module_file_ = std::move(module_file);
- return OkStatus();
- }
-
- TcpDebugClient* client_;
- bool has_requested_module_def_ = false;
- std::vector<LoadCallback> load_callbacks_;
- std::unique_ptr<ModuleFile> module_file_;
- std::vector<std::unique_ptr<RemoteFunction>> functions_;
- };
-
- class TcpRemoteContext final : public RemoteContext {
- public:
- TcpRemoteContext(int context_id, TcpDebugClient* client)
- : RemoteContext(context_id), client_(client) {}
-
- absl::Span<RemoteModule* const> modules() const override {
- return absl::MakeConstSpan(modules_);
- }
-
- Status AddModule(std::unique_ptr<TcpRemoteModule> module) {
- modules_.push_back(module.get());
- module_map_.insert({module->name(), std::move(module)});
- return OkStatus();
- }
-
- Status MergeFrom(const rpc::ContextDef& context_def) { return OkStatus(); }
-
- private:
- TcpDebugClient* client_;
- std::vector<RemoteModule*> modules_;
- absl::flat_hash_map<std::string, std::unique_ptr<TcpRemoteModule>>
- module_map_;
- };
-
- class TcpRemoteInvocation final : public RemoteInvocation {
- public:
- TcpRemoteInvocation(int invocation_id, TcpDebugClient* client)
- : RemoteInvocation(invocation_id), client_(client) {}
-
- const rpc::InvocationDefT& def() const override { return def_; }
-
- Status MergeFrom(const rpc::InvocationDef& invocation_def) {
- invocation_def.UnPackTo(&def_);
- return OkStatus();
- }
-
- private:
- TcpDebugClient* client_;
- rpc::InvocationDefT def_;
- };
-
- static StatusOr<std::unique_ptr<TcpDebugClient>> Create(int fd,
- Listener* listener) {
- VLOG(2) << "Client " << fd << ": Setting up socket options...";
- // Disable Nagel's algorithm to ensure we have low latency.
- RETURN_IF_ERROR(tcp::ToggleSocketNagelsAlgorithm(fd, false));
- // Enable keepalive assuming the client is local and this high freq is ok.
- RETURN_IF_ERROR(tcp::ToggleSocketLocalKeepalive(fd, true));
- // Linger around for a bit to flush all data.
- RETURN_IF_ERROR(tcp::ToggleSocketLinger(fd, true));
- // Disable blocking as we are poll based.
- RETURN_IF_ERROR(tcp::ToggleSocketBlocking(fd, false));
-
- auto client = absl::make_unique<TcpDebugClient>(fd, listener);
- RETURN_IF_ERROR(client->Refresh());
- return client;
- }
-
- TcpDebugClient(int fd, Listener* listener) : fd_(fd), listener_(listener) {}
-
- ~TcpDebugClient() override {
- VLOG(2) << "Client " << fd_ << ": Shutting down session socket...";
- ::shutdown(fd_, SHUT_WR);
- VLOG(2) << "Client " << fd_ << ": Closing session socket...";
- ::close(fd_);
- VLOG(2) << "Client " << fd_ << ": Closed session socket!";
- fd_ = -1;
- }
-
- int fd() const { return fd_; }
- int session_id() const { return session_id_; }
-
- absl::Span<RemoteContext* const> contexts() const override {
- return absl::MakeConstSpan(contexts_);
- }
-
- absl::Span<RemoteInvocation* const> invocations() const override {
- return absl::MakeConstSpan(invocations_);
- }
-
- absl::Span<RemoteBreakpoint* const> breakpoints() const override {
- return absl::MakeConstSpan(breakpoints_);
- }
-
- // Writes the given typed request message to the given fd by wrapping it in
- // a size-prefixed rpc::Request union.
- //
- // Example:
- // FlatBufferBuilder fbb;
- // rpc::SuspendInvocationRequestBuilder request(fbb);
- // RETURN_IF_ERROR(WriteRequest(fd_, request.Finish(), std::move(fbb)));
- template <typename T>
- Status WriteRequest(int fd, ::flatbuffers::Offset<T> request_offs,
- FlatBufferBuilder fbb) {
- rpc::RequestBuilder request_builder(fbb);
- request_builder.add_message_type(rpc::RequestUnionTraits<T>::enum_value);
- request_builder.add_message(request_offs.Union());
- fbb.FinishSizePrefixed(request_builder.Finish());
- auto write_status = tcp::WriteBuffer(fd, fbb.Release());
- if (shutdown_pending_ && IsUnavailable(write_status)) {
- return OkStatus();
- }
- return write_status;
- }
-
- Status ResolveFunction(
- std::string module_name, std::string function_name,
- std::function<void(StatusOr<int> function_ordinal)> callback) override {
- VLOG(2) << "Client " << fd_ << ": ResolveFunction(" << module_name << ", "
- << function_name << ")";
- FlatBufferBuilder fbb;
- rpc::ResolveFunctionRequestT request;
- request.session_id = session_id_;
- request.module_name = module_name;
- request.function_name = function_name;
- return IssueRequest<rpc::ResolveFunctionRequest,
- rpc::ResponseUnion::ResolveFunctionResponse>(
- rpc::ResolveFunctionRequest::Pack(fbb, &request), std::move(fbb),
- [this, module_name, function_name, callback](
- Status status, const rpc::Response& response_union) -> Status {
- if (status.ok()) {
- const auto& response =
- *response_union.message_as_ResolveFunctionResponse();
- VLOG(2) << "Client " << fd_ << ": ResolveFunction(" << module_name
- << ", " << function_name
- << ") = " << response.function_ordinal();
- callback(response.function_ordinal());
- } else {
- callback(std::move(status));
- }
- return OkStatus();
- });
- }
-
- Status GetFunction(std::string module_name, int function_ordinal,
- std::function<void(StatusOr<RemoteFunction*> function)>
- callback) override {
- // See if we have the module already. If not, we'll fetch it first.
- RemoteModule* target_module = nullptr;
- for (auto* context : contexts_) {
- for (auto* module : context->modules()) {
- if (module->name() == module_name) {
- target_module = module;
- break;
- }
- }
- if (target_module) break;
- }
- if (!target_module) {
- // TODO(benvanik): fetch contexts first.
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Demand fetch contexts not yet implemented";
- }
- // Found at least one module with the right name.
- if (target_module->is_loaded()) {
- callback(target_module->functions()[function_ordinal]);
- return OkStatus();
- } else {
- // Wait until the module completes loading.
- target_module->WhenLoaded(
- [callback, function_ordinal](StatusOr<RemoteModule*> module_or) {
- if (!module_or.ok()) {
- callback(module_or.status());
- return;
- }
- callback(module_or.ValueOrDie()->functions()[function_ordinal]);
- });
- return OkStatus();
- }
- }
-
- Status AddFunctionBreakpoint(
- std::string module_name, std::string function_name, int offset,
- std::function<void(const RemoteBreakpoint& breakpoint)> callback)
- override {
- VLOG(2) << "Client " << fd_ << ": AddFunctionBreakpoint(" << module_name
- << ", " << function_name << ", " << offset << ")";
- FlatBufferBuilder fbb;
-
- auto breakpoint = absl::make_unique<rpc::BreakpointDefT>();
- breakpoint->module_name = module_name;
- breakpoint->function_name = function_name;
- breakpoint->function_ordinal = -1;
- breakpoint->bytecode_offset = offset;
- rpc::AddBreakpointRequestT request;
- request.session_id = session_id_;
- request.breakpoint = std::move(breakpoint);
- return IssueRequest<rpc::AddBreakpointRequest,
- rpc::ResponseUnion::AddBreakpointResponse>(
- rpc::AddBreakpointRequest::Pack(fbb, &request), std::move(fbb),
- [this, callback](Status status,
- const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- const auto& response =
- *response_union.message_as_AddBreakpointResponse();
- RETURN_IF_ERROR(RegisterBreakpoint(*response.breakpoint()));
- if (callback) {
- ASSIGN_OR_RETURN(
- auto breakpoint,
- GetBreakpoint(response.breakpoint()->breakpoint_id()));
- callback(*breakpoint);
- }
- return OkStatus();
- });
- }
-
- Status RemoveBreakpoint(const RemoteBreakpoint& breakpoint) override {
- VLOG(2) << "Client " << fd_ << ": RemoveBreakpoint(" << breakpoint.id()
- << ")";
- int breakpoint_id = breakpoint.id();
- ASSIGN_OR_RETURN(auto* breakpoint_ptr, GetBreakpoint(breakpoint_id));
- RETURN_IF_ERROR(UnregisterBreakpoint(breakpoint_ptr));
- FlatBufferBuilder fbb;
- rpc::RemoveBreakpointRequestBuilder request(fbb);
- request.add_session_id(session_id_);
- request.add_breakpoint_id(breakpoint_id);
- return IssueRequest<rpc::RemoveBreakpointRequest,
- rpc::ResponseUnion::RemoveBreakpointResponse>(
- request.Finish(), std::move(fbb),
- [](Status status, const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- // No non-error status.
- return OkStatus();
- });
- }
-
- Status MakeReady() override {
- FlatBufferBuilder fbb;
- rpc::MakeReadyRequestBuilder request(fbb);
- request.add_session_id(session_id_);
- return IssueRequest<rpc::MakeReadyRequest,
- rpc::ResponseUnion::MakeReadyResponse>(
- request.Finish(), std::move(fbb),
- [](Status status, const rpc::Response& response_union) {
- return status;
- });
- }
-
- Status SuspendAllInvocations() override {
- VLOG(2) << "Client " << fd_ << ": SuspendAllInvocations()";
- FlatBufferBuilder fbb;
- rpc::SuspendInvocationsRequestBuilder request(fbb);
- request.add_session_id(session_id_);
- return IssueRequest<rpc::SuspendInvocationsRequest,
- rpc::ResponseUnion::SuspendInvocationsResponse>(
- request.Finish(), std::move(fbb),
- [this](Status status, const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- return RefreshInvocations();
- });
- }
-
- Status ResumeAllInvocations() override {
- VLOG(2) << "Client " << fd_ << ": ResumeAllInvocations()";
- FlatBufferBuilder fbb;
- rpc::ResumeInvocationsRequestBuilder request(fbb);
- request.add_session_id(session_id_);
- return IssueRequest<rpc::ResumeInvocationsRequest,
- rpc::ResponseUnion::ResumeInvocationsResponse>(
- request.Finish(), std::move(fbb),
- [this](Status status, const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- return RefreshInvocations();
- });
- }
-
- Status SuspendInvocations(
- absl::Span<RemoteInvocation*> invocations) override {
- VLOG(2) << "Client " << fd_ << ": SuspendInvocations(...)";
- FlatBufferBuilder fbb;
- auto invocation_ids_offs = fbb.CreateVector<int32_t>(
- invocations.size(),
- [&invocations](size_t i) { return invocations[i]->id(); });
- rpc::SuspendInvocationsRequestBuilder request(fbb);
- request.add_session_id(session_id_);
- request.add_invocation_ids(invocation_ids_offs);
- return IssueRequest<rpc::SuspendInvocationsRequest,
- rpc::ResponseUnion::SuspendInvocationsResponse>(
- request.Finish(), std::move(fbb),
- [this](Status status, const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- return RefreshInvocations();
- });
- }
-
- Status ResumeInvocations(absl::Span<RemoteInvocation*> invocations) override {
- VLOG(2) << "Client " << fd_ << ": ResumeInvocations(...)";
- FlatBufferBuilder fbb;
- auto invocation_ids_offs = fbb.CreateVector<int32_t>(
- invocations.size(),
- [&invocations](size_t i) { return invocations[i]->id(); });
- rpc::ResumeInvocationsRequestBuilder request(fbb);
- request.add_session_id(session_id_);
- request.add_invocation_ids(invocation_ids_offs);
- return IssueRequest<rpc::ResumeInvocationsRequest,
- rpc::ResponseUnion::ResumeInvocationsResponse>(
- request.Finish(), std::move(fbb),
- [this](Status status, const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- return RefreshInvocations();
- });
- }
-
- Status StepInvocation(const RemoteInvocation& invocation,
- std::function<void()> callback) override {
- int step_id = next_step_id_++;
- VLOG(2) << "Client " << fd_ << ": StepInvocation(" << invocation.id()
- << ") as step_id=" << step_id;
- rpc::StepInvocationRequestT step_request;
- step_request.step_id = step_id;
- step_request.invocation_id = invocation.id();
- step_request.step_mode = rpc::StepMode::STEP_ONCE;
- return StepInvocation(&step_request, std::move(callback));
- }
-
- Status StepInvocationToOffset(const RemoteInvocation& invocation,
- int bytecode_offset,
- std::function<void()> callback) override {
- int step_id = next_step_id_++;
- VLOG(2) << "Client " << fd_ << ": StepInvocationToOffset("
- << invocation.id() << ", " << bytecode_offset
- << ") as step_id=" << step_id;
- rpc::StepInvocationRequestT step_request;
- step_request.step_id = step_id;
- step_request.invocation_id = invocation.id();
- step_request.step_mode = rpc::StepMode::STEP_TO_OFFSET;
- step_request.bytecode_offset = bytecode_offset;
- return StepInvocation(&step_request, std::move(callback));
- }
-
- Status Poll() override {
- while (true) {
- // If nothing awaiting then return immediately.
- if (!tcp::CanReadBuffer(fd_)) {
- break;
- }
-
- // Read the pending response and dispatch.
- auto packet_buffer_or = tcp::ReadBuffer<rpc::ServicePacket>(fd_);
- if (!packet_buffer_or.ok()) {
- if (shutdown_pending_ && IsUnavailable(packet_buffer_or.status())) {
- // This is a graceful close.
- return CancelledErrorBuilder(IREE_LOC) << "Service shutdown";
- }
- return packet_buffer_or.status();
- }
- const auto& packet = packet_buffer_or.ValueOrDie().GetRoot();
- if (packet.response()) {
- RETURN_IF_ERROR(DispatchResponse(*packet.response()));
- }
- if (packet.event()) {
- RETURN_IF_ERROR(DispatchEvent(packet));
- }
- }
- return OkStatus();
- }
-
- using ResponseCallback =
- std::function<Status(Status status, const rpc::Response& response)>;
-
- template <typename T, rpc::ResponseUnion response_type>
- Status IssueRequest(::flatbuffers::Offset<T> request_offs,
- FlatBufferBuilder fbb, ResponseCallback callback) {
- RETURN_IF_ERROR(WriteRequest(fd_, request_offs, std::move(fbb)));
- pending_responses_.push({response_type, std::move(callback)});
- return OkStatus();
- }
-
- private:
- Status Refresh() {
- RETURN_IF_ERROR(RefreshContexts());
- RETURN_IF_ERROR(RefreshInvocations());
- RETURN_IF_ERROR(RefreshBreakpoints());
- return OkStatus();
- }
-
- Status RefreshContexts() {
- VLOG(2) << "Request contexts refresh...";
- FlatBufferBuilder fbb;
- rpc::ListContextsRequestBuilder request(fbb);
- request.add_session_id(session_id_);
- return IssueRequest<rpc::ListContextsRequest,
- rpc::ResponseUnion::ListContextsResponse>(
- request.Finish(), std::move(fbb),
- [this](Status status, const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- VLOG(2) << "Refreshing contexts...";
- const auto& response =
- *response_union.message_as_ListContextsResponse();
- for (auto* context_def : *response.contexts()) {
- auto context_or = GetContext(context_def->context_id());
- if (!context_or.ok()) {
- // Not found; add new.
- RETURN_IF_ERROR(RegisterContext(context_def->context_id()));
- context_or = GetContext(context_def->context_id());
- }
- RETURN_IF_ERROR(context_or.status());
- RETURN_IF_ERROR(context_or.ValueOrDie()->MergeFrom(*context_def));
- }
- VLOG(2) << "Refreshed contexts!";
- return OkStatus();
- });
- }
-
- Status RefreshInvocations() {
- VLOG(2) << "Request invocation states refresh...";
- FlatBufferBuilder fbb;
- rpc::ListInvocationsRequestBuilder request(fbb);
- request.add_session_id(session_id_);
- return IssueRequest<rpc::ListInvocationsRequest,
- rpc::ResponseUnion::ListInvocationsResponse>(
- request.Finish(), std::move(fbb),
- [this](Status status, const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- VLOG(2) << "Refreshing invocation states...";
- const auto& response =
- *response_union.message_as_ListInvocationsResponse();
- for (auto* invocation_def : *response.invocations()) {
- auto invocation_or = GetInvocation(invocation_def->invocation_id());
- if (!invocation_or.ok()) {
- // Not found; add new.
- RETURN_IF_ERROR(
- RegisterInvocation(invocation_def->invocation_id()));
- invocation_or = GetInvocation(invocation_def->invocation_id());
- }
- RETURN_IF_ERROR(invocation_or.status());
- RETURN_IF_ERROR(
- invocation_or.ValueOrDie()->MergeFrom(*invocation_def));
- }
- // TODO(benvanik): handle removals/deaths.
- VLOG(2) << "Refreshed invocation states!";
- return OkStatus();
- });
- }
-
- Status RefreshBreakpoints() {
- VLOG(2) << "Requesting breakpoint refresh...";
- FlatBufferBuilder fbb;
- rpc::ListBreakpointsRequestBuilder request(fbb);
- request.add_session_id(session_id_);
- return IssueRequest<rpc::ListBreakpointsRequest,
- rpc::ResponseUnion::ListBreakpointsResponse>(
- request.Finish(), std::move(fbb),
- [this](Status status, const rpc::Response& response_union) -> Status {
- if (!status.ok()) return status;
- VLOG(2) << "Refreshing breakpoints...";
- const auto& response =
- *response_union.message_as_ListBreakpointsResponse();
- for (auto* breakpoint_def : *response.breakpoints()) {
- auto breakpoint_or = GetBreakpoint(breakpoint_def->breakpoint_id());
- if (!breakpoint_or.ok()) {
- // Not found; add new.
- RETURN_IF_ERROR(RegisterBreakpoint(*breakpoint_def));
- breakpoint_or = GetBreakpoint(breakpoint_def->breakpoint_id());
- }
- RETURN_IF_ERROR(breakpoint_or.status());
- RETURN_IF_ERROR(
- breakpoint_or.ValueOrDie()->MergeFrom(*breakpoint_def));
- }
- // TODO(benvanik): handle removals/deaths.
- VLOG(2) << "Refreshed breakpoints!";
- return OkStatus();
- });
- }
-
- Status DispatchResponse(const rpc::Response& response) {
- if (pending_responses_.empty()) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Response received but no request is pending";
- }
- auto type_callback = std::move(pending_responses_.front());
- pending_responses_.pop();
-
- if (response.status()) {
- const auto& status = *response.status();
- Status client_status =
- StatusBuilder(static_cast<StatusCode>(status.code()), IREE_LOC)
- << "Server request failed: " << WrapString(status.message());
- return type_callback.second(std::move(client_status), response);
- }
-
- if (!response.message()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Response contains no message body";
- }
-
- if (response.message_type() != type_callback.first) {
- return DataLossErrorBuilder(IREE_LOC)
- << "Out of order response (mismatch pending)";
- }
- return type_callback.second(OkStatus(), response);
- }
-
- Status DispatchEvent(const rpc::ServicePacket& packet) {
- switch (packet.event_type()) {
-#define DISPATCH_EVENT(event_name) \
- case rpc::EventUnion::event_name##Event: { \
- VLOG(2) << "EVENT: " << #event_name; \
- return On##event_name(*packet.event_as_##event_name##Event()); \
- }
- DISPATCH_EVENT(ServiceShutdown);
- DISPATCH_EVENT(ContextRegistered);
- DISPATCH_EVENT(ContextUnregistered);
- DISPATCH_EVENT(ModuleLoaded);
- DISPATCH_EVENT(InvocationRegistered);
- DISPATCH_EVENT(InvocationUnregistered);
- DISPATCH_EVENT(BreakpointResolved);
- DISPATCH_EVENT(BreakpointHit);
- DISPATCH_EVENT(StepCompleted);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented debug service event: "
- << static_cast<int>(packet.event_type());
- }
- }
-
- StatusOr<TcpRemoteContext*> GetContext(int context_id) {
- auto it = context_map_.find(context_id);
- if (it == context_map_.end()) {
- return NotFoundErrorBuilder(IREE_LOC) << "Context was never registered";
- }
- return it->second.get();
- }
-
- Status OnServiceShutdown(const rpc::ServiceShutdownEvent& event) {
- LOG(INFO) << "Service is shutting down; setting pending shutdown flag";
- shutdown_pending_ = true;
- return OkStatus();
- }
-
- Status RegisterContext(int context_id) {
- auto context = absl::make_unique<TcpRemoteContext>(context_id, this);
- VLOG(2) << "RegisterContext(" << context_id << ")";
- auto context_ptr = context.get();
- context_map_.insert({context_id, std::move(context)});
- contexts_.push_back(context_ptr);
- return listener_->OnContextRegistered(*context_ptr);
- }
-
- Status OnContextRegistered(const rpc::ContextRegisteredEvent& event) {
- VLOG(2) << "OnContextRegistered(" << event.context_id() << ")";
- auto it = context_map_.find(event.context_id());
- if (it != context_map_.end()) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Context already registered";
- }
- return RegisterContext(event.context_id());
- }
-
- Status OnContextUnregistered(const rpc::ContextUnregisteredEvent& event) {
- VLOG(2) << "OnContextUnregistered(" << event.context_id() << ")";
- auto it = context_map_.find(event.context_id());
- if (it == context_map_.end()) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Context was never registered";
- }
- auto context = std::move(it->second);
- context_map_.erase(it);
- auto list_it = std::find(contexts_.begin(), contexts_.end(), context.get());
- contexts_.erase(list_it);
- return listener_->OnContextUnregistered(*context);
- }
-
- Status OnModuleLoaded(const rpc::ModuleLoadedEvent& event) {
- VLOG(2) << "OnModuleLoaded(" << event.context_id() << ", "
- << WrapString(event.module_name()) << ")";
- ASSIGN_OR_RETURN(auto* context, GetContext(event.context_id()));
- auto module_name = WrapString(event.module_name());
- auto module = absl::make_unique<TcpRemoteModule>(
- event.context_id(), std::string(module_name), this);
- auto* module_ptr = module.get();
- RETURN_IF_ERROR(context->AddModule(std::move(module)));
- return listener_->OnModuleLoaded(*context, *module_ptr);
- }
-
- StatusOr<TcpRemoteInvocation*> GetInvocation(int invocation_id) {
- auto it = invocation_map_.find(invocation_id);
- if (it == invocation_map_.end()) {
- return NotFoundErrorBuilder(IREE_LOC)
- << "Invocation was never registered";
- }
- return it->second.get();
- }
-
- Status RegisterInvocation(int invocation_id) {
- VLOG(2) << "RegisterInvocation(" << invocation_id << ")";
- auto invocation =
- absl::make_unique<TcpRemoteInvocation>(invocation_id, this);
- auto invocation_ptr = invocation.get();
- invocation_map_.insert({invocation_id, std::move(invocation)});
- invocations_.push_back(invocation_ptr);
- RETURN_IF_ERROR(RefreshInvocations());
- return listener_->OnInvocationRegistered(*invocation_ptr);
- }
-
- Status OnInvocationRegistered(const rpc::InvocationRegisteredEvent& event) {
- VLOG(2) << "OnInvocationRegistered(" << event.invocation_id() << ")";
- auto it = invocation_map_.find(event.invocation_id());
- if (it != invocation_map_.end()) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Invocation already registered";
- }
- return RegisterInvocation(event.invocation_id());
- }
-
- Status OnInvocationUnregistered(
- const rpc::InvocationUnregisteredEvent& event) {
- VLOG(2) << "OnInvocationUnregistered(" << event.invocation_id() << ")";
- auto it = invocation_map_.find(event.invocation_id());
- if (it == invocation_map_.end()) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Invocation was never registered";
- }
- auto invocation = std::move(it->second);
- invocation_map_.erase(it);
- auto list_it =
- std::find(invocations_.begin(), invocations_.end(), invocation.get());
- invocations_.erase(list_it);
- return listener_->OnInvocationUnregistered(*invocation);
- }
-
- StatusOr<TcpRemoteBreakpoint*> GetBreakpoint(int breakpoint_id) {
- auto it = breakpoint_map_.find(breakpoint_id);
- if (it == breakpoint_map_.end()) {
- return NotFoundErrorBuilder(IREE_LOC)
- << "Breakpoint " << breakpoint_id << " was never registered";
- }
- return it->second.get();
- }
-
- Status RegisterBreakpoint(const rpc::BreakpointDef& breakpoint_def) {
- auto it = breakpoint_map_.find(breakpoint_def.breakpoint_id());
- if (it != breakpoint_map_.end()) {
- VLOG(2) << "RegisterBreakpoint(" << breakpoint_def.breakpoint_id()
- << ") (update)";
- return it->second->MergeFrom(breakpoint_def);
- }
-
- VLOG(2) << "RegisterBreakpoint(" << breakpoint_def.breakpoint_id() << ")";
- auto breakpoint = absl::make_unique<TcpRemoteBreakpoint>(
- breakpoint_def.breakpoint_id(),
- static_cast<RemoteBreakpoint::Type>(breakpoint_def.breakpoint_type()),
- this);
- RETURN_IF_ERROR(breakpoint->MergeFrom(breakpoint_def));
- breakpoints_.push_back(breakpoint.get());
- breakpoint_map_.insert({breakpoint->id(), std::move(breakpoint)});
- return OkStatus();
- }
-
- Status UnregisterBreakpoint(RemoteBreakpoint* breakpoint) {
- VLOG(2) << "UnregisterBreakpoint(" << breakpoint->id() << ")";
- auto it = breakpoint_map_.find(breakpoint->id());
- if (it == breakpoint_map_.end()) {
- return NotFoundErrorBuilder(IREE_LOC)
- << "Breakpoint was never registered";
- }
- breakpoint_map_.erase(it);
- auto list_it =
- std::find(breakpoints_.begin(), breakpoints_.end(), breakpoint);
- breakpoints_.erase(list_it);
- return OkStatus();
- }
-
- Status OnBreakpointResolved(const rpc::BreakpointResolvedEvent& event) {
- VLOG(2) << "OnBreakpointResolved(" << event.breakpoint()->breakpoint_id()
- << ")";
- auto it = breakpoint_map_.find(event.breakpoint()->breakpoint_id());
- if (it == breakpoint_map_.end()) {
- RETURN_IF_ERROR(RegisterBreakpoint(*event.breakpoint()));
- } else {
- RETURN_IF_ERROR(it->second->MergeFrom(*event.breakpoint()));
- }
- return OkStatus();
- }
-
- Status OnBreakpointHit(const rpc::BreakpointHitEvent& event) {
- VLOG(2) << "OnBreakpointHit(" << event.breakpoint_id() << ")";
- ASSIGN_OR_RETURN(auto* breakpoint, GetBreakpoint(event.breakpoint_id()));
- auto* invocation_def = event.invocation();
- auto invocation_or = GetInvocation(invocation_def->invocation_id());
- if (!invocation_or.ok()) {
- // Not found; add new.
- RETURN_IF_ERROR(RegisterInvocation(invocation_def->invocation_id()));
- invocation_or = GetInvocation(invocation_def->invocation_id());
- }
- RETURN_IF_ERROR(invocation_or.status());
- RETURN_IF_ERROR(invocation_or.ValueOrDie()->MergeFrom(*invocation_def));
- return listener_->OnBreakpointHit(*breakpoint, *invocation_or.ValueOrDie());
- }
-
- Status StepInvocation(rpc::StepInvocationRequestT* step_request,
- std::function<void()> callback) {
- FlatBufferBuilder fbb;
- auto status = IssueRequest<rpc::StepInvocationRequest,
- rpc::ResponseUnion::StepInvocationResponse>(
- rpc::StepInvocationRequest::Pack(fbb, step_request), std::move(fbb),
- [](Status status, const rpc::Response& response_union) -> Status {
- return status;
- });
- RETURN_IF_ERROR(status);
- pending_step_callbacks_[step_request->step_id] = std::move(callback);
- return OkStatus();
- }
-
- Status OnStepCompleted(const rpc::StepCompletedEvent& event) {
- VLOG(2) << "OnStepCompleted(" << event.step_id() << ")";
-
- // Update all invocation states that are contained.
- // This may only be a subset of relevant states.
- for (auto* invocation_def : *event.invocations()) {
- ASSIGN_OR_RETURN(auto invocation,
- GetInvocation(invocation_def->invocation_id()));
- RETURN_IF_ERROR(invocation->MergeFrom(*invocation_def));
- }
-
- // Dispatch step callback. Note that it may have been cancelled and that's
- // ok. We'll just make ready to resume execution.
- auto it = pending_step_callbacks_.find(event.step_id());
- if (it != pending_step_callbacks_.end()) {
- it->second();
- pending_step_callbacks_.erase(it);
- } else {
- LOG(WARNING) << "Step " << event.step_id()
- << " not found; was cancelled?";
- RETURN_IF_ERROR(MakeReady());
- }
- return OkStatus();
- }
-
- int session_id_ = 123;
-
- int fd_ = -1;
- Listener* listener_;
- bool shutdown_pending_ = false;
- std::queue<std::pair<rpc::ResponseUnion, ResponseCallback>>
- pending_responses_;
-
- std::vector<RemoteContext*> contexts_;
- absl::flat_hash_map<int, std::unique_ptr<TcpRemoteContext>> context_map_;
- std::vector<RemoteInvocation*> invocations_;
- absl::flat_hash_map<int, std::unique_ptr<TcpRemoteInvocation>>
- invocation_map_;
- std::vector<RemoteBreakpoint*> breakpoints_;
- absl::flat_hash_map<int, std::unique_ptr<TcpRemoteBreakpoint>>
- breakpoint_map_;
-
- int next_step_id_ = 1;
- absl::flat_hash_map<int, std::function<void()>> pending_step_callbacks_;
-};
-
-} // namespace
-
-// static
-StatusOr<std::unique_ptr<DebugClient>> DebugClient::Connect(
- absl::string_view service_address, Listener* listener) {
- // Parse address into hostname and port.
- ASSIGN_OR_RETURN(auto hostname_port, ParseAddress(service_address));
- if (hostname_port.second == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "No port specified in service address; port must match the "
- "server: "
- << service_address;
- }
-
- // Attempt to resolve the address.
- // Note that if we only wanted local debugging we could remove the dep on
- // getaddrinfo/having a valid DNS setup.
- addrinfo hints = {0};
- hints.ai_family = AF_UNSPEC;
- hints.ai_socktype = SOCK_STREAM;
- addrinfo* resolved_address = nullptr;
- auto port_str = std::to_string(hostname_port.second);
- int getaddrinfo_ret = ::getaddrinfo(
- hostname_port.first.c_str(), port_str.c_str(), &hints, &resolved_address);
- if (getaddrinfo_ret != 0) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Unable to resolve debug service address for " << service_address
- << ": (" << getaddrinfo_ret << ") "
- << ::gai_strerror(getaddrinfo_ret);
- }
-
- // Attempt to connect with each address returned from the query.
- int fd = -1;
- for (addrinfo* rp = resolved_address; rp != nullptr; rp = rp->ai_next) {
- fd = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
- if (fd == -1) continue;
- if (::connect(fd, rp->ai_addr, rp->ai_addrlen) == 0) {
- break; // Success!
- }
- ::close(fd);
- fd = -1;
- }
- ::freeaddrinfo(resolved_address);
- if (fd == -1) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Unable to connect to " << service_address << " on any address: ("
- << errno << ") " << ::strerror(errno);
- }
-
- LOG(INFO) << "Connected to debug service at " << service_address;
-
- return TcpDebugClient::Create(fd, listener);
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/debug/debug_server.h b/iree/rt/debug/debug_server.h
deleted file mode 100644
index c10e0d4..0000000
--- a/iree/rt/debug/debug_server.h
+++ /dev/null
@@ -1,88 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_DEBUG_DEBUG_SERVER_H_
-#define IREE_RT_DEBUG_DEBUG_SERVER_H_
-
-#include "iree/base/status.h"
-
-namespace iree {
-namespace rt {
-class Context;
-class Instance;
-class Invocation;
-class Module;
-} // namespace rt
-} // namespace iree
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-// Runtime debugging server.
-// Enabled only when compiled in (by defining IREE_DEBUG=1), this provides an
-// RPC server that allows debuggers to attach, query, and manipulate contexts.
-// This interface is used by various parts of the runtime such as dispatch to
-// query the current debug state and signal events.
-//
-// Thread-safe. Contexts may be registered and unregistered from any thread.
-class DebugServer {
- public:
- // Creates a new debug service listening on the provided |port|.
- // Even when disabled the device can still be created however it will not
- // perform any actual operations and act as if the debugger is not attached.
- static StatusOr<std::unique_ptr<DebugServer>> Create(int listen_port);
-
- // TODO(benvanik): ensure this gets optimized out when disabled.
- // Seems to be the case: https://gcc.godbolt.org/z/0zf-L4
- virtual ~DebugServer() = default;
-
- // Attaches a callback that will be made when the debug server is shutting
- // down. This can be used to keep resources alive that require the debugger.
- // The callback will be made from a random thread.
- virtual void AtExit(std::function<void()> callback) = 0;
-
- // Blocks the caller until a client session connects and resumes all fibers.
- // Returns AbortedError if a session connects/is connected but disconnects
- // during the wait.
- virtual Status WaitUntilSessionReady() = 0;
-
- protected:
- friend class ::iree::rt::Instance;
-
- // Registers a context with the debug service.
- // Ownership remains with the caller and UnregisterContext must be called
- // prior to the context being destroyed.
- virtual Status RegisterContext(Context* context) = 0;
- virtual Status UnregisterContext(Context* context) = 0;
-
- friend class ::iree::rt::Context;
-
- // Registers a new module linked into an existing Context.
- virtual Status RegisterContextModule(Context* context, Module* module) = 0;
-
- friend class ::iree::rt::Invocation;
-
- // Registers an invocation with the debug service.
- // Ownership remains with the caller and UnregisterInvocation must be called
- // prior to the fiber state being destroyed.
- virtual Status RegisterInvocation(Invocation* invocation) = 0;
- virtual Status UnregisterInvocation(Invocation* invocation) = 0;
-};
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_DEBUG_DEBUG_SERVER_H_
diff --git a/iree/rt/debug/debug_server_disabled.cc b/iree/rt/debug/debug_server_disabled.cc
deleted file mode 100644
index e9b9bd8..0000000
--- a/iree/rt/debug/debug_server_disabled.cc
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/debug/debug_server.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-// static
-StatusOr<std::unique_ptr<DebugServer>> DebugServer::Create(int listen_port) {
- return {nullptr};
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/debug/debug_server_flags.cc b/iree/rt/debug/debug_server_flags.cc
deleted file mode 100644
index a7814c0..0000000
--- a/iree/rt/debug/debug_server_flags.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/debug/debug_server_flags.h"
-
-#include "absl/flags/flag.h"
-#include "absl/strings/str_cat.h"
-#include "iree/base/memory.h"
-#include "iree/base/status.h"
-
-#if defined(IREE_DEBUG_EMBEDDED_APP_PRESENT)
-#include "iree/tools/debugger/debug_app_embedded.h"
-#endif // IREE_DEBUG_EMBEDDED_APP_PRESENT
-
-ABSL_FLAG(int32_t, iree_debug_service_port, 6000,
- "TCP port to listen for debug service connections.");
-ABSL_FLAG(bool, iree_wait_for_debugger, false,
- "Waits until a debugger connects to continue startup.");
-ABSL_FLAG(bool, iree_attach_debugger, false, "Attaches a debugger at startup.");
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-StatusOr<std::unique_ptr<DebugServer>> CreateDebugServerFromFlags() {
- // Create the server based on whatever version is compiled in.
- // Note that this will return nullptr if no server is available.
- ASSIGN_OR_RETURN(
- auto debug_server,
- DebugServer::Create(absl::GetFlag(FLAGS_iree_debug_service_port)));
- if (!debug_server) {
- return nullptr;
- }
-
-#if defined(IREE_DEBUG_EMBEDDED_APP_PRESENT)
- // If the embedded debug UI is present then we can launch that now.
- std::unique_ptr<EmbeddedDebugger> debugger;
- if (absl::GetFlag(FLAGS_iree_attach_debugger)) {
- LOG(INFO) << "Attaching debugger at startup...";
- ASSIGN_OR_RETURN(
- debugger,
- AttachDebugger(absl::StrCat(
- "localhost:", absl::GetFlag(FLAGS_iree_debug_service_port))));
- RETURN_IF_ERROR(debug_server->WaitUntilSessionReady());
- LOG(INFO) << "Debugger attached";
- // TODO(benvanik): C++14 to avoid this.
- auto debugger_baton = IreeMoveToLambda(debugger);
- debug_server->AtExit([debugger_baton]() { debugger_baton.value.reset(); });
- }
-#else
- if (absl::GetFlag(FLAGS_iree_attach_debugger)) {
- LOG(WARNING) << "--iree_attach_debugger specified but no embedded debugger "
- "is present. Build with --define=IREE_DEBUG=1.";
- }
-#endif // IREE_DEBUG_EMBEDDED_APP_PRESENT
-
- // Wait for a debugger to connect.
- if (absl::GetFlag(FLAGS_iree_wait_for_debugger)) {
- LOG(INFO) << "Waiting for a debugger to connect...";
- RETURN_IF_ERROR(debug_server->WaitUntilSessionReady());
- LOG(INFO) << "Debugger ready, resuming...";
- }
-
- return std::move(debug_server);
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/debug/debug_server_flags.h b/iree/rt/debug/debug_server_flags.h
deleted file mode 100644
index 21ed287..0000000
--- a/iree/rt/debug/debug_server_flags.h
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_
-#define IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_
-
-#include "iree/base/status.h"
-#include "iree/rt/debug/debug_server.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-// Creates a debug server based on the current --iree_* debug flags.
-// Returns nullptr if no server is compiled in or the flags are not set.
-StatusOr<std::unique_ptr<DebugServer>> CreateDebugServerFromFlags();
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_
diff --git a/iree/rt/debug/debug_server_tcp.cc b/iree/rt/debug/debug_server_tcp.cc
deleted file mode 100644
index 14240e7..0000000
--- a/iree/rt/debug/debug_server_tcp.cc
+++ /dev/null
@@ -1,459 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <netinet/in.h>
-#include <netinet/tcp.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <unistd.h>
-
-#include <algorithm>
-#include <cerrno>
-#include <exception>
-#include <thread> // NOLINT
-
-#include "absl/base/thread_annotations.h"
-#include "absl/memory/memory.h"
-#include "absl/synchronization/mutex.h"
-#include "flatbuffers/flatbuffers.h"
-#include "iree/base/status.h"
-#include "iree/rt/debug/debug_server.h"
-#include "iree/rt/debug/debug_service.h"
-#include "iree/rt/debug/debug_tcp_util.h"
-#include "iree/schemas/debug_service_generated.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-namespace {
-
-// Writes the given typed response message to the given fd by wrapping it in
-// a size-prefixed rpc::Request union.
-//
-// Example:
-// ::flatbuffers::FlatBufferBuilder fbb;
-// rpc::SuspendInvocationResponseBuilder response(fbb);
-// RETURN_IF_ERROR(WriteResponse(fd_, response.Finish(), std::move(fbb)));
-template <typename T>
-Status WriteResponse(int fd, ::flatbuffers::Offset<T> message_offs,
- ::flatbuffers::FlatBufferBuilder fbb) {
- rpc::ResponseBuilder response_builder(fbb);
- response_builder.add_message_type(rpc::ResponseUnionTraits<T>::enum_value);
- response_builder.add_message(message_offs.Union());
- auto response_offs = response_builder.Finish();
- rpc::ServicePacketBuilder packet_builder(fbb);
- packet_builder.add_response(response_offs);
- fbb.FinishSizePrefixed(packet_builder.Finish());
- return tcp::WriteBuffer(fd, fbb.Release());
-}
-
-class TcpDebugSession : public DebugSession {
- public:
- using ClosedCallback =
- std::function<void(TcpDebugSession* session, Status status)>;
-
- static StatusOr<std::unique_ptr<TcpDebugSession>> Accept(
- DebugService* debug_service, int client_fd,
- ClosedCallback closed_callback) {
- VLOG(2) << "Client " << client_fd << ": Setting up socket options...";
- // Disable Nagel's algorithm to ensure we have low latency.
- RETURN_IF_ERROR(tcp::ToggleSocketNagelsAlgorithm(client_fd, false));
- // Enable keepalive assuming the client is local and this high freq is ok.
- RETURN_IF_ERROR(tcp::ToggleSocketLocalKeepalive(client_fd, true));
- // Linger around for a bit to flush all data.
- RETURN_IF_ERROR(tcp::ToggleSocketLinger(client_fd, true));
-
- return absl::make_unique<TcpDebugSession>(debug_service, client_fd,
- std::move(closed_callback));
- }
-
- TcpDebugSession(DebugService* debug_service, int client_fd,
- ClosedCallback closed_callback)
- : debug_service_(debug_service),
- client_fd_(client_fd),
- closed_callback_(std::move(closed_callback)) {
- CHECK_OK(debug_service_->RegisterDebugSession(this));
- session_thread_ = std::thread([this]() { SessionThread(); });
- }
-
- ~TcpDebugSession() override {
- CHECK_OK(debug_service_->UnregisterDebugSession(this));
- VLOG(2) << "Client " << client_fd_ << ": Shutting down session socket...";
- ::shutdown(client_fd_, SHUT_RD);
- if (session_thread_.joinable() &&
- session_thread_.get_id() != std::this_thread::get_id()) {
- VLOG(2) << "Client " << client_fd_ << ": Joining socket thread...";
- session_thread_.join();
- VLOG(2) << "Client " << client_fd_ << ": Joined socket thread!";
- } else {
- VLOG(2) << "Client " << client_fd_ << ": Detaching socket thread...";
- session_thread_.detach();
- }
- VLOG(2) << "Client " << client_fd_ << ": Closing session socket...";
- ::close(client_fd_);
- VLOG(2) << "Client " << client_fd_ << ": Closed session socket!";
- client_fd_ = -1;
- }
-
- Status OnServiceShutdown() {
- VLOG(2) << "Client " << client_fd_ << ": Post OnServiceShutdown()";
- ::flatbuffers::FlatBufferBuilder fbb;
- rpc::ServiceShutdownEventBuilder event(fbb);
- return PostEvent(event.Finish(), std::move(fbb));
- }
-
- Status OnContextRegistered(Context* context) override {
- VLOG(2) << "Client " << client_fd_ << ": Post OnContextRegistered("
- << context->id() << ")";
- ::flatbuffers::FlatBufferBuilder fbb;
- rpc::ContextRegisteredEventBuilder event(fbb);
- event.add_context_id(context->id());
- return PostEvent(event.Finish(), std::move(fbb));
- }
- Status OnContextUnregistered(Context* context) override {
- VLOG(2) << "Client " << client_fd_ << ": Post OnContextUnregistered("
- << context->id() << ")";
- ::flatbuffers::FlatBufferBuilder fbb;
- rpc::ContextUnregisteredEventBuilder event(fbb);
- event.add_context_id(context->id());
- return PostEvent(event.Finish(), std::move(fbb));
- }
-
- Status OnModuleLoaded(Context* context, Module* module) override {
- VLOG(2) << "Client " << client_fd_ << ": Post OnModuleLoaded("
- << context->id() << ", " << module->name() << ")";
- ::flatbuffers::FlatBufferBuilder fbb;
- auto module_name_offs =
- fbb.CreateString(module->name().data(), module->name().size());
- rpc::ModuleLoadedEventBuilder event(fbb);
- event.add_context_id(context->id());
- event.add_module_name(module_name_offs);
- return PostEvent(event.Finish(), std::move(fbb));
- }
-
- Status OnInvocationRegistered(Invocation* invocation) override {
- VLOG(2) << "Client " << client_fd_ << ": Post OnInvocationRegistered("
- << invocation->id() << ")";
- ::flatbuffers::FlatBufferBuilder fbb;
- rpc::InvocationRegisteredEventBuilder event(fbb);
- event.add_invocation_id(invocation->id());
- return PostEvent(event.Finish(), std::move(fbb));
- }
- Status OnInvocationUnregistered(Invocation* invocation) override {
- VLOG(2) << "Client " << client_fd_ << ": Post OnInvocationUnregistered("
- << invocation->id() << ")";
- ::flatbuffers::FlatBufferBuilder fbb;
- rpc::InvocationUnregisteredEventBuilder event(fbb);
- event.add_invocation_id(invocation->id());
- return PostEvent(event.Finish(), std::move(fbb));
- }
-
- Status OnBreakpointResolved(const rpc::BreakpointDefT& breakpoint,
- Context* context) override {
- VLOG(2) << "Client " << client_fd_ << ": Post OnBreakpointResolved("
- << breakpoint.breakpoint_id << ", " << context->id() << ", "
- << breakpoint.function_ordinal << ")";
- rpc::BreakpointResolvedEventT event;
- event.breakpoint = absl::make_unique<rpc::BreakpointDefT>();
- *event.breakpoint = breakpoint;
- event.context_id = context->id();
- ::flatbuffers::FlatBufferBuilder fbb;
- return PostEvent(rpc::BreakpointResolvedEvent::Pack(fbb, &event),
- std::move(fbb));
- }
-
- Status OnBreakpointHit(int breakpoint_id,
- const Invocation& invocation) override {
- VLOG(2) << "Client " << client_fd_ << ": Post OnBreakpointHit("
- << breakpoint_id << ", " << invocation.id() << ")";
- ::flatbuffers::FlatBufferBuilder fbb;
- ASSIGN_OR_RETURN(auto invocation_offs,
- debug_service_->SerializeInvocation(invocation, &fbb));
- rpc::BreakpointHitEventBuilder event(fbb);
- event.add_breakpoint_id(breakpoint_id);
- event.add_invocation(invocation_offs);
- return PostEvent(event.Finish(), std::move(fbb));
- }
-
- private:
- void SessionThread() {
- VLOG(2) << "Client " << client_fd_ << ": Thread entry";
- Status session_status = OkStatus();
- while (session_status.ok()) {
- auto buffer_or = tcp::ReadBuffer<rpc::Request>(client_fd_);
- if (!buffer_or.ok()) {
- if (IsCancelled(buffer_or.status())) {
- // Graceful shutdown.
- VLOG(2) << "Client " << client_fd_ << ": Graceful shutdown requested";
- break;
- }
- // Error reading.
- session_status = std::move(buffer_or).status();
- LOG(ERROR) << "Client " << client_fd_
- << ": Error reading request buffer: " << session_status;
- break;
- }
- auto request_buffer = std::move(buffer_or).ValueOrDie();
- session_status = DispatchRequest(request_buffer.GetRoot());
- if (!session_status.ok()) {
- LOG(ERROR) << "Client " << client_fd_
- << ": Error dispatching request: " << session_status;
- break;
- }
- }
- VLOG(2) << "Client " << client_fd_ << ": Thread exit";
- AbortSession(session_status);
- }
-
- void AbortSession(Status status) {
- if (status.ok()) {
- VLOG(2) << "Debug client disconnected";
- } else {
- LOG(ERROR) << "Debug session aborted; " << status;
- ::flatbuffers::FlatBufferBuilder fbb;
- auto message_offs =
- fbb.CreateString(status.message().data(), status.message().size());
- rpc::StatusBuilder status_builder(fbb);
- status_builder.add_code(static_cast<int>(status.code()));
- status_builder.add_message(message_offs);
- auto status_offs = status_builder.Finish();
- rpc::ResponseBuilder response(fbb);
- response.add_status(status_offs);
- fbb.FinishSizePrefixed(response.Finish());
- tcp::WriteBuffer(client_fd_, fbb.Release()).IgnoreError();
- }
- closed_callback_(this, std::move(status));
- }
-
- template <typename T>
- Status PostEvent(::flatbuffers::Offset<T> event_offs,
- ::flatbuffers::FlatBufferBuilder fbb) {
- rpc::ServicePacketBuilder packet_builder(fbb);
- packet_builder.add_event_type(rpc::EventUnionTraits<T>::enum_value);
- packet_builder.add_event(event_offs.Union());
- fbb.FinishSizePrefixed(packet_builder.Finish());
- return tcp::WriteBuffer(client_fd_, fbb.Release());
- }
-
- Status DispatchRequest(const rpc::Request& request) {
- ::flatbuffers::FlatBufferBuilder fbb;
- switch (request.message_type()) {
-#define DISPATCH_REQUEST(method_name) \
- case rpc::RequestUnion::method_name##Request: { \
- VLOG(2) << "Client " << client_fd_ \
- << ": DispatchRequest(" #method_name ")..."; \
- ASSIGN_OR_RETURN(auto response_offs, \
- debug_service_->method_name( \
- *request.message_as_##method_name##Request(), &fbb)); \
- return WriteResponse(client_fd_, response_offs, std::move(fbb)); \
- }
- DISPATCH_REQUEST(MakeReady);
- DISPATCH_REQUEST(GetStatus);
- DISPATCH_REQUEST(ListContexts);
- DISPATCH_REQUEST(GetModule);
- DISPATCH_REQUEST(GetFunction);
- DISPATCH_REQUEST(ListInvocations);
- DISPATCH_REQUEST(SuspendInvocations);
- DISPATCH_REQUEST(ResumeInvocations);
- DISPATCH_REQUEST(StepInvocation);
- DISPATCH_REQUEST(GetInvocationLocal);
- DISPATCH_REQUEST(SetInvocationLocal);
- DISPATCH_REQUEST(ListBreakpoints);
- DISPATCH_REQUEST(AddBreakpoint);
- DISPATCH_REQUEST(RemoveBreakpoint);
- DISPATCH_REQUEST(StartProfiling);
- DISPATCH_REQUEST(StopProfiling);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented debug service request: "
- << static_cast<int>(request.message_type());
- }
- }
-
- DebugService* debug_service_;
- int client_fd_;
- ClosedCallback closed_callback_;
- std::thread session_thread_;
-};
-
-class TcpDebugServer final : public DebugServer {
- public:
- static StatusOr<std::unique_ptr<TcpDebugServer>> Listen(int port) {
- // We support both IPv4 and IPv6 by using the IN6ADDR_ANY. This requires
- // that we setup the socket as INET6 and enable reuse (so the same port can
- // be bound for both IPv4 and IPv6).
- int listen_fd = ::socket(AF_INET6, SOCK_STREAM, 0);
- RETURN_IF_ERROR(tcp::ToggleSocketAddressReuse(listen_fd, true));
-
- struct sockaddr_in6 socket_addr = {0};
- socket_addr.sin6_family = AF_INET6;
- socket_addr.sin6_port = htons(port);
- socket_addr.sin6_addr = in6addr_any;
- if (::bind(listen_fd, reinterpret_cast<struct sockaddr*>(&socket_addr),
- sizeof(socket_addr)) < 0) {
- return AlreadyExistsErrorBuilder(IREE_LOC)
- << "Unable to bind socket to port " << port << ": (" << errno
- << ") " << ::strerror(errno);
- }
- if (::listen(listen_fd, 1)) {
- ::close(listen_fd);
- return AlreadyExistsErrorBuilder(IREE_LOC)
- << "Unable to listen on port " << port << ": (" << errno << ") "
- << ::strerror(errno);
- }
- return absl::make_unique<TcpDebugServer>(listen_fd);
- }
-
- TcpDebugServer(int listen_fd) : listen_fd_(listen_fd) {
- server_thread_ = std::thread([this]() { ListenThread(); });
- }
-
- ~TcpDebugServer() ABSL_LOCKS_EXCLUDED(mutex_) override {
- absl::ReleasableMutexLock lock(&mutex_);
- LOG(INFO) << "Shutting down debug server...";
-
- // Notify all sessions.
- for (auto& session : sessions_) {
- session->OnServiceShutdown().IgnoreError();
- }
-
- // Shut down listen socket first so that we can't accept new connections.
- VLOG(2) << "Shutting down listen socket...";
- ::shutdown(listen_fd_, SHUT_RDWR);
- if (server_thread_.joinable()) {
- VLOG(2) << "Joining listen thread...";
- server_thread_.join();
- VLOG(2) << "Joined listen thread!";
- }
- VLOG(2) << "Closing listen socket...";
- ::close(listen_fd_);
- listen_fd_ = -1;
- VLOG(2) << "Closed listen socket!";
-
- // Kill all active sessions. Note that we must do this outside of our lock.
- std::vector<std::unique_ptr<TcpDebugSession>> sessions =
- std::move(sessions_);
- std::vector<std::function<void()>> at_exit_callbacks =
- std::move(at_exit_callbacks_);
- lock.Release();
- VLOG(2) << "Clearing live sessions...";
- sessions.clear();
- VLOG(2) << "Calling AtExit callbacks...";
- for (auto& callback : at_exit_callbacks) {
- callback();
- }
- LOG(INFO) << "Debug server shutdown!";
- }
-
- DebugService* debug_service() { return &debug_service_; }
-
- Status AcceptNewSession(int client_fd) {
- LOG(INFO) << "Accepting new client session as " << client_fd;
- ASSIGN_OR_RETURN(auto session,
- TcpDebugSession::Accept(
- &debug_service_, client_fd,
- [this](TcpDebugSession* session, Status status) {
- absl::MutexLock lock(&mutex_);
- for (auto it = sessions_.begin();
- it != sessions_.end(); ++it) {
- if (it->get() == session) {
- sessions_.erase(it);
- break;
- }
- }
- return OkStatus();
- }));
-
- absl::MutexLock lock(&mutex_);
- sessions_.push_back(std::move(session));
- return OkStatus();
- }
-
- void AtExit(std::function<void()> callback) override {
- absl::MutexLock lock(&mutex_);
- at_exit_callbacks_.push_back(std::move(callback));
- }
-
- Status WaitUntilSessionReady() override {
- return debug_service_.WaitUntilAllSessionsReady();
- }
-
- protected:
- Status RegisterContext(Context* context) override {
- return debug_service_.RegisterContext(context);
- }
- Status UnregisterContext(Context* context) override {
- return debug_service_.UnregisterContext(context);
- }
- Status RegisterContextModule(Context* context, Module* module) override {
- return debug_service_.RegisterContextModule(context, module);
- }
- Status RegisterInvocation(Invocation* invocation) override {
- return debug_service_.RegisterInvocation(invocation);
- }
- Status UnregisterInvocation(Invocation* invocation) override {
- return debug_service_.UnregisterInvocation(invocation);
- }
-
- private:
- void ListenThread() {
- VLOG(2) << "Listen thread entry";
- while (true) {
- struct sockaddr_in accept_socket_addr;
- socklen_t accept_socket_addr_length = sizeof(accept_socket_addr);
- int accepted_fd = ::accept(
- listen_fd_, reinterpret_cast<struct sockaddr*>(&accept_socket_addr),
- &accept_socket_addr_length);
- if (accepted_fd < 0) {
- if (errno == EINVAL) {
- // Shutting down gracefully.
- break;
- }
- // We may be able to recover from some of these cases, but... shrug.
- LOG(FATAL) << "Failed to accept client socket: (" << errno << ") "
- << ::strerror(errno);
- break;
- }
- auto accept_status = AcceptNewSession(accepted_fd);
- if (!accept_status.ok()) {
- LOG(ERROR) << "Failed to accept incoming debug client: "
- << accept_status;
- }
- }
- VLOG(2) << "Listen thread exit";
- }
-
- int listen_fd_;
- std::thread server_thread_;
-
- absl::Mutex mutex_;
- std::vector<std::unique_ptr<TcpDebugSession>> sessions_
- ABSL_GUARDED_BY(mutex_);
- std::vector<std::function<void()>> at_exit_callbacks_ ABSL_GUARDED_BY(mutex_);
-
- DebugService debug_service_;
-};
-
-} // namespace
-
-// static
-StatusOr<std::unique_ptr<DebugServer>> DebugServer::Create(int listen_port) {
- ASSIGN_OR_RETURN(auto debug_server, TcpDebugServer::Listen(listen_port));
- LOG(INFO) << "Debug server listening on localhost:" << listen_port;
- return debug_server;
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/debug/debug_service.cc b/iree/rt/debug/debug_service.cc
deleted file mode 100644
index 9c1f7ac..0000000
--- a/iree/rt/debug/debug_service.cc
+++ /dev/null
@@ -1,850 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/debug/debug_service.h"
-
-#include <algorithm>
-#include <memory>
-
-#include "absl/strings/str_join.h"
-#include "absl/synchronization/mutex.h"
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/reflection.h"
-#include "iree/base/flatbuffer_util.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/rt/instance.h"
-#include "iree/schemas/debug_service_generated.h"
-#include "iree/schemas/reflection_data.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-namespace {
-
-using ::flatbuffers::FlatBufferBuilder;
-using ::flatbuffers::Offset;
-using ::iree::hal::BufferView;
-
-int32_t NextUniqueBreakpointId() {
- static std::atomic<int32_t> next_id = 0;
- return ++next_id;
-}
-
-// Gets an embedded flatbuffers reflection schema.
-const ::reflection::Schema& GetSchema(const char* schema_name) {
- for (const auto* file_toc = schemas::reflection_data_create();
- file_toc != nullptr; ++file_toc) {
- if (std::strcmp(file_toc->name, schema_name) == 0) {
- return *::reflection::GetSchema(file_toc->data);
- }
- }
- LOG(FATAL) << "FlatBuffer schema '" << schema_name
- << "' not found in binary; ensure it is in :reflection_data";
-}
-
-// Recursively copies a flatbuffer table, returning the root offset in |fbb|.
-template <typename T>
-StatusOr<Offset<T>> DeepCopyTable(const char* schema_name, const T& table_def,
- FlatBufferBuilder* fbb) {
- const auto* root_table =
- reinterpret_cast<const ::flatbuffers::Table*>(std::addressof(table_def));
- const auto& schema = GetSchema(schema_name);
- return {::flatbuffers::CopyTable(*fbb, schema, *schema.root_table(),
- *root_table,
- /*use_string_pooling=*/false)
- .o};
-}
-
-// Serializes a buffer_view value, optionally including the entire buffer
-// contents.
-StatusOr<Offset<rpc::BufferViewDef>> SerializeBufferView(
- const BufferView& buffer_view, bool include_buffer_contents,
- FlatBufferBuilder* fbb) {
- auto shape_offs = fbb->CreateVector(buffer_view.shape.subspan().data(),
- buffer_view.shape.subspan().size());
- rpc::BufferViewDefBuilder value(*fbb);
- value.add_is_valid(buffer_view.buffer != nullptr);
- value.add_shape(shape_offs);
- value.add_element_size(buffer_view.element_size);
- if (include_buffer_contents) {
- // TODO(benvanik): add buffer data.
- }
- return value.Finish();
-}
-
-// Serializes a stack frame.
-StatusOr<Offset<rpc::StackFrameDef>> SerializeStackFrame(
- const StackFrame& stack_frame, FlatBufferBuilder* fbb) {
- ASSIGN_OR_RETURN(int function_ordinal,
- stack_frame.module().function_table().LookupFunctionOrdinal(
- stack_frame.function()));
- auto module_name_offs = fbb->CreateString(stack_frame.module().name().data(),
- stack_frame.module().name().size());
- std::vector<Offset<rpc::BufferViewDef>> local_offs_list;
- for (const auto& local : stack_frame.locals()) {
- ASSIGN_OR_RETURN(
- auto local_offs,
- SerializeBufferView(local, /*include_buffer_contents=*/false, fbb));
- local_offs_list.push_back(local_offs);
- }
- auto locals_offs = fbb->CreateVector(local_offs_list);
- rpc::StackFrameDefBuilder sfb(*fbb);
- sfb.add_module_name(module_name_offs);
- sfb.add_function_ordinal(function_ordinal);
- sfb.add_offset(stack_frame.offset());
- sfb.add_locals(locals_offs);
- return sfb.Finish();
-}
-
-// Resolves a local from a invocation:frame:local_index to a BufferView.
-StatusOr<BufferView*> ResolveInvocationLocal(Invocation* invocation,
- int frame_index, int local_index) {
- auto frames = invocation->mutable_stack()->mutable_frames();
- if (frame_index < 0 || frame_index > frames.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Frame index " << frame_index << " out of bounds ("
- << frames.size() << ")";
- }
- auto locals = frames[frame_index].mutable_locals();
- if (local_index < 0 || local_index > locals.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Local index " << local_index << " out of bounds ("
- << locals.size() << ")";
- }
- return &locals[local_index];
-}
-
-// Suspends a set of invocations and blocks until all have been suspended (or
-// one or more fails to suspend). This works only when the caller is *not* one
-// of the threads executing a invocation in |invocations| (this normally
-// shouldn't happen, but may if we support eval()-like semantics).
-Status SuspendInvocationsAndWait(absl::Span<Invocation*> invocations) {
- absl::Mutex suspend_mutex;
- Status one_suspend_status = OkStatus();
- std::list<int> pending_suspend_ids;
- for (auto* invocation : invocations) {
- pending_suspend_ids.push_back(invocation->id());
- }
- for (auto* invocation : invocations) {
- auto suspend_callback = [&, invocation](Status suspend_status) {
- absl::MutexLock lock(&suspend_mutex);
- auto it = std::find(pending_suspend_ids.begin(),
- pending_suspend_ids.end(), invocation->id());
- CHECK(it != pending_suspend_ids.end());
- pending_suspend_ids.erase(it);
- if (!suspend_status.ok()) {
- one_suspend_status = std::move(suspend_status);
- }
- };
- RETURN_IF_ERROR(invocation->Suspend(suspend_callback));
- }
- suspend_mutex.LockWhen(absl::Condition(
- +[](std::list<int>* pending_suspend_ids) {
- return pending_suspend_ids->empty();
- },
- &pending_suspend_ids));
- suspend_mutex.Unlock();
- return one_suspend_status;
-}
-
-} // namespace
-
-Status DebugService::SuspendAllInvocations() {
- VLOG(2) << "SuspendAllInvocations";
- for (auto* invocation : invocations_) {
- RETURN_IF_ERROR(invocation->Suspend());
- }
- return OkStatus();
-}
-
-Status DebugService::ResumeAllInvocations() {
- VLOG(2) << "ResumeAllInvocations";
- for (auto* invocation : invocations_) {
- RETURN_IF_ERROR(invocation->Resume());
- }
- return OkStatus();
-}
-
-Status DebugService::RegisterContext(Context* context) {
- absl::MutexLock lock(&mutex_);
- VLOG(2) << "RegisterContext(" << context->id() << ")";
- RETURN_IF_ERROR(SuspendAllInvocations());
- RETURN_IF_ERROR(UnreadyAllSessions());
- contexts_.push_back(context);
- for (auto* session : sessions_) {
- RETURN_IF_ERROR(session->OnContextRegistered(context));
- }
- RETURN_IF_ERROR(ResumeAllInvocations());
- return OkStatus();
-}
-
-Status DebugService::UnregisterContext(Context* context) {
- absl::MutexLock lock(&mutex_);
- VLOG(2) << "UnregisterContext(" << context->id() << ")";
- auto it = std::find(contexts_.begin(), contexts_.end(), context);
- if (it == contexts_.end()) {
- return NotFoundErrorBuilder(IREE_LOC) << "Context not registered";
- }
- RETURN_IF_ERROR(SuspendAllInvocations());
- RETURN_IF_ERROR(UnreadyAllSessions());
- for (auto* session : sessions_) {
- RETURN_IF_ERROR(session->OnContextUnregistered(context));
- }
- contexts_.erase(it);
- RETURN_IF_ERROR(ResumeAllInvocations());
- return OkStatus();
-}
-
-StatusOr<Context*> DebugService::GetContext(int context_id) const {
- for (auto* context : contexts_) {
- if (context->id() == context_id) {
- return context;
- }
- }
- return NotFoundErrorBuilder(IREE_LOC)
- << "Context with ID " << context_id
- << " not registered with the debug service";
-}
-
-Status DebugService::RegisterContextModule(Context* context, Module* module) {
- absl::MutexLock lock(&mutex_);
- VLOG(2) << "RegisterContextModule(" << context->id() << ", " << module->name()
- << ")";
- RETURN_IF_ERROR(SuspendAllInvocations());
- RETURN_IF_ERROR(UnreadyAllSessions());
- RETURN_IF_ERROR(RegisterModuleBreakpoints(context, module));
- for (auto* session : sessions_) {
- RETURN_IF_ERROR(session->OnModuleLoaded(context, module));
- }
- RETURN_IF_ERROR(ResumeAllInvocations());
- return OkStatus();
-}
-
-StatusOr<Module*> DebugService::GetModule(int context_id,
- absl::string_view module_name) const {
- ASSIGN_OR_RETURN(auto* context, GetContext(context_id));
- for (const auto& module : context->modules()) {
- if (module->name() == module_name) {
- return module.get();
- }
- }
- return NotFoundErrorBuilder(IREE_LOC)
- << "Module '" << module_name << "' not found on context "
- << context_id;
-}
-
-Status DebugService::RegisterInvocation(Invocation* invocation) {
- absl::MutexLock lock(&mutex_);
- VLOG(2) << "RegisterInvocation(" << invocation->id() << ")";
- RETURN_IF_ERROR(SuspendAllInvocations());
- RETURN_IF_ERROR(UnreadyAllSessions());
- invocations_.push_back(invocation);
- if (sessions_unready_) {
- // Suspend immediately as a debugger is not yet read.
- RETURN_IF_ERROR(invocation->Suspend());
- }
- for (auto* session : sessions_) {
- RETURN_IF_ERROR(session->OnInvocationRegistered(invocation));
- }
- RETURN_IF_ERROR(ResumeAllInvocations());
- return OkStatus();
-}
-
-Status DebugService::UnregisterInvocation(Invocation* invocation) {
- absl::MutexLock lock(&mutex_);
- VLOG(2) << "UnregisterInvocation(" << invocation->id() << ")";
- auto it = std::find(invocations_.begin(), invocations_.end(), invocation);
- if (it == invocations_.end()) {
- return NotFoundErrorBuilder(IREE_LOC) << "Invocation state not registered";
- }
- RETURN_IF_ERROR(SuspendAllInvocations());
- RETURN_IF_ERROR(UnreadyAllSessions());
- for (auto* session : sessions_) {
- RETURN_IF_ERROR(session->OnInvocationUnregistered(invocation));
- }
- invocations_.erase(it);
- RETURN_IF_ERROR(ResumeAllInvocations());
- return OkStatus();
-}
-
-StatusOr<Invocation*> DebugService::GetInvocation(int invocation_id) const {
- for (auto* invocation : invocations_) {
- if (invocation->id() == invocation_id) {
- return invocation;
- }
- }
- return NotFoundErrorBuilder(IREE_LOC)
- << "Invocation state with ID " << invocation_id
- << " not registered with the debug service";
-}
-
-StatusOr<Offset<rpc::InvocationDef>> DebugService::SerializeInvocation(
- const Invocation& invocation, FlatBufferBuilder* fbb) {
- std::vector<Offset<rpc::StackFrameDef>> frame_offs_list;
- for (const auto& frame : invocation.stack().frames()) {
- ASSIGN_OR_RETURN(auto frame_offs, SerializeStackFrame(frame, fbb));
- frame_offs_list.push_back(frame_offs);
- }
- auto frames_offs = fbb->CreateVector(frame_offs_list);
- rpc::InvocationDefBuilder fsb(*fbb);
- fsb.add_invocation_id(invocation.id());
- fsb.add_frames(frames_offs);
- return fsb.Finish();
-}
-
-Status DebugService::RegisterDebugSession(DebugSession* session) {
- absl::MutexLock lock(&mutex_);
- VLOG(2) << "RegisterDebugSession(" << session->id() << ")";
- sessions_.push_back(session);
- if (session->is_ready()) {
- ++sessions_ready_;
- } else {
- // Immediately suspend all invocations until the session readies up (or
- // disconnects).
- ++sessions_unready_;
- RETURN_IF_ERROR(SuspendAllInvocations());
- }
- return OkStatus();
-}
-
-Status DebugService::UnregisterDebugSession(DebugSession* session) {
- absl::MutexLock lock(&mutex_);
- VLOG(2) << "UnregisterDebugSession(" << session->id() << ")";
- auto it = std::find(sessions_.begin(), sessions_.end(), session);
- if (it == sessions_.end()) {
- return NotFoundErrorBuilder(IREE_LOC) << "Session not registered";
- }
- sessions_.erase(it);
- if (session->is_ready()) {
- --sessions_ready_;
- } else {
- // If the session never readied up then we still have all invocations
- // suspended waiting for it. We should resume so that we don't block
- // forever.
- --sessions_unready_;
- RETURN_IF_ERROR(ResumeAllInvocations());
- }
- return OkStatus();
-}
-
-Status DebugService::WaitUntilAllSessionsReady() {
- VLOG(1) << "Waiting until all sessions are ready...";
- struct CondState {
- DebugService* service;
- bool had_sessions;
- bool consider_aborted;
- } cond_state;
- {
- absl::MutexLock lock(&mutex_);
- cond_state.service = this;
- cond_state.had_sessions = !sessions_.empty();
- cond_state.consider_aborted = false;
- }
- mutex_.LockWhen(absl::Condition(
- +[](CondState* cond_state) {
- cond_state->service->mutex_.AssertHeld();
- if (cond_state->service->sessions_ready_ > 0) {
- // One or more sessions are ready.
- return true;
- }
- if (cond_state->service->sessions_unready_ > 0) {
- // One or more sessions are connected but not yet ready.
- cond_state->had_sessions = true;
- return false;
- }
- if (cond_state->had_sessions &&
- cond_state->service->sessions_.empty()) {
- // We had sessions but now we don't, consider this an error and bail.
- // This can happen when a session connects but never readies up.
- cond_state->consider_aborted = true;
- return true;
- }
- return false;
- },
- &cond_state));
- mutex_.Unlock();
- if (cond_state.consider_aborted) {
- return AbortedErrorBuilder(IREE_LOC)
- << "At least one session connected but never readied up";
- }
- VLOG(1) << "Sessions ready, resuming";
- return OkStatus();
-}
-
-StatusOr<Offset<rpc::MakeReadyResponse>> DebugService::MakeReady(
- const rpc::MakeReadyRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: MakeReady()";
- // TODO(benvanik): support more than one session.
- CHECK_LE(sessions_.size(), 1) << "Only one session is currently supported";
- if (!sessions_.empty()) {
- RETURN_IF_ERROR(sessions_[0]->OnReady());
- }
- sessions_ready_ = 0;
- sessions_unready_ = 0;
- for (auto* session : sessions_) {
- sessions_ready_ += session->is_ready() ? 1 : 0;
- sessions_unready_ += session->is_ready() ? 0 : 1;
- }
- rpc::MakeReadyResponseBuilder response(*fbb);
- return response.Finish();
-}
-
-Status DebugService::UnreadyAllSessions() {
- for (auto* session : sessions_) {
- RETURN_IF_ERROR(session->OnUnready());
- }
- sessions_ready_ = 0;
- sessions_unready_ = sessions_.size();
- return OkStatus();
-}
-
-StatusOr<Offset<rpc::GetStatusResponse>> DebugService::GetStatus(
- const rpc::GetStatusRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: GetStatus()";
- rpc::GetStatusResponseBuilder response(*fbb);
- response.add_protocol(0);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::ListContextsResponse>> DebugService::ListContexts(
- const rpc::ListContextsRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: ListContexts()";
- std::vector<Offset<rpc::ContextDef>> context_offs;
- for (auto* context : contexts_) {
- std::vector<Offset<rpc::NativeFunctionDef>> native_function_offs_list;
- for (const auto& pair : context->native_functions()) {
- auto name_offs = fbb->CreateString(pair.first);
- rpc::NativeFunctionDefBuilder native_function(*fbb);
- native_function.add_name(name_offs);
- native_function_offs_list.push_back(native_function.Finish());
- }
- auto native_functions_offs = fbb->CreateVector(native_function_offs_list);
-
- std::vector<std::string> module_names;
- for (const auto& module : context->modules()) {
- module_names.push_back(std::string(module->name()));
- }
- auto module_names_offs = fbb->CreateVectorOfStrings(module_names);
-
- rpc::ContextDefBuilder context_def(*fbb);
- context_def.add_context_id(context->id());
- context_def.add_native_functions(native_functions_offs);
- context_def.add_module_names(module_names_offs);
- context_offs.push_back(context_def.Finish());
- }
-
- auto contexts_offs = fbb->CreateVector(context_offs);
- rpc::ListContextsResponseBuilder response(*fbb);
- response.add_contexts(contexts_offs);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::GetModuleResponse>> DebugService::GetModule(
- const rpc::GetModuleRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: GetModule(" << request.context_id() << ", "
- << WrapString(request.module_name()) << ")";
- ASSIGN_OR_RETURN(auto* module, GetModule(request.context_id(),
- WrapString(request.module_name())));
- // TODO(benvanik): find a way to do this without possibly duping all memory.
- // I suspect that when we make constants poolable then there's only one
- // place to kill and there may be magic we could use to do that during a
- // reflection pass.
- ModuleDefT module_t;
- module->def().UnPackTo(&module_t);
- for (auto& function : module_t.function_table->functions) {
- function->bytecode->contents.clear();
- }
- auto trimmed_module_offs = ModuleDef::Pack(*fbb, &module_t);
- rpc::GetModuleResponseBuilder response(*fbb);
- response.add_module_(trimmed_module_offs);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::GetFunctionResponse>> DebugService::GetFunction(
- const rpc::GetFunctionRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: GetFunction(" << WrapString(request.module_name()) << ", "
- << request.function_ordinal() << ")";
- ASSIGN_OR_RETURN(auto* module, GetModule(request.context_id(),
- WrapString(request.module_name())));
- ASSIGN_OR_RETURN(auto& function, module->function_table().LookupFunction(
- request.function_ordinal()));
- Offset<BytecodeDef> bytecode_offs;
- if (function.def().bytecode()) {
- ASSIGN_OR_RETURN(
- bytecode_offs,
- DeepCopyTable("bytecode_def.bfbs", *function.def().bytecode(), fbb));
- }
- rpc::GetFunctionResponseBuilder response(*fbb);
- response.add_bytecode(bytecode_offs);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::ResolveFunctionResponse>> DebugService::ResolveFunction(
- const rpc::ResolveFunctionRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: ResolveFunction(" << WrapString(request.module_name())
- << ", " << WrapString(request.function_name()) << ")";
- std::vector<int32_t> context_ids;
- auto context_ids_offs = fbb->CreateVector(context_ids);
- int function_ordinal = -1;
- for (auto* context : contexts_) {
- for (const auto& module : context->modules()) {
- if (module->name() == WrapString(request.module_name())) {
- ASSIGN_OR_RETURN(function_ordinal,
- module->function_table().LookupFunctionOrdinalByName(
- WrapString(request.function_name())));
- context_ids.push_back(context->id());
- break;
- }
- }
- }
- rpc::ResolveFunctionResponseBuilder response(*fbb);
- response.add_context_ids(context_ids_offs);
- response.add_function_ordinal(function_ordinal);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::ListInvocationsResponse>> DebugService::ListInvocations(
- const rpc::ListInvocationsRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: ListInvocations()";
- std::vector<Offset<rpc::InvocationDef>> invocation_offsets;
- for (auto* invocation : invocations_) {
- ASSIGN_OR_RETURN(auto invocation_offs,
- SerializeInvocation(*invocation, fbb));
- invocation_offsets.push_back(invocation_offs);
- }
- auto invocations_offs = fbb->CreateVector(invocation_offsets);
- rpc::ListInvocationsResponseBuilder response(*fbb);
- response.add_invocations(invocations_offs);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::SuspendInvocationsResponse>>
-DebugService::SuspendInvocations(const rpc::SuspendInvocationsRequest& request,
- FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: SuspendInvocations(invocation_ids=["
- << (request.invocation_ids()
- ? absl::StrJoin(*request.invocation_ids(), ", ")
- : "")
- << "])";
- std::vector<Offset<rpc::InvocationDef>> invocation_offsets;
- if (request.invocation_ids() && request.invocation_ids()->size() > 0) {
- // Suspending a list of invocations.
- std::vector<Invocation*> invocations_to_suspend;
- for (int invocation_id : *request.invocation_ids()) {
- ASSIGN_OR_RETURN(auto* invocation, GetInvocation(invocation_id));
- invocations_to_suspend.push_back(invocation);
- }
- RETURN_IF_ERROR(
- SuspendInvocationsAndWait(absl::MakeSpan(invocations_to_suspend)));
- for (auto* invocation : invocations_to_suspend) {
- ASSIGN_OR_RETURN(auto invocation_offs,
- SerializeInvocation(*invocation, fbb));
- invocation_offsets.push_back(invocation_offs);
- }
- } else {
- // Suspending all invocations.
- RETURN_IF_ERROR(SuspendAllInvocations());
- for (auto* invocation : invocations_) {
- ASSIGN_OR_RETURN(auto invocation_offs,
- SerializeInvocation(*invocation, fbb));
- invocation_offsets.push_back(invocation_offs);
- }
- }
- auto invocations_offs = fbb->CreateVector(invocation_offsets);
- rpc::SuspendInvocationsResponseBuilder response(*fbb);
- response.add_invocations(invocations_offs);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::ResumeInvocationsResponse>>
-DebugService::ResumeInvocations(const rpc::ResumeInvocationsRequest& request,
- FlatBufferBuilder* fbb) {
- VLOG(1) << "RPC: ResumeInvocations(invocation_ids=["
- << (request.invocation_ids()
- ? absl::StrJoin(*request.invocation_ids(), ", ")
- : "")
- << "])";
- absl::MutexLock lock(&mutex_);
- if (request.invocation_ids() && request.invocation_ids()->size() > 0) {
- // Resuming a list of invocations.
- for (int invocation_id : *request.invocation_ids()) {
- ASSIGN_OR_RETURN(auto* invocation, GetInvocation(invocation_id));
- RETURN_IF_ERROR(invocation->Resume());
- }
- } else {
- // Resuming all invocations.
- RETURN_IF_ERROR(ResumeAllInvocations());
- }
- rpc::ResumeInvocationsResponseBuilder response(*fbb);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::StepInvocationResponse>> DebugService::StepInvocation(
- const rpc::StepInvocationRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: StepInvocation(" << request.invocation_id() << ")";
- ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id()));
- Invocation::StepTarget step_target;
- // TODO(benvanik): step settings.
- RETURN_IF_ERROR(invocation->Step(step_target));
- rpc::StepInvocationResponseBuilder response(*fbb);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::GetInvocationLocalResponse>>
-DebugService::GetInvocationLocal(const rpc::GetInvocationLocalRequest& request,
- FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: GetInvocationLocal(" << request.invocation_id() << ", "
- << request.frame_index() << ", " << request.local_index() << ")";
- ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id()));
- ASSIGN_OR_RETURN(auto* local,
- ResolveInvocationLocal(invocation, request.frame_index(),
- request.local_index()));
-
- ASSIGN_OR_RETURN(
- auto value_offs,
- SerializeBufferView(*local, /*include_buffer_contents=*/true, fbb));
- rpc::GetInvocationLocalResponseBuilder response(*fbb);
- response.add_value(value_offs);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::SetInvocationLocalResponse>>
-DebugService::SetInvocationLocal(const rpc::SetInvocationLocalRequest& request,
- FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: SetInvocationLocal(" << request.invocation_id() << ", "
- << request.frame_index() << ", " << request.local_index() << ")";
- ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id()));
- ASSIGN_OR_RETURN(auto* local,
- ResolveInvocationLocal(invocation, request.frame_index(),
- request.local_index()));
-
- if (!request.value()) {
- local->shape.clear();
- local->element_size = 0;
- local->buffer.reset();
- } else {
- const auto& value = *request.value();
- local->shape.clear();
- if (value.shape()) {
- for (int dim : *value.shape()) {
- local->shape.push_back(dim);
- }
- }
- local->element_size = value.element_size();
- // TODO(benvanik): copy buffer data.
- }
-
- ASSIGN_OR_RETURN(
- auto value_offs,
- SerializeBufferView(*local, /*include_buffer_contents=*/true, fbb));
- rpc::SetInvocationLocalResponseBuilder response(*fbb);
- response.add_value(value_offs);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::ListBreakpointsResponse>> DebugService::ListBreakpoints(
- const rpc::ListBreakpointsRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: ListBreakpoints()";
- std::vector<Offset<rpc::BreakpointDef>> breakpoint_offs;
- for (const auto& breakpoint : breakpoints_) {
- breakpoint_offs.push_back(rpc::BreakpointDef::Pack(*fbb, &breakpoint));
- }
- auto breakpoints_offs = fbb->CreateVector(breakpoint_offs);
- rpc::ListBreakpointsResponseBuilder response(*fbb);
- response.add_breakpoints(breakpoints_offs);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::AddBreakpointResponse>> DebugService::AddBreakpoint(
- const rpc::AddBreakpointRequest& request, FlatBufferBuilder* fbb) {
- if (!request.breakpoint()) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No breakpoint specified";
- }
- absl::MutexLock lock(&mutex_);
- int breakpoint_id = NextUniqueBreakpointId();
- VLOG(1) << "RPC: AddBreakpoint(" << breakpoint_id << ")";
-
- RETURN_IF_ERROR(SuspendAllInvocations());
-
- rpc::BreakpointDefT breakpoint;
- request.breakpoint()->UnPackTo(&breakpoint);
- breakpoint.breakpoint_id = breakpoint_id;
- switch (breakpoint.breakpoint_type) {
- case rpc::BreakpointType::BYTECODE_FUNCTION:
- case rpc::BreakpointType::NATIVE_FUNCTION:
- for (auto* context : contexts_) {
- auto module_or = context->LookupModule(breakpoint.module_name);
- if (!module_or.ok()) continue;
- auto* module = module_or.ValueOrDie();
- RETURN_IF_ERROR(
- RegisterFunctionBreakpoint(context, module, &breakpoint));
- }
- break;
- default:
- return UnimplementedErrorBuilder(IREE_LOC) << "Unhandled breakpoint type";
- }
- breakpoints_.push_back(std::move(breakpoint));
-
- RETURN_IF_ERROR(ResumeAllInvocations());
-
- auto breakpoint_offs = rpc::BreakpointDef::Pack(*fbb, &breakpoints_.back());
- rpc::AddBreakpointResponseBuilder response(*fbb);
- response.add_breakpoint(breakpoint_offs);
- return response.Finish();
-}
-
-StatusOr<Offset<rpc::RemoveBreakpointResponse>> DebugService::RemoveBreakpoint(
- const rpc::RemoveBreakpointRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: RemoveBreakpoint(" << request.breakpoint_id() << ")";
- RETURN_IF_ERROR(SuspendAllInvocations());
-
- bool found = false;
- for (auto it = breakpoints_.begin(); it != breakpoints_.end(); ++it) {
- if (it->breakpoint_id == request.breakpoint_id()) {
- auto& breakpoint = *it;
- found = true;
- switch (breakpoint.breakpoint_type) {
- case rpc::BreakpointType::BYTECODE_FUNCTION:
- case rpc::BreakpointType::NATIVE_FUNCTION:
- RETURN_IF_ERROR(UnregisterFunctionBreakpoint(breakpoint));
- break;
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unhandled breakpoint type";
- }
- breakpoints_.erase(it);
- break;
- }
- }
-
- RETURN_IF_ERROR(ResumeAllInvocations());
- if (!found) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Breakpoint ID " << request.breakpoint_id() << " not found";
- }
-
- rpc::RemoveBreakpointResponseBuilder response(*fbb);
- return response.Finish();
-}
-
-Status DebugService::RegisterModuleBreakpoints(Context* context,
- Module* module) {
- for (auto& breakpoint : breakpoints_) {
- switch (breakpoint.breakpoint_type) {
- case rpc::BreakpointType::BYTECODE_FUNCTION:
- if (breakpoint.module_name == module->name()) {
- RETURN_IF_ERROR(
- RegisterFunctionBreakpoint(context, module, &breakpoint));
- }
- break;
- default:
- // Not relevant to modules.
- break;
- }
- }
- return OkStatus();
-}
-
-Status DebugService::RegisterFunctionBreakpoint(
- Context* context, Module* module, rpc::BreakpointDefT* breakpoint) {
- if (!breakpoint->function_name.empty()) {
- ASSIGN_OR_RETURN(breakpoint->function_ordinal,
- module->function_table().LookupFunctionOrdinalByName(
- breakpoint->function_name));
- }
- RETURN_IF_ERROR(module->mutable_function_table()->RegisterBreakpoint(
- breakpoint->function_ordinal, breakpoint->bytecode_offset,
- std::bind(&DebugService::OnFunctionBreakpointHit, this,
- breakpoint->breakpoint_id, std::placeholders::_1)));
- for (auto* session : sessions_) {
- RETURN_IF_ERROR(session->OnBreakpointResolved(*breakpoint, context));
- }
- return OkStatus();
-}
-
-Status DebugService::UnregisterFunctionBreakpoint(
- const rpc::BreakpointDefT& breakpoint) {
- for (auto* context : contexts_) {
- auto module_or = context->LookupModule(breakpoint.module_name);
- if (!module_or.ok()) continue;
- auto* module = module_or.ValueOrDie();
- RETURN_IF_ERROR(module->mutable_function_table()->UnregisterBreakpoint(
- breakpoint.function_ordinal, breakpoint.bytecode_offset));
- }
- return OkStatus();
-}
-
-Status DebugService::OnFunctionBreakpointHit(int breakpoint_id,
- const Invocation& invocation) {
- absl::ReleasableMutexLock lock(&mutex_);
- LOG(INFO) << "Breakpoint hit: " << breakpoint_id;
- RETURN_IF_ERROR(UnreadyAllSessions());
- for (auto* session : sessions_) {
- RETURN_IF_ERROR(session->OnBreakpointHit(breakpoint_id, invocation));
- }
- lock.Release();
-
- // TODO(benvanik): on-demand attach if desired?
-
- // Wait until all clients are ready.
- auto wait_status = WaitUntilAllSessionsReady();
- if (IsAborted(wait_status)) {
- // This means we lost all sessions. Just continue.
- VLOG(1) << "No sessions active; ignoring breakpoint and continuing";
- return OkStatus();
- }
- return wait_status;
-}
-
-StatusOr<Offset<rpc::StartProfilingResponse>> DebugService::StartProfiling(
- const rpc::StartProfilingRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: StartProfiling()";
- // TODO(benvanik): implement profiling.
- // ASSIGN_OR_RETURN(auto* context, GetContext(request.context_id()));
- // rpc::StartProfilingResponseBuilder response(*fbb);
- // return response.Finish();
- return UnimplementedErrorBuilder(IREE_LOC)
- << "StartProfiling not yet implemented";
-}
-
-StatusOr<Offset<rpc::StopProfilingResponse>> DebugService::StopProfiling(
- const rpc::StopProfilingRequest& request, FlatBufferBuilder* fbb) {
- absl::MutexLock lock(&mutex_);
- VLOG(1) << "RPC: StopProfiling()";
- // TODO(benvanik): implement profiling.
- // ASSIGN_OR_RETURN(auto* context, GetContext(request.context_id()));
- // rpc::StopProfilingResponseBuilder response(*fbb);
- // return response.Finish();
- return UnimplementedErrorBuilder(IREE_LOC)
- << "StopProfiling not yet implemented";
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/debug/debug_service.h b/iree/rt/debug/debug_service.h
deleted file mode 100644
index 47f17ab..0000000
--- a/iree/rt/debug/debug_service.h
+++ /dev/null
@@ -1,175 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_DEBUG_DEBUG_SERVICE_H_
-#define IREE_RT_DEBUG_DEBUG_SERVICE_H_
-
-#include <vector>
-
-#include "absl/base/thread_annotations.h"
-#include "absl/strings/string_view.h"
-#include "absl/synchronization/mutex.h"
-#include "flatbuffers/flatbuffers.h"
-#include "iree/base/status.h"
-#include "iree/rt/context.h"
-#include "iree/rt/debug/debug_session.h"
-#include "iree/schemas/debug_service_generated.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-// Debugging service used to implement the DebugService RPC methods in a
-// transport-independent way. Specific DebugServer implementations can compose
-// with a DebugService to avoid needing to maintain state themselves. Multiple
-// DebugServer instances could share the same DebugService instance to ensure
-// all clients - regardless of transport - share the same state.
-//
-// Thread-safe.
-class DebugService {
- public:
- // Registers a context with the debug service.
- // Ownership remains with the caller and UnregisterContext must be called
- // prior to the context being destroyed.
- Status RegisterContext(Context* context);
- Status UnregisterContext(Context* context);
-
- // Registers a new module linked into an existing Context.
- Status RegisterContextModule(Context* context, Module* module);
-
- // Registers a invocation state with the debug service.
- // Ownership remains with the caller and UnregisterInvocation must be called
- // prior to the invocation state being destroyed.
- Status RegisterInvocation(Invocation* invocation);
- Status UnregisterInvocation(Invocation* invocation);
-
- // Registers a debug session with the service.
- Status RegisterDebugSession(DebugSession* session);
- Status UnregisterDebugSession(DebugSession* session);
-
- // Blocks the caller until all sessions are ready.
- // Returns AbortedError if a session connects/is already connected but
- // disconnects during the wait.
- Status WaitUntilAllSessionsReady();
-
- StatusOr<::flatbuffers::Offset<rpc::MakeReadyResponse>> MakeReady(
- const rpc::MakeReadyRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
-
- StatusOr<::flatbuffers::Offset<rpc::GetStatusResponse>> GetStatus(
- const rpc::GetStatusRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
-
- StatusOr<::flatbuffers::Offset<rpc::ListContextsResponse>> ListContexts(
- const rpc::ListContextsRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
-
- StatusOr<::flatbuffers::Offset<rpc::GetModuleResponse>> GetModule(
- const rpc::GetModuleRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
- StatusOr<::flatbuffers::Offset<rpc::GetFunctionResponse>> GetFunction(
- const rpc::GetFunctionRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
- StatusOr<::flatbuffers::Offset<rpc::ResolveFunctionResponse>> ResolveFunction(
- const rpc::ResolveFunctionRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
-
- StatusOr<::flatbuffers::Offset<rpc::ListInvocationsResponse>> ListInvocations(
- const rpc::ListInvocationsRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
- StatusOr<::flatbuffers::Offset<rpc::SuspendInvocationsResponse>>
- SuspendInvocations(const rpc::SuspendInvocationsRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
- StatusOr<::flatbuffers::Offset<rpc::ResumeInvocationsResponse>>
- ResumeInvocations(const rpc::ResumeInvocationsRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
- StatusOr<::flatbuffers::Offset<rpc::StepInvocationResponse>> StepInvocation(
- const rpc::StepInvocationRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
- StatusOr<::flatbuffers::Offset<rpc::GetInvocationLocalResponse>>
- GetInvocationLocal(const rpc::GetInvocationLocalRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
- StatusOr<::flatbuffers::Offset<rpc::SetInvocationLocalResponse>>
- SetInvocationLocal(const rpc::SetInvocationLocalRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
-
- StatusOr<::flatbuffers::Offset<rpc::ListBreakpointsResponse>> ListBreakpoints(
- const rpc::ListBreakpointsRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
- StatusOr<::flatbuffers::Offset<rpc::AddBreakpointResponse>> AddBreakpoint(
- const rpc::AddBreakpointRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
- StatusOr<::flatbuffers::Offset<rpc::RemoveBreakpointResponse>>
- RemoveBreakpoint(const rpc::RemoveBreakpointRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
-
- StatusOr<::flatbuffers::Offset<rpc::StartProfilingResponse>> StartProfiling(
- const rpc::StartProfilingRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
- StatusOr<::flatbuffers::Offset<rpc::StopProfilingResponse>> StopProfiling(
- const rpc::StopProfilingRequest& request,
- ::flatbuffers::FlatBufferBuilder* fbb);
-
- // Serializes an invocation and its stack frames.
- StatusOr<::flatbuffers::Offset<rpc::InvocationDef>> SerializeInvocation(
- const Invocation& invocation, ::flatbuffers::FlatBufferBuilder* fbb);
-
- private:
- StatusOr<Context*> GetContext(int context_id) const
- ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
- StatusOr<Module*> GetModule(int context_id,
- absl::string_view module_name) const
- ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
- StatusOr<Invocation*> GetInvocation(int invocation_id) const
- ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
- // Suspends all invocations on all contexts. Returns only once all invocations
- // have been suspended successfully. Fails if any invocation fails to suspend.
- Status SuspendAllInvocations() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
- // Resumes all invocations on all contexts (the inverse of
- // SuspendAllInvocations). Returns immediately.
- Status ResumeAllInvocations() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
- // Marks all sessions as unready.
- Status UnreadyAllSessions() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
- // Attempts to re-register all breakpoints for a module.
- Status RegisterModuleBreakpoints(Context* context, Module* module)
- ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
- Status RegisterFunctionBreakpoint(Context* context, Module* module,
- rpc::BreakpointDefT* breakpoint)
- ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
- Status UnregisterFunctionBreakpoint(const rpc::BreakpointDefT& breakpoint)
- ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
- // Signals that the given breakpoint was hit by the specified invocation.
- // Called without the debug lock held.
- Status OnFunctionBreakpointHit(int breakpoint_id,
- const Invocation& invocation);
-
- absl::Mutex mutex_;
- std::vector<Context*> contexts_ ABSL_GUARDED_BY(mutex_);
- std::vector<Invocation*> invocations_ ABSL_GUARDED_BY(mutex_);
- std::vector<DebugSession*> sessions_ ABSL_GUARDED_BY(mutex_);
- int sessions_unready_ ABSL_GUARDED_BY(mutex_) = 0;
- int sessions_ready_ ABSL_GUARDED_BY(mutex_) = 0;
-
- std::vector<rpc::BreakpointDefT> breakpoints_ ABSL_GUARDED_BY(mutex_);
-};
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_DEBUG_DEBUG_SERVICE_H_
diff --git a/iree/rt/debug/debug_session.cc b/iree/rt/debug/debug_session.cc
deleted file mode 100644
index 83a28da..0000000
--- a/iree/rt/debug/debug_session.cc
+++ /dev/null
@@ -1,49 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/debug/debug_session.h"
-
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-bool DebugSession::is_ready() const {
- absl::MutexLock lock(&mutex_);
- return ready_ == 0;
-}
-
-Status DebugSession::OnReady() {
- absl::MutexLock lock(&mutex_);
- if (ready_ > 0) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Session has already readied up";
- }
- ++ready_;
- VLOG(2) << "Session " << id() << ": ++ready = " << ready_;
- return OkStatus();
-}
-
-Status DebugSession::OnUnready() {
- absl::MutexLock lock(&mutex_);
- --ready_;
- VLOG(2) << "Session " << id() << ": --ready = " << ready_;
- return OkStatus();
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/debug/debug_session.h b/iree/rt/debug/debug_session.h
deleted file mode 100644
index f6cd87c..0000000
--- a/iree/rt/debug/debug_session.h
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_DEBUG_DEBUG_SESSION_H_
-#define IREE_RT_DEBUG_DEBUG_SESSION_H_
-
-#include "absl/base/thread_annotations.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/base/status.h"
-#include "iree/rt/context.h"
-#include "iree/rt/invocation.h"
-#include "iree/rt/module.h"
-#include "iree/schemas/debug_service_generated.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-// An active debugging session maintained by the DebugService.
-// Each connected client gets a session and transport-specific implementations
-// use the event methods to receive signals from the service.
-//
-// All methods are called only while the debug lock is held and may be called
-// from any thread.
-class DebugSession {
- public:
- virtual ~DebugSession() = default;
-
- // Session ID used in all RPCs related to this session.
- // This can be used for attributing RPCs to the originating session when
- // multiple sessions may be active at a time/over the same transport.
- int id() const { return session_id_; }
-
- // Returns true if the session has issued a MakeReady request and is ok if
- // execution resumes.
- bool is_ready() const;
-
- // Signals that the session has readied up and is now active.
- // Called with the global debug lock held.
- virtual Status OnReady();
-
- // Signals that the session has gone unready (from an event/etc) and the
- // service is now awaiting it to ready up.
- // Called with the global debug lock held.
- virtual Status OnUnready();
-
- // Signals that a context has been registered.
- // Called with the global debug lock held.
- virtual Status OnContextRegistered(Context* context) = 0;
- virtual Status OnContextUnregistered(Context* context) = 0;
-
- // Signals that a module has been loaded in a context.
- // Called with the global debug lock held.
- virtual Status OnModuleLoaded(Context* context, Module* module) = 0;
-
- // Signals that a invocation has been registered.
- // Called with the global debug lock held.
- virtual Status OnInvocationRegistered(Invocation* invocation) = 0;
- virtual Status OnInvocationUnregistered(Invocation* invocation) = 0;
-
- // Signals that a breakpoint has been resolved to a particular function in a
- // context.
- // Called with the global debug lock held.
- virtual Status OnBreakpointResolved(const rpc::BreakpointDefT& breakpoint,
- Context* context) = 0;
-
- // Signals that the given breakpoint has been hit during execution.
- // Called with the global debug lock held.
- virtual Status OnBreakpointHit(int breakpoint_id,
- const Invocation& invocation) = 0;
-
- private:
- mutable absl::Mutex mutex_;
- int session_id_ = 0;
- int ready_ ABSL_GUARDED_BY(mutex_) = -1;
-};
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_DEBUG_DEBUG_SESSION_H_
diff --git a/iree/rt/debug/debug_tcp_util.h b/iree/rt/debug/debug_tcp_util.h
deleted file mode 100644
index d9c33bf..0000000
--- a/iree/rt/debug/debug_tcp_util.h
+++ /dev/null
@@ -1,217 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Utilities for working with TCP sockets.
-// These are (mostly) portable to systems implementing BSD sockets.
-
-#ifndef IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_
-#define IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_
-
-#include <fcntl.h>
-#include <netinet/in.h>
-#include <netinet/tcp.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-
-#include <cstddef>
-
-#include "flatbuffers/base.h"
-#include "flatbuffers/flatbuffers.h"
-#include "iree/base/status.h"
-#include "iree/schemas/debug_service_generated.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-namespace tcp {
-
-// Toggles address reuse on a socket. Call prior to binding.
-// This is useful if a socket is sitting in close_wait from a previous process
-// while a new one is trying to bind to it.
-inline Status ToggleSocketAddressReuse(int fd, bool is_enabled) {
- int toggle = is_enabled ? 1 : 0;
- ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &toggle, sizeof(toggle));
- return OkStatus();
-}
-
-// Toggles the linger option on a socket.
-// Enabling linger will ensure all data on the socket is sent (if it can be
-// sent within N sec) prior to closing. Disabling linger will cause the socket
-// to close gracefully.
-inline Status ToggleSocketLinger(int fd, bool is_enabled) {
- struct linger linger;
- linger.l_onoff = is_enabled ? 1 : 0;
- linger.l_linger = 1;
- ::setsockopt(fd, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger));
- return OkStatus();
-}
-
-// Toggles Nagel's algorithm on a socket.
-// Enabled by default, sockets have ~250ms delay for small packets. Disabling
-// the algorithm will make socket flushes actually send data.
-inline Status ToggleSocketNagelsAlgorithm(int fd, bool is_enabled) {
- int toggle = is_enabled ? 1 : 0;
- ::setsockopt(fd, SOL_TCP, TCP_NODELAY, &toggle, sizeof(toggle));
- return OkStatus();
-}
-
-// Toggles TCP keepalive on a socket.
-// Assumes that the remote side is on the local machine/network and that we can
-// spam it with packets.
-//
-// NOTE: we may want to adjust this when real debuggers are attached (to prevent
-// dropping our own connections). Need to figure out how to reliably detect
-// debug suspends vs. actual death.
-inline Status ToggleSocketLocalKeepalive(int fd, bool is_enabled) {
- // Toggle keepalive.
- int keepalive_enable = is_enabled ? 1 : 0;
- ::setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &keepalive_enable,
- sizeof(keepalive_enable));
- // Begin sending keepalive probes after N sec.
- int keepalive_idle_delay = 3;
- ::setsockopt(fd, SOL_TCP, TCP_KEEPIDLE, &keepalive_idle_delay,
- sizeof(keepalive_idle_delay));
- // Try one probe and bail (faster detection).
- int keepalive_retry_count = 1;
- ::setsockopt(fd, SOL_TCP, TCP_KEEPINTVL, &keepalive_retry_count,
- sizeof(keepalive_retry_count));
- // Send keepalives every N sec.
- int keepalive_interval = 1;
- ::setsockopt(fd, SOL_TCP, TCP_KEEPINTVL, &keepalive_interval,
- sizeof(keepalive_interval));
- return OkStatus();
-}
-
-// Toggles the blocking state of a socket.
-// If a socket has been set to non-blocking methods like read and write will
-// return EWOULDBLOCK if they would have blocked on the specific operation.
-inline Status ToggleSocketBlocking(int fd, bool is_blocking) {
- if (is_blocking) {
- ::fcntl(fd, F_SETFL, ::fcntl(fd, F_GETFL) & ~O_NONBLOCK);
- } else {
- ::fcntl(fd, F_SETFL, ::fcntl(fd, F_GETFL) | O_NONBLOCK);
- }
- return OkStatus();
-}
-
-// RAII wrapper for messages containing flatbuffer roots of type T.
-template <typename T>
-struct MessageBuffer {
- public:
- explicit MessageBuffer(std::vector<uint8_t> buffer)
- : buffer_(std::move(buffer)) {}
- MessageBuffer(const MessageBuffer&) = delete;
- MessageBuffer& operator=(const MessageBuffer&) = delete;
- MessageBuffer(MessageBuffer&&) = default;
- MessageBuffer& operator=(MessageBuffer&&) = default;
-
- const T& GetRoot() const {
- return *::flatbuffers::GetRoot<T>(buffer_.data());
- }
-
- private:
- std::vector<uint8_t> buffer_;
-};
-
-// Reads a size prefix value from the given fd.
-// If |poll_only| is true then the size prefix is not consumed from the stream
-// and the call will return 0 if there is no size prefix available.
-// Returns CancelledError if a (probably) graceful close is detected.
-inline StatusOr<size_t> ReadSizePrefix(int fd, bool poll_only) {
- ::flatbuffers::uoffset_t size_prefix = 0;
- int read_bytes = ::recv(fd, &size_prefix, sizeof(size_prefix),
- poll_only ? (MSG_PEEK | MSG_DONTWAIT) : 0);
- if (read_bytes == 0) {
- // Remote side disconnected.
- return CancelledErrorBuilder(IREE_LOC) << "Graceful remote close";
- } else if (read_bytes < 0) {
- if (errno == ECONNRESET) {
- return CancelledErrorBuilder(IREE_LOC) << "Ungraceful remote close";
- }
- return DataLossErrorBuilder(IREE_LOC)
- << "Failed to read size prefix from socket: (" << errno << ") "
- << ::strerror(errno);
- } else if (read_bytes != sizeof(size_prefix)) {
- if (poll_only) {
- // No data available.
- return 0;
- } else {
- return DataLossErrorBuilder(IREE_LOC)
- << "Failed to read full size prefix (got " << read_bytes << "b of "
- << sizeof(size_prefix) << "b expected)";
- }
- }
- return size_prefix;
-}
-
-// Returns true if ReadBuffer will (likely) not block when called.
-// Returns CancelledError if a (probably) graceful close is detected.
-inline StatusOr<bool> CanReadBuffer(int fd) {
- ASSIGN_OR_RETURN(size_t size_prefix, ReadSizePrefix(fd, /*poll_only=*/true));
- return size_prefix != 0;
-}
-
-// Reads a size-prefixed message from the given fd.
-// This will block until the entire message contents are available.
-// Returns a buffer reference that will deallocate the buffer automatically or
-// CancelledError if a (probably) graceful close is detected.
-template <typename T>
-StatusOr<MessageBuffer<T>> ReadBuffer(int fd) {
- // Read the size prefix (written as a uoffset_t by the Write* methods).
- ASSIGN_OR_RETURN(size_t size_prefix, ReadSizePrefix(fd, /*poll_only=*/false));
-
- // Allocate the buffer for the entire message.
- // We'll use the BufferRef to free() it when it's no longer required.
- std::vector<uint8_t> buffer(size_prefix);
-
- // Read the entire message contents.
- int full_read_bytes = ::recv(fd, buffer.data(), buffer.size(), 0);
- if (full_read_bytes < 0) {
- return DataLossErrorBuilder(IREE_LOC)
- << "Failed to read full message contents from socket: (" << errno
- << ") " << ::strerror(errno);
- } else if (full_read_bytes != buffer.size()) {
- return DataLossErrorBuilder(IREE_LOC)
- << "Failed to read full message contents (got " << full_read_bytes
- << "b of " << buffer.size() << "b expected)";
- }
-
- // Verify the contents. Not strictly required (as we won't ever ship this to
- // prod), but useful in ensuring our socket code isn't corrupting things.
- ::flatbuffers::Verifier verifier(buffer.data(), buffer.size());
- if (!verifier.VerifyBuffer<T>()) {
- return DataLossErrorBuilder(IREE_LOC)
- << "Verification of input buffer of type " << typeid(T).name()
- << " (" << buffer.size() << "b) failed";
- }
-
- // Wrap the buffer to get some RAII goodness.
- return MessageBuffer<T>(std::move(buffer));
-}
-
-// Writes a buffer to the given fd.
-inline Status WriteBuffer(int fd, ::flatbuffers::DetachedBuffer buffer) {
- if (::send(fd, buffer.data(), buffer.size(), 0) < 0) {
- return UnavailableErrorBuilder(IREE_LOC)
- << "Write failed: (" << errno << ") " << ::strerror(errno);
- }
- return OkStatus();
-}
-
-} // namespace tcp
-} // namespace debug
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_
diff --git a/iree/rt/disassembler.h b/iree/rt/disassembler.h
deleted file mode 100644
index 52d91ee..0000000
--- a/iree/rt/disassembler.h
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_DISASSEMBLER_H_
-#define IREE_RT_DISASSEMBLER_H_
-
-#include <cstdint>
-#include <ostream>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "iree/base/status.h"
-#include "iree/rt/function.h"
-#include "iree/rt/source_location.h"
-
-namespace iree {
-namespace rt {
-
-// A single disassembled instruction.
-struct Instruction {
- // Offset of the instruction within the function.
- // The meaning of this is backend-dependent.
- SourceOffset offset;
-
- // The first line of |long_text|.
- absl::string_view short_text;
-
- // Human-readable text of the instruction. May contain multiple lines.
- std::string long_text;
-};
-
-// Disassembles functions into instructions.
-//
-// Thread-safe.
-class Disassembler {
- public:
- virtual ~Disassembler() = default;
-
- // Disassembles one or more instructions within the given function based on
- // source offsets.
- virtual StatusOr<std::vector<Instruction>> DisassembleInstructions(
- const Function& function, SourceOffset offset,
- int32_t instruction_count = INT32_MAX) const = 0;
-
- protected:
- Disassembler() = default;
-};
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_DISASSEMBLER_H_
diff --git a/iree/rt/function.cc b/iree/rt/function.cc
deleted file mode 100644
index c32fbce..0000000
--- a/iree/rt/function.cc
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/function.h"
-
-#include "iree/rt/module.h"
-
-namespace iree {
-namespace rt {
-
-absl::string_view Function::name() const {
- auto result_or = module_->GetFunctionName(linkage_, ordinal_);
- return result_or.ok() ? result_or.ValueOrDie() : absl::string_view();
-}
-
-const FunctionSignature Function::signature() const {
- auto result_or = module_->GetFunctionSignature(linkage_, ordinal_);
- return result_or.ok() ? result_or.ValueOrDie() : FunctionSignature();
-}
-
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/function.h b/iree/rt/function.h
deleted file mode 100644
index 9acc0e6..0000000
--- a/iree/rt/function.h
+++ /dev/null
@@ -1,84 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_FUNCTION_H_
-#define IREE_RT_FUNCTION_H_
-
-#include <cstdint>
-#include <ostream>
-
-#include "absl/strings/string_view.h"
-#include "iree/rt/function_signature.h"
-
-namespace iree {
-namespace rt {
-
-class Module;
-
-// Reference to a function within a module.
-// Functions are either visible or hidden from the module interface and may be
-// of one Linkage type. Imports and exports are always visible (as they are
-// required for dynamic linking) however functions with internal linkage may be
-// hidden in optimized builds to reduce the amount of reflection metadata
-// required.
-class Function final {
- public:
- enum class Linkage {
- // Function is internal to the module and may not be reflectable.
- kInternal = 0,
- // Function is an import from another module.
- kImport = 1,
- // Function is an export from the module.
- kExport = 2,
- };
-
- Function() = default;
- Function(const Module* module, Linkage linkage, int32_t ordinal)
- : module_(module), linkage_(linkage), ordinal_(ordinal) {}
-
- // Module the function is contained within.
- const Module* module() const { return module_; }
-
- // Linkage of the function. Note that Linkage::kInternal functions may be
- // missing reflection information.
- Linkage linkage() const { return linkage_; }
-
- // Ordinal within the module in the linkage scope.
- int32_t ordinal() const { return ordinal_; }
-
- // Returns the original name of the function.
- // Internal functions may return empty if debugging info has been stripped.
- absl::string_view name() const;
-
- // Returns the signature of the function.
- // Always present for imports and exports but may be empty for internal
- // functions if debugging info has been stripped.
- const FunctionSignature signature() const;
-
- private:
- const Module* module_ = nullptr;
- Linkage linkage_ = Linkage::kInternal;
- int32_t ordinal_ = -1;
-};
-
-inline std::ostream& operator<<(std::ostream& stream,
- const Function& function) {
- stream << '@' << function.name() << '#' << function.ordinal();
- return stream;
-}
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_FUNCTION_H_
diff --git a/iree/rt/instance.cc b/iree/rt/instance.cc
deleted file mode 100644
index d793bc7..0000000
--- a/iree/rt/instance.cc
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/instance.h"
-
-#include "iree/base/tracing.h"
-#include "iree/rt/debug/debug_server.h"
-
-namespace iree {
-namespace rt {
-
-Instance::Instance(std::unique_ptr<debug::DebugServer> debug_server)
- : debug_server_(std::move(debug_server)) {
- IREE_TRACE_SCOPE0("Instance::ctor");
-}
-
-Instance::~Instance() { IREE_TRACE_SCOPE0("Instance::dtor"); }
-
-void Instance::RegisterContext(Context* context) {
- {
- absl::MutexLock lock(&contexts_mutex_);
- contexts_.push_back(context);
- }
- if (debug_server_) {
- CHECK_OK(debug_server_->RegisterContext(context));
- }
-}
-
-void Instance::UnregisterContext(Context* context) {
- if (debug_server_) {
- CHECK_OK(debug_server_->UnregisterContext(context));
- }
- {
- absl::MutexLock lock(&contexts_mutex_);
- contexts_.erase(context);
- }
-}
-
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/instance.h b/iree/rt/instance.h
deleted file mode 100644
index 61fd544..0000000
--- a/iree/rt/instance.h
+++ /dev/null
@@ -1,70 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_INSTANCE_H_
-#define IREE_RT_INSTANCE_H_
-
-#include "absl/base/thread_annotations.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/base/intrusive_list.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/hal/device_manager.h"
-#include "iree/rt/context.h"
-#include "iree/rt/debug/debug_server.h"
-
-namespace iree {
-namespace rt {
-
-// Shared runtime instance responsible for routing Context events, enumerating
-// and creating hardware device interfaces, and managing thread pools.
-//
-// A single runtime instance can service multiple contexts and hosting
-// applications should try to reuse instances as much as possible. This ensures
-// that resource allocation across contexts is handled and extraneous device
-// interaction is avoided. For devices that may have exclusive access
-// restrictions it is mandatory to share instances, so plan accordingly.
-//
-// Thread-safe.
-class Instance final : public RefObject<Instance> {
- public:
- // Creates an instance with an optional attached |debug_server|.
- Instance() : Instance(nullptr) {}
- explicit Instance(std::unique_ptr<debug::DebugServer> debug_server);
- ~Instance();
- Instance(const Instance&) = delete;
- Instance& operator=(const Instance&) = delete;
-
- // Optional debug server that has access to contexts in this instance.
- debug::DebugServer* debug_server() const { return debug_server_.get(); }
-
- // Device manager used to enumerate available devices.
- hal::DeviceManager* device_manager() const { return &device_manager_; }
-
- private:
- friend class Context;
- void RegisterContext(Context* context);
- void UnregisterContext(Context* context);
-
- std::unique_ptr<debug::DebugServer> debug_server_;
- mutable hal::DeviceManager device_manager_;
-
- absl::Mutex contexts_mutex_;
- IntrusiveList<Context, offsetof(Context, instance_list_link_)> contexts_
- ABSL_GUARDED_BY(contexts_mutex_);
-};
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_INSTANCE_H_
diff --git a/iree/rt/invocation.cc b/iree/rt/invocation.cc
deleted file mode 100644
index 738c65b..0000000
--- a/iree/rt/invocation.cc
+++ /dev/null
@@ -1,185 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/invocation.h"
-
-#include <atomic>
-#include <iterator>
-
-#include "absl/strings/str_cat.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/rt/context.h"
-
-namespace iree {
-namespace rt {
-
-namespace {
-
-int32_t NextUniqueInvocationId() {
- static std::atomic<int32_t> next_id = {0};
- return ++next_id;
-}
-
-} // namespace
-
-// static
-StatusOr<ref_ptr<Invocation>> Invocation::Create(
- ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy,
- absl::InlinedVector<ref_ptr<Invocation>, 4> dependencies,
- absl::InlinedVector<hal::BufferView, 8> arguments,
- absl::optional<absl::InlinedVector<hal::BufferView, 8>> results) {
- IREE_TRACE_SCOPE0("Invocation::Create");
-
- const auto& signature = function.signature();
- if (arguments.size() != signature.argument_count()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Argument count mismatch; expected " << signature.argument_count()
- << " but received " << arguments.size();
- } else if (results.has_value() &&
- results.value().size() != signature.result_count()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Result count mismatch; expected " << signature.result_count()
- << " but received " << results.value().size();
- }
-
- absl::InlinedVector<hal::BufferView, 8> results_value;
- if (results.has_value()) {
- results_value = std::move(results.value());
- } else {
- results_value.resize(signature.result_count());
- }
-
- auto invocation = assign_ref(
- new Invocation(std::move(context), function, std::move(policy)));
-
- // TODO(benvanik): grab execution state, insert deps, etc.
- if (!dependencies.empty()) {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Dependencies are not yet implemented";
- }
-
- // TODO(benvanik): fiber scheduling and such.
- auto execute_status = function.module()->Execute(
- &invocation->stack_, function, std::move(arguments), &results_value);
- if (execute_status.ok()) {
- invocation->CompleteSuccess(std::move(results_value));
- } else {
- invocation->CompleteFailure(std::move(execute_status), nullptr);
- }
-
- return invocation;
-}
-
-// static
-StatusOr<ref_ptr<Invocation>> Invocation::Create(
- ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy,
- absl::Span<const ref_ptr<Invocation>> dependencies,
- absl::Span<const hal::BufferView> arguments) {
- absl::InlinedVector<ref_ptr<Invocation>, 4> dependency_list;
- dependency_list.reserve(dependencies.size());
- for (auto& dependency : dependencies) {
- dependency_list.push_back(add_ref(dependency));
- }
- absl::InlinedVector<hal::BufferView, 8> argument_list;
- argument_list.reserve(arguments.size());
- for (auto& buffer_view : arguments) {
- argument_list.push_back(buffer_view);
- }
- return Invocation::Create(std::move(context), function, std::move(policy),
- std::move(dependency_list),
- std::move(argument_list));
-}
-
-Invocation::Invocation(ref_ptr<Context> context, const Function function,
- ref_ptr<Policy> policy)
- : id_(NextUniqueInvocationId()),
- context_(std::move(context)),
- function_(function),
- policy_(std::move(policy)),
- stack_(context_.get()) {
- IREE_TRACE_SCOPE0("Invocation::ctor");
- context_->RegisterInvocation(this);
-}
-
-Invocation::~Invocation() {
- IREE_TRACE_SCOPE0("Invocation::dtor");
- context_->UnregisterInvocation(this);
-}
-
-std::string Invocation::DebugStringShort() const {
- return absl::StrCat("invocation_", id_);
-}
-
-std::string Invocation::DebugString() const { return DebugStringShort(); }
-
-Status Invocation::QueryStatus() {
- IREE_TRACE_SCOPE0("Invocation::QueryStatus");
- absl::MutexLock lock(&status_mutex_);
- return completion_status_;
-}
-
-StatusOr<absl::InlinedVector<hal::BufferView, 8>> Invocation::ConsumeResults() {
- IREE_TRACE_SCOPE0("Invocation::ConsumeResults");
- absl::MutexLock lock(&status_mutex_);
- if (!completion_status_.ok()) {
- return completion_status_;
- }
- return std::move(results_);
-}
-
-Status Invocation::Await(absl::Time deadline) {
- IREE_TRACE_SCOPE0("Invocation::Await");
- absl::MutexLock lock(&status_mutex_);
- // TODO(benvanik): implement async invocation behavior.
- return completion_status_;
-}
-
-Status Invocation::Abort() {
- IREE_TRACE_SCOPE0("Invocation::Abort");
- // TODO(benvanik): implement async invocation behavior.
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Async invocations not yet implemented";
-}
-
-void Invocation::CompleteSuccess(
- absl::InlinedVector<hal::BufferView, 8> results) {
- IREE_TRACE_SCOPE0("Invocation::CompleteSuccess");
- absl::MutexLock lock(&status_mutex_);
- if (IsAborted(completion_status_)) {
- // Ignore as the invocation was already aborted prior to completion.
- return;
- }
- DCHECK(IsUnavailable(completion_status_));
- completion_status_ = OkStatus();
- failure_stack_trace_.reset();
- results_ = std::move(results);
-}
-
-void Invocation::CompleteFailure(
- Status completion_status, std::unique_ptr<StackTrace> failure_stack_trace) {
- IREE_TRACE_SCOPE0("Invocation::CompleteFailure");
- absl::MutexLock lock(&status_mutex_);
- if (IsAborted(completion_status_)) {
- // Ignore as the invocation was already aborted prior to completion.
- return;
- }
- DCHECK(IsUnavailable(completion_status_));
- completion_status_ = std::move(completion_status);
- failure_stack_trace_ = std::move(failure_stack_trace);
- results_.clear();
-}
-
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/invocation.h b/iree/rt/invocation.h
deleted file mode 100644
index 600fd8a..0000000
--- a/iree/rt/invocation.h
+++ /dev/null
@@ -1,157 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_INVOCATION_H_
-#define IREE_RT_INVOCATION_H_
-
-#include <ostream>
-#include <string>
-
-#include "absl/base/thread_annotations.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "iree/base/intrusive_list.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/rt/function.h"
-#include "iree/rt/policy.h"
-#include "iree/rt/stack.h"
-#include "iree/rt/stack_trace.h"
-
-namespace iree {
-namespace rt {
-
-class Context;
-
-// An asynchronous invocation of a function.
-// Holds the invocation state and allows querying and waiting on completion.
-// Invocations are conceptually fibers and may suspend and resume execution
-// several times before completing.
-//
-// Thread-safe.
-class Invocation final : public RefObject<Invocation> {
- public:
- // TODO(benvanik): define error propagation semantics across dependencies.
- // TODO(benvanik): support more dependency types (semaphores, etc).
- // Creates a new invocation tracking object for invoking the given |function|
- // from |context|. |arguments| will be retained until the invocation is made.
- // If |dependencies| are provided then the invocation will wait until they are
- // resolved before executing. If a |policy| is provided it will override the
- // context-level policy.
- //
- // Optionally |results| may be provided with preallocated buffers that will
- // receive the outputs of the invocation. Invocation will fail if they do not
- // match expected sizes.
- //
- // Note that it's possible for the invocation to complete prior to the return
- // of this function. Any errors that occur will be set on the invocation and
- // callers should query its state prior to assuming it is in-flight.
- static StatusOr<ref_ptr<Invocation>> Create(
- ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy,
- absl::InlinedVector<ref_ptr<Invocation>, 4> dependencies,
- absl::InlinedVector<hal::BufferView, 8> arguments,
- absl::optional<absl::InlinedVector<hal::BufferView, 8>> results =
- absl::nullopt);
- static StatusOr<ref_ptr<Invocation>> Create(
- ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy,
- absl::Span<const ref_ptr<Invocation>> dependencies,
- absl::Span<const hal::BufferView> arguments);
-
- ~Invocation();
-
- // A process-unique ID for the invocation.
- int32_t id() const { return id_; }
-
- // Context this invocation is running within.
- const ref_ptr<Context>& context() const { return context_; }
-
- // Function being invoked.
- const Function& function() const { return function_; }
-
- // A single-line human-readable debug string for the invocation.
- std::string DebugStringShort() const;
-
- // A long-form debug string with stack trace (if available).
- std::string DebugString() const;
-
- // Queries the completion status of the invocation.
- // Returns one of the following:
- // StatusCode::kOk: the invocation completed successfully.
- // StatusCode::kUnavailable: the invocation has not yet completed.
- // StatusCode::kCancelled: the invocation was cancelled internally.
- // StatusCode::kAborted: the invocation was aborted.
- // StatusCode::*: an error occurred during invocation.
- Status QueryStatus();
-
- // Returns ownership of the results of the operation to the caller.
- // If the invocation failed then the result will be returned as if Query had
- // been called.
- StatusOr<absl::InlinedVector<hal::BufferView, 8>> ConsumeResults();
-
- // Blocks the caller until the invocation completes (successfully or
- // otherwise).
- //
- // Returns StatusCode::kDeadlineExceeded if |deadline| elapses before the
- // invocation completes and otherwise returns the result of Query().
- Status Await(absl::Time deadline);
-
- // Attempts to abort the invocation if it is in-flight.
- // A no-op if the invocation has already completed.
- Status Abort();
-
- // TODO(benvanik): export a hal::TimelineSemaphore.
-
- private:
- friend class Context;
-
- Invocation(ref_ptr<Context> context, const Function function,
- ref_ptr<Policy> policy);
-
- // Completes the invocation with a successful result.
- void CompleteSuccess(absl::InlinedVector<hal::BufferView, 8> results);
-
- // Completes the invocation with a failure, including an optional stack trace.
- void CompleteFailure(Status completion_status,
- std::unique_ptr<StackTrace> failure_stack_trace);
-
- int32_t id_;
- ref_ptr<Context> context_;
- const Function function_;
- ref_ptr<Policy> policy_;
-
- Stack stack_;
-
- absl::Mutex status_mutex_;
- Status completion_status_ ABSL_GUARDED_BY(status_mutex_) =
- UnavailableErrorBuilder(IREE_LOC);
- std::unique_ptr<StackTrace> failure_stack_trace_
- ABSL_GUARDED_BY(status_mutex_);
- absl::InlinedVector<hal::BufferView, 8> results_
- ABSL_GUARDED_BY(status_mutex_);
-
- friend class Context;
- IntrusiveListLink context_list_link_;
-};
-
-inline std::ostream& operator<<(std::ostream& stream,
- const Invocation& invocation) {
- stream << invocation.DebugStringShort();
- return stream;
-}
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_INVOCATION_H_
diff --git a/iree/rt/module.h b/iree/rt/module.h
deleted file mode 100644
index f91a3a5..0000000
--- a/iree/rt/module.h
+++ /dev/null
@@ -1,109 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_MODULE_H_
-#define IREE_RT_MODULE_H_
-
-#include <ostream>
-
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/string_view.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/rt/function.h"
-#include "iree/rt/module_signature.h"
-
-namespace iree {
-namespace rt {
-
-class Disassembler;
-class SourceResolver;
-class Stack;
-
-// Abstract compiled module interface for resolving functions.
-//
-// Modules are (generally) stateless, immutable, and may exist in multiple
-// contexts at the same time.
-class Module : public RefObject<Module> {
- public:
- virtual ~Module() = default;
-
- // Name of the module used to resolve fully-qualified references.
- // The lifetime of the returned reference is not guaranteed beyond the current
- // calling scope and callers must clone it if they want to retain it.
- virtual absl::string_view name() const = 0;
-
- // A description of the module imports, exports, and other metadata.
- virtual const ModuleSignature signature() const = 0;
-
- // Returns a resolver capable of resolving functions to source and performing
- // basic debugging logic (such as offset calculation).
- // May be nullptr if debugging info has been stripped.
- virtual SourceResolver* source_resolver() const = 0;
-
- // Returns a disassembler that can be used to disassemble functions in the
- // module. May be nullptr if debugging info has been stripped or disassembly
- // has been disabled as a compile option.
- virtual Disassembler* disassembler() const = 0;
-
- // A short human-readable string that matches the compiler formatting.
- virtual std::string DebugStringShort() const = 0;
-
- // Looks up a visible function by ordinal.
- // Internal functions may not be found if debugging info has been stripped.
- virtual StatusOr<const Function> LookupFunctionByOrdinal(
- Function::Linkage linkage, int32_t ordinal) const = 0;
-
- // Looks up a visible function by name.
- // Internal functions may not be found if debugging info has been stripped.
- virtual StatusOr<const Function> LookupFunctionByName(
- Function::Linkage linkage, absl::string_view name) const = 0;
-
- // Returns the name of the visible function as a string reference.
- //
- // May return empty for functions with internal linkage if debugging info has
- // been stripped.
- //
- // The lifetime of the returned reference is not guaranteed beyond the current
- // calling scope and callers must clone it if they want to retain it.
- virtual StatusOr<absl::string_view> GetFunctionName(
- Function::Linkage linkage, int32_t ordinal) const = 0;
-
- // Returns the full function signature for the given |ordinal|.
- //
- // May return empty for functions with internal linkage if the debugging info
- // has been stripped.
- virtual StatusOr<const FunctionSignature> GetFunctionSignature(
- Function::Linkage linkage, int32_t ordinal) const = 0;
-
- // Temporary until scheduler is built.
- virtual Status Execute(
- Stack* stack, const Function function,
- absl::InlinedVector<hal::BufferView, 8> arguments,
- absl::InlinedVector<hal::BufferView, 8>* results) const = 0;
-
- protected:
- Module() = default;
-};
-
-inline std::ostream& operator<<(std::ostream& stream, const Module& module) {
- stream << module.DebugStringShort();
- return stream;
-}
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_MODULE_H_
diff --git a/iree/rt/module_printer.cc b/iree/rt/module_printer.cc
deleted file mode 100644
index dd0267e..0000000
--- a/iree/rt/module_printer.cc
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/module_printer.h"
-
-#include <iomanip>
-
-#include "iree/rt/disassembler.h"
-#include "iree/rt/source_resolver.h"
-
-namespace iree {
-namespace rt {
-
-Status PrintModuleToStream(const Module& module, PrintModuleFlagBitfield flags,
- std::ostream* stream) {
- *stream << "Imports:\n";
- for (int i = 0; i < module.signature().import_function_count(); ++i) {
- ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal(
- Function::Linkage::kImport, i));
- *stream << " " << i << ": " << function << "\n";
- }
- *stream << "Exports:\n";
- for (int i = 0; i < module.signature().export_function_count(); ++i) {
- ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal(
- Function::Linkage::kExport, i));
- *stream << " " << i << ": " << function << "\n";
- }
- if (module.signature().internal_function_count()) {
- *stream << "Internal:\n";
- auto* disassembler = module.disassembler();
- for (int i = 0; i < module.signature().internal_function_count(); ++i) {
- ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal(
- Function::Linkage::kInternal, i));
- *stream << " " << i << ": " << function << "\n";
- if (disassembler && AllBitsSet(flags, PrintModuleFlag::kDisassemble)) {
- auto instructions_or =
- disassembler->DisassembleInstructions(function, 0);
- if (IsUnavailable(instructions_or.status())) continue;
- for (const auto& instruction : instructions_or.ValueOrDie()) {
- *stream << " " << std::setw(6) << instruction.offset << ": "
- << instruction.long_text << "\n";
- }
- }
- }
- }
- return OkStatus();
-}
-
-Status PrintModuleToStream(const Module& module, std::ostream* stream) {
- return PrintModuleToStream(module, PrintModuleFlag::kNone, stream);
-}
-
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/module_printer.h b/iree/rt/module_printer.h
deleted file mode 100644
index 945dbbd..0000000
--- a/iree/rt/module_printer.h
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_MODULE_PRINTER_H_
-#define IREE_RT_MODULE_PRINTER_H_
-
-#include <ostream>
-
-#include "iree/base/bitfield.h"
-#include "iree/base/status.h"
-#include "iree/rt/module.h"
-
-namespace iree {
-namespace rt {
-
-enum class PrintModuleFlag {
- kNone = 0,
- kDisassemble = 1 << 0,
-};
-IREE_BITFIELD(PrintModuleFlag);
-using PrintModuleFlagBitfield = PrintModuleFlag;
-
-// Prints all functions within the module to the given |stream|.
-Status PrintModuleToStream(const Module& module, std::ostream* stream);
-Status PrintModuleToStream(const Module& module, PrintModuleFlagBitfield flags,
- std::ostream* stream);
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_MODULE_PRINTER_H_
diff --git a/iree/rt/policy.h b/iree/rt/policy.h
deleted file mode 100644
index 8bd72fa..0000000
--- a/iree/rt/policy.h
+++ /dev/null
@@ -1,43 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_POLICY_H_
-#define IREE_RT_POLICY_H_
-
-#include "iree/base/ref_ptr.h"
-
-namespace iree {
-namespace rt {
-
-// Defines how invocation scheduling is to be performed.
-// The policy instance is used by the scheduler to determine when submissions
-// should be flushed to target queues.
-//
-// Thread-safe; the policy may be evaluated from arbitrary threads after an
-// invocation has began processing.
-class Policy : public RefObject<Policy> {
- public:
- virtual ~Policy() = default;
-
- // TODO(benvanik): constraints:
- // - max memory usage
- // - max delay
- // - max in-flight items/etc
- // - allowed device types
-};
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_POLICY_H_
diff --git a/iree/rt/source_location.cc b/iree/rt/source_location.cc
deleted file mode 100644
index 835905b..0000000
--- a/iree/rt/source_location.cc
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/source_location.h"
-
-#include <sstream>
-
-#include "iree/rt/source_resolver.h"
-
-namespace iree {
-namespace rt {
-
-std::string SourceLocation::DebugStringShort() const {
- if (is_unknown()) return "(unknown)";
- std::ostringstream stream;
- resolver_->PrintSourceLocation(resolver_args_, &stream);
- return stream.str();
-}
-
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/source_resolver.h b/iree/rt/source_resolver.h
deleted file mode 100644
index 3ceb1b9..0000000
--- a/iree/rt/source_resolver.h
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_SOURCE_RESOLVER_H_
-#define IREE_RT_SOURCE_RESOLVER_H_
-
-#include <cstdint>
-#include <ostream>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "iree/base/status.h"
-#include "iree/rt/function.h"
-#include "iree/rt/source_location.h"
-
-namespace iree {
-namespace rt {
-
-// Resolves offsets within functions to SourceLocations and provides source
-// language services.
-//
-// Thread-safe.
-class SourceResolver {
- public:
- virtual ~SourceResolver() = default;
-
- // Resolves a function-relative offset to a source location.
- // Not all offsets within a function may have source mapping information.
- virtual absl::optional<SourceLocation> ResolveFunctionOffset(
- const Function& function, SourceOffset offset) = 0;
-
- // Converts a source location to a human-readable string, commonly in a single
- // line denoting an original source file location (such as path:line:col).
- virtual void PrintSourceLocation(SourceResolverArgs resolver_args,
- std::ostream* stream) const = 0;
-
- // TODO(benvanik): query local variable names.
-
- // TODO(benvanik): step target calculation (relative mapping).
- // TODO(benvanik): step target based on SourceLocation delta.
-
- // TODO(benvanik): expression evaluator? (setting variables)
-
- protected:
- friend class SourceLocation;
-
- SourceResolver() = default;
-
- // TODO(benvanik): get line mapping information.
-};
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_SOURCE_RESOLVER_H_
diff --git a/iree/rt/stack.cc b/iree/rt/stack.cc
deleted file mode 100644
index f4f300f..0000000
--- a/iree/rt/stack.cc
+++ /dev/null
@@ -1,60 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/stack.h"
-
-#include <iterator>
-
-#include "absl/strings/str_join.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace rt {
-
-constexpr int Stack::kMaxStackDepth;
-
-Stack::Stack(Context* context) : context_(context) {}
-
-Stack::~Stack() = default;
-
-StatusOr<StackFrame*> Stack::PushFrame(Function function) {
- if (stack_depth_ + 1 > kMaxStackDepth) {
- return InternalErrorBuilder(IREE_LOC)
- << "Max stack depth of " << kMaxStackDepth << " exceeded";
- }
- frames_[stack_depth_++] = StackFrame(function);
-
- // TODO(benvanik): WTF scope enter.
-
- return current_frame();
-}
-
-Status Stack::PopFrame() {
- if (stack_depth_ == 0) {
- return InternalErrorBuilder(IREE_LOC) << "Unbalanced stack pop";
- }
-
- // TODO(benvanik): WTF scope leave.
-
- --stack_depth_;
- frames_[stack_depth_] = {};
- return OkStatus();
-}
-
-std::string Stack::DebugString() const {
- return absl::StrJoin(frames(), "\n", StackFrameFormatter());
-}
-
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/stack.h b/iree/rt/stack.h
deleted file mode 100644
index a9547fe..0000000
--- a/iree/rt/stack.h
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_STACK_H_
-#define IREE_RT_STACK_H_
-
-#include <functional>
-
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/rt/stack_frame.h"
-
-namespace iree {
-namespace rt {
-
-class Context;
-
-// A runtime call stack for managing stack frames.
-// The frames within a stack may be from different backends and may provide
-// varying levels of information based on capabilities.
-//
-// Thread-compatible. Do not attempt to investigate a stack while another thread
-// may be mutating it!
-class Stack final {
- public:
- static constexpr int kMaxStackDepth = 32;
-
- explicit Stack(Context* context);
- Stack(const Stack&) = delete;
- Stack& operator=(const Stack&) = delete;
- ~Stack();
-
- // Context defining the module and global workspaces.
- Context* context() const { return context_; }
-
- // All stack frames within the stack.
- absl::Span<StackFrame> frames() {
- return absl::MakeSpan(frames_).subspan(0, stack_depth_);
- }
- absl::Span<const StackFrame> frames() const {
- return absl::MakeConstSpan(frames_).subspan(0, stack_depth_);
- }
-
- // The current stack frame.
- StackFrame* current_frame() {
- return stack_depth_ > 0 ? &frames_[stack_depth_ - 1] : nullptr;
- }
-
- // The stack frame of the caller of the current function.
- StackFrame* caller_frame() {
- return stack_depth_ > 1 ? &frames_[stack_depth_ - 2] : nullptr;
- }
-
- StatusOr<StackFrame*> PushFrame(Function function);
- Status PopFrame();
-
- // Returns a full stack frame listing in human-readable form.
- std::string DebugString() const;
-
- private:
- Context* context_ = nullptr;
- std::array<StackFrame, kMaxStackDepth> frames_;
- int stack_depth_ = 0;
-};
-
-inline std::ostream& operator<<(std::ostream& stream, const Stack& stack) {
- stream << stack.DebugString();
- return stream;
-}
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_STACK_H_
diff --git a/iree/rt/stack_frame.cc b/iree/rt/stack_frame.cc
deleted file mode 100644
index 12ea5c8..0000000
--- a/iree/rt/stack_frame.cc
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/stack_frame.h"
-
-#include "absl/strings/str_cat.h"
-#include "iree/rt/source_resolver.h"
-
-namespace iree {
-namespace rt {
-
-absl::optional<SourceLocation> StackFrame::source_location() const {
- auto* source_resolver = function_.module()->source_resolver();
- if (!source_resolver) return absl::nullopt;
- return source_resolver->ResolveFunctionOffset(function_, offset_);
-}
-
-std::string StackFrame::DebugStringShort() const {
- return absl::StrCat(module().name(), ":", function().name(), "@", offset());
-}
-
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/stack_frame.h b/iree/rt/stack_frame.h
deleted file mode 100644
index bb77ef4..0000000
--- a/iree/rt/stack_frame.h
+++ /dev/null
@@ -1,106 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_STACK_FRAME_H_
-#define IREE_RT_STACK_FRAME_H_
-
-#include <ostream>
-
-#include "absl/types/span.h"
-#include "iree/rt/function.h"
-#include "iree/rt/module.h"
-#include "iree/rt/source_location.h"
-
-namespace iree {
-namespace rt {
-
-// TODO(benvanik): allocate in-place from an arena.
-// Register table used within a stack frame.
-struct Registers {
- std::vector<hal::BufferView> buffer_views;
-};
-
-// A single frame on the call stack containing current execution state and
-// register values.
-//
-// As different backends may support different features this interface exposes
-// only the things we want to view in our debugger/stack dumps. This allows us
-// to ignore the actual implementation (bytecode VM, compiled C code, etc) so
-// long as it can respond to queries for register values. This has the benefit
-// of keeping the actual frame very lightweight as we are not storing the values
-// but instead just routing to the real storage via indirection. If the debugger
-// is not attached and no errors are hit then no additional bookkeeping is done.
-//
-// Thread-compatible, as is the owning Stack/StackTrace.
-class StackFrame final {
- public:
- StackFrame() = default;
- explicit StackFrame(Function function) : function_(function) {}
- StackFrame(Function function, SourceOffset offset, Registers registers)
- : function_(function),
- offset_(offset),
- registers_(std::move(registers)) {}
- StackFrame(const StackFrame&) = delete;
- StackFrame& operator=(const StackFrame&) = delete;
- StackFrame(StackFrame&&) = default;
- StackFrame& operator=(StackFrame&&) = default;
-
- // Module that owns the function this stack frame represents.
- const Module& module() const { return *function_.module(); }
-
- // Function the stack frame represents.
- const Function& function() const { return function_; }
-
- // Current virtual offset within the function.
- // The exact meaning of the offset is backend dependent and callers should
- // treat them as opaque and must use the SourceResolver to compute new
- // offsets (such as 'next offset').
- SourceOffset offset() const { return offset_; }
- SourceOffset* mutable_offset() { return &offset_; }
-
- // Returns a source location, if available, for the current offset within the
- // target function.
- absl::optional<SourceLocation> source_location() const;
-
- // Registers used within the stack frame.
- // Storage is implementation-defined and is valid only for the lifetime of the
- // frame.
- const Registers& registers() const { return registers_; }
- Registers* mutable_registers() { return ®isters_; }
-
- // A short human-readable string for the frame; a single line.
- std::string DebugStringShort() const;
-
- private:
- Function function_;
- SourceOffset offset_ = 0;
- Registers registers_;
-};
-
-struct StackFrameFormatter {
- void operator()(std::string* out, const StackFrame& stack_frame) const {
- out->append(stack_frame.DebugStringShort());
- }
-};
-
-inline std::ostream& operator<<(std::ostream& stream,
- const StackFrame& stack_frame) {
- stream << stack_frame.DebugStringShort();
- return stream;
-}
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_STACK_FRAME_H_
diff --git a/iree/rt/stack_trace.cc b/iree/rt/stack_trace.cc
deleted file mode 100644
index ace8803..0000000
--- a/iree/rt/stack_trace.cc
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/rt/stack_trace.h"
-
-#include "absl/strings/str_join.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace rt {
-
-std::string StackTrace::DebugString() const {
- return absl::StrJoin(frames_, "\n", StackFrameFormatter());
-}
-
-} // namespace rt
-} // namespace iree
diff --git a/iree/rt/stack_trace.h b/iree/rt/stack_trace.h
deleted file mode 100644
index beca466..0000000
--- a/iree/rt/stack_trace.h
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_RT_STACK_TRACE_H_
-#define IREE_RT_STACK_TRACE_H_
-
-#include <ostream>
-#include <vector>
-
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/rt/stack_frame.h"
-
-namespace iree {
-namespace rt {
-
-// A snapshot of a stack at a point in time.
-// The frames within a stack may be from different backends and may provide
-// varying levels of information based on capabilities.
-//
-// Depending on the capture options the trace may contain references to register
-// values (such as buffers) from the time of capture. If the buffers were
-// modified after the capture was taken those results will be reflected!
-class StackTrace final {
- public:
- StackTrace() = default;
- explicit StackTrace(std::vector<StackFrame> frames)
- : frames_(std::move(frames)) {}
- StackTrace(const StackTrace&) = delete;
- StackTrace& operator=(const StackTrace&) = delete;
- ~StackTrace() = default;
-
- // All stack frames within the stack.
- absl::Span<const StackFrame> frames() const {
- return absl::MakeConstSpan(frames_);
- }
-
- // Returns a full stack frame listing in human-readable form.
- std::string DebugString() const;
-
- private:
- std::vector<StackFrame> frames_;
-};
-
-inline std::ostream& operator<<(std::ostream& stream,
- const StackTrace& stack_trace) {
- stream << stack_trace.DebugString();
- return stream;
-}
-
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_RT_STACK_TRACE_H_
diff --git a/iree/samples/hal/BUILD b/iree/samples/hal/BUILD
deleted file mode 100644
index 8fa22bf..0000000
--- a/iree/samples/hal/BUILD
+++ /dev/null
@@ -1,48 +0,0 @@
-# Samples demonstrating use of the HAL API.
-# These do not rely on higher layers of the system (such as the VM or runtime).
-
-load("//iree:build_defs.bzl", "PLATFORM_VULKAN_TEST_DEPS")
-load("//iree/tools:compilation.bzl", "iree_bytecode_module")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_bytecode_module(
- name = "simple_compute_test_module",
- srcs = ["simple_compute_test.mlir"],
- cc_namespace = "iree::hal::samples",
-)
-
-cc_test(
- name = "simple_compute_test",
- srcs = ["simple_compute_test.cc"],
- data = [
- # When building with --config=asan you must specify the following
- # envvar when using Vulkan + a local Nvidia GPU:
- # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt
- "//iree/tools:sanitizer_suppressions.txt",
- ],
- deps = [
- ":simple_compute_test_module_cc",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "//iree/base:flatbuffer_util",
- "//iree/base:status_matchers",
- "//iree/hal:command_buffer",
- "//iree/hal:command_queue",
- "//iree/hal:driver_registry",
- "//iree/schemas",
- "//iree/base:status",
-
- # These are the drivers we support running with and can produce
- # executables for from the source MLIR.
- "//iree/hal/interpreter:interpreter_driver_module", # build-cleaner: keep
- "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep
-
- # TODO(b/142004903): enable when Dawn HAL implementation is functional
- # "//iree/hal/dawn:dawn_driver_module", # build-cleaner: keep
- ] + PLATFORM_VULKAN_TEST_DEPS,
-)
diff --git a/iree/samples/hal/simple_compute_test.cc b/iree/samples/hal/simple_compute_test.cc
deleted file mode 100644
index 24d7414..0000000
--- a/iree/samples/hal/simple_compute_test.cc
+++ /dev/null
@@ -1,221 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// A simple backend-agnostic compute test for the HAL API.
-// This will load an IREE module containing one or more executables and attempt
-// to run them against all registered driver backends.
-//
-// The input file, simple_compute_test.mlir, is as generic as possible to ensure
-// we don't need too many variants. This means that it does not use any FFI
-// imports requiring runtime support, uses floats exclusively (as that's assumed
-// available everywhere), etc.
-//
-// The `iree_bytecode_module` build rule is used to translate the MLIR to the
-// module flatbuffer. Additional target support can be defined there.
-
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/str_replace.h"
-#include "absl/time/time.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/flatbuffer_util.h"
-#include "iree/base/status.h"
-#include "iree/base/status_matchers.h"
-#include "iree/hal/command_buffer.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/driver_registry.h"
-#include "iree/samples/hal/simple_compute_test_module.h"
-#include "iree/schemas/module_def_generated.h"
-
-namespace iree {
-namespace hal {
-namespace samples {
-namespace {
-
-using ModuleFile = FlatBufferFile<ModuleDef>;
-
-struct TestParams {
- // HAL driver to use for the test.
- std::string driver_name;
- // Ordinal within the module to execute.
- int executable_ordinal;
- // Name of the executable (just for prettier logging).
- std::string executable_name;
-};
-
-std::ostream& operator<<(std::ostream& os, const TestParams& params) {
- return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}}) << "_ex"
- << params.executable_ordinal << "_" << params.executable_name;
-}
-
-// Loads the precompiled module file (from simple_compute_test.mlir).
-std::unique_ptr<ModuleFile> LoadModuleFile() {
- const auto* file_toc = simple_compute_test_module_create();
- return ModuleFile::WrapBuffer(
- ModuleDefIdentifier(),
- absl::MakeSpan(reinterpret_cast<const uint8_t*>(file_toc->data),
- file_toc->size))
- .ValueOrDie();
-}
-
-// Builds a list of tests to run for each [driver x available executable].
-std::vector<TestParams> GetAvailableDriverTestParams() {
- auto module_file = LoadModuleFile();
- auto& executable_table = *module_file->root()->executable_table();
- std::vector<TestParams> all_test_params;
- for (const auto& driver_name :
- DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) {
- int executable_ordinal = 0;
- for (const auto* multi_arch_executable_def :
- *executable_table.multi_arch_executables()) {
- TestParams test_params;
- test_params.driver_name = driver_name;
- test_params.executable_ordinal = executable_ordinal--;
- test_params.executable_name =
- std::string(WrapString(multi_arch_executable_def->name()));
- all_test_params.push_back(std::move(test_params));
- }
- }
- return all_test_params;
-}
-
-class SimpleComputeTest : public ::testing::Test,
- public ::testing::WithParamInterface<TestParams> {
- protected:
- virtual void SetUp() { module_file_ = LoadModuleFile(); }
-
- std::unique_ptr<ModuleFile> module_file_;
-};
-
-TEST_P(SimpleComputeTest, RunOnce) {
- const auto& test_params = GetParam();
-
- // Create driver for this test (based on params) and then get a default
- // device.
- LOG(INFO) << "Creating driver '" << test_params.driver_name << "'...";
- auto driver_or =
- DriverRegistry::shared_registry()->Create(test_params.driver_name);
- if (IsUnavailable(driver_or.status())) {
- LOG(WARNING) << "Skipping test as driver is unavailable: "
- << driver_or.status();
- GTEST_SKIP();
- return;
- }
- ASSERT_OK_AND_ASSIGN(auto driver, driver_or);
- ASSERT_OK_AND_ASSIGN(auto available_devices,
- driver->EnumerateAvailableDevices());
- for (const auto& device_info : available_devices) {
- LOG(INFO) << " Device: " << device_info.name();
- }
- LOG(INFO) << "Creating default device...";
- ASSERT_OK_AND_ASSIGN(auto device, driver->CreateDefaultDevice());
- LOG(INFO) << "Successfully created device '" << device->info().name() << "'";
-
- // Attempt to compile the appropriate executable. This may fail if there's no
- // executable available in the input file that the driver can load.
- auto executable_cache = device->CreateExecutableCache();
- auto& executable_table = *module_file_->root()->executable_table();
- auto multi_arch_executable_def =
- executable_table.multi_arch_executables()->Get(
- test_params.executable_ordinal);
- ref_ptr<Executable> executable;
- for (auto executable_def : *multi_arch_executable_def->executables()) {
- if (!executable_cache->CanPrepareFormat(executable_def->format())) {
- continue;
- }
- ExecutableSpec spec;
- spec.format = executable_def->format();
- spec.executable_data = *executable_def->contents();
- ASSERT_OK_AND_ASSIGN(executable,
- executable_cache->PrepareExecutable(
- ExecutableCachingMode::kDefault, spec));
- break;
- }
- ASSERT_NE(executable, nullptr)
- << "No executable found that has a supported format for driver "
- << test_params.driver_name;
-
- // Create I/O buffers.
- ASSERT_OK_AND_ASSIGN(auto arg0_buffer,
- device->allocator()->Allocate(
- MemoryType::kHostLocal | MemoryType::kDeviceVisible,
- BufferUsage::kAll, 4 * sizeof(float)));
- ASSERT_OK_AND_ASSIGN(auto arg1_buffer,
- device->allocator()->Allocate(
- MemoryType::kHostLocal | MemoryType::kDeviceVisible,
- BufferUsage::kAll, 4 * sizeof(float)));
- ASSERT_OK_AND_ASSIGN(auto ret0_buffer,
- device->allocator()->Allocate(
- MemoryType::kHostLocal | MemoryType::kDeviceVisible,
- BufferUsage::kAll, 4 * sizeof(float)));
-
- // Populate initial values for 4 * 2 = 8.
- // We scribble into the result buffer so that it's easy to ensure it's
- // overwritten.
- ASSERT_OK(arg0_buffer->Fill32(4.0f));
- ASSERT_OK(arg1_buffer->Fill32(2.0f));
- ASSERT_OK(ret0_buffer->Fill32(99999.0f));
-
- // Record the command buffer that dispatches the executable.
- ASSERT_OK_AND_ASSIGN(
- auto cmd, device->CreateCommandBuffer(
- CommandBufferMode::kOneShot,
- CommandCategory::kTransfer | CommandCategory::kDispatch));
- ASSERT_OK(cmd->Begin());
- DispatchRequest dispatch_request;
- dispatch_request.executable = executable.get();
- dispatch_request.entry_point = 0;
- dispatch_request.workload[0] = 4;
- dispatch_request.workload[1] = 1;
- dispatch_request.workload[2] = 1;
- BufferBinding bindings[3];
- bindings[0].buffer = arg0_buffer.get();
- bindings[0].access = MemoryAccess::kRead;
- bindings[0].element_size = sizeof(float);
- bindings[0].shape = {4};
- bindings[1].buffer = arg1_buffer.get();
- bindings[1].access = MemoryAccess::kRead;
- bindings[1].element_size = sizeof(float);
- bindings[1].shape = {4};
- bindings[2].buffer = ret0_buffer.get();
- bindings[2].access = MemoryAccess::kDiscardWrite;
- bindings[2].element_size = sizeof(float);
- bindings[2].shape = {4};
- dispatch_request.bindings = bindings;
- ASSERT_OK(cmd->Dispatch(dispatch_request));
- ASSERT_OK(cmd->End());
-
- // Schedule and wait for completion.
- ASSERT_FALSE(device->dispatch_queues().empty());
- CommandQueue* queue = device->dispatch_queues().front();
- ASSERT_OK_AND_ASSIGN(auto fence, device->CreateFence(0u));
- ASSERT_OK(
- queue->Submit(SubmissionBatch{{}, {cmd.get()}, {}}, {fence.get(), 1u}));
- ASSERT_OK(device->WaitAllFences({{fence.get(), 1u}}, absl::InfiniteFuture()));
-
- // Read back the results.
- ASSERT_OK_AND_ASSIGN(auto ret0_mapping,
- ret0_buffer->MapMemory<float>(MemoryAccess::kRead));
- EXPECT_THAT(ret0_mapping.contents(),
- ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f}));
-}
-
-INSTANTIATE_TEST_SUITE_P(AllDrivers, SimpleComputeTest,
- ::testing::ValuesIn(GetAvailableDriverTestParams()),
- ::testing::PrintToStringParamName());
-
-} // namespace
-} // namespace samples
-} // namespace hal
-} // namespace iree
diff --git a/iree/samples/rt/BUILD b/iree/samples/rt/BUILD
deleted file mode 100644
index e880a5d..0000000
--- a/iree/samples/rt/BUILD
+++ /dev/null
@@ -1,72 +0,0 @@
-# Samples demonstrating use of the RT API.
-
-load("//iree/tools:compilation.bzl", "iree_bytecode_module")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_bytecode_module(
- name = "simple_module_test_bytecode_module",
- srcs = ["simple_module_test.mlir"],
- cc_namespace = "iree::rt::samples",
-)
-
-cc_test(
- name = "bytecode_module_test",
- srcs = ["bytecode_module_test.cc"],
- data = [
- # When building with --config=asan you must specify the following
- # envvar when using Vulkan + a local Nvidia GPU:
- # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt
- "//iree/tools:sanitizer_suppressions.txt",
- ],
- deps = [
- ":simple_module_test_bytecode_module_cc",
- "@com_google_googletest//:gtest_main",
- "@com_google_absl//absl/strings",
- "//iree/base:flatbuffer_util",
- "//iree/base:status",
- "//iree/base:status_matchers",
- "//iree/hal:buffer_view",
- "//iree/hal:driver_registry",
- "//iree/rt",
- "//iree/schemas",
- "//iree/vm:bytecode_module",
- "//iree/vm:sequencer_module",
-
- # These are the drivers we support running with and can produce
- # executables for from the source MLIR.
- "//iree/hal/interpreter:interpreter_driver_module", # build-cleaner: keep
- # TODO(benvanik): include SPIR-V.
- # "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep
- ],
-)
-
-cc_test(
- name = "bytecode_module_api_test",
- srcs = ["bytecode_module_api_test.cc"],
- data = [
- # When building with --config=asan you must specify the following
- # envvar when using Vulkan + a local Nvidia GPU:
- # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt
- "//iree/tools:sanitizer_suppressions.txt",
- ],
- deps = [
- ":simple_module_test_bytecode_module_cc",
- "@com_google_googletest//:gtest_main",
- "@com_google_absl//absl/strings",
- "//iree/base:api",
- "//iree/hal:api",
- "//iree/hal:driver_registry",
- "//iree/rt:api",
- "//iree/vm:api",
-
- # These are the drivers we support running with and can produce
- # executables for from the source MLIR.
- "//iree/hal/interpreter:interpreter_driver_module", # build-cleaner: keep
- # TODO(benvanik): include SPIR-V.
- # "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep
- ],
-)
diff --git a/iree/samples/rt/bytecode_module_api_test.cc b/iree/samples/rt/bytecode_module_api_test.cc
deleted file mode 100644
index 385a4ae..0000000
--- a/iree/samples/rt/bytecode_module_api_test.cc
+++ /dev/null
@@ -1,188 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "absl/strings/str_replace.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-// C API:
-#include "iree/base/api.h"
-#include "iree/hal/api.h"
-#include "iree/rt/api.h"
-#include "iree/vm/api.h"
-
-// Only temporary, used for test device registration:
-#include "iree/hal/driver_registry.h"
-
-// Compiled module embedded here to avoid file IO:
-#include "iree/samples/rt/simple_module_test_bytecode_module.h"
-
-namespace iree {
-namespace rt {
-namespace samples {
-namespace {
-
-#define ASSERT_API_OK(expr) ASSERT_EQ(IREE_STATUS_OK, (expr))
-
-struct TestParams {
- // HAL driver to use for the test.
- std::string driver_name;
-};
-
-std::ostream& operator<<(std::ostream& os, const TestParams& params) {
- return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}});
-}
-
-// Builds a list of tests to run based on the linked in driver modules.
-std::vector<TestParams> GetAvailableDriverTestParams() {
- std::vector<TestParams> all_test_params;
- for (const auto& driver_name :
- hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) {
- TestParams test_params;
- test_params.driver_name = driver_name;
- all_test_params.push_back(std::move(test_params));
- }
- return all_test_params;
-}
-
-class BytecodeModuleApiTest : public ::testing::Test,
- public ::testing::WithParamInterface<TestParams> {
- protected:
-};
-
-TEST_P(BytecodeModuleApiTest, RunOnce) {
- iree_rt_instance_t* instance = nullptr;
- ASSERT_API_OK(iree_rt_instance_create(IREE_ALLOCATOR_DEFAULT, &instance));
-
- // TEMPORARY: until policies and placement are performed with manually
- // register drivers via a magic function.
- const auto& driver_name = GetParam().driver_name;
- LOG(INFO) << "Creating driver '" << driver_name << "'...";
- ASSERT_API_OK(iree_rt_instance_register_driver_ex(
- instance, iree_string_view_t{driver_name.data(), driver_name.size()}));
-
- // Allocate a context that will hold the module state across invocations.
- iree_rt_policy_t* dummy_policy = nullptr;
- ASSERT_API_OK(iree_rt_policy_create(IREE_ALLOCATOR_DEFAULT, &dummy_policy));
- iree_rt_context_t* context = nullptr;
- ASSERT_API_OK(iree_rt_context_create(instance, dummy_policy,
- IREE_ALLOCATOR_DEFAULT, &context));
- iree_rt_policy_release(dummy_policy);
-
- // Load bytecode module from the embedded data.
- LOG(INFO) << "Loading simple_module_test.mlir...";
- const auto* module_file_toc = simple_module_test_bytecode_module_create();
- iree_rt_module_t* bytecode_module = nullptr;
- ASSERT_API_OK(iree_vm_bytecode_module_create_from_buffer(
- iree_const_byte_span_t{
- reinterpret_cast<const uint8_t*>(module_file_toc->data),
- module_file_toc->size},
- nullptr, nullptr, IREE_ALLOCATOR_DEFAULT, &bytecode_module));
-
- // Register modules that we want to be able to use in the context.
- std::vector<iree_rt_module_t*> modules;
- modules.push_back(bytecode_module);
- ASSERT_API_OK(
- iree_rt_context_register_modules(context, &modules[0], modules.size()));
- iree_rt_module_release(bytecode_module);
- LOG(INFO) << "Module loaded and context is ready for use";
-
- // Lookup the entry point function.
- iree_rt_function_t main_function;
- const char kMainFunctionName[] = "module.simple_mul";
- ASSERT_API_OK(iree_rt_context_resolve_function(
- context,
- iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1},
- &main_function));
-
- // Allocate buffers that can be mapped on the CPU and that can also be used
- // on the device. Not all devices support this, but the ones we have now do.
- LOG(INFO) << "Creating I/O buffers...";
- constexpr int kElementCount = 4;
- iree_hal_buffer_t* arg0_buffer = nullptr;
- iree_hal_buffer_t* arg1_buffer = nullptr;
- ASSERT_API_OK(iree_rt_context_allocate_device_visible_buffer(
- context, IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount,
- IREE_ALLOCATOR_DEFAULT, &arg0_buffer));
- ASSERT_API_OK(iree_rt_context_allocate_device_visible_buffer(
- context, IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount,
- IREE_ALLOCATOR_DEFAULT, &arg1_buffer));
-
- // Populate initial values for 4 * 2 = 8.
- float kFloat4 = 4.0f;
- float kFloat2 = 2.0f;
- ASSERT_API_OK(iree_hal_buffer_fill(arg0_buffer, 0, IREE_WHOLE_BUFFER,
- &kFloat4, sizeof(float)));
- ASSERT_API_OK(iree_hal_buffer_fill(arg1_buffer, 0, IREE_WHOLE_BUFFER,
- &kFloat2, sizeof(float)));
-
- // Wrap buffers in buffer views to provide shape information.
- std::array<iree_hal_buffer_view_t*, 2> arg_buffer_views;
- ASSERT_API_OK(iree_hal_buffer_view_create(
- arg0_buffer, iree_shape_t{1, {kElementCount}}, sizeof(float),
- IREE_ALLOCATOR_DEFAULT, &arg_buffer_views[0]));
- ASSERT_API_OK(iree_hal_buffer_view_create(
- arg1_buffer, iree_shape_t{1, {kElementCount}}, sizeof(float),
- IREE_ALLOCATOR_DEFAULT, &arg_buffer_views[1]));
- iree_hal_buffer_release(arg0_buffer);
- iree_hal_buffer_release(arg1_buffer);
-
- // Call into the @simple_mul function.
- LOG(INFO) << "Calling @simple_mul...";
- iree_rt_invocation_t* invocation = nullptr;
- ASSERT_API_OK(iree_rt_invocation_create(
- context, &main_function, nullptr, nullptr, arg_buffer_views.data(), 2,
- nullptr, 0, IREE_ALLOCATOR_DEFAULT, &invocation));
- ASSERT_API_OK(iree_hal_buffer_view_release(arg_buffer_views[0]));
- ASSERT_API_OK(iree_hal_buffer_view_release(arg_buffer_views[1]));
- ASSERT_API_OK(
- iree_rt_invocation_await(invocation, IREE_TIME_INFINITE_FUTURE));
-
- // Get the result buffers from the invocation.
- LOG(INFO) << "Retreiving results...";
- std::array<iree_hal_buffer_view_t*, 2> result_buffer_views;
- iree_host_size_t result_count;
- ASSERT_API_OK(iree_rt_invocation_consume_results(
- invocation, result_buffer_views.size(), IREE_ALLOCATOR_DEFAULT,
- result_buffer_views.data(), &result_count));
- iree_rt_invocation_release(invocation);
-
- // Read back the results and ensure we got the right values.
- LOG(INFO) << "Reading back results...";
- iree_hal_buffer_t* result_buffer =
- iree_hal_buffer_view_buffer(result_buffer_views[0]);
- iree_hal_mapped_memory_t mapped_memory;
- ASSERT_API_OK(iree_hal_buffer_map(result_buffer, IREE_HAL_MEMORY_ACCESS_READ,
- 0, IREE_WHOLE_BUFFER, &mapped_memory));
- ASSERT_THAT(absl::Span<const float>(
- reinterpret_cast<const float*>(mapped_memory.contents.data),
- mapped_memory.contents.data_length / sizeof(float)),
- ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f}));
- ASSERT_API_OK(iree_hal_buffer_unmap(result_buffer, &mapped_memory));
- LOG(INFO) << "Results match!";
-
- iree_hal_buffer_view_release(result_buffer_views[0]);
-
- iree_rt_context_release(context);
- iree_rt_instance_release(instance);
-}
-
-INSTANTIATE_TEST_SUITE_P(AllDrivers, BytecodeModuleApiTest,
- ::testing::ValuesIn(GetAvailableDriverTestParams()),
- ::testing::PrintToStringParamName());
-
-} // namespace
-} // namespace samples
-} // namespace rt
-} // namespace iree
diff --git a/iree/samples/rt/bytecode_module_test.cc b/iree/samples/rt/bytecode_module_test.cc
deleted file mode 100644
index 9671fd9..0000000
--- a/iree/samples/rt/bytecode_module_test.cc
+++ /dev/null
@@ -1,170 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// A simple sample demonstrating simple synchronous module loading and VM use.
-// This will load an IREE module containing a @simple_mul method that performs
-// an element-wise multiplication. It will invoke @simple_mul in the VM, once
-// for each available HAL driver linked into the binary.
-//
-// The synchronous invocation method (Context::Invoke) used here waits until all
-// asynchronous HAL work completes before returning. It's still possible get
-// overlapped execution by invoking methods from other threads with their own
-// FiberState, though it's best to use the asynchronous API instead.
-//
-// The `iree_module` build rule is used to translate the MLIR to the module
-// flatbuffer. Additional HAL backend target support can be defined there.
-
-#include "iree/vm/bytecode_module.h"
-
-#include "absl/strings/str_replace.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-#include "iree/base/flatbuffer_util.h"
-#include "iree/base/status.h"
-#include "iree/base/status_matchers.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/driver_registry.h"
-#include "iree/rt/context.h"
-#include "iree/rt/instance.h"
-#include "iree/samples/rt/simple_module_test_bytecode_module.h"
-#include "iree/schemas/module_def_generated.h"
-#include "iree/vm/sequencer_module.h"
-
-namespace iree {
-namespace rt {
-namespace samples {
-namespace {
-
-using ::iree::hal::BufferView;
-using ::iree::vm::ModuleFile;
-
-struct TestParams {
- // HAL driver to use for the test.
- std::string driver_name;
-};
-
-std::ostream& operator<<(std::ostream& os, const TestParams& params) {
- return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}});
-}
-
-// Builds a list of tests to run based on the linked in driver modules.
-std::vector<TestParams> GetAvailableDriverTestParams() {
- std::vector<TestParams> all_test_params;
- for (const auto& driver_name :
- hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) {
- TestParams test_params;
- test_params.driver_name = driver_name;
- all_test_params.push_back(std::move(test_params));
- }
- return all_test_params;
-}
-
-class BytecodeModuleTest : public ::testing::Test,
- public ::testing::WithParamInterface<TestParams> {
- protected:
-};
-
-TEST_P(BytecodeModuleTest, RunOnce) {
- auto instance = make_ref<Instance>();
-
- // Create driver for this test (based on params) and then get a default
- // device.
- const auto& test_params = GetParam();
- LOG(INFO) << "Creating driver '" << test_params.driver_name << "'...";
- auto driver_or =
- hal::DriverRegistry::shared_registry()->Create(test_params.driver_name);
- if (IsUnavailable(driver_or.status())) {
- LOG(WARNING) << "Skipping test as driver is unavailable: "
- << driver_or.status();
- GTEST_SKIP();
- return;
- }
- ASSERT_OK_AND_ASSIGN(auto driver, driver_or);
- ASSERT_OK_AND_ASSIGN(auto available_devices,
- driver->EnumerateAvailableDevices());
- for (const auto& device_info : available_devices) {
- LOG(INFO) << " Device: " << device_info.name();
- }
- LOG(INFO) << "Creating default device...";
- ASSERT_OK_AND_ASSIGN(auto device, driver->CreateDefaultDevice());
- ASSERT_OK(instance->device_manager()->RegisterDevice(device));
- LOG(INFO) << "Successfully created device '" << device->info().name() << "'";
-
- // Make a new context and load the precompiled module file (from
- // simple_module_test.mlir) into it.
- LOG(INFO) << "Loading simple_module_test.mlir...";
- auto policy = make_ref<Policy>();
- Context context(add_ref(instance), add_ref(policy));
- const auto* module_file_toc = simple_module_test_bytecode_module_create();
- ASSERT_OK_AND_ASSIGN(auto module_file,
- vm::ModuleFile::WrapBuffer(
- ModuleDefIdentifier(),
- absl::MakeSpan(reinterpret_cast<const uint8_t*>(
- module_file_toc->data),
- module_file_toc->size)));
- ASSERT_OK_AND_ASSIGN(auto main_module,
- vm::SequencerModule::FromFile(std::move(module_file)));
- ASSERT_OK(context.RegisterModule(std::move(main_module)));
- LOG(INFO) << "Module loaded and context is ready for use";
-
- // Allocate buffers that can be mapped on the CPU and that can also be used
- // on the device. Not all devices support this, but the ones we have now do.
- LOG(INFO) << "Creating I/O buffers...";
- constexpr int kElementCount = 4;
- ASSERT_OK_AND_ASSIGN(
- auto arg0_buffer,
- instance->device_manager()->AllocateDeviceVisibleBuffer(
- hal::BufferUsage::kAll, sizeof(float) * kElementCount, {{device}}));
- ASSERT_OK_AND_ASSIGN(
- auto arg1_buffer,
- instance->device_manager()->AllocateDeviceVisibleBuffer(
- hal::BufferUsage::kAll, sizeof(float) * kElementCount, {{device}}));
-
- // Populate initial values for 4 * 2 = 8.
- ASSERT_OK(arg0_buffer->Fill32(4.0f));
- ASSERT_OK(arg1_buffer->Fill32(2.0f));
-
- // Call into the @simple_mul function.
- LOG(INFO) << "Calling @simple_mul...";
- absl::InlinedVector<BufferView, 8> args{
- BufferView{add_ref(arg0_buffer), {kElementCount}, sizeof(float)},
- BufferView{add_ref(arg1_buffer), {kElementCount}, sizeof(float)},
- };
- ASSERT_OK_AND_ASSIGN(auto simple_mul,
- context.ResolveFunction("module.simple_mul"));
- ASSERT_OK_AND_ASSIGN(auto invocation,
- Invocation::Create(add_ref(&context), simple_mul,
- nullptr, {}, std::move(args)));
- ASSERT_OK(invocation->Await(absl::InfiniteFuture()));
- ASSERT_OK_AND_ASSIGN(auto results, invocation->ConsumeResults());
-
- // Read back the results and ensure we got the right values.
- LOG(INFO) << "Reading back results...";
- auto& ret_buffer_view = results[0];
- ASSERT_OK_AND_ASSIGN(
- auto ret_mapping,
- ret_buffer_view.buffer->MapMemory<float>(hal::MemoryAccess::kRead));
- ASSERT_THAT(ret_mapping.contents(),
- ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f}));
- LOG(INFO) << "Results match!";
-}
-
-INSTANTIATE_TEST_SUITE_P(AllDrivers, BytecodeModuleTest,
- ::testing::ValuesIn(GetAvailableDriverTestParams()),
- ::testing::PrintToStringParamName());
-
-} // namespace
-} // namespace samples
-} // namespace rt
-} // namespace iree
diff --git a/iree/schemas/BUILD b/iree/schemas/BUILD
deleted file mode 100644
index ebe0bd6..0000000
--- a/iree/schemas/BUILD
+++ /dev/null
@@ -1,252 +0,0 @@
-load("//iree:build_defs.bzl", "FLATBUFFER_SUPPORTS_REFLECTIONS", "iree_build_test", "iree_flatbuffer_cc_library")
-load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-FLATC_ARGS = [
- # Preserve workspace-relative include paths in generated code.
- "--keep-prefix",
- # Use C++11 'enum class' for enums.
- "--scoped-enums",
- # Include reflection tables used for dumping debug representations.
- "--reflect-names",
- # Generate FooT types for unpack/pack support. Note that this should only
- # be used in tooling as the code size/runtime overhead is non-trivial.
- "--gen-object-api",
-]
-
-iree_flatbuffer_cc_library(
- name = "archive_def_cc_fbs",
- srcs = ["archive_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ":bytecode_def_cc_fbs_includes",
- ":device_def_cc_fbs_includes",
- ":device_group_def_cc_fbs_includes",
- ":device_table_def_cc_fbs_includes",
- ":executable_def_cc_fbs_includes",
- ":executable_table_def_cc_fbs_includes",
- ":function_def_cc_fbs_includes",
- ":function_table_def_cc_fbs_includes",
- ":module_def_cc_fbs_includes",
- ":source_map_def_cc_fbs_includes",
- ":type_def_cc_fbs_includes",
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "bytecode_def_cc_fbs",
- srcs = ["bytecode_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "debug_service_cc_fbs",
- srcs = ["debug_service.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ":bytecode_def_cc_fbs_includes",
- ":device_def_cc_fbs_includes",
- ":device_group_def_cc_fbs_includes",
- ":device_table_def_cc_fbs_includes",
- ":executable_def_cc_fbs_includes",
- ":executable_table_def_cc_fbs_includes",
- ":function_def_cc_fbs_includes",
- ":function_table_def_cc_fbs_includes",
- ":module_def_cc_fbs_includes",
- ":source_map_def_cc_fbs_includes",
- ":type_def_cc_fbs_includes",
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "device_def_cc_fbs",
- srcs = ["device_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "device_group_def_cc_fbs",
- srcs = ["device_group_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "device_table_def_cc_fbs",
- srcs = ["device_table_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ":device_def_cc_fbs_includes",
- ":device_group_def_cc_fbs_includes",
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "executable_def_cc_fbs",
- srcs = ["executable_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "executable_table_def_cc_fbs",
- srcs = ["executable_table_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ":executable_def_cc_fbs_includes",
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "function_def_cc_fbs",
- srcs = ["function_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ":bytecode_def_cc_fbs_includes",
- ":type_def_cc_fbs_includes",
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "function_table_def_cc_fbs",
- srcs = ["function_table_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ":bytecode_def_cc_fbs_includes",
- ":function_def_cc_fbs_includes",
- ":type_def_cc_fbs_includes",
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "module_def_cc_fbs",
- srcs = ["module_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ":bytecode_def_cc_fbs_includes",
- ":device_def_cc_fbs_includes",
- ":device_group_def_cc_fbs_includes",
- ":device_table_def_cc_fbs_includes",
- ":executable_def_cc_fbs_includes",
- ":executable_table_def_cc_fbs_includes",
- ":function_def_cc_fbs_includes",
- ":function_table_def_cc_fbs_includes",
- ":source_map_def_cc_fbs_includes",
- ":type_def_cc_fbs_includes",
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "source_map_def_cc_fbs",
- srcs = ["source_map_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "spirv_executable_def_cc_fbs",
- srcs = ["spirv_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ],
-)
-
-iree_flatbuffer_cc_library(
- name = "type_def_cc_fbs",
- srcs = ["type_def.fbs"],
- flatc_args = FLATC_ARGS,
- includes = [
- ],
-)
-
-iree_build_test(
- name = "schema_build_test",
- targets = [
- ":archive_def_cc_fbs",
- ":bytecode_def_cc_fbs",
- ":debug_service_cc_fbs",
- ":device_def_cc_fbs",
- ":device_group_def_cc_fbs",
- ":device_table_def_cc_fbs",
- ":executable_def_cc_fbs",
- ":executable_table_def_cc_fbs",
- ":function_def_cc_fbs",
- ":function_table_def_cc_fbs",
- ":module_def_cc_fbs",
- ":source_map_def_cc_fbs",
- ":spirv_executable_def_cc_fbs",
- ":type_def_cc_fbs",
- ],
-)
-
-cc_library(
- name = "schemas",
- hdrs = [
- ":archive_def_generated.h",
- ":bytecode_def_generated.h",
- ":debug_service_generated.h",
- ":device_def_generated.h",
- ":device_group_def_generated.h",
- ":device_table_def_generated.h",
- ":executable_def_generated.h",
- ":executable_table_def_generated.h",
- ":function_def_generated.h",
- ":function_table_def_generated.h",
- ":module_def_generated.h",
- ":source_map_def_generated.h",
- ":type_def_generated.h",
- ],
- deps = [
- ":archive_def_cc_fbs",
- ":bytecode_def_cc_fbs",
- ":debug_service_cc_fbs",
- ":device_def_cc_fbs",
- ":device_group_def_cc_fbs",
- ":device_table_def_cc_fbs",
- ":executable_def_cc_fbs",
- ":executable_table_def_cc_fbs",
- ":function_def_cc_fbs",
- ":function_table_def_cc_fbs",
- ":module_def_cc_fbs",
- ":source_map_def_cc_fbs",
- ":spirv_executable_def_cc_fbs",
- ":type_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
- ],
-)
-
-REFLECTION_SRCS = [] if not FLATBUFFER_SUPPORTS_REFLECTIONS else [
- "archive_def.bfbs",
- "bytecode_def.bfbs",
- "debug_service.bfbs",
- "executable_def.bfbs",
- "executable_table_def.bfbs",
- "function_def.bfbs",
- "function_table_def.bfbs",
- "module_def.bfbs",
- "source_map_def.bfbs",
- "spirv_executable_def.bfbs",
- "type_def.bfbs",
- "device_def.bfbs",
- "device_group_def.bfbs",
- "device_table_def.bfbs",
-]
-
-cc_embed_data(
- name = "reflection_data",
- srcs = REFLECTION_SRCS,
- cc_file_output = "reflection_data.cc",
- cpp_namespace = "iree::schemas",
- h_file_output = "reflection_data.h",
-)
diff --git a/iree/schemas/archive_def.fbs b/iree/schemas/archive_def.fbs
deleted file mode 100644
index 2e96683..0000000
--- a/iree/schemas/archive_def.fbs
+++ /dev/null
@@ -1,14 +0,0 @@
-include "iree/schemas/module_def.fbs";
-
-namespace iree;
-
-// 'Executable ARChive'.
-file_identifier "EARC";
-file_extension "earc";
-
-table ArchiveDef {
- name:string;
- modules:[ModuleDef];
-}
-
-root_type ArchiveDef;
diff --git a/iree/schemas/bytecode/BUILD b/iree/schemas/bytecode/BUILD
deleted file mode 100644
index 11f5541..0000000
--- a/iree/schemas/bytecode/BUILD
+++ /dev/null
@@ -1,29 +0,0 @@
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "bytecode_v0",
- hdrs = ["bytecode_v0.h"],
- deps = [
- "//iree/base:bitfield",
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "interpreter_bytecode_v0",
- hdrs = ["interpreter_bytecode_v0.h"],
- deps = [
- ":bytecode_v0",
- ],
-)
-
-cc_library(
- name = "sequencer_bytecode_v0",
- hdrs = ["sequencer_bytecode_v0.h"],
- deps = [
- ":bytecode_v0",
- ],
-)
diff --git a/iree/schemas/bytecode/bytecode_v0.h b/iree/schemas/bytecode/bytecode_v0.h
deleted file mode 100644
index 46f56b8..0000000
--- a/iree/schemas/bytecode/bytecode_v0.h
+++ /dev/null
@@ -1,143 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Opcode table for the V0 binary format.
-// Additions are fine but changing the behavior or order of any opcodes will
-// break pasring of existing files.
-//
-// Opcodes have been selected on frequency of use, general applicability, and
-// relative stability. Experimental ops should be implemented via the FFI fisrt
-// before graduating into the core set. Ops that may only be present on certain
-// targets should also be kept as imports via the FFI.
-//
-// Opcodes may be specified for particular types (int32_t), categories of types
-// (all floating-point types), or implicit types (output matches input). Saving
-// opcode space by sharing a single opcode for multiple types is preferred
-// except where hot operations are performed (for example, comparison used in
-// loop iteratosr).
-
-#ifndef IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_
-#define IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_
-
-#include <cstdint>
-
-#include "iree/base/bitfield.h"
-
-namespace iree {
-
-#define IREE_CONSTANT_ENCODING_LIST(ENC) \
- ENC(0x00, kDense, "dense") \
- ENC(0x01, kSplat, "splat")
-
-#define IREE_TYPE_LIST(TYP) \
- TYP(0x00, kI8, "i8", 1) \
- TYP(0x01, kI16, "i16", 2) \
- TYP(0x02, kI32, "i32", 4) \
- TYP(0x03, kI64, "i64", 8) \
- TYP(0x04, kF16, "f16", 2) \
- TYP(0x05, kF32, "f32", 4) \
- TYP(0x06, kF64, "f64", 8) \
- TYP(0x80, kDevice, "device", 0) \
- TYP(0x81, kCommandBuffer, "command_buffer", 0) \
- TYP(0x82, kEvent, "event", 0) \
- TYP(0x83, kSemaphore, "semaphore", 0) \
- TYP(0x84, kFence, "fence", 0) \
- TYP(0xFF, kOpaque, "opaque", 0)
-
-#define IREE_CMPI_PREDICATE_LIST(PRED) \
- PRED(0, kEq, "eq") \
- PRED(1, kNe, "ne") \
- PRED(2, kSlt, "slt") \
- PRED(3, kSle, "sle") \
- PRED(4, kSgt, "sgt") \
- PRED(5, kSge, "sge") \
- PRED(6, kUlt, "ult") \
- PRED(7, kUle, "ule") \
- PRED(8, kUgt, "ugt") \
- PRED(9, kUge, "uge")
-
-#define IREE_CMPF_PREDICATE_LIST(PRED) \
- PRED(0, kFalse, "false") \
- PRED(1, kOeq, "oeq") \
- PRED(2, kOgt, "ogt") \
- PRED(3, kOge, "oge") \
- PRED(4, kOlt, "olt") \
- PRED(5, kOle, "ole") \
- PRED(6, kOne, "one") \
- PRED(7, kOrd, "ord") \
- PRED(8, kUeq, "ueq") \
- PRED(9, kUgt, "ugt") \
- PRED(10, kUge, "uge") \
- PRED(11, kUlt, "ult") \
- PRED(12, kUle, "ule") \
- PRED(13, kUne, "une") \
- PRED(14, kUno, "uno") \
- PRED(15, kTrue, "true")
-
-// NOTE: FF is a to-be-defined flag value for encoding/decoding.
-#define FLAG(V) ::iree::OpcodeFlag::V
-
-#define RSV(opcode, RESERVED_OPC) \
- RESERVED_OPC(opcode, kReserved##opcode, "rsv." #opcode, FLAG(kDefault), "", \
- FF)
-
-#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal,
-
-enum class ConstantEncoding : uint8_t {
- IREE_CONSTANT_ENCODING_LIST(DECLARE_ENUM)
-};
-
-enum class BuiltinType : uint8_t { IREE_TYPE_LIST(DECLARE_ENUM) };
-
-enum class CmpIPredicate : uint8_t { IREE_CMPI_PREDICATE_LIST(DECLARE_ENUM) };
-
-enum class CmpFPredicate : uint8_t { IREE_CMPF_PREDICATE_LIST(DECLARE_ENUM) };
-
-#undef DECLARE_ENUM
-
-static constexpr uint8_t kBuiltinTypeCount =
- static_cast<uint8_t>(BuiltinType::kF64) + 1;
-
-enum class OpcodeFlag : uint8_t {
- kDefault = 0,
-};
-IREE_BITFIELD(OpcodeFlag);
-using OpcodeFlagBitfield = OpcodeFlag;
-
-enum class OperandEncoding : char {
- kNone = '\0',
- kInputSlot = 's',
- kVariadicInputSlots = 'S',
- kOutputSlot = 'o',
- kVariadicOutputSlots = 'O',
- kResultSlot = 'r',
- kVariadicResultSlots = 'R',
- kVariadicTransferSlots = 'T',
- kConstant = 'c',
- kFunctionOrdinal = 'f',
- kImportOrdinal = 'F',
- kDispatchOrdinal = 'd',
- kBlockOffset = 'b',
- kTypeIndex = 't',
- kIndex = 'i',
- kIndexList = 'I',
- kCmpIPredicate = 'p',
- kCmpFPredicate = 'P',
-};
-IREE_BITFIELD(OperandEncoding);
-using OperandEncodingBitfield = OperandEncoding;
-
-} // namespace iree
-
-#endif // IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_
diff --git a/iree/schemas/bytecode/interpreter_bytecode_v0.h b/iree/schemas/bytecode/interpreter_bytecode_v0.h
deleted file mode 100644
index 05a0c56..0000000
--- a/iree/schemas/bytecode/interpreter_bytecode_v0.h
+++ /dev/null
@@ -1,325 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Opcode table for the V0 binary format.
-// Additions are fine but changing the behavior or order of any opcodes will
-// break parsing of existing files.
-//
-// Opcodes have been selected on frequency of use, general applicability, and
-// relative stability. Experimental ops should be implemented via the Foreign
-// Function Interface (FFI) first before graduating into the core set. Ops that
-// may only be present on certain targets should also be kept as imports via the
-// FFI.
-//
-// Opcodes may be specified for particular types (int32_t), categories of types
-// (all floating-point types), or implicit types (output matches input). Saving
-// opcode space by sharing a single opcode for multiple types is preferred
-// except where hot operations are performed (for example, comparison used in
-// loop iterators).
-
-#ifndef IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_
-#define IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_
-
-#include "iree/schemas/bytecode/bytecode_v0.h"
-
-namespace iree {
-
-#define IREE_INTERPRETER_OPCODE_LIST(OPC, RESERVED_OPC) \
- OPC(0x00, kConstant, "constant", FLAG(kDefault), "cr", FF) \
- \
- OPC(0x01, kCall, "call", FLAG(kDefault), "fSR", FF) \
- OPC(0x02, kCallImport, "call_import", FLAG(kDefault), "FSR", FF) \
- OPC(0x03, kCallIndirect, "call_indirect", FLAG(kDefault), "tsSR", FF) \
- OPC(0x04, kReturn, "return", FLAG(kDefault), "S", FF) \
- OPC(0x05, kBranch, "br", FLAG(kDefault), "bT", FF) \
- OPC(0x06, kCondBranch, "cond_br", FLAG(kDefault), "sbTbT", FF) \
- OPC(0x07, kCmpI, "cmp_i", FLAG(kDefault), "psso", FF) \
- OPC(0x08, kCmpF, "cmp_f", FLAG(kDefault), "Psso", FF) \
- \
- RSV(0x09, RESERVED_OPC) \
- RSV(0x0A, RESERVED_OPC) \
- RSV(0x0B, RESERVED_OPC) \
- RSV(0x0C, RESERVED_OPC) \
- RSV(0x0D, RESERVED_OPC) \
- RSV(0x0E, RESERVED_OPC) \
- RSV(0x0F, RESERVED_OPC) \
- RSV(0x10, RESERVED_OPC) \
- RSV(0x11, RESERVED_OPC) \
- RSV(0x12, RESERVED_OPC) \
- RSV(0x13, RESERVED_OPC) \
- RSV(0x14, RESERVED_OPC) \
- RSV(0x15, RESERVED_OPC) \
- RSV(0x16, RESERVED_OPC) \
- RSV(0x17, RESERVED_OPC) \
- RSV(0x18, RESERVED_OPC) \
- RSV(0x19, RESERVED_OPC) \
- RSV(0x1A, RESERVED_OPC) \
- RSV(0x1B, RESERVED_OPC) \
- RSV(0x1C, RESERVED_OPC) \
- RSV(0x1D, RESERVED_OPC) \
- RSV(0x1E, RESERVED_OPC) \
- RSV(0x1F, RESERVED_OPC) \
- \
- OPC(0x20, kAllocStatic, "alloc_static", FLAG(kDefault), "Icr", FF) \
- OPC(0x21, kAllocStack, "alloc_stack", FLAG(kDefault), "itISr", FF) \
- OPC(0x22, kAllocStackInit, "alloc_stack_init", FLAG(kDefault), "tIScr", FF) \
- OPC(0x23, kAllocHeap, "alloc_heap", FLAG(kDefault), "itISr", FF) \
- OPC(0x24, kDiscard, "discard", FLAG(kDefault), "s", FF) \
- \
- RSV(0x25, RESERVED_OPC) \
- RSV(0x26, RESERVED_OPC) \
- RSV(0x27, RESERVED_OPC) \
- RSV(0x28, RESERVED_OPC) \
- RSV(0x29, RESERVED_OPC) \
- RSV(0x2A, RESERVED_OPC) \
- RSV(0x2B, RESERVED_OPC) \
- RSV(0x2C, RESERVED_OPC) \
- RSV(0x2D, RESERVED_OPC) \
- RSV(0x2E, RESERVED_OPC) \
- RSV(0x2F, RESERVED_OPC) \
- \
- OPC(0x30, kRank, "rank", FLAG(kDefault), "so", FF) \
- OPC(0x31, kDim, "dim", FLAG(kDefault), "iso", FF) \
- OPC(0x32, kShape, "shape", FLAG(kDefault), "so", FF) \
- OPC(0x33, kLength, "length", FLAG(kDefault), "so", FF) \
- OPC(0x34, kDynamicSlice, "dynamic_slice", FLAG(kDefault), "sssr", FF) \
- OPC(0x35, kStaticSlice, "static_slice", FLAG(kDefault), "sIIr", FF) \
- OPC(0x36, kDynamicCopy, "dynamic_copy", FLAG(kDefault), "ssoss", FF) \
- OPC(0x37, kStaticCopy, "static_copy", FLAG(kDefault), "sIoII", FF) \
- OPC(0x38, kClone, "clone", FLAG(kDefault), "sr", FF) \
- RSV(0x39, RESERVED_OPC) \
- OPC(0x3A, kSplit, "split", FLAG(kDefault), "isR", FF) \
- OPC(0x3B, kAssign, "assign", FLAG(kDefault), "sr", FF) \
- OPC(0x3C, kCondAssign, "cond_assign", FLAG(kDefault), "sssr", FF) \
- OPC(0x3D, kReshape, "reshape", FLAG(kDefault), "ssr", FF) \
- OPC(0x3E, kSelect, "select", FLAG(kDefault), "ssso", FF) \
- OPC(0x3F, kTranspose, "transpose", FLAG(kDefault), "sso", FF) \
- OPC(0x40, kBroadcast, "broadcast", FLAG(kDefault), "sso", FF) \
- OPC(0x41, kTile, "tile", FLAG(kDefault), "sso", FF) \
- OPC(0x42, kReverse, "reverse", FLAG(kDefault), "sso", FF) \
- OPC(0x43, kPad, "pad", FLAG(kDefault), "ssssso", FF) \
- \
- RSV(0x44, RESERVED_OPC) \
- RSV(0x45, RESERVED_OPC) \
- RSV(0x46, RESERVED_OPC) \
- RSV(0x47, RESERVED_OPC) \
- RSV(0x48, RESERVED_OPC) \
- RSV(0x49, RESERVED_OPC) \
- RSV(0x4A, RESERVED_OPC) \
- RSV(0x4B, RESERVED_OPC) \
- RSV(0x4C, RESERVED_OPC) \
- RSV(0x4D, RESERVED_OPC) \
- RSV(0x4E, RESERVED_OPC) \
- RSV(0x4F, RESERVED_OPC) \
- \
- OPC(0x50, kNot, "not", FLAG(kDefault), "so", FF) \
- OPC(0x51, kAnd, "and", FLAG(kDefault), "sso", FF) \
- OPC(0x52, kOr, "or", FLAG(kDefault), "sso", FF) \
- OPC(0x53, kXor, "xor", FLAG(kDefault), "sso", FF) \
- OPC(0x54, kShiftLeft, "sll", FLAG(kDefault), "sso", FF) \
- OPC(0x55, kShiftRightLogical, "srl", FLAG(kDefault), "sso", FF) \
- OPC(0x56, kShiftRightArithmetic, "sra", FLAG(kDefault), "sso", FF) \
- \
- RSV(0x57, RESERVED_OPC) \
- RSV(0x58, RESERVED_OPC) \
- RSV(0x59, RESERVED_OPC) \
- RSV(0x5A, RESERVED_OPC) \
- RSV(0x5B, RESERVED_OPC) \
- RSV(0x5C, RESERVED_OPC) \
- RSV(0x5D, RESERVED_OPC) \
- RSV(0x5E, RESERVED_OPC) \
- RSV(0x5F, RESERVED_OPC) \
- RSV(0x60, RESERVED_OPC) \
- RSV(0x61, RESERVED_OPC) \
- RSV(0x62, RESERVED_OPC) \
- RSV(0x63, RESERVED_OPC) \
- RSV(0x64, RESERVED_OPC) \
- RSV(0x65, RESERVED_OPC) \
- RSV(0x66, RESERVED_OPC) \
- RSV(0x67, RESERVED_OPC) \
- RSV(0x68, RESERVED_OPC) \
- RSV(0x69, RESERVED_OPC) \
- RSV(0x6A, RESERVED_OPC) \
- RSV(0x6B, RESERVED_OPC) \
- RSV(0x6C, RESERVED_OPC) \
- RSV(0x6D, RESERVED_OPC) \
- RSV(0x6E, RESERVED_OPC) \
- RSV(0x6F, RESERVED_OPC) \
- \
- /* TODO(benvanik): remove ones we don't need/can emulate */ \
- OPC(0x70, kAddI, "add_i", FLAG(kDefault), "sso", FF) \
- OPC(0x71, kAddF, "add_f", FLAG(kDefault), "sso", FF) \
- OPC(0x72, kSubI, "sub_i", FLAG(kDefault), "sso", FF) \
- OPC(0x73, kSubF, "sub_f", FLAG(kDefault), "sso", FF) \
- OPC(0x74, kAbsI, "abs_i", FLAG(kDefault), "so", FF) \
- OPC(0x75, kAbsF, "abs_f", FLAG(kDefault), "so", FF) \
- OPC(0x76, kMulI, "mul_i", FLAG(kDefault), "sso", FF) \
- OPC(0x77, kMulF, "mul_f", FLAG(kDefault), "sso", FF) \
- OPC(0x78, kDivIS, "div_i_s", FLAG(kDefault), "sso", FF) \
- OPC(0x79, kDivIU, "div_i_u", FLAG(kDefault), "sso", FF) \
- OPC(0x7A, kDivF, "div_f", FLAG(kDefault), "sso", FF) \
- OPC(0x7B, kMulAddI, "madd_i", FLAG(kDefault), "ssso", FF) \
- OPC(0x7C, kMulAddF, "madd_f", FLAG(kDefault), "ssso", FF) \
- OPC(0x7D, kCosF, "cos_f", FLAG(kDefault), "so", FF) \
- OPC(0x7E, kSinF, "sin_f", FLAG(kDefault), "so", FF) \
- OPC(0x7F, kTanhF, "tanh_f", FLAG(kDefault), "so", FF) \
- OPC(0x80, kAtan2F, "atan2_f", FLAG(kDefault), "sso", FF) \
- OPC(0x81, kExpF, "exp_f", FLAG(kDefault), "so", FF) \
- OPC(0x82, kLogF, "log_f", FLAG(kDefault), "so", FF) \
- OPC(0x83, kRsqrtF, "rsqrt_f", FLAG(kDefault), "so", FF) \
- \
- RSV(0x84, RESERVED_OPC) \
- RSV(0x85, RESERVED_OPC) \
- RSV(0x86, RESERVED_OPC) \
- RSV(0x87, RESERVED_OPC) \
- RSV(0x88, RESERVED_OPC) \
- RSV(0x89, RESERVED_OPC) \
- RSV(0x8A, RESERVED_OPC) \
- RSV(0x8B, RESERVED_OPC) \
- RSV(0x8C, RESERVED_OPC) \
- RSV(0x8D, RESERVED_OPC) \
- RSV(0x8E, RESERVED_OPC) \
- RSV(0x8F, RESERVED_OPC) \
- \
- OPC(0x90, kMinIS, "min_i_s", FLAG(kDefault), "sso", FF) \
- OPC(0x91, kMinIU, "min_i_u", FLAG(kDefault), "sso", FF) \
- OPC(0x92, kMinF, "min_f", FLAG(kDefault), "sso", FF) \
- OPC(0x93, kMaxIS, "max_i_s", FLAG(kDefault), "sso", FF) \
- OPC(0x94, kMaxIU, "max_i_u", FLAG(kDefault), "sso", FF) \
- OPC(0x95, kMaxF, "max_f", FLAG(kDefault), "sso", FF) \
- OPC(0x96, kClampIS, "clamp_i_s", FLAG(kDefault), "ssso", FF) \
- OPC(0x97, kClampIU, "clamp_i_u", FLAG(kDefault), "ssso", FF) \
- OPC(0x98, kClampF, "clamp_f", FLAG(kDefault), "ssso", FF) \
- OPC(0x99, kFloorF, "floor_f", FLAG(kDefault), "so", FF) \
- OPC(0x9A, kCeilF, "ceil_f", FLAG(kDefault), "so", FF) \
- \
- OPC(0x9B, kConvertSS, "convert_s_s", FLAG(kDefault), "tsto", FF) \
- OPC(0x9C, kConvertUU, "convert_u_u", FLAG(kDefault), "tsto", FF) \
- OPC(0x9D, kConvertSU, "convert_s_u", FLAG(kDefault), "tsto", FF) \
- OPC(0x9E, kConvertUS, "convert_u_s", FLAG(kDefault), "tsto", FF) \
- \
- RSV(0x9F, RESERVED_OPC) \
- \
- /* TODO(benvanik): reduction/sum/etc */ \
- /* TODO(benvanik): sort */ \
- \
- OPC(0xA0, kMatMulI, "matmul_i", FLAG(kDefault), "sssso", FF) \
- OPC(0xA1, kMatMulF, "matmul_f", FLAG(kDefault), "sso", FF) \
- /* TODO(benvanik): convolution */ \
- \
- OPC(0xA2, kReduceSumI, "reduce_sum_i", FLAG(kDefault), "ssio", FF) \
- OPC(0xA3, kReduceSumF, "reduce_sum_f", FLAG(kDefault), "ssio", FF) \
- OPC(0xA4, kReduceMinI, "reduce_min_i", FLAG(kDefault), "ssio", FF) \
- OPC(0xA5, kReduceMinF, "reduce_min_f", FLAG(kDefault), "ssio", FF) \
- OPC(0xA6, kReduceMaxI, "reduce_max_i", FLAG(kDefault), "ssio", FF) \
- OPC(0xA7, kReduceMaxF, "reduce_max_f", FLAG(kDefault), "ssio", FF) \
- RSV(0xA8, RESERVED_OPC) \
- RSV(0xA9, RESERVED_OPC) \
- RSV(0xAA, RESERVED_OPC) \
- RSV(0xAB, RESERVED_OPC) \
- RSV(0xAC, RESERVED_OPC) \
- RSV(0xAD, RESERVED_OPC) \
- RSV(0xAE, RESERVED_OPC) \
- RSV(0xAF, RESERVED_OPC) \
- RSV(0xB0, RESERVED_OPC) \
- RSV(0xB1, RESERVED_OPC) \
- RSV(0xB2, RESERVED_OPC) \
- RSV(0xB3, RESERVED_OPC) \
- RSV(0xB4, RESERVED_OPC) \
- RSV(0xB5, RESERVED_OPC) \
- RSV(0xB6, RESERVED_OPC) \
- RSV(0xB7, RESERVED_OPC) \
- RSV(0xB8, RESERVED_OPC) \
- RSV(0xB9, RESERVED_OPC) \
- RSV(0xBA, RESERVED_OPC) \
- RSV(0xBB, RESERVED_OPC) \
- RSV(0xBC, RESERVED_OPC) \
- RSV(0xBD, RESERVED_OPC) \
- RSV(0xBE, RESERVED_OPC) \
- RSV(0xBF, RESERVED_OPC) \
- RSV(0xC0, RESERVED_OPC) \
- RSV(0xC1, RESERVED_OPC) \
- RSV(0xC2, RESERVED_OPC) \
- RSV(0xC3, RESERVED_OPC) \
- RSV(0xC4, RESERVED_OPC) \
- RSV(0xC5, RESERVED_OPC) \
- RSV(0xC6, RESERVED_OPC) \
- RSV(0xC7, RESERVED_OPC) \
- RSV(0xC8, RESERVED_OPC) \
- RSV(0xC9, RESERVED_OPC) \
- RSV(0xCA, RESERVED_OPC) \
- RSV(0xCB, RESERVED_OPC) \
- RSV(0xCC, RESERVED_OPC) \
- RSV(0xCD, RESERVED_OPC) \
- RSV(0xCE, RESERVED_OPC) \
- RSV(0xCF, RESERVED_OPC) \
- RSV(0xD0, RESERVED_OPC) \
- RSV(0xD1, RESERVED_OPC) \
- RSV(0xD2, RESERVED_OPC) \
- RSV(0xD3, RESERVED_OPC) \
- RSV(0xD4, RESERVED_OPC) \
- RSV(0xD5, RESERVED_OPC) \
- RSV(0xD6, RESERVED_OPC) \
- RSV(0xD7, RESERVED_OPC) \
- RSV(0xD8, RESERVED_OPC) \
- RSV(0xD9, RESERVED_OPC) \
- RSV(0xDA, RESERVED_OPC) \
- RSV(0xDB, RESERVED_OPC) \
- RSV(0xDC, RESERVED_OPC) \
- RSV(0xDD, RESERVED_OPC) \
- RSV(0xDE, RESERVED_OPC) \
- RSV(0xDF, RESERVED_OPC) \
- RSV(0xE0, RESERVED_OPC) \
- RSV(0xE1, RESERVED_OPC) \
- RSV(0xE2, RESERVED_OPC) \
- RSV(0xE3, RESERVED_OPC) \
- RSV(0xE4, RESERVED_OPC) \
- RSV(0xE5, RESERVED_OPC) \
- RSV(0xE6, RESERVED_OPC) \
- RSV(0xE7, RESERVED_OPC) \
- RSV(0xE8, RESERVED_OPC) \
- RSV(0xE9, RESERVED_OPC) \
- RSV(0xEA, RESERVED_OPC) \
- RSV(0xEB, RESERVED_OPC) \
- RSV(0xEC, RESERVED_OPC) \
- RSV(0xED, RESERVED_OPC) \
- RSV(0xEE, RESERVED_OPC) \
- RSV(0xEF, RESERVED_OPC) \
- RSV(0xF0, RESERVED_OPC) \
- RSV(0xF1, RESERVED_OPC) \
- RSV(0xF2, RESERVED_OPC) \
- RSV(0xF3, RESERVED_OPC) \
- RSV(0xF4, RESERVED_OPC) \
- RSV(0xF5, RESERVED_OPC) \
- RSV(0xF6, RESERVED_OPC) \
- RSV(0xF7, RESERVED_OPC) \
- RSV(0xF8, RESERVED_OPC) \
- RSV(0xF9, RESERVED_OPC) \
- RSV(0xFA, RESERVED_OPC) \
- RSV(0xFB, RESERVED_OPC) \
- RSV(0xFC, RESERVED_OPC) \
- \
- OPC(0xFD, kTrace, "trace", FLAG(kDefault), "s", FF) \
- OPC(0xFE, kCondBreak, "cond_break", FLAG(kDefault), "s", FF) \
- OPC(0xFF, kBreak, "break", FLAG(kDefault), "", FF)
-
-#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal,
-enum class InterpreterOpcode : uint8_t {
- IREE_INTERPRETER_OPCODE_LIST(DECLARE_ENUM, DECLARE_ENUM)
-};
-#undef DECLARE_ENUM
-
-} // namespace iree
-
-#endif // IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_
diff --git a/iree/schemas/bytecode/sequencer_bytecode_v0.h b/iree/schemas/bytecode/sequencer_bytecode_v0.h
deleted file mode 100644
index 7cba224..0000000
--- a/iree/schemas/bytecode/sequencer_bytecode_v0.h
+++ /dev/null
@@ -1,313 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Opcode table for the V0 binary format.
-// Additions are fine but changing the behavior or order of any opcodes will
-// break parsing of existing files.
-//
-// Opcodes have been selected on frequency of use, general applicability, and
-// relative stability. Experimental ops should be implemented via the Foreign
-// Function Interface (FFI) first before graduating into the core set. Ops that
-// may only be present on certain targets should also be kept as imports via the
-// FFI.
-//
-// Opcodes may be specified for particular types (int32_t), categories of types
-// (all floating-point types), or implicit types (output matches input). Saving
-// opcode space by sharing a single opcode for multiple types is preferred
-// except where hot operations are performed (for example, comparison used in
-// loop iterators).
-
-#ifndef IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_
-#define IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_
-
-#include "iree/schemas/bytecode/bytecode_v0.h"
-
-namespace iree {
-
-#define IREE_SEQUENCER_OPCODE_LIST(OPC, RESERVED_OPC) \
- OPC(0x00, kConstant, "constant", FLAG(kDefault), "cr", FF) \
- \
- OPC(0x01, kCall, "call", FLAG(kDefault), "fSR", FF) \
- OPC(0x02, kCallImport, "call_import", FLAG(kDefault), "FSR", FF) \
- OPC(0x03, kCallIndirect, "call_indirect", FLAG(kDefault), "tsSR", FF) \
- OPC(0x04, kReturn, "return", FLAG(kDefault), "S", FF) \
- OPC(0x05, kBranch, "br", FLAG(kDefault), "bT", FF) \
- OPC(0x06, kCondBranch, "cond_br", FLAG(kDefault), "sbTbT", FF) \
- \
- RSV(0x07, RESERVED_OPC) \
- RSV(0x08, RESERVED_OPC) \
- RSV(0x09, RESERVED_OPC) \
- RSV(0x0A, RESERVED_OPC) \
- RSV(0x0B, RESERVED_OPC) \
- RSV(0x0C, RESERVED_OPC) \
- RSV(0x0D, RESERVED_OPC) \
- RSV(0x0E, RESERVED_OPC) \
- RSV(0x0F, RESERVED_OPC) \
- \
- OPC(0x10, kDynamicDispatch, "dynamic_dispatch", FLAG(kDefault), "dsSOR", FF) \
- OPC(0x11, kStaticDispatch, "static_dispatch", FLAG(kDefault), "diiiSOR", FF) \
- \
- RSV(0x12, RESERVED_OPC) \
- RSV(0x13, RESERVED_OPC) \
- RSV(0x14, RESERVED_OPC) \
- RSV(0x15, RESERVED_OPC) \
- RSV(0x16, RESERVED_OPC) \
- RSV(0x17, RESERVED_OPC) \
- RSV(0x18, RESERVED_OPC) \
- RSV(0x19, RESERVED_OPC) \
- RSV(0x1A, RESERVED_OPC) \
- RSV(0x1B, RESERVED_OPC) \
- RSV(0x1C, RESERVED_OPC) \
- RSV(0x1D, RESERVED_OPC) \
- RSV(0x1E, RESERVED_OPC) \
- RSV(0x1F, RESERVED_OPC) \
- \
- OPC(0x20, kAllocStatic, "alloc_static", FLAG(kDefault), "Icr", FF) \
- OPC(0x21, kAllocStack, "alloc_stack", FLAG(kDefault), "itISr", FF) \
- OPC(0x22, kAllocStackInit, "alloc_stack_init", FLAG(kDefault), "tIScr", FF) \
- OPC(0x23, kAllocHeap, "alloc_heap", FLAG(kDefault), "itISr", FF) \
- OPC(0x24, kDiscard, "discard", FLAG(kDefault), "s", FF) \
- \
- RSV(0x25, RESERVED_OPC) \
- RSV(0x26, RESERVED_OPC) \
- RSV(0x27, RESERVED_OPC) \
- RSV(0x28, RESERVED_OPC) \
- RSV(0x29, RESERVED_OPC) \
- RSV(0x2A, RESERVED_OPC) \
- RSV(0x2B, RESERVED_OPC) \
- RSV(0x2C, RESERVED_OPC) \
- RSV(0x2D, RESERVED_OPC) \
- RSV(0x2E, RESERVED_OPC) \
- RSV(0x2F, RESERVED_OPC) \
- RSV(0x30, RESERVED_OPC) \
- \
- OPC(0x31, kComputeRange, "compute_range", FLAG(kDefault), "sissoo", FF) \
- OPC(0x32, kShape, "shape", FLAG(kDefault), "so", FF) \
- OPC(0x33, kLength, "length", FLAG(kDefault), "so", FF) \
- OPC(0x34, kDynamicSlice, "dynamic_slice", FLAG(kDefault), "ssstsr", FF) \
- OPC(0x35, kStaticSlice, "static_slice", FLAG(kDefault), "siitIr", FF) \
- OPC(0x36, kDynamicCopy, "dynamic_copy", FLAG(kDefault), "ssoss", FF) \
- OPC(0x37, kStaticCopy, "static_copy", FLAG(kDefault), "sioii", FF) \
- OPC(0x38, kDynamicFill, "dynamic_fill", FLAG(kDefault), "soss", FF) \
- OPC(0x39, kStaticFill, "static_fill", FLAG(kDefault), "ioii", FF) \
- OPC(0x3A, kClone, "clone", FLAG(kDefault), "sr", FF) \
- OPC(0x3B, kAssign, "assign", FLAG(kDefault), "sr", FF) \
- OPC(0x3C, kCondAssign, "cond_assign", FLAG(kDefault), "sssr", FF) \
- OPC(0x3D, kReshape, "reshape", FLAG(kDefault), "ssr", FF) \
- \
- RSV(0x3E, RESERVED_OPC) \
- RSV(0x3F, RESERVED_OPC) \
- RSV(0x40, RESERVED_OPC) \
- RSV(0x41, RESERVED_OPC) \
- RSV(0x42, RESERVED_OPC) \
- RSV(0x43, RESERVED_OPC) \
- RSV(0x44, RESERVED_OPC) \
- RSV(0x45, RESERVED_OPC) \
- RSV(0x46, RESERVED_OPC) \
- RSV(0x47, RESERVED_OPC) \
- RSV(0x48, RESERVED_OPC) \
- RSV(0x49, RESERVED_OPC) \
- RSV(0x4A, RESERVED_OPC) \
- RSV(0x4B, RESERVED_OPC) \
- RSV(0x4C, RESERVED_OPC) \
- RSV(0x4D, RESERVED_OPC) \
- RSV(0x4E, RESERVED_OPC) \
- RSV(0x4F, RESERVED_OPC) \
- RSV(0x50, RESERVED_OPC) \
- RSV(0x51, RESERVED_OPC) \
- RSV(0x52, RESERVED_OPC) \
- RSV(0x53, RESERVED_OPC) \
- RSV(0x54, RESERVED_OPC) \
- RSV(0x55, RESERVED_OPC) \
- RSV(0x56, RESERVED_OPC) \
- RSV(0x57, RESERVED_OPC) \
- RSV(0x58, RESERVED_OPC) \
- RSV(0x59, RESERVED_OPC) \
- RSV(0x5A, RESERVED_OPC) \
- RSV(0x5B, RESERVED_OPC) \
- RSV(0x5C, RESERVED_OPC) \
- RSV(0x5D, RESERVED_OPC) \
- RSV(0x5E, RESERVED_OPC) \
- RSV(0x5F, RESERVED_OPC) \
- RSV(0x60, RESERVED_OPC) \
- RSV(0x61, RESERVED_OPC) \
- RSV(0x62, RESERVED_OPC) \
- RSV(0x63, RESERVED_OPC) \
- RSV(0x64, RESERVED_OPC) \
- RSV(0x65, RESERVED_OPC) \
- RSV(0x66, RESERVED_OPC) \
- RSV(0x67, RESERVED_OPC) \
- RSV(0x68, RESERVED_OPC) \
- RSV(0x69, RESERVED_OPC) \
- RSV(0x6A, RESERVED_OPC) \
- RSV(0x6B, RESERVED_OPC) \
- RSV(0x6C, RESERVED_OPC) \
- RSV(0x6D, RESERVED_OPC) \
- RSV(0x6E, RESERVED_OPC) \
- RSV(0x6F, RESERVED_OPC) \
- RSV(0x70, RESERVED_OPC) \
- RSV(0x71, RESERVED_OPC) \
- RSV(0x72, RESERVED_OPC) \
- RSV(0x73, RESERVED_OPC) \
- RSV(0x74, RESERVED_OPC) \
- RSV(0x75, RESERVED_OPC) \
- RSV(0x76, RESERVED_OPC) \
- RSV(0x77, RESERVED_OPC) \
- RSV(0x78, RESERVED_OPC) \
- RSV(0x79, RESERVED_OPC) \
- RSV(0x7A, RESERVED_OPC) \
- RSV(0x7B, RESERVED_OPC) \
- RSV(0x7C, RESERVED_OPC) \
- RSV(0x7D, RESERVED_OPC) \
- RSV(0x7E, RESERVED_OPC) \
- RSV(0x7F, RESERVED_OPC) \
- RSV(0x80, RESERVED_OPC) \
- RSV(0x81, RESERVED_OPC) \
- RSV(0x82, RESERVED_OPC) \
- RSV(0x83, RESERVED_OPC) \
- RSV(0x84, RESERVED_OPC) \
- RSV(0x85, RESERVED_OPC) \
- RSV(0x86, RESERVED_OPC) \
- RSV(0x87, RESERVED_OPC) \
- RSV(0x88, RESERVED_OPC) \
- RSV(0x89, RESERVED_OPC) \
- RSV(0x8A, RESERVED_OPC) \
- RSV(0x8B, RESERVED_OPC) \
- RSV(0x8C, RESERVED_OPC) \
- RSV(0x8D, RESERVED_OPC) \
- RSV(0x8E, RESERVED_OPC) \
- RSV(0x8F, RESERVED_OPC) \
- RSV(0x90, RESERVED_OPC) \
- RSV(0x91, RESERVED_OPC) \
- RSV(0x92, RESERVED_OPC) \
- RSV(0x93, RESERVED_OPC) \
- RSV(0x94, RESERVED_OPC) \
- RSV(0x95, RESERVED_OPC) \
- RSV(0x96, RESERVED_OPC) \
- RSV(0x97, RESERVED_OPC) \
- RSV(0x98, RESERVED_OPC) \
- RSV(0x99, RESERVED_OPC) \
- RSV(0x9A, RESERVED_OPC) \
- RSV(0x9B, RESERVED_OPC) \
- RSV(0x9C, RESERVED_OPC) \
- RSV(0x9D, RESERVED_OPC) \
- RSV(0x9E, RESERVED_OPC) \
- RSV(0x9F, RESERVED_OPC) \
- RSV(0xA0, RESERVED_OPC) \
- RSV(0xA1, RESERVED_OPC) \
- RSV(0xA2, RESERVED_OPC) \
- RSV(0xA3, RESERVED_OPC) \
- RSV(0xA4, RESERVED_OPC) \
- RSV(0xA5, RESERVED_OPC) \
- RSV(0xA6, RESERVED_OPC) \
- RSV(0xA7, RESERVED_OPC) \
- RSV(0xA8, RESERVED_OPC) \
- RSV(0xA9, RESERVED_OPC) \
- RSV(0xAA, RESERVED_OPC) \
- RSV(0xAB, RESERVED_OPC) \
- RSV(0xAC, RESERVED_OPC) \
- RSV(0xAD, RESERVED_OPC) \
- RSV(0xAE, RESERVED_OPC) \
- RSV(0xAF, RESERVED_OPC) \
- RSV(0xB0, RESERVED_OPC) \
- RSV(0xB1, RESERVED_OPC) \
- RSV(0xB2, RESERVED_OPC) \
- RSV(0xB3, RESERVED_OPC) \
- RSV(0xB4, RESERVED_OPC) \
- RSV(0xB5, RESERVED_OPC) \
- RSV(0xB6, RESERVED_OPC) \
- RSV(0xB7, RESERVED_OPC) \
- RSV(0xB8, RESERVED_OPC) \
- RSV(0xB9, RESERVED_OPC) \
- RSV(0xBA, RESERVED_OPC) \
- RSV(0xBB, RESERVED_OPC) \
- RSV(0xBC, RESERVED_OPC) \
- RSV(0xBD, RESERVED_OPC) \
- RSV(0xBE, RESERVED_OPC) \
- RSV(0xBF, RESERVED_OPC) \
- RSV(0xC0, RESERVED_OPC) \
- RSV(0xC1, RESERVED_OPC) \
- RSV(0xC2, RESERVED_OPC) \
- RSV(0xC3, RESERVED_OPC) \
- RSV(0xC4, RESERVED_OPC) \
- RSV(0xC5, RESERVED_OPC) \
- RSV(0xC6, RESERVED_OPC) \
- RSV(0xC7, RESERVED_OPC) \
- RSV(0xC8, RESERVED_OPC) \
- RSV(0xC9, RESERVED_OPC) \
- RSV(0xCA, RESERVED_OPC) \
- RSV(0xCB, RESERVED_OPC) \
- RSV(0xCC, RESERVED_OPC) \
- RSV(0xCD, RESERVED_OPC) \
- RSV(0xCE, RESERVED_OPC) \
- RSV(0xCF, RESERVED_OPC) \
- RSV(0xD0, RESERVED_OPC) \
- RSV(0xD1, RESERVED_OPC) \
- RSV(0xD2, RESERVED_OPC) \
- RSV(0xD3, RESERVED_OPC) \
- RSV(0xD4, RESERVED_OPC) \
- RSV(0xD5, RESERVED_OPC) \
- RSV(0xD6, RESERVED_OPC) \
- RSV(0xD7, RESERVED_OPC) \
- RSV(0xD8, RESERVED_OPC) \
- RSV(0xD9, RESERVED_OPC) \
- RSV(0xDA, RESERVED_OPC) \
- RSV(0xDB, RESERVED_OPC) \
- RSV(0xDC, RESERVED_OPC) \
- RSV(0xDD, RESERVED_OPC) \
- RSV(0xDE, RESERVED_OPC) \
- RSV(0xDF, RESERVED_OPC) \
- RSV(0xE0, RESERVED_OPC) \
- RSV(0xE1, RESERVED_OPC) \
- RSV(0xE2, RESERVED_OPC) \
- RSV(0xE3, RESERVED_OPC) \
- RSV(0xE4, RESERVED_OPC) \
- RSV(0xE5, RESERVED_OPC) \
- RSV(0xE6, RESERVED_OPC) \
- RSV(0xE7, RESERVED_OPC) \
- RSV(0xE8, RESERVED_OPC) \
- RSV(0xE9, RESERVED_OPC) \
- RSV(0xEA, RESERVED_OPC) \
- RSV(0xEB, RESERVED_OPC) \
- RSV(0xEC, RESERVED_OPC) \
- RSV(0xED, RESERVED_OPC) \
- RSV(0xEE, RESERVED_OPC) \
- RSV(0xEF, RESERVED_OPC) \
- RSV(0xF0, RESERVED_OPC) \
- RSV(0xF1, RESERVED_OPC) \
- RSV(0xF2, RESERVED_OPC) \
- RSV(0xF3, RESERVED_OPC) \
- RSV(0xF4, RESERVED_OPC) \
- RSV(0xF5, RESERVED_OPC) \
- RSV(0xF6, RESERVED_OPC) \
- RSV(0xF7, RESERVED_OPC) \
- RSV(0xF8, RESERVED_OPC) \
- RSV(0xF9, RESERVED_OPC) \
- RSV(0xFA, RESERVED_OPC) \
- RSV(0xFB, RESERVED_OPC) \
- RSV(0xFC, RESERVED_OPC) \
- \
- OPC(0xFD, kTrace, "trace", FLAG(kDefault), "s", FF) \
- OPC(0xFE, kCondBreak, "cond_break", FLAG(kDefault), "s", FF) \
- OPC(0xFF, kBreak, "break", FLAG(kDefault), "", FF)
-
-#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal,
-enum class SequencerOpcode : uint8_t {
- IREE_SEQUENCER_OPCODE_LIST(DECLARE_ENUM, DECLARE_ENUM)
-};
-#undef DECLARE_ENUM
-
-} // namespace iree
-
-#endif // IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_
diff --git a/iree/schemas/debug_service.fbs b/iree/schemas/debug_service.fbs
deleted file mode 100644
index 36e84b5..0000000
--- a/iree/schemas/debug_service.fbs
+++ /dev/null
@@ -1,347 +0,0 @@
-include "iree/schemas/function_def.fbs";
-include "iree/schemas/module_def.fbs";
-
-namespace iree.rt.debug.rpc;
-
-table Status {
- code:int;
- message:string;
-}
-
-table CreateSessionRequest {
-}
-table CreateSessionResponse {
- session_id:int;
-}
-
-table MakeReadyRequest {
- session_id:int;
-}
-table MakeReadyResponse {
-}
-
-table GetStatusRequest {
- session_id:int;
- // TODO(benvanik): caps debugger supports? version expected?
-}
-table GetStatusResponse {
- protocol:int;
- // TODO(benvanik): run state.
- // TODO(benvanik): profiling state.
-}
-
-table NativeFunctionDef {
- name:string;
- // TODO(benvanik): more information about the fns (stack trace of registrant?)
-}
-
-table ContextDef {
- context_id:int;
- native_functions:[NativeFunctionDef];
- module_names:[string];
-}
-
-table ListContextsRequest {
- session_id:int;
-}
-
-table ListContextsResponse {
- contexts:[ContextDef];
-}
-
-table GetModuleRequest {
- session_id:int;
- context_id:int;
- module_name:string;
-}
-table GetModuleResponse {
- module:ModuleDef;
-}
-
-table GetFunctionRequest {
- session_id:int;
- context_id:int;
- module_name:string;
- function_ordinal:int;
-}
-table GetFunctionResponse {
- bytecode:BytecodeDef;
- // TODO(benvanik): import info (linked module, etc).
-}
-
-table ResolveFunctionRequest {
- session_id:int;
- module_name:string;
- function_name:string;
-}
-table ResolveFunctionResponse {
- context_ids:[int];
- function_ordinal:int;
-}
-
-table BufferViewDef {
- is_valid:bool;
- shape:[int];
- element_size:int;
- // TODO(benvanik): buffer attrs (type, access, usage).
- // TODO(benvanik): buffer size/allocated_size.
- // TODO(benvanik): buffer data (if accessible).
-}
-
-table StackFrameDef {
- module_name:string;
- function_ordinal:int;
- offset:int;
- locals:[BufferViewDef];
-}
-
-table InvocationDef {
- invocation_id:int;
- frames:[StackFrameDef];
-}
-
-table ListInvocationsRequest {
- session_id:int;
-}
-table ListInvocationsResponse {
- invocations:[InvocationDef];
-}
-
-table SuspendInvocationsRequest {
- session_id:int;
- invocation_ids:[int];
-}
-table SuspendInvocationsResponse {
- invocations:[InvocationDef];
-}
-
-table ResumeInvocationsRequest {
- session_id:int;
- invocation_ids:[int];
-}
-table ResumeInvocationsResponse {
-}
-
-enum StepMode : uint8 {
- STEP_ONCE = 0,
- STEP_TO_OFFSET = 1,
-}
-
-table StepInvocationRequest {
- session_id:int;
- step_id:int;
- invocation_id:int;
- step_mode:StepMode;
- bytecode_offset:int;
-}
-table StepInvocationResponse {}
-
-table GetInvocationLocalRequest {
- session_id:int;
- invocation_id:int;
- frame_index:int;
- local_index:int;
-}
-table GetInvocationLocalResponse {
- value:BufferViewDef;
-}
-
-table SetInvocationLocalRequest {
- session_id:int;
- invocation_id:int;
- frame_index:int;
- local_index:int;
- value:BufferViewDef;
-}
-table SetInvocationLocalResponse {
- value:BufferViewDef;
-}
-
-enum BreakpointType : uint8 {
- BYTECODE_FUNCTION = 0,
- NATIVE_FUNCTION = 1,
-}
-
-table BreakpointDef {
- breakpoint_id:int;
- breakpoint_type:BreakpointType;
-
- module_name:string;
- function_name:string;
- function_ordinal:int;
- bytecode_offset:int;
-}
-
-table ListBreakpointsRequest {
- session_id:int;
-}
-table ListBreakpointsResponse {
- breakpoints:[BreakpointDef];
-}
-
-table AddBreakpointRequest {
- session_id:int;
- breakpoint:BreakpointDef;
-}
-table AddBreakpointResponse {
- breakpoint:BreakpointDef;
-}
-
-table RemoveBreakpointRequest {
- session_id:int;
- breakpoint_id:int;
-}
-table RemoveBreakpointResponse {
-}
-
-table StartProfilingRequest {
- session_id:int;
- context_id:int;
- // TODO(benvanik): profiling mode.
- // mode: sampling_timing, instrumented_coverage, instrumented_log,
- // invoke_log
-}
-table StartProfilingResponse {
- // TODO(benvanik): current/new mode.
-}
-
-table StopProfilingRequest {
- session_id:int;
- context_id:int;
-}
-table StopProfilingResponse {
- // TODO(benvanik): profiling data.
-}
-
-// TODO(benvanik): streaming profiling data query.
-
-table ServiceShutdownEvent {
-}
-
-table ContextRegisteredEvent {
- context_id:int;
-}
-table ContextUnregisteredEvent {
- context_id:int;
-}
-
-table ModuleLoadedEvent {
- context_id:int;
- module_name:string;
-}
-
-table InvocationRegisteredEvent {
- invocation_id:int;
-}
-table InvocationUnregisteredEvent {
- invocation_id:int;
-}
-
-table BreakpointResolvedEvent {
- breakpoint:BreakpointDef;
- context_id:int;
-}
-
-table BreakpointHitEvent {
- breakpoint_id:int;
- invocation:InvocationDef;
-}
-
-table StepCompletedEvent {
- step_id:int;
- invocations:[InvocationDef];
-}
-
-union RequestUnion {
- CreateSessionRequest,
- MakeReadyRequest,
- GetStatusRequest,
- ListContextsRequest,
- GetModuleRequest,
- GetFunctionRequest,
- ResolveFunctionRequest,
- ListInvocationsRequest,
- SuspendInvocationsRequest,
- ResumeInvocationsRequest,
- StepInvocationRequest,
- GetInvocationLocalRequest,
- SetInvocationLocalRequest,
- ListBreakpointsRequest,
- AddBreakpointRequest,
- RemoveBreakpointRequest,
- StartProfilingRequest,
- StopProfilingRequest,
-}
-
-union ResponseUnion {
- CreateSessionResponse,
- MakeReadyResponse,
- GetStatusResponse,
- ListContextsResponse,
- GetModuleResponse,
- GetFunctionResponse,
- ResolveFunctionResponse,
- ListInvocationsResponse,
- SuspendInvocationsResponse,
- ResumeInvocationsResponse,
- StepInvocationResponse,
- GetInvocationLocalResponse,
- SetInvocationLocalResponse,
- ListBreakpointsResponse,
- AddBreakpointResponse,
- RemoveBreakpointResponse,
- StartProfilingResponse,
- StopProfilingResponse,
-}
-
-union EventUnion {
- ServiceShutdownEvent,
- ContextRegisteredEvent,
- ContextUnregisteredEvent,
- ModuleLoadedEvent,
- InvocationRegisteredEvent,
- InvocationUnregisteredEvent,
- BreakpointResolvedEvent,
- BreakpointHitEvent,
- StepCompletedEvent,
-}
-
-table Request {
- message:RequestUnion;
-}
-
-table Response {
- status:Status;
- message:ResponseUnion;
-}
-
-table ServicePacket {
- response:Response;
- event:EventUnion;
-}
-
-// NOTE: we aren't using this yet as the FlatBuffers gRPC code is... suspect.
-rpc_service DebugServiceRpc {
- MakeReady(MakeReadyRequest):MakeReadyResponse;
-
- GetStatus(GetStatusRequest):GetStatusResponse;
-
- ListContexts(ListContextsRequest):ListContextsResponse;
- GetModule(GetModuleRequest):GetModuleResponse;
- GetFunction(GetFunctionRequest):GetFunctionResponse;
- ResolveFunction(ResolveFunctionRequest):ResolveFunctionResponse;
-
- ListInvocations(ListInvocationsRequest):ListInvocationsResponse;
- SuspendInvocations(SuspendInvocationsRequest):SuspendInvocationsResponse;
- ResumeInvocations(ResumeInvocationsRequest):ResumeInvocationsResponse;
- StepInvocation(StepInvocationRequest):StepInvocationResponse;
- GetInvocationLocal(GetInvocationLocalRequest):GetInvocationLocalResponse;
- SetInvocationLocal(SetInvocationLocalRequest):SetInvocationLocalResponse;
-
- ListBreakpoints(ListBreakpointsRequest):ListBreakpointsResponse;
- AddBreakpoint(AddBreakpointRequest):AddBreakpointResponse;
- RemoveBreakpoint(RemoveBreakpointRequest):RemoveBreakpointResponse;
-
- StartProfiling(StartProfilingRequest):StartProfilingResponse;
- StopProfiling(StopProfilingRequest):StopProfilingResponse;
-}
diff --git a/iree/schemas/device_table_def.fbs b/iree/schemas/device_table_def.fbs
deleted file mode 100644
index 9b00e29..0000000
--- a/iree/schemas/device_table_def.fbs
+++ /dev/null
@@ -1,13 +0,0 @@
-include "iree/schemas/device_def.fbs";
-include "iree/schemas/device_group_def.fbs";
-
-namespace iree;
-
-// A table of devices used for runtime device resolution and referencing.
-table DeviceTableDef {
- // One or more virtual devices referenced by ordinal in the sequencer ops.
- devices:[DeviceDef];
-
- // Zero or more device groups that specify which devices must be compatible.
- device_groups:[DeviceGroupDef];
-}
diff --git a/iree/schemas/executable_table_def.fbs b/iree/schemas/executable_table_def.fbs
deleted file mode 100644
index 25174ad..0000000
--- a/iree/schemas/executable_table_def.fbs
+++ /dev/null
@@ -1,28 +0,0 @@
-include "iree/schemas/executable_def.fbs";
-
-namespace iree;
-
-// A fat executable containing multiple format variants for the same logical
-// entry points.
-table MultiArchExecutableDef {
- // Friendly name of the executable used for diagnostics.
- name:string;
-
- // Number of available entry points.
- // This is used for bytecode verification even when the executable is not
- // fully loaded into a device. All executables must have the same entry
- // points.
- entry_point_count:uint;
-
- // A set of executables of various formats and supported feature sets.
- // The runtime will select the appropriate executable based on the dispatch
- // requirements.
- executables:[ExecutableDef];
-}
-
-// A table of executables used for runtime dispatch lookup.
-table ExecutableTableDef {
- // One or more top level executables referenced by sequencer dispatch ops.
- // Ordinal is referenced by dispatch ops to index into the table.
- multi_arch_executables:[MultiArchExecutableDef];
-}
diff --git a/iree/schemas/function_def.fbs b/iree/schemas/function_def.fbs
deleted file mode 100644
index 9547604..0000000
--- a/iree/schemas/function_def.fbs
+++ /dev/null
@@ -1,18 +0,0 @@
-include "iree/schemas/bytecode_def.fbs";
-include "iree/schemas/type_def.fbs";
-
-namespace iree;
-
-table FunctionAttributeDef {
- key:string;
- value:string;
-}
-
-table FunctionDef {
- name:string;
- type:FunctionTypeDef;
-
- attrs:[FunctionAttributeDef];
-
- bytecode:BytecodeDef;
-}
diff --git a/iree/schemas/function_table_def.fbs b/iree/schemas/function_table_def.fbs
deleted file mode 100644
index 2065df7..0000000
--- a/iree/schemas/function_table_def.fbs
+++ /dev/null
@@ -1,9 +0,0 @@
-include "iree/schemas/function_def.fbs";
-
-namespace iree;
-
-table FunctionTableDef {
- functions:[FunctionDef];
- imports:[int];
- exports:[int];
-}
diff --git a/iree/schemas/module_def.fbs b/iree/schemas/module_def.fbs
deleted file mode 100644
index 70bcbe9..0000000
--- a/iree/schemas/module_def.fbs
+++ /dev/null
@@ -1,20 +0,0 @@
-include "iree/schemas/executable_table_def.fbs";
-include "iree/schemas/device_table_def.fbs";
-include "iree/schemas/function_table_def.fbs";
-include "iree/schemas/source_map_def.fbs";
-
-namespace iree;
-
-// 'Executable MODule'.
-file_identifier "EMOD";
-file_extension "emod";
-
-table ModuleDef {
- name:string;
- device_table:DeviceTableDef;
- function_table:FunctionTableDef;
- executable_table:ExecutableTableDef;
- source_map:SourceMapDef;
-}
-
-root_type ModuleDef;
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
deleted file mode 100644
index cca9c86..0000000
--- a/iree/tools/BUILD
+++ /dev/null
@@ -1,124 +0,0 @@
-# Misc tools used to optimize, translate, and evaluate IREE.
-# Most of these are not designed to run on-device.
-
-load("//iree:build_defs.bzl", "PLATFORM_VULKAN_DEPS")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-exports_files([
- "run_lit.sh",
- "sanitizer_suppressions.txt",
-])
-
-cc_binary(
- name = "iree-opt",
- deps = [
- "//iree/compiler/Transforms",
- "//iree/compiler/Transforms/Interpreter",
- "//iree/compiler/Transforms/Sequencer",
- "//iree/compiler/Translation/SPIRV",
- "@llvm//:support",
- "@local_config_mlir//:AffineDialectRegistration",
- "@local_config_mlir//:MlirOptLib",
- "@local_config_mlir//:MlirOptMain",
- "@local_config_mlir//:StandardDialectRegistration",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration",
- ],
-)
-
-cc_binary(
- name = "iree-run-mlir",
- srcs = ["run_mlir_main.cc"],
- deps = [
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- "//iree/base:source_location",
- "//iree/rt",
- "//iree/vm:sequencer_module",
- "@llvm//:support",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:Parser",
- "@local_config_mlir//:Support",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/compiler/Translation/Sequencer",
- "//iree/compiler/Translation/Interpreter",
- "//iree/compiler/Translation/SPIRV",
- "//iree/hal:buffer_view_string_util",
- "//iree/hal:driver_registry",
- "//iree/schemas",
- "//iree/rt/debug:debug_server_flags",
- ] + PLATFORM_VULKAN_DEPS + [
- "//iree/hal/interpreter:interpreter_driver_module",
- # TODO(b/142004903): enable when Dawn HAL implementation is functional
- # "//iree/hal/dawn:dawn_driver_module",
- "//iree/hal/vulkan:vulkan_driver_module",
- ],
-)
-
-cc_binary(
- name = "iree-translate",
- srcs = ["iree_translate_main.cc"],
- deps = [
- "//iree/compiler/Translation/Interpreter",
- "//iree/compiler/Translation/SPIRV",
- "//iree/compiler/Translation/Sequencer",
- "@llvm//:support",
- "@local_config_mlir//:AffineDialectRegistration",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:Pass",
- "@local_config_mlir//:StandardDialectRegistration",
- "@local_config_mlir//:Support",
- "@local_config_mlir//:TranslateClParser",
- "@local_config_mlir//:Translation",
- "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration",
- ],
-)
-
-cc_binary(
- name = "run_module",
- srcs = ["run_module_main.cc"],
- deps = [
- "//iree/base:file_io",
- "//iree/base:file_path",
- "//iree/base:init",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/hal:buffer_view_string_util",
- "//iree/hal:driver_registry",
- "//iree/hal/interpreter:interpreter_driver_module",
- "//iree/rt",
- "//iree/rt/debug:debug_server_flags",
- "//iree/schemas",
- "//iree/vm:sequencer_module",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_binary(
- name = "benchmark_module",
- testonly = 1,
- srcs = ["benchmark_module.cc"],
- deps = [
- "//iree/base:file_io",
- "//iree/base:file_path",
- "//iree/base:init",
- "//iree/base:source_location",
- "//iree/base:status",
- "//iree/hal:buffer_view_string_util",
- "//iree/hal:driver_registry",
- "//iree/hal/interpreter:interpreter_driver_module",
- "//iree/rt",
- "//iree/rt/debug:debug_server_flags",
- "//iree/schemas",
- "//iree/vm:sequencer_module",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- "@com_google_benchmark//:benchmark",
- ],
-)
diff --git a/iree/tools/benchmark_module.cc b/iree/tools/benchmark_module.cc
deleted file mode 100644
index 1b62a52..0000000
--- a/iree/tools/benchmark_module.cc
+++ /dev/null
@@ -1,157 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <iostream>
-#include <vector>
-
-#include "absl/flags/flag.h"
-#include "absl/strings/numbers.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/str_split.h"
-#include "absl/strings/string_view.h"
-#include "benchmark/benchmark.h"
-#include "iree/base/file_io.h"
-#include "iree/base/file_path.h"
-#include "iree/base/init.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view_string_util.h"
-#include "iree/hal/driver_registry.h"
-#include "iree/rt/context.h"
-#include "iree/rt/debug/debug_server_flags.h"
-#include "iree/rt/instance.h"
-#include "iree/rt/module_printer.h"
-#include "iree/schemas/module_def_generated.h"
-#include "iree/vm/sequencer_module.h"
-
-ABSL_FLAG(std::string, main_module, "", "Main module with entry point.");
-ABSL_FLAG(std::string, main_function, "",
- "Function within the main module to execute.");
-
-ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values.");
-ABSL_FLAG(std::string, input_file, "",
- "Input shapes and optional values serialized in a file.");
-
-namespace iree {
-namespace {
-
-// Parses a list of input shapes and values from a string of newline-separated
-// inputs. Expects the contents to have one value per line with each value
-// listed as
-// [shape]xtype=[value]
-// Example:
-// 4x4xi8=0,1,2,3
-StatusOr<std::vector<hal::BufferView>> ParseInputsFromFlags(
- hal::Allocator* allocator) {
- std::string file_contents;
- if (!absl::GetFlag(FLAGS_input_values).empty()) {
- file_contents =
- absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}});
- } else if (!absl::GetFlag(FLAGS_input_file).empty()) {
- ASSIGN_OR_RETURN(file_contents,
- file_io::GetFileContents(absl::GetFlag(FLAGS_input_file)));
- }
- std::vector<hal::BufferView> inputs;
- for (const auto& line :
- absl::StrSplit(file_contents, '\n', absl::SkipWhitespace())) {
- ASSIGN_OR_RETURN(auto input,
- hal::ParseBufferViewFromString(line, allocator));
- inputs.push_back(input);
- }
- return inputs;
-}
-
-Status Run(benchmark::State& state) {
- ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags());
- auto instance = make_ref<rt::Instance>(std::move(debug_server));
- ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
- "interpreter"));
- ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
- RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
- auto policy = make_ref<rt::Policy>();
- auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy));
-
- // Load main module.
- ASSIGN_OR_RETURN(
- auto main_module_file,
- vm::ModuleFile::LoadFile(ModuleDefIdentifier(),
- absl::GetFlag(FLAGS_main_module)),
- _ << "while loading module file " << absl::GetFlag(FLAGS_main_module));
- ASSIGN_OR_RETURN(auto main_module,
- vm::SequencerModule::FromFile(std::move(main_module_file)));
-
- // Register the main module with the context.
- // We could add additional modules (specializations, shared libraries, etc).
- // ModuleFiles are stateless so we could have the same module_file used by
- // multiple contexts simultaneously.
- RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module)));
-
- rt::Function main_function;
- if (!absl::GetFlag(FLAGS_main_function).empty()) {
- // User-specified main function.
- ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByName(
- rt::Function::Linkage::kExport,
- absl::GetFlag(FLAGS_main_function)));
- } else {
- // No main function specified; to prevent non-deterministic behavior we
- // require one unless there's exactly one exported function in the module.
- if (main_module->signature().export_function_count() == 1) {
- ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByOrdinal(
- rt::Function::Linkage::kExport, 0));
- } else {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "--main_function= must be specified to disambiguate the "
- "function to run";
- }
- }
-
- // Call into the main function.
- ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(device->allocator()));
-
- for (auto _ : state) {
- ASSIGN_OR_RETURN(auto invocation,
- rt::Invocation::Create(add_ref(context), main_function,
- make_ref<rt::Policy>(), {},
- absl::MakeConstSpan(arguments)));
- RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture()));
- }
-
- return OkStatus();
-}
-
-void BM_RunModule(benchmark::State& state) {
- // Delegate to a status-returning function so we can use the status macros.
- CHECK_OK(Run(state));
-}
-
-// By default only the main thread is included in CPU time. Include all the
-// threads instead. To make single and multi-threaded benchmarks more
-// comparable, use the wall time to determine how many iterations to run.
-// See https://github.com/google/benchmark#cpu-timers,
-BENCHMARK(BM_RunModule)->MeasureProcessCPUTime()->UseRealTime();
-
-} // namespace
-
-extern "C" int main(int argc, char** argv) {
- // The benchmark library uses a different mechanism for its flags. This
- // consumes any arguments it understands from argv. It must come before
- // InitializeEnvironment to avoid failures on unknown flags.
- ::benchmark::Initialize(&argc, argv);
- InitializeEnvironment(&argc, &argv);
- size_t run_benchmark_count = ::benchmark::RunSpecifiedBenchmarks();
- CHECK_GT(run_benchmark_count, 0) << "No benchmarks were run";
- return 0;
-}
-
-} // namespace iree
diff --git a/iree/tools/compilation.bzl b/iree/tools/compilation.bzl
deleted file mode 100644
index b232099..0000000
--- a/iree/tools/compilation.bzl
+++ /dev/null
@@ -1,43 +0,0 @@
-"""Rules for compiling IREE executables, modules, and archives."""
-
-load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
-
-# TODO(benvanik): port to a full starlark rule, document, etc.
-def iree_bytecode_module(
- name,
- srcs,
- cc_namespace = None,
- visibility = None):
- native.genrule(
- name = name,
- srcs = srcs,
- outs = [
- "%s.emod" % (name),
- ],
- cmd = " && ".join([
- " ".join([
- "$(location //iree/tools:iree-translate)",
- "-mlir-to-iree-module",
- "-o $(location %s.emod)" % (name),
- ] + ["$(locations %s)" % (src) for src in srcs]),
- ]),
- tools = [
- "//iree/tools:iree-translate",
- ],
- message = "Compiling IREE module %s..." % (name),
- output_to_bindir = 1,
- )
-
- # Embed the module for use in C++. This avoids the need for file IO in
- # tests and samples that would otherwise complicate execution/porting.
- if cc_namespace:
- cc_embed_data(
- name = "%s_cc" % (name),
- identifier = name,
- srcs = ["%s.emod" % (name)],
- cc_file_output = "%s.cc" % (name),
- h_file_output = "%s.h" % (name),
- cpp_namespace = cc_namespace,
- visibility = visibility,
- flatten = True,
- )
diff --git a/iree/tools/debugger/BUILD b/iree/tools/debugger/BUILD
deleted file mode 100644
index e3b90a4..0000000
--- a/iree/tools/debugger/BUILD
+++ /dev/null
@@ -1,173 +0,0 @@
-# IREE Debugger UIs.
-#
-# The main debugger UI can be used in standalone mode connected to a remote
-# host (via :debugger) or can be directly embedded into the IREE runtime to
-# allow for attaching (--iree_attach_debugger).
-#
-# By default the IREE runtime does not compile in debug support. To link it in
-# pass --define=IREE_DEBUG=1 to bazel builds of the runtime.
-
-# TODO(benvanik): re-enable debugger after refactoring.
-# load("//third_party/emscripten:split_transition_defs.bzl", "auto_wasm_binary")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-# TODO(benvanik): re-enable debugger after refactoring.
-# alias(
-# name = "debugger",
-# actual = select({
-# "//tools/cc_target_os:emscripten": ":debug_app_emscripten_files",
-# "//conditions:default": ":debug_app_native",
-# }),
-# )
-#
-# cc_library(
-# name = "debug_app_library",
-# srcs = ["debug_app.cc"],
-# hdrs = ["debug_app.h"],
-# deps = [
-# "//third_party/GL:GLES2_headers",
-# "//third_party/SDL2",
-# "@com_google_absl//absl/flags:flag",
-# "@com_google_absl//absl/memory",
-# "@com_google_absl//absl/strings",
-# "@com_google_absl//absl/types:optional",
-# "//third_party/dear_imgui",
-# "//third_party/dear_imgui:imgui_sdl_opengl3",
-# "//iree/base:memory",
-# "//iree/base:source_location",
-# "//iree/base:status",
-# "//iree/rt/debug:debug_client",
-# "//iree/schemas",
-# ],
-# )
-#
-# # NOTE: users must also link in a GL implementation, like:
-# # "//third_party/GL/native:GLESv2", # build-cleaner: keep
-# cc_library(
-# name = "debug_app_embedded",
-# srcs = ["debug_app_embedded.cc"],
-# hdrs = ["debug_app_embedded.h"],
-# deps = [
-# ":debug_app_library",
-# "//third_party/SDL2",
-# "@com_google_absl//absl/base:core_headers",
-# "@com_google_absl//absl/memory",
-# "@com_google_absl//absl/strings",
-# "@com_google_absl//absl/synchronization",
-# "//third_party/dear_imgui",
-# "//iree/base:memory",
-# "//iree/base:status",
-# ],
-# )
-#
-# EMSCRIPTEN_LINKOPTS_COMMON = [
-# # Error at compile time on unresolved symbols.
-# "-s ERROR_ON_UNDEFINED_SYMBOLS=1",
-#
-# # Required by SDL.
-# "-s EXTRA_EXPORTED_RUNTIME_METHODS=Pointer_stringify",
-#
-# # TODO(benvanik): tweak to enable support when needed.
-# "-s ALLOW_MEMORY_GROWTH=1",
-# # "-s WASM_MEM_MAX=268435456", # 256MB
-# # "-s TOTAL_MEMORY=268435456", # 256MB
-# ]
-#
-# EMSCRIPTEN_LINKOPTS_DBG = [
-# # Show WASM stack trace in Chrome debugger.
-# "-g2",
-# "-s DEMANGLE_SUPPORT=1",
-#
-# # Enable verbose assertions.
-# "-s ASSERTIONS=2",
-# "-s SAFE_HEAP=1",
-# "-s STACK_OVERFLOW_CHECK=2",
-# ]
-#
-# EMSCRIPTEN_LINKOPTS_OPT = []
-#
-# cc_binary(
-# name = "debug_app_emscripten",
-# srcs = ["debug_app_main_emscripten.cc"],
-# linkopts = EMSCRIPTEN_LINKOPTS_COMMON + select({
-# "//tools/compilation_mode:dbg": EMSCRIPTEN_LINKOPTS_DBG,
-# "//tools/compilation_mode:opt": EMSCRIPTEN_LINKOPTS_OPT,
-# "//conditions:default": EMSCRIPTEN_LINKOPTS_OPT,
-# }),
-# tags = [
-# "manual",
-# "notap", # TODO(b/137088911): Build/test on TAP
-# "wasm",
-# ],
-# deps = [
-# ":debug_app_library",
-# "//third_party/SDL2",
-# "@com_google_absl//absl/memory",
-# "//third_party/dear_imgui",
-# "//third_party/dear_imgui:imgui_sdl_opengl3",
-# "//iree/base:init",
-# "//iree/base:source_location",
-# "//iree/base:status",
-# ],
-# )
-#
-# auto_wasm_binary(
-# name = "debug_app_emscripten_binary",
-# cc_target = ":debug_app_emscripten",
-# tags = ["manual"],
-# )
-#
-# Fileset(
-# name = "debug_app_emscripten_files",
-# out = "wasm_files",
-# entries = [
-# FilesetEntry(
-# files = [":debug_app_emscripten_binary"],
-# strip_prefix = "debug_app_emscripten_binary",
-# destdir = "wasm",
-# ),
-# FilesetEntry(
-# files = ["debug_app.html"],
-# destdir = "wasm",
-# ),
-# ],
-# tags = ["manual"],
-# )
-#
-# cc_binary(
-# name = "debug_app_native",
-# srcs = ["debug_app_main_native.cc"],
-# deps = [
-# ":debug_app_embedded",
-# "//third_party/GL/native:EGL", # build-cleaner: keep
-# "//third_party/GL/native:GLESv2", # build-cleaner: keep
-# "//iree/base:init",
-# "//iree/base:status",
-# ],
-# )
-#
-# cc_binary(
-# name = "debug_cli",
-# srcs = ["debug_cli_main.cc"],
-# deps = [
-# ":debug_prompt",
-# "@com_google_absl//absl/flags:flag",
-# "//iree/base:init",
-# "//iree/base:status",
-# ],
-# )
-#
-# cc_library(
-# name = "debug_prompt",
-# srcs = ["debug_prompt.cc"],
-# hdrs = ["debug_prompt.h"],
-# deps = [
-# "@com_google_absl//absl/strings",
-# "//iree/base:status",
-# "//iree/rt/debug:debug_client",
-# ],
-# )
diff --git a/iree/tools/debugger/debug_app.cc b/iree/tools/debugger/debug_app.cc
deleted file mode 100644
index 3e63f53..0000000
--- a/iree/tools/debugger/debug_app.cc
+++ /dev/null
@@ -1,1422 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/tools/debugger/debug_app.h"
-
-#include <GLES2/gl2.h>
-
-#include <algorithm>
-#include <cstdio>
-
-#include "absl/flags/flag.h"
-#include "absl/memory/memory.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/str_split.h"
-#include "absl/types/optional.h"
-#include "iree/base/memory.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/rt/debug/debug_client.h"
-#include "iree/schemas/debug_service_generated.h"
-#include "iree/vm/bytecode_module.h"
-#include "iree/vm/bytecode_tables_sequencer.h"
-#include "third_party/dear_imgui/imgui.h"
-#include "third_party/dear_imgui/imgui_internal.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-namespace {
-
-void PushButtonHue(float hue) {
- ImGui::PushStyleColor(ImGuiCol_Button,
- (ImVec4)ImColor::HSV(hue / 7.0f, 0.6f, 0.6f));
- ImGui::PushStyleColor(ImGuiCol_ButtonHovered,
- (ImVec4)ImColor::HSV(hue / 7.0f, 0.7f, 0.7f));
- ImGui::PushStyleColor(ImGuiCol_ButtonActive,
- (ImVec4)ImColor::HSV(hue / 7.0f, 0.8f, 0.8f));
-}
-
-void PushButtonColor(const ImVec4& color) {
- ImGui::PushStyleColor(ImGuiCol_Button, color);
- ImGui::PushStyleColor(ImGuiCol_ButtonHovered, color);
- ImGui::PushStyleColor(ImGuiCol_ButtonActive, color);
-}
-
-void PopButtonStyle() { ImGui::PopStyleColor(3); }
-
-bool AreBreakpointsEqual(const RemoteBreakpoint& breakpoint,
- const DebugApp::UserBreakpoint& user_breakpoint) {
- if (user_breakpoint.active_breakpoint == &breakpoint) {
- return true;
- } else if (user_breakpoint.type != breakpoint.type()) {
- return false;
- }
- switch (breakpoint.type()) {
- case RemoteBreakpoint::Type::kBytecodeFunction:
- if (user_breakpoint.function_ordinal != -1 &&
- user_breakpoint.function_ordinal != breakpoint.function_ordinal()) {
- return false;
- }
- return breakpoint.module_name() == user_breakpoint.module_name &&
- breakpoint.function_name() == user_breakpoint.function_name &&
- breakpoint.bytecode_offset() == user_breakpoint.bytecode_offset;
- case RemoteBreakpoint::Type::kNativeFunction:
- return breakpoint.function_name() == user_breakpoint.native_function;
- default:
- return false;
- }
-}
-
-} // namespace
-
-// static
-void DebugApp::PumpMainLoopThunk(void* arg) {
- auto status = reinterpret_cast<DebugApp*>(arg)->PumpMainLoop();
- if (IsCancelled(status)) {
- return;
- } else if (!status.ok()) {
- CHECK_OK(status);
- }
-}
-
-DebugApp::DebugApp(SDL_Window* window, SDL_GLContext gl_context,
- const char* glsl_version)
- : window_(window), gl_context_(gl_context) {
- VLOG(1) << "DebugApp initializing...";
- IMGUI_CHECKVERSION();
- ImGui::CreateContext();
- ImGuiIO& io = ImGui::GetIO();
- io.ConfigFlags |= ImGuiConfigFlags_NavEnableKeyboard;
- io.ConfigFlags |= ImGuiConfigFlags_DockingEnable;
-
- // TODO(benvanik): ini file for settings.
- io.IniFilename = nullptr;
- // ImGui::LoadIniSettingsFromMemory()
- // ImGui::SaveIniSettingsToMemory()
-
- // TODO(benvanik): theming.
- ImGui::StyleColorsDark();
-
- // Setup Platform/Renderer bindings
- ImGui_ImplSDL2_InitForOpenGL(window_, gl_context_);
- ImGui_ImplOpenGL3_Init(glsl_version);
- SDL_GL_MakeCurrent(nullptr, nullptr);
- VLOG(1) << "DebugApp initialized";
-}
-
-DebugApp::~DebugApp() {
- VLOG(1) << "DebugApp shutting down...";
- ImGui_ImplOpenGL3_Shutdown();
- ImGui_ImplSDL2_Shutdown();
- ImGui::DestroyContext();
-
- SDL_GL_DeleteContext(gl_context_);
- SDL_GL_MakeCurrent(nullptr, nullptr);
- SDL_DestroyWindow(window_);
- SDL_Quit();
- VLOG(1) << "DebugApp shut down (SDL_Quit)";
-}
-
-Status DebugApp::Connect(absl::string_view service_address) {
- VLOG(1) << "Connecting to debug service at " << service_address << "...";
- ASSIGN_OR_RETURN(debug_client_, DebugClient::Connect(service_address, this));
-
- // TODO(benvanik): load breakpoints from file.
- UserBreakpoint user_breakpoint;
- user_breakpoint.module_name = "module";
- user_breakpoint.function_name = "main";
- user_breakpoint.bytecode_offset = 0;
- user_breakpoint.wants_enabled = true;
- user_breakpoint_list_.push_back(std::move(user_breakpoint));
- RETURN_IF_ERROR(RefreshActiveBreakpoints());
-
- // Set paused so that we need to resume to continue execution.
- is_paused_ = true;
- return OkStatus();
-}
-
-Status DebugApp::Disconnect() {
- VLOG(1) << "Disconnecting from debug service";
- debug_client_.reset();
- return OkStatus();
-}
-
-bool DebugApp::is_paused() const {
- if (!debug_client_) {
- return false;
- }
- if (!hit_breakpoints_.empty()) {
- return true; // One or more breakpoints hit.
- }
- return is_paused_ || !is_stepping_;
-}
-
-RemoteInvocation* DebugApp::GetSelectedInvocation() const {
- if (!debug_client_ || !selected_invocation_id_.has_value()) {
- return nullptr;
- }
- for (auto* invocation : debug_client_->invocations()) {
- if (invocation->id() == selected_invocation_id_.value()) {
- return invocation;
- }
- }
- return nullptr;
-}
-
-Status DebugApp::RefreshActiveBreakpoints() {
- // Set all breakpoints to disabled. We'll re-enable them as we find them
- // below.
- for (auto& user_breakpoint : user_breakpoint_list_) {
- user_breakpoint.active_breakpoint = nullptr;
- }
-
- // If not connected then no breakpoints are active.
- if (!debug_client_) {
- return OkStatus();
- }
-
- // Reconcile the user breakpoint list with the breakpoints available on the
- // server.
- for (auto* breakpoint : debug_client_->breakpoints()) {
- auto it =
- std::find_if(user_breakpoint_list_.begin(), user_breakpoint_list_.end(),
- [breakpoint](const UserBreakpoint& user_breakpoint) {
- return AreBreakpointsEqual(*breakpoint, user_breakpoint);
- });
- if (it == user_breakpoint_list_.end()) {
- // Breakpoint not found - add to user list.
- UserBreakpoint user_breakpoint;
- user_breakpoint.type = breakpoint->type();
- user_breakpoint.active_breakpoint = breakpoint;
- user_breakpoint.module_name = breakpoint->module_name();
- user_breakpoint.function_name = breakpoint->function_name();
- user_breakpoint.function_ordinal = breakpoint->function_ordinal();
- user_breakpoint.bytecode_offset = breakpoint->bytecode_offset();
- user_breakpoint_list_.push_back(std::move(user_breakpoint));
- } else {
- // Breakpoint found - set the active pointer.
- UserBreakpoint& user_breakpoint = *it;
- user_breakpoint.active_breakpoint = breakpoint;
- user_breakpoint.is_enabling = false;
- user_breakpoint.module_name = breakpoint->module_name();
- user_breakpoint.function_name = breakpoint->function_name();
- user_breakpoint.function_ordinal = breakpoint->function_ordinal();
- user_breakpoint.bytecode_offset = breakpoint->bytecode_offset();
- }
- }
-
- // Ensure any breakpoint the user wants enabled is active/otherwise.
- for (auto& user_breakpoint : user_breakpoint_list_) {
- if (user_breakpoint.wants_enabled && !user_breakpoint.is_enabling &&
- !user_breakpoint.active_breakpoint) {
- // Add breakpoint on server.
- switch (user_breakpoint.type) {
- case RemoteBreakpoint::Type::kBytecodeFunction:
- RETURN_IF_ERROR(debug_client_->AddFunctionBreakpoint(
- user_breakpoint.module_name, user_breakpoint.function_name,
- user_breakpoint.bytecode_offset,
- [&user_breakpoint](const RemoteBreakpoint& breakpoint) {
- user_breakpoint.function_ordinal =
- breakpoint.function_ordinal();
- }));
- break;
- case RemoteBreakpoint::Type::kNativeFunction:
- // TODO(benvanik): native breakpoint support.
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Native function breakpoints are TODO";
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented breakpoint type";
- }
- user_breakpoint.is_enabling = true;
- } else if (!user_breakpoint.wants_enabled &&
- user_breakpoint.active_breakpoint) {
- // Remove breakpoint from server.
- RETURN_IF_ERROR(
- debug_client_->RemoveBreakpoint(*user_breakpoint.active_breakpoint));
-
- user_breakpoint.active_breakpoint = nullptr;
- }
- }
-
- return OkStatus();
-}
-
-bool DebugApp::IsStoppedAtBreakpoint(
- const UserBreakpoint& user_breakpoint) const {
- return std::find(hit_breakpoints_.begin(), hit_breakpoints_.end(),
- user_breakpoint.active_breakpoint) != hit_breakpoints_.end();
-}
-
-int DebugApp::FindMatchingUserBreakpointIndex(absl::string_view module_name,
- int function_ordinal,
- int offset) {
- for (int i = 0; i < user_breakpoint_list_.size(); ++i) {
- auto& user_breakpoint = user_breakpoint_list_[i];
- if (user_breakpoint.module_name == module_name &&
- user_breakpoint.function_ordinal == function_ordinal &&
- user_breakpoint.bytecode_offset == offset) {
- return i;
- }
- }
- return -1;
-}
-
-int DebugApp::FindMatchingUserBreakpointIndex(absl::string_view module_name,
- absl::string_view function_name,
- int offset) {
- for (int i = 0; i < user_breakpoint_list_.size(); ++i) {
- auto& user_breakpoint = user_breakpoint_list_[i];
- if (user_breakpoint.module_name == module_name &&
- user_breakpoint.function_name == function_name &&
- user_breakpoint.bytecode_offset == offset) {
- return i;
- }
- }
- return -1;
-}
-
-Status DebugApp::ResumeFromBreakpoint(UserBreakpoint* user_breakpoint) {
- if (!user_breakpoint->active_breakpoint) {
- return FailedPreconditionErrorBuilder(IREE_LOC) << "Breakpoint not active";
- }
- VLOG(1) << "Resuming from breakpoint "
- << user_breakpoint->active_breakpoint->id() << "...";
- auto it = std::find(hit_breakpoints_.begin(), hit_breakpoints_.end(),
- user_breakpoint->active_breakpoint);
- if (it == hit_breakpoints_.end()) {
- return NotFoundErrorBuilder(IREE_LOC) << "Breakpoint not found";
- }
- hit_breakpoints_.erase(it);
- return debug_client_->MakeReady();
-}
-
-Status DebugApp::OnContextRegistered(const RemoteContext& context) {
- // Ack event.
- return debug_client_->MakeReady();
-}
-
-Status DebugApp::OnContextUnregistered(const RemoteContext& context) {
- // Close documents that may reference modules in the context.
- std::vector<CodeViewDocument*> closing_documents;
- for (auto& document : documents_) {
- auto* module = document->function->module();
- if (module->context_id() != context.id()) {
- // Document is not from this context so it's fine.
- continue;
- }
-
- // See if any other live context still has the module loaded. We can change
- // the document over to that.
- RemoteModule* replacement_module = nullptr;
- for (auto* context : debug_client_->contexts()) {
- for (auto* other_module : context->modules()) {
- if (other_module->name() == module->name()) {
- replacement_module = other_module;
- break;
- }
- }
- if (replacement_module) break;
- }
- if (replacement_module && replacement_module->is_loaded()) {
- // Replace document module reference.
- int function_ordinal = document->function->ordinal();
- auto functions = replacement_module->functions();
- if (function_ordinal < functions.size()) {
- document->function = functions[function_ordinal];
- } else {
- document->function = nullptr;
- }
- } else {
- document->function = nullptr;
- }
-
- if (!document->function) {
- // Close the document if we don't have a valid function for it.
- VLOG(1)
- << "Closing document " << document->title
- << " because the last context using the module is being unregistered";
- closing_documents.push_back(document.get());
- }
- }
- for (auto* document : closing_documents) {
- auto it = std::find_if(
- documents_.begin(), documents_.end(),
- [document](const std::unique_ptr<CodeViewDocument>& open_document) {
- return document == open_document.get();
- });
- documents_.erase(it);
- }
-
- // Ack event.
- return debug_client_->MakeReady();
-}
-
-Status DebugApp::OnModuleLoaded(const RemoteContext& context,
- const RemoteModule& module) {
- // Ack event.
- return debug_client_->MakeReady();
-}
-
-Status DebugApp::OnInvocationRegistered(const RemoteInvocation& invocation) {
- if (!selected_invocation_id_.has_value()) {
- selected_invocation_id_ = invocation.id();
- selected_stack_frame_index_ = {};
- }
-
- // Ack event.
- return debug_client_->MakeReady();
-}
-
-Status DebugApp::OnInvocationUnregistered(const RemoteInvocation& invocation) {
- if (selected_invocation_id_.has_value() &&
- selected_invocation_id_.value() == invocation.id()) {
- selected_invocation_id_ = {};
- selected_stack_frame_index_ = {};
- }
-
- // Ack event.
- return debug_client_->MakeReady();
-}
-
-Status DebugApp::OnBreakpointHit(const RemoteBreakpoint& breakpoint,
- const RemoteInvocation& invocation) {
- // Keep track of where we are stopped.
- hit_breakpoints_.push_back(&breakpoint);
- return NavigateToCodeView(invocation, -1, NavigationMode::kMatchDocument);
-}
-
-Status DebugApp::PumpMainLoop() {
- ImGuiIO& io = ImGui::GetIO();
-
- if (debug_client_) {
- RETURN_IF_ERROR(debug_client_->Poll());
- }
- RETURN_IF_ERROR(RefreshActiveBreakpoints());
-
- SDL_GL_MakeCurrent(window_, gl_context_);
-
- SDL_Event event;
- while (SDL_PollEvent(&event)) {
- ImGui_ImplSDL2_ProcessEvent(&event);
- if (event.type == SDL_QUIT) {
- return CancelledErrorBuilder(IREE_LOC) << "Quit hotkey";
- } else if (event.type == SDL_WINDOWEVENT &&
- event.window.event == SDL_WINDOWEVENT_CLOSE &&
- event.window.windowID == SDL_GetWindowID(window_)) {
- return CancelledErrorBuilder(IREE_LOC) << "Window closed";
- }
- }
- ImGui_ImplOpenGL3_NewFrame();
- ImGui_ImplSDL2_NewFrame(window_);
- ImGui::NewFrame();
-
- auto draw_status = DrawUI();
- if (!draw_status.ok()) {
- // TODO(benvanik): show on screen? Probably all messed up.
- LOG(ERROR) << draw_status;
- }
-
- // Blit the entire ImGui UI.
- ImGui::Render();
- SDL_GL_MakeCurrent(window_, gl_context_);
- glViewport(0, 0, (int)io.DisplaySize.x, (int)io.DisplaySize.y);
- glClearColor(0.45f, 0.55f, 0.60f, 1.0f);
- glClear(GL_COLOR_BUFFER_BIT);
- // Workaround for terrible bad SDL/graphics driver leaks.
- IREE_DISABLE_LEAK_CHECKS();
- ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData());
- IREE_ENABLE_LEAK_CHECKS();
-
- // Render additional viewport windows (desktop only).
- if (io.ConfigFlags & ImGuiConfigFlags_ViewportsEnable) {
- SDL_Window* backup_current_window = SDL_GL_GetCurrentWindow();
- SDL_GLContext backup_current_context = SDL_GL_GetCurrentContext();
- ImGui::UpdatePlatformWindows();
- ImGui::RenderPlatformWindowsDefault();
- SDL_GL_MakeCurrent(backup_current_window, backup_current_context);
- }
-
- SDL_GL_SwapWindow(window_);
- return OkStatus();
-}
-
-Status DebugApp::LayoutInitialDockSpace() {
- dockspace_id_ = ImGui::GetID("MainDockSpace");
- if (ImGui::DockBuilderGetNode(dockspace_id_)) {
- // Already configured.
- return OkStatus();
- }
- ImGui::DockBuilderAddNode(dockspace_id_, ImGuiDockNodeFlags_DockSpace);
-
- dock_content_id_ = dockspace_id_;
- dock_top_id_ = ImGui::DockBuilderSplitNode(dock_content_id_, ImGuiDir_Up,
- 0.05f, nullptr, &dock_content_id_);
- dock_left_id_ = ImGui::DockBuilderSplitNode(
- dock_content_id_, ImGuiDir_Left, 0.20f, nullptr, &dock_content_id_);
- dock_bottom_id_ = ImGui::DockBuilderSplitNode(
- dock_content_id_, ImGuiDir_Down, 0.20f, nullptr, &dock_content_id_);
- dock_right_id_ = ImGui::DockBuilderSplitNode(
- dock_content_id_, ImGuiDir_Right, 0.20f, nullptr, &dock_content_id_);
- dock_bottom_left_id_ = ImGui::DockBuilderSplitNode(
- dock_bottom_id_, ImGuiDir_Left, 0.50f, nullptr, &dock_bottom_right_id_);
-
- ImGui::DockBuilderDockWindow("Toolbar", dock_top_id_);
- auto* dock_top_node = ImGui::DockBuilderGetNode(dock_top_id_);
- dock_top_node->LocalFlags = ImGuiDockNodeFlags_NoSplit |
- ImGuiDockNodeFlags_NoResize |
- ImGuiDockNodeFlags_AutoHideTabBar;
-
- ImGui::DockBuilderDockWindow("Modules", dock_left_id_);
- ImGui::DockBuilderDockWindow("Locals", dock_bottom_left_id_);
- ImGui::DockBuilderDockWindow("Invocations", dock_bottom_right_id_);
- ImGui::DockBuilderDockWindow("Breakpoints", dock_bottom_right_id_);
-
- ImGui::DockBuilderFinish(dockspace_id_);
- return OkStatus();
-}
-
-Status DebugApp::DrawUI() {
- ImGuiWindowFlags window_flags =
- ImGuiWindowFlags_MenuBar | ImGuiWindowFlags_NoDocking;
- window_flags |= ImGuiWindowFlags_NoTitleBar | ImGuiWindowFlags_NoCollapse |
- ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoMove |
- ImGuiWindowFlags_NoNavFocus;
-
- ImGuiViewport* viewport = ImGui::GetMainViewport();
- ImGui::SetNextWindowPos(viewport->Pos);
- ImGui::SetNextWindowSize(viewport->Size);
- ImGui::SetNextWindowViewport(viewport->ID);
- ImGui::PushStyleVar(ImGuiStyleVar_WindowRounding, 0.0f);
- ImGui::PushStyleVar(ImGuiStyleVar_WindowBorderSize, 0.0f);
- ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(0.0f, 0.0f));
- ImGui::Begin("IREEDebugRoot", nullptr, window_flags);
- ImGui::PopStyleVar(3);
-
- RETURN_IF_ERROR(LayoutInitialDockSpace());
- ImGui::DockSpace(dockspace_id_, ImVec2(0.0f, 0.0f), ImGuiDockNodeFlags_None);
-
- RETURN_IF_ERROR(DrawMainMenu());
- RETURN_IF_ERROR(DrawToolbar());
-
- ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(2, 2));
- RETURN_IF_ERROR(DrawBreakpointListPanel());
- RETURN_IF_ERROR(DrawModuleListPanel());
- RETURN_IF_ERROR(DrawLocalListPanel());
- RETURN_IF_ERROR(DrawInvocationListPanel());
- ImGui::PopStyleVar();
-
- RETURN_IF_ERROR(DrawCodeViewPanels());
-
- ImGui::End();
- return OkStatus();
-}
-
-Status DebugApp::DrawMainMenu() {
- if (!ImGui::BeginMenuBar()) return OkStatus();
-
- // TODO(benvanik): main menu.
- if (ImGui::BeginMenu("File")) {
- ImGui::EndMenu();
- }
-
- ImGui::EndMenuBar();
- return OkStatus();
-}
-
-Status DebugApp::DrawToolbar() {
- // TODO(benvanik): figure out how to make this not grow.
- ImGui::Begin("Toolbar", nullptr,
- ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoTitleBar |
- ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoCollapse |
- ImGuiWindowFlags_NoScrollbar);
- ImGui::BeginGroup();
-
-#if !defined(IMGUI_DISABLE_DEMO_WINDOWS)
- static bool show_demo_window = false;
- if (ImGui::Button("Demo")) {
- show_demo_window = !show_demo_window;
- }
- if (show_demo_window) {
- ImGui::SetNextWindowDockID(dock_content_id_);
- ImGui::ShowDemoWindow(&show_demo_window);
- }
-#endif // !IMGUI_DISABLE_DEMO_WINDOWS
-
- ImGui::SameLine();
- if (!debug_client_) {
- if (ImGui::Button("Connect")) {
- // TODO(benvanik): connection dialog and/or autoconnect.
- }
- } else {
- if (ImGui::Button("Disconnect")) {
- debug_client_.reset();
- }
- }
-
- ImGui::SameLine();
- if (debug_client_) {
- ImGui::Text("<status>");
- } else {
- ImGui::TextDisabled("disconnected");
- }
-
- ImGui::SameLine();
- ImGui::Spacing();
- ImGui::SameLine();
- ImGui::Spacing();
-
- ImGui::SameLine();
- ImGui::BeginGroup();
- ImGui::Text("Invocation: ");
- ImGui::SameLine();
- ImGui::SetNextItemWidth(300);
- auto* selected_invocation = GetSelectedInvocation();
- const std::string& active_invocation_name =
- selected_invocation ? selected_invocation->name() : "";
- if (ImGui::BeginCombo("##active_invocation", active_invocation_name.c_str(),
- ImGuiComboFlags_PopupAlignLeft)) {
- if (debug_client_) {
- for (auto* invocation : debug_client_->invocations()) {
- ImGui::PushID(invocation->id());
- bool is_selected = invocation == selected_invocation;
- if (ImGui::Selectable(invocation->name().c_str(), is_selected)) {
- RETURN_IF_ERROR(NavigateToCodeView(*invocation, -1,
- NavigationMode::kMatchDocument));
- }
- if (is_selected) {
- ImGui::SetItemDefaultFocus();
- }
- ImGui::PopID();
- }
- }
- ImGui::EndCombo();
- }
- ImGui::EndGroup();
-
- ImGui::SameLine();
- ImGui::BeginGroup();
- static const float kPauseButtonHue = 0.0f;
- static const float kResumeButtonHue = 2.0f;
- static const float kStepButtonHue = 1.0f;
- if (debug_client_ && !is_paused()) {
- PushButtonHue(kPauseButtonHue);
- if (ImGui::Button("Pause")) {
- RETURN_IF_ERROR(debug_client_->SuspendAllInvocations());
- }
- PopButtonStyle();
- } else if (debug_client_ && is_paused()) {
- ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666);
- ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA);
- ImGui::ButtonEx("Pause", {}, ImGuiButtonFlags_Disabled);
- ImGui::PopStyleColor(2);
- }
- if (debug_client_ && is_paused()) {
- ImGui::SameLine();
- PushButtonHue(kResumeButtonHue);
- if (ImGui::Button("Resume")) {
- if (is_paused_) {
- is_paused_ = false;
- RETURN_IF_ERROR(debug_client_->MakeReady());
- }
- while (!hit_breakpoints_.empty()) {
- hit_breakpoints_.pop_back();
- RETURN_IF_ERROR(debug_client_->MakeReady());
- }
- }
- PopButtonStyle();
- } else {
- ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666);
- ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA);
- ImGui::SameLine();
- ImGui::ButtonEx("Resume", {}, ImGuiButtonFlags_Disabled);
- ImGui::PopStyleColor(2);
- }
-
- if (debug_client_ && is_paused() && selected_invocation) {
- ImGui::SameLine();
- PushButtonHue(kStepButtonHue);
- if (ImGui::Button("Step Into")) {
- RETURN_IF_ERROR(
- debug_client_->StepInvocation(*selected_invocation, [this]() {
- is_paused_ = true;
- is_stepping_ = false;
- }));
- is_stepping_ = true;
- }
- PopButtonStyle();
- ImGui::SameLine();
- if (ImGui::Button("Step Over")) {
- RETURN_IF_ERROR(
- debug_client_->StepInvocationOver(*selected_invocation, [this]() {
- is_paused_ = true;
- is_stepping_ = false;
- }));
- is_stepping_ = true;
- }
- ImGui::SameLine();
- if (ImGui::Button("Step Out")) {
- RETURN_IF_ERROR(
- debug_client_->StepInvocationOut(*selected_invocation, [this]() {
- is_paused_ = true;
- is_stepping_ = false;
- }));
- is_stepping_ = true;
- }
- if (ImGui::BeginPopup("Step to...")) {
- // TODO(benvanik): step to Invoke exit, next FFI call, etc
- ImGui::MenuItem("(stuff)");
- ImGui::EndPopup();
- }
- ImGui::SameLine();
- if (ImGui::Button("Step to...")) {
- ImGui::OpenPopup("Step to...");
- }
- } else {
- ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666);
- ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA);
- ImGui::SameLine();
- ImGui::ButtonEx("Step Into", {}, ImGuiButtonFlags_Disabled);
- ImGui::SameLine();
- ImGui::ButtonEx("Step Over", {}, ImGuiButtonFlags_Disabled);
- ImGui::SameLine();
- ImGui::ButtonEx("Step Out", {}, ImGuiButtonFlags_Disabled);
- ImGui::SameLine();
- ImGui::ButtonEx("Step to...", {}, ImGuiButtonFlags_Disabled);
- ImGui::PopStyleColor(2);
- }
- ImGui::EndGroup();
-
- ImGui::EndGroup();
- ImGui::End();
- return OkStatus();
-}
-
-Status DebugApp::DrawBreakpointListPanel() {
- static bool is_panel_visible = true;
- if (!ImGui::Begin("Breakpoints", &is_panel_visible, ImGuiWindowFlags_None)) {
- ImGui::End();
- return OkStatus();
- }
-
- ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(8, 8));
- absl::optional<RemoteBreakpoint::Type> add_breakpoint_type;
- if (ImGui::BeginPopup("+ Function")) {
- if (ImGui::MenuItem("Bytecode Function")) {
- add_breakpoint_type = RemoteBreakpoint::Type::kBytecodeFunction;
- }
- if (ImGui::MenuItem("Native Function")) {
- add_breakpoint_type = RemoteBreakpoint::Type::kNativeFunction;
- }
- ImGui::EndPopup();
- }
- ImGui::PopStyleVar();
- if (ImGui::Button("+ Function")) {
- ImGui::OpenPopup("+ Function");
- }
- RETURN_IF_ERROR(DrawAddBreakpointDialogs(add_breakpoint_type));
-
- ImGui::SameLine();
- if (ImGui::Button("Remove All")) {
- // TODO(benvanik): removal all is broken - need removebreakpoints or a
- // 'want_removal' flag so that RefreshActiveBreakpoints handles things.
- // Right now if you have 2 breakpoints and hit remove all the second will
- // come back during the next refresh (as the server hasn't removed it yet).
- for (auto& user_breakpoint : user_breakpoint_list_) {
- if (user_breakpoint.active_breakpoint) {
- RETURN_IF_ERROR(debug_client_->RemoveBreakpoint(
- *user_breakpoint.active_breakpoint));
- user_breakpoint.active_breakpoint = nullptr;
- }
- }
- user_breakpoint_list_.clear();
- }
- ImGui::Separator();
-
- ImGui::BeginChild("BreakpointList", ImVec2(-1, -1), false,
- ImGuiWindowFlags_AlwaysVerticalScrollbar);
- std::vector<UserBreakpoint*> dead_breakpoints;
- for (auto& user_breakpoint : user_breakpoint_list_) {
- ASSIGN_OR_RETURN(bool should_keep, DrawBreakpoint(&user_breakpoint));
- if (!should_keep) {
- dead_breakpoints.push_back(&user_breakpoint);
- }
- }
- for (auto* user_breakpoint : dead_breakpoints) {
- for (auto it = user_breakpoint_list_.begin();
- it != user_breakpoint_list_.end(); ++it) {
- if (&*it == user_breakpoint) {
- if (user_breakpoint->active_breakpoint) {
- RETURN_IF_ERROR(debug_client_->RemoveBreakpoint(
- *user_breakpoint->active_breakpoint));
- }
- user_breakpoint_list_.erase(it);
- break;
- }
- }
- }
- ImGui::EndChild();
-
- ImGui::End();
- return OkStatus();
-}
-
-StatusOr<bool> DebugApp::DrawBreakpoint(UserBreakpoint* user_breakpoint) {
- std::string breakpoint_name;
- switch (user_breakpoint->type) {
- case RemoteBreakpoint::Type::kBytecodeFunction:
- breakpoint_name =
- absl::StrCat("[bytecode] ", user_breakpoint->module_name, ":",
- user_breakpoint->function_name, ":",
- user_breakpoint->bytecode_offset);
- if (user_breakpoint->function_ordinal != -1) {
- absl::StrAppend(&breakpoint_name, " @",
- user_breakpoint->function_ordinal);
- }
- break;
- case RemoteBreakpoint::Type::kNativeFunction:
- breakpoint_name =
- absl::StrCat("[native ] ", user_breakpoint->native_function);
- break;
- }
- ImGui::BeginGroup();
- bool is_closing = true;
- bool is_expanded = ImGui::CollapsingHeader(
- ("##" + breakpoint_name).c_str(), &is_closing,
- ImGuiTreeNodeFlags_Framed | ImGuiTreeNodeFlags_NoTreePushOnOpen |
- ImGuiTreeNodeFlags_NoAutoOpenOnLog | ImGuiTreeNodeFlags_OpenOnArrow |
- ImGuiTreeNodeFlags_OpenOnDoubleClick);
- ImGui::SameLine();
- ImGui::Checkbox(breakpoint_name.c_str(), &user_breakpoint->wants_enabled);
- ImGui::EndGroup();
- if (!is_expanded) {
- return is_closing;
- }
- ImGui::PushID(breakpoint_name.c_str());
-
- ImGui::Text("(breakpoint stats/etc)");
-
- ImGui::PopID();
- return is_closing;
-}
-
-Status DebugApp::DrawAddBreakpointDialogs(
- absl::optional<RemoteBreakpoint::Type> add_breakpoint_type) {
- if (add_breakpoint_type.has_value()) {
- switch (add_breakpoint_type.value()) {
- case RemoteBreakpoint::Type::kBytecodeFunction:
- ImGui::OpenPopup("Add Bytecode Function Breakpoint");
- break;
- case RemoteBreakpoint::Type::kNativeFunction:
- ImGui::OpenPopup("Add Native Function Breakpoint");
- break;
- }
- }
- ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(8, 8));
- RETURN_IF_ERROR(DrawAddBytecodeFunctionBreakpointDialog());
- RETURN_IF_ERROR(DrawAddNativeFunctionBreakpointDialog());
- ImGui::PopStyleVar();
- return OkStatus();
-}
-
-Status DebugApp::DrawAddBytecodeFunctionBreakpointDialog() {
- ImGui::SetNextWindowSize(ImVec2(400, 400), ImGuiCond_FirstUseEver);
- bool close_popup = true;
- if (!ImGui::BeginPopupModal("Add Bytecode Function Breakpoint", &close_popup,
- ImGuiWindowFlags_None)) {
- return OkStatus();
- }
- ImGui::BeginGroup();
- ImGui::BeginChild("##data_entry",
- ImVec2(0, -ImGui::GetFrameHeightWithSpacing()));
-
- ImGui::TextWrapped(
- "Adds a breakpoint set on the entry of the function (offset=0).");
- ImGui::Separator();
-
- // TODO(benvanik): fancy list, filtering, etc.
-
- static char module_name[256] = {0};
- ImGui::InputText("Module", module_name, sizeof(module_name));
- ImGui::SetItemDefaultFocus();
-
- static char function_name[256] = {0};
- ImGui::InputText("Function", function_name, sizeof(function_name));
-
- ImGui::EndChild();
- ImGui::Separator();
-
- if (ImGui::Button("Add")) {
- int offset = 0;
- if (FindMatchingUserBreakpointIndex(module_name, function_name, offset) ==
- -1) {
- UserBreakpoint user_breakpoint;
- user_breakpoint.type = RemoteBreakpoint::Type::kBytecodeFunction;
- user_breakpoint.module_name = module_name;
- user_breakpoint.function_name = function_name;
- user_breakpoint.bytecode_offset = offset;
- user_breakpoint.wants_enabled = true;
- user_breakpoint_list_.push_back(std::move(user_breakpoint));
- }
- ImGui::CloseCurrentPopup();
- }
- ImGui::SameLine();
- if (ImGui::Button("Cancel")) {
- ImGui::CloseCurrentPopup();
- }
-
- ImGui::EndGroup();
- ImGui::EndPopup();
- return OkStatus();
-}
-
-Status DebugApp::DrawAddNativeFunctionBreakpointDialog() {
- ImGui::SetNextWindowSize(ImVec2(400, 400), ImGuiCond_FirstUseEver);
- bool close_popup = true;
- if (!ImGui::BeginPopupModal("Add Native Function Breakpoint", &close_popup,
- ImGuiWindowFlags_None)) {
- return OkStatus();
- }
- ImGui::BeginGroup();
- ImGui::BeginChild("##data_entry",
- ImVec2(0, -ImGui::GetFrameHeightWithSpacing()));
-
- ImGui::TextWrapped(
- "Adds a breakpoint set on any call to the given FFI imported "
- "function.");
- ImGui::Separator();
-
- static char function_name[256] = {0};
- ImGui::InputText("Function", function_name, sizeof(function_name));
- ImGui::SetItemDefaultFocus();
-
- ImGui::EndChild();
- ImGui::Separator();
-
- if (ImGui::Button("Add")) {
- UserBreakpoint user_breakpoint;
- user_breakpoint.type = RemoteBreakpoint::Type::kNativeFunction;
- user_breakpoint.native_function = function_name;
- user_breakpoint.wants_enabled = true;
- user_breakpoint_list_.push_back(std::move(user_breakpoint));
- ImGui::CloseCurrentPopup();
- }
- ImGui::SameLine();
- if (ImGui::Button("Cancel")) {
- ImGui::CloseCurrentPopup();
- }
-
- ImGui::EndGroup();
- ImGui::EndPopup();
- return OkStatus();
-}
-
-Status DebugApp::DrawModuleListPanel() {
- static bool is_panel_visible = true;
- if (!ImGui::Begin("Modules", &is_panel_visible, ImGuiWindowFlags_None)) {
- ImGui::End();
- return OkStatus();
- } else if (!debug_client_) {
- ImGui::TextDisabled("disconnected");
- ImGui::End();
- return OkStatus();
- }
- ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(4, 4));
-
- ImGui::BeginGroup();
- ImGui::SetNextItemWidth(ImGui::GetContentRegionAvailWidth());
- static char function_name_filter_text[256] = {0};
- ImGui::InputTextWithHint(
- "##function_name_filter", "Filter functions", function_name_filter_text,
- sizeof(function_name_filter_text), ImGuiInputTextFlags_AutoSelectAll);
- ImGuiTextFilter function_name_filter(function_name_filter_text);
- ImGui::EndGroup();
-
- ImGui::Separator();
-
- ImGui::BeginGroup();
- ImGui::BeginChild("##context_list", ImVec2(0, -ImGui::GetFrameHeight()));
- for (auto* context : debug_client_->contexts()) {
- RETURN_IF_ERROR(DrawContext(*context, function_name_filter));
- }
- ImGui::EndChild();
- ImGui::EndGroup();
-
- ImGui::PopStyleVar();
- ImGui::End();
- return OkStatus();
-}
-
-Status DebugApp::DrawContext(const RemoteContext& context,
- const ImGuiTextFilter& filter) {
- std::string context_name = absl::StrCat("Context ", context.id());
- if (!ImGui::CollapsingHeader(context_name.c_str(), nullptr,
- ImGuiTreeNodeFlags_DefaultOpen |
- ImGuiTreeNodeFlags_Framed |
- ImGuiTreeNodeFlags_NoTreePushOnOpen |
- ImGuiTreeNodeFlags_NoAutoOpenOnLog |
- ImGuiTreeNodeFlags_OpenOnArrow |
- ImGuiTreeNodeFlags_OpenOnDoubleClick)) {
- return OkStatus();
- }
- ImGui::PushID(context.id());
- for (auto* module : context.modules()) {
- RETURN_IF_ERROR(DrawModule(module, filter));
- }
- ImGui::PopID();
- return OkStatus();
-}
-
-Status DebugApp::DrawModule(RemoteModule* module,
- const ImGuiTextFilter& filter) {
- ImGui::PushID(module->name().c_str());
- if (ImGui::TreeNodeEx(module->name().c_str(),
- ImGuiTreeNodeFlags_Framed |
- ImGuiTreeNodeFlags_DefaultOpen |
- ImGuiTreeNodeFlags_OpenOnDoubleClick |
- ImGuiTreeNodeFlags_OpenOnArrow)) {
- if (module->CheckLoadedOrRequest()) {
- for (auto* function : module->functions()) {
- char function_name[128];
- if (function->name().empty()) {
- std::snprintf(function_name, sizeof(function_name), "@%d",
- function->ordinal());
- } else {
- std::snprintf(function_name, sizeof(function_name), "@%d %s",
- function->ordinal(), function->name().c_str());
- }
- if (filter.IsActive() && !filter.PassFilter(function_name)) {
- continue;
- }
- ImGui::PushID(function->ordinal());
- bool is_selected = false;
- if (ImGui::Selectable("##selectable", &is_selected,
- ImGuiSelectableFlags_AllowDoubleClick |
- ImGuiSelectableFlags_DrawFillAvailWidth)) {
- if (is_selected) {
- RETURN_IF_ERROR(NavigateToCodeView(module->name(),
- function->ordinal(), 0,
- NavigationMode::kMatchDocument));
- }
- }
- ImGui::SameLine();
- // TODO(benvanik): detect if breakpoint active at offset 0.
- ImGui::BulletText("%s", function_name);
- ImGui::PopID();
- }
- } else {
- ImGui::TextDisabled("Loading...");
- }
- ImGui::TreePop();
- }
- ImGui::PopID();
- return OkStatus();
-}
-
-Status DebugApp::DrawLocalListPanel() {
- static bool is_panel_visible = true;
- if (!ImGui::Begin("Locals", &is_panel_visible, ImGuiWindowFlags_None)) {
- ImGui::End();
- return OkStatus();
- } else if (!debug_client_) {
- ImGui::TextDisabled("disconnected");
- ImGui::End();
- return OkStatus();
- }
- auto* invocation = GetSelectedInvocation();
- if (!invocation) {
- ImGui::TextDisabled("select a invocation to view locals");
- ImGui::End();
- return OkStatus();
- } else if (invocation->def().frames.empty()) {
- ImGui::TextDisabled("(invocation has no frames)");
- ImGui::End();
- return OkStatus();
- }
- int stack_frame_index = selected_stack_frame_index_.value_or(-1);
- if (stack_frame_index == -1) {
- stack_frame_index = invocation->def().frames.size() - 1;
- }
- auto& stack_frame = invocation->def().frames[stack_frame_index];
-
- // TODO(benvanik): toggle for IREE VM locals vs. source locals.
- for (int i = 0; i < stack_frame->locals.size(); ++i) {
- auto& local = stack_frame->locals[i];
- RETURN_IF_ERROR(DrawLocal(invocation, stack_frame_index, i, *local));
- }
-
- ImGui::End();
- return OkStatus();
-}
-
-Status DebugApp::DrawLocal(RemoteInvocation* invocation, int stack_frame_index,
- int local_index, const rpc::BufferViewDefT& local) {
- // TODO(benvanik): columns and such in fancy table.
- ImGui::Text("l%d", local_index);
- ImGui::SameLine(50);
- if (local.is_valid) {
- auto shape_str =
- absl::StrCat(absl::StrJoin(local.shape, "x"), "x", local.element_size);
- ImGui::Text("%s", shape_str.c_str());
- } else {
- ImGui::TextDisabled("∅");
- }
- // TODO(benvanik): editing options (change shape, change contents, upload).
- // TODO(benvanik): save/download/log options.
- return OkStatus();
-}
-
-Status DebugApp::DrawInvocationListPanel() {
- static bool is_panel_visible = true;
- if (!ImGui::Begin("Invocations", &is_panel_visible, ImGuiWindowFlags_None)) {
- ImGui::End();
- return OkStatus();
- } else if (!debug_client_) {
- ImGui::TextDisabled("disconnected");
- ImGui::End();
- return OkStatus();
- }
- for (auto* invocation : debug_client_->invocations()) {
- RETURN_IF_ERROR(DrawInvocation(*invocation));
- }
- ImGui::End();
- return OkStatus();
-}
-
-Status DebugApp::DrawInvocation(const RemoteInvocation& invocation) {
- // TODO(benvanik): expand if any breakpoints are stopped in invocation.
- if (selected_invocation_id_.has_value() &&
- selected_invocation_id_.value() == invocation.id()) {
- ImGui::SetNextTreeNodeOpen(true);
- }
- if (!ImGui::CollapsingHeader(invocation.name().c_str())) {
- return OkStatus();
- }
- ImGui::PushID(invocation.id());
-
- for (int i = 0; i < invocation.def().frames.size(); ++i) {
- const auto& stack_frame = invocation.def().frames[i];
- ImGui::PushID(i);
- // TODO(benvanik): highlight frames with breakpoints in them.
- bool is_selected = selected_invocation_id_.has_value() &&
- selected_invocation_id_.value() == invocation.id() &&
- selected_stack_frame_index_.has_value() &&
- selected_stack_frame_index_.value() == i;
- if (ImGui::Selectable("##selectable", &is_selected,
- ImGuiSelectableFlags_AllowDoubleClick |
- ImGuiSelectableFlags_DrawFillAvailWidth)) {
- // TODO(benvanik): detect when clicking but already selected.
- if (is_selected) {
- RETURN_IF_ERROR(
- NavigateToCodeView(invocation, i, NavigationMode::kMatchDocument));
- }
- }
- ImGui::SameLine();
- ImGui::Bullet();
- ImGui::SameLine();
- // TODO(benvanik): better naming/etc (resolve function).
- ImGui::Text("%s:%d:%d", stack_frame->module_name.c_str(),
- stack_frame->function_ordinal, stack_frame->offset);
-
- ImGui::PopID();
- }
-
- ImGui::PopID();
- return OkStatus();
-}
-
-DebugApp::CodeViewDocument* DebugApp::FindMatchingDocument(
- absl::string_view module_name, int function_ordinal) {
- for (auto& document : documents_) {
- if (document->function->module()->name() == module_name &&
- document->function->ordinal() == function_ordinal) {
- return document.get();
- }
- }
- return nullptr;
-}
-
-Status DebugApp::NavigateToCodeView(absl::string_view module_name,
- int function_ordinal, int offset,
- NavigationMode navigation_mode) {
- if (!debug_client_) {
- return UnavailableErrorBuilder(IREE_LOC) << "No connection established";
- }
- VLOG(1) << "NavigateToCodeView(" << module_name << ", " << function_ordinal
- << ", " << offset << ")";
- CodeViewDocument* existing_document = nullptr;
- switch (navigation_mode) {
- case NavigationMode::kNewDocument:
- // Fall through and create below.
- break;
- case NavigationMode::kCurrentDocument:
- // Not yet done - treat as a new document.
- break;
- case NavigationMode::kMatchDocument:
- existing_document = FindMatchingDocument(module_name, function_ordinal);
- break;
- }
- if (existing_document) {
- ImGui::SetWindowFocus(existing_document->title.c_str());
- return OkStatus();
- }
-
- // TODO(benvanik): make this common code.
- RETURN_IF_ERROR(debug_client_->GetFunction(
- std::string(module_name), function_ordinal,
- [this, offset](StatusOr<RemoteFunction*> function_or) {
- if (!function_or.ok()) {
- // TODO(benvanik): error dialog.
- CHECK_OK(function_or.status());
- }
- auto* function = function_or.ValueOrDie();
- auto document = absl::make_unique<CodeViewDocument>();
- document->title =
- absl::StrCat(function->module()->name(), ":", function->name());
- document->function = function;
- document->focus_offset = offset;
- ImGui::SetWindowFocus(document->title.c_str());
- documents_.push_back(std::move(document));
- }));
- return OkStatus();
-}
-
-Status DebugApp::NavigateToCodeView(absl::string_view module_name,
- absl::string_view function_name, int offset,
- NavigationMode navigation_mode) {
- if (!debug_client_) {
- return UnavailableErrorBuilder(IREE_LOC) << "No connection established";
- }
- return debug_client_->ResolveFunction(
- std::string(module_name), std::string(function_name),
- [this, navigation_mode, module_name,
- offset](StatusOr<int> function_ordinal) {
- CHECK_OK(function_ordinal.status());
- CHECK_OK(NavigateToCodeView(module_name, function_ordinal.ValueOrDie(),
- offset, navigation_mode));
- });
-}
-
-Status DebugApp::NavigateToCodeView(const RemoteInvocation& invocation,
- int stack_frame_index,
- NavigationMode navigation_mode) {
- if (!debug_client_) {
- return UnavailableErrorBuilder(IREE_LOC) << "No connection established";
- }
- const auto& stack_frame = stack_frame_index == -1
- ? *invocation.def().frames.back()
- : *invocation.def().frames[stack_frame_index];
- selected_invocation_id_ = invocation.id();
- selected_stack_frame_index_ = stack_frame_index;
- return NavigateToCodeView(stack_frame.module_name,
- stack_frame.function_ordinal, stack_frame.offset,
- NavigationMode::kMatchDocument);
-}
-
-Status DebugApp::NavigateToCodeView(const UserBreakpoint& user_breakpoint,
- NavigationMode navigation_mode) {
- if (!debug_client_) {
- return UnavailableErrorBuilder(IREE_LOC) << "No connection established";
- }
- switch (user_breakpoint.type) {
- case RemoteBreakpoint::Type::kBytecodeFunction:
- if (user_breakpoint.function_ordinal != -1) {
- return NavigateToCodeView(
- user_breakpoint.module_name, user_breakpoint.function_ordinal,
- user_breakpoint.bytecode_offset, navigation_mode);
- } else {
- return NavigateToCodeView(
- user_breakpoint.module_name, user_breakpoint.function_name,
- user_breakpoint.bytecode_offset, navigation_mode);
- }
- case RemoteBreakpoint::Type::kNativeFunction:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Navigation to non-bytecode functions unimplemented";
- }
-}
-
-Status DebugApp::DrawCodeViewPanels() {
- // If we've disconnected then we need to clear bodies.
- // TODO(benvanik): allow documents to persist by caching all required info.
- if (!debug_client_) {
- documents_.clear();
- return OkStatus();
- }
-
- std::vector<CodeViewDocument*> closing_documents;
- for (auto& document : documents_) {
- ASSIGN_OR_RETURN(bool is_open, DrawCodeViewDocument(document.get()));
- if (!is_open) {
- closing_documents.push_back(document.get());
- }
- }
- for (auto* closing_document : closing_documents) {
- auto it = std::find_if(
- documents_.begin(), documents_.end(),
- [closing_document](const std::unique_ptr<CodeViewDocument>& document) {
- return document.get() == closing_document;
- });
- documents_.erase(it);
- }
- return OkStatus();
-}
-
-StatusOr<bool> DebugApp::DrawCodeViewDocument(CodeViewDocument* document) {
- ImGui::SetNextWindowDockID(dockspace_id_, ImGuiCond_FirstUseEver);
- ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(0, 0));
- bool is_open = true;
- bool is_visible =
- ImGui::Begin(document->title.c_str(), &is_open, ImGuiWindowFlags_None);
- if (!is_open || !is_visible) {
- ImGui::End();
- ImGui::PopStyleVar();
- return is_open;
- }
- ImGui::PopStyleVar();
-
- auto* remote_module = document->function->module();
- auto* remote_function = document->function;
- if (remote_module->CheckLoadedOrRequest() &&
- remote_function->CheckLoadedOrRequest()) {
- // TODO(benvanik): draw function signature.
- if (remote_function->bytecode()) {
- RETURN_IF_ERROR(DrawBytecodeCodeView(document));
- } else {
- // TODO(benvanik): display native registration info.
- ImGui::TextDisabled("(native)");
- }
- } else {
- ImGui::TextDisabled("loading...");
- }
-
- ImGui::End();
- return true;
-}
-
-Status DebugApp::PrepareBytecodeCodeView(CodeViewDocument* document) {
- auto* remote_module = document->function->module();
- auto* remote_function = document->function;
-
- document->bytecode_info.lines = remote_function->name();
-
- return OkStatus();
-}
-
-Status DebugApp::DrawBytecodeCodeView(CodeViewDocument* document) {
- // Ensure we have cached our line information.
- RETURN_IF_ERROR(PrepareBytecodeCodeView(document));
-
- auto* remote_module = document->function->module();
- auto* remote_function = document->function;
-
- ImGui::BeginGroup();
- ImGui::BeginChild("##bytecode_view", ImVec2(0, 0), false,
- ImGuiWindowFlags_AlwaysVerticalScrollbar);
- ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(0, 0));
-
- // TODO(benvanik): cache breakpoints for this function for faster lookup.
-
- auto& bytecode_info = document->bytecode_info;
- ImGuiListClipper clipper(bytecode_info.lines.size(),
- ImGui::GetTextLineHeightWithSpacing());
- while (clipper.Step()) {
- for (int i = clipper.DisplayStart; i < clipper.DisplayEnd; ++i) {
- ImGui::PushID(i);
-
- // TODO(benvanik): lookup line info.
- int bytecode_offset = 0;
- int breakpoint_index = FindMatchingUserBreakpointIndex(
- remote_module->name(), remote_function->ordinal(), bytecode_offset);
- bool has_breakpoint = breakpoint_index != -1;
- bool active_on_any_invocation = false;
- bool active_on_selected_invocation = false;
-
- ImGui::Dummy(ImVec2(4, 0));
-
- // Gutter breakpoint button.
- ImGui::SameLine();
- if (has_breakpoint) {
- PushButtonHue(0.0f); // Red
- if (ImGui::Button(" ##toggle_breakpoint")) {
- CHECK_GE(breakpoint_index, 0);
- auto& user_breakpoint = user_breakpoint_list_[breakpoint_index];
- if (user_breakpoint.active_breakpoint) {
- RETURN_IF_ERROR(debug_client_->RemoveBreakpoint(
- *user_breakpoint.active_breakpoint));
- }
- user_breakpoint_list_.erase(user_breakpoint_list_.begin() +
- breakpoint_index);
- }
- PopButtonStyle();
- if (ImGui::IsItemHovered()) {
- ImGui::SetTooltip("Remove the breakpoint at this offset.");
- }
- } else {
- PushButtonColor(ImGui::GetStyleColorVec4(ImGuiCol_ChildBg));
- if (ImGui::Button(" ##toggle_breakpoint")) {
- UserBreakpoint user_breakpoint;
- user_breakpoint.type = RemoteBreakpoint::Type::kBytecodeFunction;
- user_breakpoint.module_name = remote_module->name();
- user_breakpoint.function_name = remote_function->name();
- user_breakpoint.bytecode_offset = bytecode_offset;
- user_breakpoint.wants_enabled = true;
- user_breakpoint_list_.push_back(std::move(user_breakpoint));
- }
- PopButtonStyle();
- if (ImGui::IsItemHovered()) {
- ImGui::SetTooltip("Add a breakpoint at this offset.");
- }
- }
-
- // Active execution chevron (shows when active or any invocation is
- // executing this region).
- ImGui::SameLine();
- if (active_on_selected_invocation) {
- // The selected invocation is active here.
- ImGui::TextColored(ImGui::GetStyleColorVec4(ImGuiCol_SeparatorActive),
- " > ");
- } else if (active_on_any_invocation) {
- // At least one other invocation is active here.
- ImGui::TextColored(ImGui::GetStyleColorVec4(ImGuiCol_Separator), " > ");
- } else {
- // Not active.
- ImGui::Text(" ");
- }
-
- // Line contents.
- ImGui::SameLine();
- ImGui::Text("%s", bytecode_info.lines[i].c_str());
-
- if (document->focus_offset.has_value() &&
- bytecode_offset == document->focus_offset.value()) {
- document->bytecode_offset = document->focus_offset.value();
- document->focus_offset = {};
- ImGui::SetScrollHereY();
- }
-
- ImGui::PopID();
- }
- }
-
- ImGui::PopStyleVar();
- ImGui::EndChild();
- ImGui::EndGroup();
-
- return OkStatus();
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/tools/debugger/debug_app.h b/iree/tools/debugger/debug_app.h
deleted file mode 100644
index e970f5f..0000000
--- a/iree/tools/debugger/debug_app.h
+++ /dev/null
@@ -1,200 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_TOOLS_DEBUGGER_DEBUG_APP_H_
-#define IREE_TOOLS_DEBUGGER_DEBUG_APP_H_
-
-#include <SDL.h>
-
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "iree/base/status.h"
-#include "iree/rt/debug/debug_client.h"
-
-// NOTE: order matters here, imgui must come first:
-#include "third_party/dear_imgui/imgui.h"
-// NOTE: must follow imgui.h:
-#include "third_party/dear_imgui/examples/imgui_impl_opengl3.h"
-#include "third_party/dear_imgui/examples/imgui_impl_sdl.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-// Debug client app UI.
-// Uses a DebugClient to communicate with a remote DebugServer and ImGui to
-// display a nifty UI.
-//
-// See the ImGui site for more info: https://github.com/ocornut/imgui
-// The most useful thing is the imgui_demo.cpp file that contains example usage
-// of most features.
-class DebugApp : private DebugClient::Listener {
- public:
- struct UserBreakpoint {
- RemoteBreakpoint::Type type = RemoteBreakpoint::Type::kBytecodeFunction;
- const RemoteBreakpoint* active_breakpoint = nullptr;
- bool wants_enabled = true;
- bool is_enabling = false;
- // TODO(benvanik): reuse BreakpointDef here?
- std::string module_name;
- std::string function_name;
- int function_ordinal = -1;
- int bytecode_offset = 0;
- std::string native_function;
- };
-
- static void PumpMainLoopThunk(void* arg);
-
- DebugApp(SDL_Window* window, SDL_GLContext gl_context,
- const char* glsl_version);
- ~DebugApp();
-
- // Connects to the service at the specified address.
- Status Connect(absl::string_view service_address);
- // Disconnects from the currently connected service, if any.
- Status Disconnect();
-
- // Returns true if the remote service is paused at our request.
- bool is_paused() const;
-
- // Pumps the main UI loop once.
- // This polls the DebugClient, SDL input, and renders the UI.
- // It should be called as frequently as possible to ensure snappy UI updates.
- // Returns CancelledError if the app is being closed by the user.
- Status PumpMainLoop();
-
- // Defines how NavigationToCodeView methods behave.
- enum class NavigationMode {
- // The target will be opened in a new document tab.
- kNewDocument,
- // The target will be opened in the current document tab, replacing the
- // current contents.
- kCurrentDocument,
- // The target will be opened in a document tab that mostly matches (like
- // the same function in a module at a different offset), otherwise a new
- // document will be opened.
- kMatchDocument,
- };
-
- // Navigates to a particular function offset based on resolution of the given
- // arguments. Navigation may happen asynchronously if targets need to be
- // resolved or contents fetched.
- Status NavigateToCodeView(absl::string_view module_name, int function_ordinal,
- int offset, NavigationMode navigation_mode);
- Status NavigateToCodeView(absl::string_view module_name,
- absl::string_view function_name, int offset,
- NavigationMode navigation_mode);
- Status NavigateToCodeView(const RemoteInvocation& invocation,
- int stack_frame_index,
- NavigationMode navigation_mode);
- Status NavigateToCodeView(const UserBreakpoint& user_breakpoint,
- NavigationMode navigation_mode);
-
- private:
- struct CodeViewDocument {
- // Document display title (and ID).
- std::string title;
- // Function (and offset within the function) being displayed.
- RemoteFunction* function = nullptr;
- int bytecode_offset = 0;
- // Set to a bytecode offset to have the document focus there.
- absl::optional<int> focus_offset;
- // Cached info for bytecode display.
- struct {
- std::vector<std::string> lines;
- } bytecode_info;
- };
-
- CodeViewDocument* FindMatchingDocument(absl::string_view module_name,
- int function_ordinal);
- RemoteInvocation* GetSelectedInvocation() const;
-
- Status RefreshActiveBreakpoints();
- bool IsStoppedAtBreakpoint(const UserBreakpoint& user_breakpoint) const;
- int FindMatchingUserBreakpointIndex(absl::string_view module_name,
- int function_ordinal, int offset);
- int FindMatchingUserBreakpointIndex(absl::string_view module_name,
- absl::string_view function_name,
- int offset);
- Status ResumeFromBreakpoint(UserBreakpoint* user_breakpoint);
-
- Status OnContextRegistered(const RemoteContext& context) override;
- Status OnContextUnregistered(const RemoteContext& context) override;
- Status OnModuleLoaded(const RemoteContext& context,
- const RemoteModule& module) override;
- Status OnInvocationRegistered(const RemoteInvocation& invocation) override;
- Status OnInvocationUnregistered(const RemoteInvocation& invocation) override;
- Status OnBreakpointHit(const RemoteBreakpoint& breakpoint,
- const RemoteInvocation& invocation) override;
-
- Status LayoutInitialDockSpace();
-
- Status DrawUI();
- Status DrawMainMenu();
- Status DrawToolbar();
-
- Status DrawBreakpointListPanel();
- StatusOr<bool> DrawBreakpoint(UserBreakpoint* user_breakpoint);
- Status DrawAddBreakpointDialogs(
- absl::optional<RemoteBreakpoint::Type> add_breakpoint_type);
- Status DrawAddBytecodeFunctionBreakpointDialog();
- Status DrawAddNativeFunctionBreakpointDialog();
-
- Status DrawModuleListPanel();
- Status DrawContext(const RemoteContext& context,
- const ImGuiTextFilter& filter);
- Status DrawModule(RemoteModule* module, const ImGuiTextFilter& filter);
-
- Status DrawLocalListPanel();
- Status DrawLocal(RemoteInvocation* invocation, int stack_frame_index,
- int local_index, const rpc::BufferViewDefT& local);
-
- Status DrawInvocationListPanel();
- Status DrawInvocation(const RemoteInvocation& invocation);
-
- Status DrawCodeViewPanels();
- StatusOr<bool> DrawCodeViewDocument(CodeViewDocument* document);
- Status PrepareBytecodeCodeView(CodeViewDocument* document);
- Status DrawBytecodeCodeView(CodeViewDocument* document);
-
- SDL_Window* window_ = nullptr;
- SDL_GLContext gl_context_ = nullptr;
-
- ImGuiID dockspace_id_;
- ImGuiID dock_top_id_;
- ImGuiID dock_left_id_;
- ImGuiID dock_bottom_id_;
- ImGuiID dock_bottom_left_id_;
- ImGuiID dock_bottom_right_id_;
- ImGuiID dock_right_id_;
- ImGuiID dock_content_id_;
-
- std::unique_ptr<DebugClient> debug_client_;
- std::vector<UserBreakpoint> user_breakpoint_list_;
-
- bool is_paused_ = false;
- std::vector<const RemoteBreakpoint*> hit_breakpoints_;
- bool is_stepping_ = false;
-
- absl::optional<int> selected_invocation_id_;
- absl::optional<int> selected_stack_frame_index_;
-
- std::vector<std::unique_ptr<CodeViewDocument>> documents_;
-};
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_TOOLS_DEBUGGER_DEBUG_APP_H_
diff --git a/iree/tools/debugger/debug_app_embedded.cc b/iree/tools/debugger/debug_app_embedded.cc
deleted file mode 100644
index f44692c..0000000
--- a/iree/tools/debugger/debug_app_embedded.cc
+++ /dev/null
@@ -1,153 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/tools/debugger/debug_app_embedded.h"
-
-#include <SDL.h>
-
-#include <thread> // NOLINT
-
-#include "absl/base/thread_annotations.h"
-#include "absl/memory/memory.h"
-#include "absl/synchronization/mutex.h"
-#include "iree/base/memory.h"
-#include "iree/base/status.h"
-#include "iree/tools/debugger/debug_app.h"
-#include "third_party/SDL2/include/SDL_thread.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-class InProcessEmbeddedDebugger : public EmbeddedDebugger {
- public:
- explicit InProcessEmbeddedDebugger(std::unique_ptr<DebugApp> app)
- : app_(std::move(app)) {
- thread_ =
- SDL_CreateThread(&ThreadMainThunk, "InProcessEmbeddedDebugger", this);
- }
-
- ~InProcessEmbeddedDebugger() override {
- VLOG(1) << "Setting shutdown flag and waiting on thread...";
- shutdown_flag_ = true;
- int status = 0;
- SDL_WaitThread(thread_, &status);
- VLOG(1) << "Thread shutdown, killing app...";
- app_.reset();
- }
-
- Status AwaitClose() override {
- await_mutex_.LockWhen(absl::Condition(
- +[](bool* is_shutdown) { return *is_shutdown; }, &is_shutdown_));
- auto status = std::move(shutdown_status_);
- await_mutex_.Unlock();
- return status;
- }
-
- private:
- static int ThreadMainThunk(void* arg) {
- return reinterpret_cast<InProcessEmbeddedDebugger*>(arg)->ThreadMain();
- }
-
- int ThreadMain() {
- VLOG(1) << "Thread entry";
- while (!shutdown_flag_) {
- auto status = app_->PumpMainLoop();
- if (IsCancelled(status)) {
- shutdown_flag_ = true;
- break;
- } else if (!shutdown_flag_ && !status.ok()) {
- absl::MutexLock lock(&await_mutex_);
- shutdown_status_ = std::move(status);
- // TODO(benvanik): don't check unless no one is watching.
- CHECK_OK(shutdown_status_);
- }
- }
- app_.reset();
- {
- absl::MutexLock lock(&await_mutex_);
- is_shutdown_ = true;
- }
- VLOG(1) << "Thread exit";
- return 0;
- }
-
- std::unique_ptr<DebugApp> app_;
- SDL_Thread* thread_;
- std::atomic<bool> shutdown_flag_ = {false};
- absl::Mutex await_mutex_;
- bool is_shutdown_ ABSL_GUARDED_BY(await_mutex_) = false;
- Status shutdown_status_ ABSL_GUARDED_BY(await_mutex_);
-};
-
-StatusOr<std::unique_ptr<EmbeddedDebugger>> LaunchDebugger() {
- return AttachDebugger("");
-}
-
-StatusOr<std::unique_ptr<EmbeddedDebugger>> AttachDebugger(
- absl::string_view service_address) {
- LOG(INFO) << "Launching embedded debugger; service=" << service_address;
- // Workaround for terrible bad SDL/graphics driver leaks.
- IREE_DISABLE_LEAK_CHECKS();
-
- if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
- return InternalErrorBuilder(IREE_LOC)
- << "Unable to init SDL: " << SDL_GetError();
- }
-
-#if __APPLE__
- // GL 3.2 Core + GLSL 150
- const char* glsl_version = "#version 150";
- SDL_GL_SetAttribute(
- SDL_GL_CONTEXT_FLAGS,
- SDL_GL_CONTEXT_FORWARD_COMPATIBLE_FLAG); // Always required on Mac
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 3);
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 2);
-#else
- // GL 3.0 + GLSL 130
- const char* glsl_version = "#version 130";
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_FLAGS, 0);
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 3);
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 0);
-#endif
-
- SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
- SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 24);
- SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 8);
- SDL_DisplayMode current;
- SDL_GetCurrentDisplayMode(0, ¤t);
- SDL_WindowFlags window_flags = (SDL_WindowFlags)(
- SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
- SDL_Window* window =
- SDL_CreateWindow("IREE Debugger (embedded)", SDL_WINDOWPOS_CENTERED,
- SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
- SDL_GLContext gl_context = SDL_GL_CreateContext(window);
- SDL_GL_MakeCurrent(nullptr, nullptr);
-
- IREE_ENABLE_LEAK_CHECKS();
-
- auto app = absl::make_unique<DebugApp>(window, gl_context, glsl_version);
- if (!service_address.empty()) {
- RETURN_IF_ERROR(app->Connect(service_address));
- }
-
- auto handle = absl::make_unique<InProcessEmbeddedDebugger>(std::move(app));
- return handle;
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/tools/debugger/debug_app_embedded.h b/iree/tools/debugger/debug_app_embedded.h
deleted file mode 100644
index 8e7c15f..0000000
--- a/iree/tools/debugger/debug_app_embedded.h
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_
-#define IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_
-
-#include <memory>
-
-#include "absl/strings/string_view.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-// RAII handle for keeping the debugger alive.
-// When the instance is destroyed the debugger app will be closed.
-class EmbeddedDebugger {
- public:
- virtual ~EmbeddedDebugger() = default;
-
- // Blocks the caller until the debugger is closed by the user.
- virtual Status AwaitClose() = 0;
-};
-
-// Launches the debugger app.
-// Returns a handle that can be used to wait for the debugger to close or
-// force it to close.
-StatusOr<std::unique_ptr<EmbeddedDebugger>> LaunchDebugger();
-
-// Launches the debugger app and attaches to the given server address.
-// Returns a handle that can be used to wait for the debugger to close or
-// force it to close.
-StatusOr<std::unique_ptr<EmbeddedDebugger>> AttachDebugger(
- absl::string_view service_address);
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_
diff --git a/iree/tools/debugger/debug_app_main_emscripten.cc b/iree/tools/debugger/debug_app_main_emscripten.cc
deleted file mode 100644
index 1df8d98..0000000
--- a/iree/tools/debugger/debug_app_main_emscripten.cc
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Emscripten debug_app entry point.
-// Though we are using SDL here we need to do some emscripten-specific magic to
-// handle the different main looping mode (as we can't block in main() like on
-// other platforms) as well as support some emscripten-specific features for
-// file upload/download/etc.
-
-#include <SDL.h>
-#include <emscripten.h>
-
-#include "iree/base/init.h"
-#include "iree/tools/debugger/debug_app.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-extern "C" int main(int argc, char** argv) {
- InitializeEnvironment(&argc, &argv);
-
- if (SDL_Init(SDL_INIT_VIDEO) != 0) {
- printf("Error: %s\n", SDL_GetError());
- return -1;
- }
-
- const char* glsl_version = "#version 100";
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_FLAGS, 0);
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_ES);
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 2);
- SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 0);
-
- SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
- SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 24);
- SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 8);
- SDL_DisplayMode current;
- SDL_GetCurrentDisplayMode(0, ¤t);
- SDL_WindowFlags window_flags = (SDL_WindowFlags)(
- SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
- SDL_Window* window =
- SDL_CreateWindow("IREE Debugger", SDL_WINDOWPOS_CENTERED,
- SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
- SDL_GLContext gl_context = SDL_GL_CreateContext(window);
- if (!gl_context) {
- printf("Failed to initialize WebGL context!\n");
- return 1;
- }
-
- auto app = absl::make_unique<DebugApp>(window, gl_context, glsl_version);
- ::emscripten_set_main_loop_arg(DebugApp::PumpMainLoopThunk, app.release(), 0,
- false);
- return 0;
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/tools/debugger/debug_app_main_native.cc b/iree/tools/debugger/debug_app_main_native.cc
deleted file mode 100644
index d364d14..0000000
--- a/iree/tools/debugger/debug_app_main_native.cc
+++ /dev/null
@@ -1,45 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Native (linux/etc) debug_app entry point.
-// This should work on any platform with pthreads and SDL support.
-
-#include "iree/base/init.h"
-#include "iree/base/status.h"
-#include "iree/tools/debugger/debug_app_embedded.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-Status Run() {
- ASSIGN_OR_RETURN(auto handle, LaunchDebugger());
- RETURN_IF_ERROR(handle->AwaitClose());
- handle.reset();
- return OkStatus();
-}
-
-extern "C" int main(int argc, char** argv) {
- InitializeEnvironment(&argc, &argv);
- auto status = Run();
- if (!status.ok()) {
- LOG(ERROR) << "Debugger error: " << status;
- return 1;
- }
- return 0;
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/tools/debugger/debug_cli_main.cc b/iree/tools/debugger/debug_cli_main.cc
deleted file mode 100644
index 4970a70..0000000
--- a/iree/tools/debugger/debug_cli_main.cc
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "absl/flags/flag.h"
-#include "iree/base/init.h"
-#include "iree/base/status.h"
-#include "iree/tools/debugger/debug_prompt.h"
-
-ABSL_FLAG(std::string, debug_service_uri, "0.0.0.0:6000",
- "IP/port of debug service to connect to.");
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-Status Run() {
- // TODO(benvanik): retry until connected? would allow auto-build reconnects.
- return AttachDebugPrompt(absl::GetFlag(FLAGS_debug_service_uri));
-}
-
-extern "C" int main(int argc, char** argv) {
- InitializeEnvironment(&argc, &argv);
- CHECK_OK(Run());
- return 0;
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/tools/debugger/debug_prompt.cc b/iree/tools/debugger/debug_prompt.cc
deleted file mode 100644
index aceda80..0000000
--- a/iree/tools/debugger/debug_prompt.cc
+++ /dev/null
@@ -1,90 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/tools/debugger/debug_prompt.h"
-
-#include "iree/base/status.h"
-#include "iree/rt/debug/debug_client.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-namespace {
-
-class DebugPrompt : private DebugClient::Listener {
- public:
- Status Connect(absl::string_view debug_service_uri) {
- // Connect to the debug service.
- ASSIGN_OR_RETURN(debug_client_,
- DebugClient::Connect(debug_service_uri, this));
- return OkStatus();
- }
-
- Status Run() {
- // Query commands, transmit requests, and dispatch responses.
- while (true) {
- RETURN_IF_ERROR(debug_client_->Poll());
-
- // TODO(benvanik): ask for a command.
- // TODO(benvanik): display stuff.
- }
- }
-
- private:
- Status OnContextRegistered(const RemoteContext& context) override {
- // Ack.
- return debug_client_->MakeReady();
- }
-
- Status OnContextUnregistered(const RemoteContext& context) override {
- // Ack.
- return debug_client_->MakeReady();
- }
-
- Status OnModuleLoaded(const RemoteContext& context,
- const RemoteModule& module) override {
- // Ack.
- return debug_client_->MakeReady();
- }
-
- Status OnInvocationRegistered(const RemoteInvocation& invocation) override {
- // Ack.
- return debug_client_->MakeReady();
- }
-
- Status OnInvocationUnregistered(const RemoteInvocation& invocation) override {
- // Ack.
- return debug_client_->MakeReady();
- }
-
- Status OnBreakpointHit(const RemoteBreakpoint& breakpoint,
- const RemoteInvocation& invocation) override {
- // Ack.
- return debug_client_->MakeReady();
- }
-
- std::unique_ptr<DebugClient> debug_client_;
-};
-
-} // namespace
-
-Status AttachDebugPrompt(absl::string_view debug_service_uri) {
- DebugPrompt debug_prompt;
- RETURN_IF_ERROR(debug_prompt.Connect(debug_service_uri));
- return debug_prompt.Run();
-}
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
diff --git a/iree/tools/debugger/debug_prompt.h b/iree/tools/debugger/debug_prompt.h
deleted file mode 100644
index 2c9bf7a..0000000
--- a/iree/tools/debugger/debug_prompt.h
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_
-#define IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_
-
-#include "absl/strings/string_view.h"
-#include "iree/base/status.h"
-
-namespace iree {
-namespace rt {
-namespace debug {
-
-// TODO(benvanik): take stdin/stdout as arguments.
-// Attaches a debug prompt reading stdin for commands and printing results to
-// stdout. The calling thread will block until the debugger is exited or the
-// debug service closes.
-Status AttachDebugPrompt(absl::string_view debug_service_uri);
-
-} // namespace debug
-} // namespace rt
-} // namespace iree
-
-#endif // IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_
diff --git a/iree/tools/run_mlir_main.cc b/iree/tools/run_mlir_main.cc
deleted file mode 100644
index 3b7b3f9..0000000
--- a/iree/tools/run_mlir_main.cc
+++ /dev/null
@@ -1,360 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// IREE source.mlir -> execution output test runner.
-// This is meant to be called from LIT for FileCheck tests, and tries to match
-// the interface of mlir-opt (featuring -split-input-file, etc) so it's easier
-// to work with there. If you want a more generalized runner for standalone
-// precompiled IREE modules use //third_party/iree/tools:run_module.
-//
-// By default all exported functions in the module will be run in order.
-// All input values, provided via -input-values, will be passed to the
-// functions (this means all input signatures must match). Results from the
-// executed functions will be printed to stdout for checking.
-// Use -output_types to set the function output data types, which like args will
-// be used for all functions executed.
-//
-// Example input:
-// // RUN: iree-run %s | FileCheck %s
-// // CHECK-LABEL: @foo
-// // CHECK: 1xf32: 2
-// func @foo() -> memref<f32> attributes {iree.module.export} {
-// %0 = "iree.constant"() {value: dense<tensor<f32>, 2.0>} : () -> memref<f32>
-// return %0 : memref<f32>
-// }
-
-#include <iostream>
-
-#include "absl/flags/flag.h"
-#include "absl/strings/numbers.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/str_split.h"
-#include "absl/strings/string_view.h"
-#include "iree/base/init.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h"
-#include "iree/hal/buffer_view_string_util.h"
-#include "iree/hal/driver_registry.h"
-#include "iree/rt/context.h"
-#include "iree/rt/debug/debug_server_flags.h"
-#include "iree/rt/instance.h"
-#include "iree/rt/invocation.h"
-#include "iree/rt/module.h"
-#include "iree/rt/module_printer.h"
-#include "iree/schemas/module_def_generated.h"
-#include "iree/vm/sequencer_module.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/SourceMgr.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
-#include "mlir/Parser.h"
-#include "mlir/Support/FileUtilities.h"
-
-ABSL_FLAG(bool, split_input_file, true,
- "Split the input file into multiple modules.");
-
-ABSL_FLAG(std::string, target_backends, "",
- "Comma-separated list of target backends to translate executables "
- "into. Omit to translate using all linked-in backend translators.");
-ABSL_FLAG(
- bool, export_all, true,
- "Automatically add the iree.module.export attribute to all functions.");
-
-ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values.");
-ABSL_FLAG(std::string, output_types, "",
- "Output data types (comma delimited list of b/i/u/f for "
- "binary/signed int/unsigned int/float).");
-
-// TODO(benvanik): is there a more canonical flag we can use?
-ABSL_FLAG(bool, print_mlir, true, "Prints MLIR IR during translation.");
-
-ABSL_FLAG(bool, print_bytecode, false,
- "Prints IREE bytecode after translation.");
-
-ABSL_FLAG(bool, run, true,
- "Option to run the file. Setting it to false just compiles it.");
-
-namespace iree {
-namespace {
-
-using ::iree::hal::BufferView;
-using ::iree::rt::Function;
-using ::iree::rt::Module;
-
-// Returns a driver name capable of handling input from the given backend.
-std::string BackendToDriverName(std::string backend) {
- size_t dash = backend.find('-');
- if (dash == std::string::npos) {
- return backend;
- } else {
- return backend.substr(0, dash);
- }
-}
-
-// Prepares a module for evaluation by running MLIR import and IREE translation.
-StatusOr<ref_ptr<Module>> PrepareModule(
- std::string target_backend,
- std::unique_ptr<llvm::MemoryBuffer> file_buffer) {
- mlir::MLIRContext context;
-
- // Parse input MLIR module.
- llvm::SourceMgr source_mgr;
- source_mgr.AddNewSourceBuffer(std::move(file_buffer), llvm::SMLoc());
- mlir::OwningModuleRef mlir_module =
- mlir::parseSourceFile(source_mgr, &context);
-
- if (absl::GetFlag(FLAGS_export_all)) {
- for (auto function : mlir_module->getOps<mlir::FuncOp>()) {
- function.setAttr("iree.module.export", mlir::UnitAttr::get(&context));
- }
- }
-
- // Translate from MLIR to IREE bytecode.
- mlir::iree_compiler::ModuleTranslationOptions options;
- options.print_mlir = absl::GetFlag(FLAGS_print_mlir);
- options.target_backends = {target_backend};
- auto iree_module_bytes =
- mlir::iree_compiler::translateMlirToIreeSequencerModule(mlir_module.get(),
- options);
- if (iree_module_bytes.empty()) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << "Error translating MLIR to an IREE sequencer module";
- }
-
- if (absl::GetFlag(FLAGS_print_mlir)) {
- mlir_module->dump();
- }
-
- // Wrap module in a file handle.
- ASSIGN_OR_RETURN(auto iree_module_file,
- vm::ModuleFile::FromBuffer(ModuleDefIdentifier(),
- std::move(iree_module_bytes)));
- return vm::SequencerModule::FromFile(std::move(iree_module_file));
-}
-
-// Parses a list of input shapes and values from a string of newline-separated
-// inputs. Expects the contents to have one value per line with each value
-// listed as
-// [shape]xtype=[value]
-// Example:
-// 4x4xi8=0,1,2,3
-StatusOr<std::vector<BufferView>> ParseInputsFromFlags(
- hal::Allocator *allocator) {
- std::string file_contents =
- absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}});
- std::vector<BufferView> inputs;
- std::vector<std::string> lines = absl::StrSplit(
- file_contents, absl::ByAnyChar("\n;"), absl::SkipWhitespace());
- for (const auto &line : lines) {
- ASSIGN_OR_RETURN(auto input,
- hal::ParseBufferViewFromString(line, allocator));
- inputs.push_back(input);
- }
- return inputs;
-}
-
-// Outputs all results from the function to stdout in IREE BufferView format.
-Status OutputFunctionResults(const Function &function,
- absl::Span<BufferView> results) {
- std::vector<std::string> output_types =
- absl::StrSplit(absl::GetFlag(FLAGS_output_types), absl::ByAnyChar(", "),
- absl::SkipWhitespace());
- if (!output_types.empty() && output_types.size() != results.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "--output_types= specified but has " << output_types.size()
- << " types when the function returns " << results.size();
- }
-
- for (int i = 0; i < results.size(); ++i) {
- const auto &result = results[i];
- auto print_mode = hal::BufferViewPrintMode::kFloatingPoint;
- if (!output_types.empty()) {
- ASSIGN_OR_RETURN(print_mode,
- hal::ParseBufferViewPrintMode(output_types[i]));
- }
- ASSIGN_OR_RETURN(auto result_str,
- hal::PrintBufferViewToString(result, print_mode, 1024));
- LOG(INFO) << "result[" << i << "]: " << result.buffer->DebugString();
- std::cout << result_str << "\n";
- }
-
- return OkStatus();
-}
-
-// Evaluates a single function in its own fiber, printing the results to stdout.
-Status EvaluateFunction(const ref_ptr<rt::Context> &context,
- hal::Allocator *allocator, const Function &function) {
- std::cout << "EXEC @" << function.name() << std::endl;
-
- // Create invocation that will perform the execution.
- ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(allocator));
- ASSIGN_OR_RETURN(
- auto invocation,
- rt::Invocation::Create(add_ref(context), function, make_ref<rt::Policy>(),
- {}, absl::MakeConstSpan(arguments)));
-
- // Wait until invocation completes.
- RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture()));
-
- // Print outputs.
- ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults());
- RETURN_IF_ERROR(OutputFunctionResults(function, absl::MakeSpan(results)));
-
- return OkStatus();
-}
-
-// Evaluates all exported functions within given module.
-Status EvaluateFunctions(absl::string_view target_backend,
- ref_ptr<Module> module) {
- // Create the context we'll use for this (ensuring that we can't interfere
- // with other running evaluations, such as when in a multithreaded test
- // runner).
- ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags());
- auto instance = make_ref<rt::Instance>(std::move(debug_server));
- ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
- target_backend));
- ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
- RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
-
- if (absl::GetFlag(FLAGS_print_bytecode)) {
- RETURN_IF_ERROR(rt::PrintModuleToStream(
- *module, rt::PrintModuleFlag::kDisassemble, &std::cout));
- }
-
- // Evaluate all exported functions.
- auto policy = make_ref<rt::Policy>();
- auto run_function = [&](int ordinal) -> Status {
- // Setup a new context for this invocation.
- auto context = make_ref<rt::Context>(add_ref(instance), add_ref(policy));
- RETURN_IF_ERROR(context->RegisterModule(add_ref(module)));
-
- // Invoke the function and print results.
- ASSIGN_OR_RETURN(auto function,
- module->LookupFunctionByOrdinal(
- rt::Function::Linkage::kExport, ordinal));
- RETURN_IF_ERROR(EvaluateFunction(context, device->allocator(), function));
- return OkStatus();
- };
-
- Status evaluate_status = OkStatus();
- for (int i = 0; i < module->signature().export_function_count(); ++i) {
- evaluate_status = run_function(i);
- if (!evaluate_status.ok()) {
- break;
- }
- }
-
- RETURN_IF_ERROR(instance->device_manager()->UnregisterDevice(device.get()));
- device.reset();
- driver.reset();
-
- return evaluate_status;
-}
-
-// Translates and runs a single LLVM file buffer.
-Status EvaluateFile(std::unique_ptr<llvm::MemoryBuffer> file_buffer) {
- std::vector<std::string> target_backends;
- if (absl::GetFlag(FLAGS_target_backends).empty()) {
- target_backends =
- hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers();
- } else {
- // We need to map specific backends names to drivers (like 'vulkan-spirv' to
- // the driver 'vulkan').
- target_backends = absl::StrSplit(absl::GetFlag(FLAGS_target_backends), ',');
- }
-
- for (auto target_backend : target_backends) {
- // Prepare the module for execution and evaluate it.
- auto cloned_file_buffer = llvm::MemoryBuffer::getMemBufferCopy(
- file_buffer->getBuffer(), file_buffer->getBufferIdentifier());
- ASSIGN_OR_RETURN(auto module, PrepareModule(target_backend + '*',
- std::move(cloned_file_buffer)));
- if (!absl::GetFlag(FLAGS_run)) {
- continue;
- }
- RETURN_IF_ERROR(EvaluateFunctions(BackendToDriverName(target_backend),
- std::move(module)));
- }
-
- return OkStatus();
-}
-
-// Runs the given .mlir file based on the current flags.
-Status RunFile(std::string mlir_filename) {
- // Load input file/from stdin.
- std::string error_message;
- auto file = mlir::openInputFile(mlir_filename, &error_message);
- if (!file) {
- return NotFoundErrorBuilder(IREE_LOC)
- << "Unable to open input file " << mlir_filename << ": "
- << error_message;
- }
-
- if (!absl::GetFlag(FLAGS_split_input_file)) {
- // Use entire buffer as a single module.
- return EvaluateFile(std::move(file));
- }
-
- // Split the buffer into separate modules and evaluate independently.
- // This matches the -split-input-file arg to mlir-opt.
- const char kSplitMarker[] = "// -----\n";
- auto *full_buffer = file.get();
- llvm::SmallVector<llvm::StringRef, 8> source_buffers;
- full_buffer->getBuffer().split(source_buffers, kSplitMarker);
-
- // Add the original buffer to the source manager.
- llvm::SourceMgr fileSourceMgr;
- fileSourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
-
- // Process each chunk in turn. Only return the first error (but log all).
- Status any_failure;
- for (auto &sub_source_buffer : source_buffers) {
- auto split_loc = llvm::SMLoc::getFromPointer(sub_source_buffer.data());
- unsigned split_line = fileSourceMgr.getLineAndColumn(split_loc).first;
- auto sub_buffer = llvm::MemoryBuffer::getMemBufferCopy(
- sub_source_buffer, full_buffer->getBufferIdentifier() +
- llvm::Twine(" split at line #") +
- llvm::Twine(split_line));
- auto sub_failure = EvaluateFile(std::move(sub_buffer));
- if (!sub_failure.ok()) {
- LOG(ERROR) << sub_failure;
- if (any_failure.ok()) {
- any_failure = std::move(sub_failure);
- }
- }
- }
-
- return any_failure;
-}
-
-} // namespace
-
-extern "C" int main(int argc, char **argv) {
- InitializeEnvironment(&argc, &argv);
- if (argc < 2) {
- LOG(ERROR) << "Must supply an input .mlir file.";
- return 1;
- }
- auto status = RunFile(argv[1]);
- if (!status.ok()) {
- std::cerr << "ERROR running file (" << argv[1] << "): " << status << "\n";
- return 1;
- }
- return 0;
-}
-
-} // namespace iree
diff --git a/iree/tools/run_module_main.cc b/iree/tools/run_module_main.cc
deleted file mode 100644
index be634ee..0000000
--- a/iree/tools/run_module_main.cc
+++ /dev/null
@@ -1,183 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <iostream>
-#include <vector>
-
-#include "absl/flags/flag.h"
-#include "absl/strings/numbers.h"
-#include "absl/strings/str_replace.h"
-#include "absl/strings/str_split.h"
-#include "absl/strings/string_view.h"
-#include "iree/base/file_io.h"
-#include "iree/base/file_path.h"
-#include "iree/base/init.h"
-#include "iree/base/source_location.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view_string_util.h"
-#include "iree/hal/driver_registry.h"
-#include "iree/rt/context.h"
-#include "iree/rt/debug/debug_server_flags.h"
-#include "iree/rt/instance.h"
-#include "iree/rt/module_printer.h"
-#include "iree/schemas/module_def_generated.h"
-#include "iree/vm/sequencer_module.h"
-
-ABSL_FLAG(std::string, main_module, "", "Main module with entry point.");
-ABSL_FLAG(std::string, main_function, "",
- "Function within the main module to execute.");
-
-ABSL_FLAG(bool, print_disassembly, true,
- "Prints bytecode disassembly for the module.");
-
-ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values.");
-ABSL_FLAG(std::string, input_file, "",
- "Input shapes and optional values serialized in a file.");
-
-ABSL_FLAG(std::string, output_types, "",
- "Output data types (comma delimited list of b/i/u/f for "
- "binary/signed int/unsigned int/float).");
-
-namespace iree {
-namespace {
-
-// Parses a list of input shapes and values from a string of newline-separated
-// inputs. Expects the contents to have one value per line with each value
-// listed as
-// [shape]xtype=[value]
-// Example:
-// 4x4xi8=0,1,2,3
-StatusOr<std::vector<hal::BufferView>> ParseInputsFromFlags(
- hal::Allocator* allocator) {
- std::string file_contents;
- if (!absl::GetFlag(FLAGS_input_values).empty()) {
- file_contents =
- absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}});
- } else if (!absl::GetFlag(FLAGS_input_file).empty()) {
- ASSIGN_OR_RETURN(file_contents,
- file_io::GetFileContents(absl::GetFlag(FLAGS_input_file)));
- }
- std::vector<hal::BufferView> inputs;
- for (const auto& line :
- absl::StrSplit(file_contents, '\n', absl::SkipWhitespace())) {
- ASSIGN_OR_RETURN(auto input,
- hal::ParseBufferViewFromString(line, allocator));
- inputs.push_back(input);
- }
- return inputs;
-}
-
-} // namespace
-
-Status Run() {
- ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags());
- auto instance = make_ref<rt::Instance>(std::move(debug_server));
- ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
- "interpreter"));
- ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
- RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
- auto policy = make_ref<rt::Policy>();
- auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy));
-
- // Load main module.
- ASSIGN_OR_RETURN(
- auto main_module_file,
- vm::ModuleFile::LoadFile(ModuleDefIdentifier(),
- absl::GetFlag(FLAGS_main_module)),
- _ << "while loading module file " << absl::GetFlag(FLAGS_main_module));
- ASSIGN_OR_RETURN(auto main_module,
- vm::SequencerModule::FromFile(std::move(main_module_file)));
-
- // Register the main module with the context.
- // We could add additional modules (specializations, shared libraries, etc).
- // ModuleFioles are stateless so we could have the same module_file used by
- // multiple contexts simultaneously.
- RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module)));
-
- // Dump the registered modules.
- rt::PrintModuleFlagBitfield print_flags = rt::PrintModuleFlag::kNone;
- if (absl::GetFlag(FLAGS_print_disassembly)) {
- print_flags |= rt::PrintModuleFlag::kDisassemble;
- }
- for (const auto& module : context->modules()) {
- RETURN_IF_ERROR(PrintModuleToStream(*module, print_flags, &std::cout));
- }
-
- rt::Function main_function;
- if (!absl::GetFlag(FLAGS_main_function).empty()) {
- // User-specified main function.
- ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByName(
- rt::Function::Linkage::kExport,
- absl::GetFlag(FLAGS_main_function)));
- } else {
- // No main function specified; to prevent non-deterministic behavior we
- // require one unless there's exactly one exported function in the module.
- if (main_module->signature().export_function_count() == 1) {
- ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByOrdinal(
- rt::Function::Linkage::kExport, 0));
- } else {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "--main_function= must be specified to disambiguate the "
- "function to run";
- }
- }
-
- // Call into the main function.
- ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(device->allocator()));
- ASSIGN_OR_RETURN(auto invocation,
- rt::Invocation::Create(add_ref(context), main_function,
- make_ref<rt::Policy>(), {},
- absl::MakeConstSpan(arguments)));
-
- // Wait until invocation completes.
- RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture()));
- ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults());
-
- // Dump all results to stdout.
- std::vector<std::string> output_types =
- absl::StrSplit(absl::GetFlag(FLAGS_output_types), absl::ByAnyChar(", "),
- absl::SkipWhitespace());
- if (!output_types.empty() && output_types.size() != results.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "--output_types= specified but has " << output_types.size()
- << " types when the function returns " << results.size();
- }
- for (int i = 0; i < results.size(); ++i) {
- const auto& result = results[i];
- auto print_mode = hal::BufferViewPrintMode::kFloatingPoint;
- if (!output_types.empty()) {
- ASSIGN_OR_RETURN(print_mode,
- hal::ParseBufferViewPrintMode(output_types[i]));
- }
- ASSIGN_OR_RETURN(auto result_str,
- PrintBufferViewToString(result, print_mode, 1024));
- const auto& buffer = result.buffer;
- if (!buffer) {
- return InternalErrorBuilder(IREE_LOC)
- << "result[" << i << "] unexpectedly has no buffer";
- }
- LOG(INFO) << "result[" << i << "]: " << buffer->DebugString();
- std::cout << result_str << "\n";
- }
-
- return OkStatus();
-}
-
-extern "C" int main(int argc, char** argv) {
- InitializeEnvironment(&argc, &argv);
- CHECK_OK(Run());
- return 0;
-}
-
-} // namespace iree
diff --git a/iree/tools/web/BUILD b/iree/tools/web/BUILD
deleted file mode 100644
index f55ef9d..0000000
--- a/iree/tools/web/BUILD
+++ /dev/null
@@ -1,89 +0,0 @@
-# IREE web tools.
-
-load("//third_party/emscripten:split_transition_defs.bzl", "auto_wasm_binary")
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-EMSCRIPTEN_LINKOPTS_COMMON = [
- # Error at compile time on unresolved symbols.
- "-s ERROR_ON_UNDEFINED_SYMBOLS=1",
-
- # Note: If pthreads and memory growth are enabled, WASM_MEM_MAX must be set.
- # Also, USE_PTHREADS + ALLOW_MEMORY_GROWTH may run non-wasm code slowly.
- # "-s ALLOW_MEMORY_GROWTH=1",
- # "-s WASM_MEM_MAX=268435456", # 256MB
- # "-s TOTAL_MEMORY=268435456", # 256MB
-
- # Request a prepopulated pool of web workers for pthreads to use.
- # Without this, threads may not start until the javascript thread yields.
- # See considerations at https://emscripten.org/docs/porting/pthreads.html.
- "-s PTHREAD_POOL_SIZE=1",
-]
-
-EMSCRIPTEN_LINKOPTS_DBG = [
- # Show WASM stack trace in Chrome debugger.
- "-g2",
- "-s DEMANGLE_SUPPORT=1",
-
- # Enable verbose assertions.
- "-s ASSERTIONS=2",
- "-s SAFE_HEAP=1",
- "-s STACK_OVERFLOW_CHECK=2",
-]
-
-EMSCRIPTEN_LINKOPTS_OPT = []
-
-# To use run_module_emscripten:
-# > bazel build third_party/iree/tools/web:run_module_emscripten_files
-
-cc_binary(
- name = "run_module_emscripten",
- srcs = ["run_module_emscripten.cc"],
- linkopts = EMSCRIPTEN_LINKOPTS_COMMON + select({
- "//tools/compilation_mode:dbg": EMSCRIPTEN_LINKOPTS_DBG,
- "//tools/compilation_mode:opt": EMSCRIPTEN_LINKOPTS_OPT,
- "//conditions:default": EMSCRIPTEN_LINKOPTS_OPT,
- }),
- tags = [
- "manual",
- "notap", # TODO(b/137088911): Build/test on TAP
- "wasm",
- ],
- deps = [
- "//iree/base:init",
- "//iree/base:status",
- "//iree/hal:buffer_view_string_util",
- "//iree/hal:driver_registry",
- "//iree/hal/interpreter:interpreter_driver_module",
- "//iree/rt",
- "//iree/vm:sequencer_module",
- "//third_party/emscripten:embind",
- ],
-)
-
-auto_wasm_binary(
- name = "run_module_emscripten_binary",
- cc_target = ":run_module_emscripten",
- tags = ["manual"],
- threads = "emscripten",
-)
-
-Fileset(
- name = "run_module_emscripten_files",
- out = "wasm_files",
- entries = [
- FilesetEntry(
- files = [":run_module_emscripten_binary"],
- strip_prefix = "run_module_emscripten_binary",
- destdir = "wasm",
- ),
- FilesetEntry(
- files = ["run_module.html"],
- destdir = "wasm",
- ),
- ],
- tags = ["manual"],
-)
diff --git a/iree/tools/web/run_module_emscripten.cc b/iree/tools/web/run_module_emscripten.cc
deleted file mode 100644
index 49d88a7..0000000
--- a/iree/tools/web/run_module_emscripten.cc
+++ /dev/null
@@ -1,140 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <emscripten.h>
-#include <emscripten/bind.h>
-
-#include <vector>
-
-#include "absl/strings/str_replace.h"
-#include "absl/strings/str_split.h"
-#include "absl/strings/string_view.h"
-#include "iree/base/flatbuffer_util.h"
-#include "iree/base/init.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/buffer_view_string_util.h"
-#include "iree/hal/driver_registry.h"
-#include "iree/rt/context.h"
-#include "iree/rt/instance.h"
-#include "iree/schemas/module_def_generated.h"
-#include "iree/vm/sequencer_module.h"
-
-namespace iree {
-
-// Parses a list of input shapes and values from a string of newline-separated
-// inputs. Expects the contents to have one value per line with each value
-// listed as
-// [shape]xtype=[value]
-// Example:
-// 4x4xi8=0,1,2,3
-StatusOr<std::vector<hal::BufferView>> ParseInputs(
- absl::string_view inputs_string, hal::Allocator* allocator) {
- std::string input_lines = absl::StrReplaceAll(inputs_string, {{"\\n", "\n"}});
- std::vector<hal::BufferView> input_buffer_views;
- for (const auto& input_line :
- absl::StrSplit(input_lines, '\n', absl::SkipWhitespace())) {
- ASSIGN_OR_RETURN(auto input_buffer_view,
- hal::ParseBufferViewFromString(input_line, allocator));
- input_buffer_views.push_back(input_buffer_view);
- }
- return input_buffer_views;
-}
-
-// Runs an IREE module with the provided inputs and returns its outputs.
-StatusOr<std::string> RunIreeModule(std::string module_file_data,
- absl::string_view inputs_string) {
- auto instance = make_ref<rt::Instance>();
-
- // Create driver and device.
- ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
- "interpreter"));
- ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
- RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
-
- auto policy = make_ref<rt::Policy>();
- auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy));
-
- // Load main module FlatBuffer.
- ASSIGN_OR_RETURN(auto main_module_file,
- FlatBufferFile<ModuleDef>::FromString(ModuleDefIdentifier(),
- module_file_data));
- ASSIGN_OR_RETURN(auto main_module,
- vm::SequencerModule::FromFile(std::move(main_module_file)));
-
- // Register the main module with the context.
- RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module)));
-
- // Setup arguments and storage for results.
- // TODO(scotttodd): Receive main function name from JS.
- ASSIGN_OR_RETURN(auto main_function,
- main_module->LookupFunctionByName(
- rt::Function::Linkage::kExport, "main"));
-
- ASSIGN_OR_RETURN(auto arguments,
- ParseInputs(inputs_string, device->allocator()));
-
- // Call into the main function.
- ASSIGN_OR_RETURN(auto invocation,
- rt::Invocation::Create(add_ref(context), main_function,
- make_ref<rt::Policy>(), {},
- absl::MakeConstSpan(arguments)));
-
- // Wait until invocation completes.
- // TODO(scotttodd): make this an async callback.
- RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture()));
- ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults());
-
- // Dump all results to stdout.
- // TODO(scotttodd): Receive output types / print mode from JS.
- // TODO(scotttodd): Return list of outputs instead of just the first (proto?)
- for (int i = 0; i < results.size(); ++i) {
- const auto& result = results[i];
- auto print_mode = hal::BufferViewPrintMode::kFloatingPoint;
- ASSIGN_OR_RETURN(auto result_str,
- PrintBufferViewToString(result, print_mode, 1024));
- const auto& buffer = result.buffer;
- if (!buffer) {
- return InternalErrorBuilder(IREE_LOC)
- << "result[" << i << "] unexpectedly has no buffer";
- }
-
- return result_str;
- }
-
- return InternalErrorBuilder(IREE_LOC) << "Received no results";
-}
-
-std::string RunIreeModuleEntry(std::string module_file_data,
- std::string inputs_string) {
- // TODO(scotttodd): optimize, minimize copies
- // https://groups.google.com/d/msg/emscripten-discuss/CMfYljLWMvY/Di52WB2QAgAJ
- auto result_or = RunIreeModule(std::move(module_file_data), inputs_string);
- if (!result_or.ok()) {
- return "Error: " + result_or.status().ToString();
- } else {
- return result_or.ValueOrDie();
- }
-}
-
-EMSCRIPTEN_BINDINGS(iree) {
- emscripten::function("runIreeModule", &RunIreeModuleEntry);
-}
-
-extern "C" int main(int argc, char** argv) {
- InitializeEnvironment(&argc, &argv);
- return 0;
-}
-
-} // namespace iree
diff --git a/iree/vm/BUILD b/iree/vm/BUILD
deleted file mode 100644
index b197f32..0000000
--- a/iree/vm/BUILD
+++ /dev/null
@@ -1,200 +0,0 @@
-# Bytecode VM used by the IREE sequencer and interpreter.
-
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "api",
- srcs = ["api.cc"],
- hdrs = ["api.h"],
- visibility = ["//visibility:public"],
- deps = [
- ":api_hdrs",
- ":sequencer_module",
- "//iree/base:api",
- "//iree/base:api_util",
- "//iree/base:flatbuffer_util",
- "//iree/base:tracing",
- "//iree/rt:api",
- ],
-)
-
-cc_library(
- name = "api_hdrs",
- hdrs = ["api.h"],
- deps = [
- "//iree/base:api_hdrs",
- "//iree/rt:api_hdrs",
- ],
-)
-
-cc_library(
- name = "bytecode_module",
- srcs = [
- "bytecode_disassembler.cc",
- "bytecode_module.cc",
- ],
- hdrs = [
- "bytecode_disassembler.h",
- "bytecode_module.h",
- ],
- deps = [
- ":bytecode_util",
- ":opcode_info",
- ":source_map_resolver",
- ":type",
- "//iree/base:flatbuffer_util",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:buffer_view",
- "//iree/rt",
- "//iree/schemas",
- "//iree/schemas/bytecode:bytecode_v0",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "bytecode_reader",
- srcs = ["bytecode_reader.cc"],
- hdrs = ["bytecode_reader.h"],
- deps = [
- ":bytecode_module",
- ":type",
- "//iree/base:shape",
- "//iree/base:status",
- "//iree/hal:buffer_view",
- "//iree/hal:heap_buffer",
- "//iree/rt",
- "//iree/schemas/bytecode:bytecode_v0",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- ],
-)
-
-cc_library(
- name = "bytecode_tables_interpreter",
- srcs = ["bytecode_tables_interpreter.cc"],
- hdrs = ["bytecode_tables_interpreter.h"],
- deps = [
- ":opcode_info",
- "//iree/schemas/bytecode:interpreter_bytecode_v0",
- ],
-)
-
-cc_library(
- name = "bytecode_tables_sequencer",
- srcs = ["bytecode_tables_sequencer.cc"],
- hdrs = ["bytecode_tables_sequencer.h"],
- deps = [
- ":opcode_info",
- "//iree/schemas/bytecode:sequencer_bytecode_v0",
- ],
-)
-
-cc_library(
- name = "bytecode_util",
- srcs = ["bytecode_util.cc"],
- hdrs = ["bytecode_util.h"],
- deps = [
- "//iree/schemas/bytecode:bytecode_v0",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
- name = "bytecode_validator",
- srcs = ["bytecode_validator.cc"],
- hdrs = ["bytecode_validator.h"],
- deps = [
- ":bytecode_module",
- "//iree/base:status",
- "//iree/schemas",
- ],
-)
-
-cc_library(
- name = "opcode_info",
- hdrs = ["opcode_info.h"],
- deps = [
- "//iree/schemas/bytecode:bytecode_v0",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "sequencer_dispatch",
- srcs = ["sequencer_dispatch.cc"],
- hdrs = ["sequencer_dispatch.h"],
- deps = [
- ":bytecode_module",
- ":bytecode_reader",
- ":bytecode_tables_sequencer",
- ":bytecode_util",
- ":opcode_info",
- "//iree/base:logging",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/hal:buffer_view",
- "//iree/hal:command_queue",
- "//iree/hal:device",
- "//iree/hal:device_placement",
- "//iree/hal:heap_buffer",
- "//iree/rt",
- "//iree/schemas/bytecode:sequencer_bytecode_v0",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "sequencer_module",
- srcs = ["sequencer_module.cc"],
- hdrs = ["sequencer_module.h"],
- deps = [
- ":bytecode_module",
- ":bytecode_tables_sequencer",
- ":sequencer_dispatch",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:buffer_view",
- "//iree/rt",
- "@com_google_absl//absl/memory",
- ],
-)
-
-cc_library(
- name = "source_map_resolver",
- srcs = ["source_map_resolver.cc"],
- hdrs = ["source_map_resolver.h"],
- deps = [
- "//iree/base:flatbuffer_util",
- "//iree/base:status",
- "//iree/rt",
- "//iree/schemas",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- ],
-)
-
-cc_library(
- name = "type",
- srcs = ["type.cc"],
- hdrs = ["type.h"],
- deps = [
- "//iree/base:status",
- "//iree/schemas",
- "//iree/schemas/bytecode:bytecode_v0",
- ],
-)
diff --git a/iree/vm/api.cc b/iree/vm/api.cc
deleted file mode 100644
index 4a6cf06..0000000
--- a/iree/vm/api.cc
+++ /dev/null
@@ -1,99 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/api.h"
-
-#include "iree/base/api.h"
-#include "iree/base/api_util.h"
-#include "iree/base/flatbuffer_util.h"
-#include "iree/base/tracing.h"
-#include "iree/vm/sequencer_module.h"
-
-namespace iree {
-namespace vm {
-
-//===----------------------------------------------------------------------===//
-// iree::vm::BytecodeModule
-//===----------------------------------------------------------------------===//
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_vm_bytecode_module_create_from_buffer(
- iree_const_byte_span_t buffer_data,
- void (*buffer_free_fn)(void* self, iree_byte_span_t buffer_data),
- void* buffer_free_self, iree_allocator_t allocator,
- iree_rt_module_t** out_module) {
- IREE_TRACE_SCOPE0("iree_vm_bytecode_module_create_from_buffer");
-
- if (!out_module) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_module = nullptr;
-
- if (!buffer_data.data || !buffer_data.data_length) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- IREE_API_ASSIGN_OR_RETURN(
- auto module_file,
- FlatBufferFile<ModuleDef>::FromBuffer(
- ModuleDefIdentifier(), {buffer_data.data, buffer_data.data_length},
- [buffer_free_fn, buffer_free_self, buffer_data]() {
- if (buffer_free_fn != nullptr) {
- buffer_free_fn(buffer_free_self,
- {const_cast<uint8_t*>(buffer_data.data),
- buffer_data.data_length});
- }
- }));
-
- IREE_API_ASSIGN_OR_RETURN(auto module,
- SequencerModule::FromFile(std::move(module_file)));
-
- *out_module = reinterpret_cast<iree_rt_module_t*>(module.release());
-
- return IREE_STATUS_OK;
-}
-
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_vm_bytecode_module_create_from_file_mapping(
- iree_file_mapping_t* file_mapping, iree_allocator_t allocator,
- iree_rt_module_t** out_module) {
- IREE_TRACE_SCOPE0("iree_vm_bytecode_module_create_from_file_mapping");
-
- if (!out_module) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
- *out_module = nullptr;
-
- if (!file_mapping) {
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- auto buffer_data = iree_file_mapping_data(file_mapping);
- IREE_API_ASSIGN_OR_RETURN(
- auto module_file,
- FlatBufferFile<ModuleDef>::FromBuffer(
- ModuleDefIdentifier(), {buffer_data.data, buffer_data.data_length},
- [file_mapping]() { iree_file_mapping_release(file_mapping); }));
- iree_file_mapping_retain(file_mapping);
-
- IREE_API_ASSIGN_OR_RETURN(auto module,
- SequencerModule::FromFile(std::move(module_file)));
-
- *out_module = reinterpret_cast<iree_rt_module_t*>(module.release());
-
- return IREE_STATUS_OK;
-}
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/api.h b/iree/vm/api.h
deleted file mode 100644
index f724309..0000000
--- a/iree/vm/api.h
+++ /dev/null
@@ -1,60 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// See iree/base/api.h for documentation on the API conventions used.
-
-#ifndef IREE_VM_API_H_
-#define IREE_VM_API_H_
-
-#include <stdint.h>
-
-#include "iree/base/api.h"
-#include "iree/rt/api.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-//===----------------------------------------------------------------------===//
-// iree::vm::BytecodeModule
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_API_NO_PROTOTYPES
-
-// Creates a VM module from an in-memory ModuleDef FlatBuffer.
-// The provided |buffer_free_fn| will be called when the module is destroyed
-// and only if this creation function succeeds. If ownership remains with the
-// caller then pass nullptr for |buffer_free_fn|.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_vm_bytecode_module_create_from_buffer(
- iree_const_byte_span_t buffer_data,
- void (*buffer_free_fn)(void* self, iree_byte_span_t buffer_data),
- void* buffer_free_self, iree_allocator_t allocator,
- iree_rt_module_t** out_module);
-
-// Creates a VM module from a mapped ModuleDef FlatBuffer.
-// The provided |file_mapping| will be retained for the life of the module and
-// the contents will be accessed by reference.
-IREE_API_EXPORT iree_status_t IREE_API_CALL
-iree_vm_bytecode_module_create_from_file_mapping(
- iree_file_mapping_t* file_mapping, iree_allocator_t allocator,
- iree_rt_module_t** out_module);
-
-#endif // IREE_API_NO_PROTOTYPES
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-
-#endif // IREE_VM_API_H_
diff --git a/iree/vm/bytecode_disassembler.cc b/iree/vm/bytecode_disassembler.cc
deleted file mode 100644
index 4868792..0000000
--- a/iree/vm/bytecode_disassembler.cc
+++ /dev/null
@@ -1,482 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/bytecode_disassembler.h"
-
-#include <iomanip>
-#include <sstream>
-
-#include "absl/base/macros.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "iree/base/status.h"
-#include "iree/schemas/bytecode/bytecode_v0.h"
-#include "iree/schemas/source_map_def_generated.h"
-#include "iree/vm/bytecode_module.h"
-#include "iree/vm/bytecode_util.h"
-#include "iree/vm/type.h"
-
-namespace iree {
-namespace vm {
-
-namespace {
-
-using ::iree::rt::SourceOffset;
-
-template <typename T>
-StatusOr<T> ReadValue(absl::Span<const uint8_t> data, SourceOffset* offset) {
- if (*offset + sizeof(T) > data.size()) {
- return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode data underrun";
- }
- auto value = *reinterpret_cast<const T*>(&data[*offset]);
- *offset = *offset + sizeof(T);
- return value;
-}
-
-StatusOr<const Type> ReadType(absl::Span<const uint8_t> data,
- SourceOffset* offset) {
- ASSIGN_OR_RETURN(uint8_t type_index, ReadValue<uint8_t>(data, offset));
- return Type::FromTypeIndex(type_index);
-}
-
-StatusOr<uint8_t> ReadCount(absl::Span<const uint8_t> data,
- SourceOffset* offset) {
- return ReadValue<uint8_t>(data, offset);
-}
-
-StatusOr<uint16_t> ReadValueSlot(absl::Span<const uint8_t> data,
- SourceOffset* offset) {
- return ReadValue<uint16_t>(data, offset);
-}
-
-absl::string_view ConstantEncodingToString(ConstantEncoding encoding) {
- switch (encoding) {
-#define GET_NAME(ordinal, enum_name, str, ...) \
- case ConstantEncoding::enum_name: \
- return str;
- IREE_CONSTANT_ENCODING_LIST(GET_NAME)
-#undef GET_NAME
- default:
- return "unknown";
- }
-}
-
-template <typename T>
-std::string TypedDataToString(absl::Span<const uint8_t> bytes) {
- auto typed_data = absl::Span<const T>{
- reinterpret_cast<const T*>(bytes.data()), bytes.size() / sizeof(T)};
- return absl::StrJoin(typed_data, ",");
-}
-
-std::string ConstantToString(const Type& type,
- absl::Span<const uint8_t> bytes) {
- if (!type.is_builtin()) {
- return absl::StrJoin(bytes, ",");
- }
- switch (type.builtin_type()) {
- case BuiltinType::kI8:
- return TypedDataToString<uint8_t>(bytes);
- case BuiltinType::kI16:
- return TypedDataToString<uint16_t>(bytes);
- case BuiltinType::kI32:
- return TypedDataToString<uint32_t>(bytes);
- case BuiltinType::kI64:
- return TypedDataToString<uint64_t>(bytes);
- case BuiltinType::kF16:
- return TypedDataToString<uint16_t>(bytes);
- case BuiltinType::kF32:
- return TypedDataToString<float>(bytes);
- case BuiltinType::kF64:
- return TypedDataToString<double>(bytes);
- default:
- return "<unsupported>";
- }
-}
-
-} // namespace
-
-StatusOr<std::vector<rt::Instruction>>
-BytecodeDisassembler::DisassembleInstructions(const rt::Function& function,
- SourceOffset offset,
- int32_t instruction_count) const {
- std::vector<rt::Instruction> instructions;
-
- ASSIGN_OR_RETURN(
- auto* function_def,
- static_cast<const BytecodeModule*>(function.module())
- ->GetFunctionDef(function.linkage(), function.ordinal()));
- auto* bytecode_def = function_def->bytecode();
- if (!bytecode_def) {
- return UnavailableErrorBuilder(IREE_LOC) << "Function contains no body";
- }
- auto data = absl::MakeSpan(
- reinterpret_cast<const uint8_t*>(bytecode_def->contents()->data()),
- bytecode_def->contents()->size());
-
- // TODO(benvanik): scan and find all branch offsets to insert labels
-
- while (offset < data.length() && instructions.size() < instruction_count) {
- instructions.push_back({});
- auto& instruction = instructions.back();
- instruction.offset = offset;
-
- uint8_t opcode = data[offset++];
- const auto& opcode_info = GetOpcodeInfo(opcode_table_, opcode);
- if (!opcode_info.mnemonic) {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unhandled opcode " << opcode << " at offset " << (offset - 1);
- }
- int payload_offset = offset;
-
- std::ostringstream stream;
-
- // Print out return values, if any.
- int base_result_index = 0;
- int printed_result_count = 0;
- for (int i = base_result_index; i < ABSL_ARRAYSIZE(opcode_info.operands);
- ++i) {
- if (opcode_info.operands[i] == OperandEncoding::kNone) break;
- if (printed_result_count > 0) {
- stream << ", ";
- }
- switch (opcode_info.operands[i]) {
- default:
- case OperandEncoding::kNone:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unhandled op encoding "
- << static_cast<int>(opcode_info.operands[i]) << " at offset "
- << (offset - 1);
- case OperandEncoding::kInputSlot:
- case OperandEncoding::kOutputSlot: {
- // Printing handled below.
- offset += sizeof(uint16_t);
- break;
- }
- case OperandEncoding::kVariadicInputSlots:
- case OperandEncoding::kVariadicOutputSlots: {
- // Printing handled below.
- ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
- offset += count * sizeof(uint16_t);
- break;
- }
- case OperandEncoding::kResultSlot: {
- ++printed_result_count;
- ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset));
- stream << "%" << slot_ordinal;
- break;
- }
- case OperandEncoding::kVariadicResultSlots: {
- ++printed_result_count;
- stream << "[";
- ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
- for (int j = 0; j < count; ++j) {
- ASSIGN_OR_RETURN(uint16_t slot_ordinal,
- ReadValueSlot(data, &offset));
- if (j > 0) stream << ", ";
- stream << "%" << slot_ordinal;
- }
- stream << "]";
- break;
- }
- case OperandEncoding::kVariadicTransferSlots: {
- // Printing handled below.
- ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
- offset += count * 2 * sizeof(uint16_t);
- break;
- }
- case OperandEncoding::kConstant: {
- // Printing handled below.
- ASSIGN_OR_RETURN(auto type, ReadType(data, &offset));
- ASSIGN_OR_RETURN(int rank, ReadCount(data, &offset));
- int element_count = 1;
- for (int j = 0; j < rank; ++j) {
- ASSIGN_OR_RETURN(int dim, ReadValue<int32_t>(data, &offset));
- element_count *= dim;
- }
- offset += sizeof(ConstantEncoding);
- offset += element_count * type.element_size();
- break;
- }
- case OperandEncoding::kFunctionOrdinal: {
- // Printing handled below.
- offset += sizeof(uint32_t);
- break;
- }
- case OperandEncoding::kDispatchOrdinal: {
- // Printing handled below.
- offset += sizeof(uint32_t) + sizeof(uint16_t);
- break;
- }
- case OperandEncoding::kBlockOffset: {
- // Printing handled below.
- offset += sizeof(uint32_t);
- break;
- }
- case OperandEncoding::kTypeIndex: {
- // Printing handled below.
- offset += sizeof(uint8_t);
- break;
- }
- case OperandEncoding::kIndex: {
- // Printing handled below.
- offset += sizeof(int32_t);
- break;
- }
- case OperandEncoding::kIndexList: {
- // Printing handled below.
- ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
- offset += count * sizeof(int32_t);
- break;
- }
- case OperandEncoding::kCmpIPredicate:
- case OperandEncoding::kCmpFPredicate: {
- // Printing handled below.
- offset += sizeof(uint8_t);
- break;
- }
- }
- }
- if (printed_result_count > 0) {
- stream << " = ";
- }
- offset = payload_offset;
-
- stream << opcode_info.mnemonic;
-
- // Print out operands.
- int base_operand_index = 0;
- int printed_operand_count = 0;
- for (int i = base_operand_index; i < ABSL_ARRAYSIZE(opcode_info.operands);
- ++i) {
- if (opcode_info.operands[i] == OperandEncoding::kNone) break;
- if (opcode_info.operands[i] != OperandEncoding::kResultSlot &&
- opcode_info.operands[i] != OperandEncoding::kVariadicResultSlots) {
- if (i == base_operand_index) {
- stream << " ";
- } else if (printed_operand_count > 0) {
- stream << ", ";
- }
- }
- switch (opcode_info.operands[i]) {
- default:
- case OperandEncoding::kNone:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unhandled op encoding "
- << static_cast<int>(opcode_info.operands[i]) << " at offset "
- << (offset - 1);
- case OperandEncoding::kInputSlot: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset));
- stream << "%" << slot_ordinal;
- break;
- }
- case OperandEncoding::kVariadicInputSlots: {
- ++printed_operand_count;
- stream << "[";
- ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
- for (int j = 0; j < count; ++j) {
- ASSIGN_OR_RETURN(uint16_t slot_ordinal,
- ReadValueSlot(data, &offset));
- if (j > 0) stream << ", ";
- stream << "%" << slot_ordinal;
- }
- stream << "]";
- break;
- }
- case OperandEncoding::kOutputSlot: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset));
- stream << "&"
- << "%" << slot_ordinal;
- break;
- }
- case OperandEncoding::kVariadicOutputSlots: {
- ++printed_operand_count;
- stream << "[";
- ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
- for (int j = 0; j < count; ++j) {
- ASSIGN_OR_RETURN(uint16_t slot_ordinal,
- ReadValueSlot(data, &offset));
- if (j > 0) stream << ", ";
- stream << "&"
- << "%" << slot_ordinal;
- }
- stream << "]";
- break;
- }
- case OperandEncoding::kResultSlot: {
- // Printing handled above.
- offset += sizeof(uint16_t);
- break;
- }
- case OperandEncoding::kVariadicResultSlots: {
- // Printing handled above.
- ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
- offset += count * sizeof(uint16_t);
- break;
- }
- case OperandEncoding::kVariadicTransferSlots: {
- ++printed_operand_count;
- stream << "[";
- ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
- for (int j = 0; j < count; ++j) {
- ASSIGN_OR_RETURN(uint16_t src_slot_ordinal,
- ReadValueSlot(data, &offset));
- ASSIGN_OR_RETURN(uint16_t dst_slot_ordinal,
- ReadValueSlot(data, &offset));
- if (j > 0) stream << ", ";
- stream << "%" << src_slot_ordinal << "=>%" << dst_slot_ordinal;
- }
- stream << "]";
- break;
- }
- case OperandEncoding::kConstant: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(auto type, ReadType(data, &offset));
- ASSIGN_OR_RETURN(int rank, ReadCount(data, &offset));
- absl::InlinedVector<int32_t, 4> shape(rank);
- int element_count = 1;
- for (int j = 0; j < rank; ++j) {
- ASSIGN_OR_RETURN(int dim, ReadValue<int32_t>(data, &offset));
- shape[j] = dim;
- element_count *= dim;
- }
- ASSIGN_OR_RETURN(auto encoding,
- ReadValue<ConstantEncoding>(data, &offset));
- stream << ConstantEncodingToString(encoding);
- int serialized_element_count = 1;
- switch (encoding) {
- case ConstantEncoding::kDense:
- serialized_element_count = element_count;
- break;
- case ConstantEncoding::kSplat:
- serialized_element_count = 1;
- break;
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented constant encoding "
- << static_cast<int>(encoding);
- }
- stream << " buffer_view<";
- if (!shape.empty()) {
- stream << absl::StrJoin(shape, "x") << "x";
- }
- stream << type << ">{";
- size_t element_size = type.element_size();
- auto bytes = data.subspan(
- offset, std::min(serialized_element_count, 1024) * element_size);
- stream << ConstantToString(type, bytes);
- if (serialized_element_count > 1024) stream << "...";
- offset += serialized_element_count * element_size;
- stream << "}";
- break;
- }
- case OperandEncoding::kFunctionOrdinal: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(auto function_ordinal,
- ReadValue<uint32_t>(data, &offset));
- ASSIGN_OR_RETURN(
- auto target_function,
- function.module()->LookupFunctionByOrdinal(
- rt::Function::Linkage::kInternal, function_ordinal));
- stream << "@" << function_ordinal << " " << target_function.name();
- break;
- }
- case OperandEncoding::kDispatchOrdinal: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(auto dispatch_ordinal,
- ReadValue<uint32_t>(data, &offset));
- ASSIGN_OR_RETURN(auto export_ordinal,
- ReadValue<uint16_t>(data, &offset));
- // TODO(benvanik): lookup in executable table.
- stream << "@" << dispatch_ordinal << ":" << export_ordinal;
- break;
- }
- case OperandEncoding::kImportOrdinal: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(auto import_ordinal,
- ReadValue<uint32_t>(data, &offset));
- ASSIGN_OR_RETURN(auto target_function,
- function.module()->LookupFunctionByOrdinal(
- rt::Function::Linkage::kImport, import_ordinal));
- stream << "@i" << import_ordinal << " " << target_function.name();
- break;
- }
- case OperandEncoding::kBlockOffset: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(uint32_t block_offset,
- ReadValue<uint32_t>(data, &offset));
- stream << ":" << block_offset;
- break;
- }
- case OperandEncoding::kTypeIndex: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(auto type, ReadType(data, &offset));
- stream << type;
- break;
- }
- case OperandEncoding::kIndex: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(auto index, ReadValue<int32_t>(data, &offset));
- stream << "#" << index;
- break;
- }
- case OperandEncoding::kIndexList: {
- ++printed_operand_count;
- stream << "{";
- ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
- for (int j = 0; j < count; ++j) {
- ASSIGN_OR_RETURN(auto dim, ReadValue<int32_t>(data, &offset));
- if (j > 0) stream << ",";
- stream << dim;
- }
- stream << "}";
- break;
- }
- case OperandEncoding::kCmpIPredicate: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(auto predicate_value,
- ReadValue<uint8_t>(data, &offset));
- stream << "<"
- << PredicateToString(
- static_cast<CmpIPredicate>(predicate_value))
- << ">";
- break;
- }
- case OperandEncoding::kCmpFPredicate: {
- ++printed_operand_count;
- ASSIGN_OR_RETURN(auto predicate_value,
- ReadValue<uint8_t>(data, &offset));
- stream << "<"
- << PredicateToString(
- static_cast<CmpFPredicate>(predicate_value))
- << ">";
- break;
- }
- }
- }
-
- stream << "\n";
-
- instruction.long_text = stream.str();
- instruction.short_text = instruction.long_text;
- }
-
- return instructions;
-}
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/bytecode_disassembler.h b/iree/vm/bytecode_disassembler.h
deleted file mode 100644
index 633e290..0000000
--- a/iree/vm/bytecode_disassembler.h
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_BYTECODE_DISASSEMBLER_H_
-#define IREE_VM_BYTECODE_DISASSEMBLER_H_
-
-#include <ostream>
-
-#include "iree/base/status.h"
-#include "iree/rt/disassembler.h"
-#include "iree/schemas/bytecode_def_generated.h"
-#include "iree/schemas/source_map_def_generated.h"
-#include "iree/vm/opcode_info.h"
-
-namespace iree {
-namespace vm {
-
-// Disassembles bytecode with a specific op set to text.
-class BytecodeDisassembler final : public rt::Disassembler {
- public:
- explicit BytecodeDisassembler(OpcodeTable opcode_table)
- : opcode_table_(opcode_table) {}
-
- StatusOr<std::vector<rt::Instruction>> DisassembleInstructions(
- const rt::Function& function, rt::SourceOffset offset,
- int32_t instruction_count) const override;
-
- private:
- OpcodeTable opcode_table_;
-};
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_BYTECODE_DISASSEMBLER_H_
diff --git a/iree/vm/bytecode_module.cc b/iree/vm/bytecode_module.cc
deleted file mode 100644
index 7812cb6..0000000
--- a/iree/vm/bytecode_module.cc
+++ /dev/null
@@ -1,310 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/bytecode_module.h"
-
-#include "absl/memory/memory.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/vm/bytecode_disassembler.h"
-
-namespace iree {
-namespace vm {
-
-namespace {
-
-using ::iree::hal::BufferView;
-using ::iree::rt::Function;
-using ::iree::rt::FunctionSignature;
-using ::iree::rt::Module;
-using ::iree::rt::ModuleSignature;
-
-Status ValidateElementSize(int element_bit_width,
- const ElementTypeDef& expected_element_type) {
- switch (expected_element_type.type_union_type()) {
- case ElementTypeDefUnion::FloatTypeDef: {
- auto expected_bit_width =
- expected_element_type.type_union_as_FloatTypeDef()->width();
- if (element_bit_width != expected_bit_width) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Has element bit width " << element_bit_width
- << " but expected " << expected_bit_width;
- }
- return OkStatus();
- }
- case ElementTypeDefUnion::IntegerTypeDef: {
- auto expected_bit_width =
- expected_element_type.type_union_as_IntegerTypeDef()->width();
- if (element_bit_width != expected_bit_width) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Has element bit width " << element_bit_width
- << " but expected " << expected_bit_width;
- }
- return OkStatus();
- }
- case ElementTypeDefUnion::UnknownTypeDef:
- case ElementTypeDefUnion::NONE: {
- }
- }
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Defined type has unsupported element type "
- << EnumNameElementTypeDefUnion(
- expected_element_type.type_union_type());
-}
-
-Status ValidateTypeStructure(const FunctionTypeDef& type_def) {
- // Ensure all fields are populated.
- return OkStatus();
-}
-
-Status ValidateFunctionTableStructure(
- const FunctionTableDef& function_table_def) {
- if (!function_table_def.functions()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Function table is missing the function listing";
- }
-
- // All functions must contain a valid type.
- const auto& functions = *function_table_def.functions();
- for (int i = 0; i < functions.size(); ++i) {
- const auto* function = functions[i];
- if (!function) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Function ordinal " << i << " is missing its contents";
- }
- if (!function->type()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Function ordinal " << i << " is missing its type";
- }
- RETURN_IF_ERROR(ValidateTypeStructure(*function->type()));
- }
-
- // Imports must also have a name (that we can use to resolve it).
- if (function_table_def.imports()) {
- const auto& imports = *function_table_def.imports();
- for (int i = 0; i < imports.size(); ++i) {
- int function_index = imports[i];
- if (!functions[function_index]->name()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Import ordinal " << i << " is missing its contents";
- }
- }
- }
-
- // Exports must also have a name (that others will use to look it up).
- if (function_table_def.exports()) {
- const auto& exports = *function_table_def.exports();
- for (int i = 0; i < exports.size(); ++i) {
- int function_index = exports[i];
- if (!functions[function_index]->name()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Export ordinal " << i << " is missing its contents";
- }
- }
- }
-
- return OkStatus();
-}
-
-Status ValidateExecutableTableStructure(
- const ExecutableTableDef& executable_table_def) {
- if (!executable_table_def.multi_arch_executables()) {
- // May have sequencer only fns. Fine to not have dispatchable executables.
- return OkStatus();
- }
-
- // All fat executables need at least one device-specific executable.
- const auto& multi_arch_executables =
- *executable_table_def.multi_arch_executables();
- for (int i = 0; i < multi_arch_executables.size(); ++i) {
- const auto* multi_arch_executable = multi_arch_executables[i];
- if (!multi_arch_executable || !multi_arch_executable->executables() ||
- multi_arch_executable->executables()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Multi-arch executable ordinal " << i
- << " is missing its contents";
- }
- }
-
- return OkStatus();
-}
-
-} // namespace
-
-// static
-Status BytecodeModule::ValidateStructure(const ModuleDef& module_def) {
- IREE_TRACE_SCOPE0("BytecodeModule::ValidateStructure");
-
- // Must have a function table.
- if (module_def.function_table()) {
- RETURN_IF_ERROR(
- ValidateFunctionTableStructure(*module_def.function_table()));
- } else {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "ModuleDef is missing a function table";
- }
-
- // Must have an executable table.
- if (module_def.executable_table()) {
- RETURN_IF_ERROR(
- ValidateExecutableTableStructure(*module_def.executable_table()));
- } else {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "ModuleDef is missing an executable table";
- }
-
- return OkStatus();
-}
-
-BytecodeModule::BytecodeModule(std::unique_ptr<ModuleFile> module_file,
- OpcodeTable opcode_table)
- : module_file_(std::move(module_file)),
- module_def_(*module_file_->root()),
- source_resolver_(SourceMapResolver::FromModule(module_def_)),
- disassembler_(absl::make_unique<BytecodeDisassembler>(opcode_table)) {}
-
-BytecodeModule::~BytecodeModule() = default;
-
-const ModuleSignature BytecodeModule::signature() const {
- return ModuleSignature(function_table_def().imports()->size(),
- function_table_def().exports()->size(),
- function_table_def().functions()->size(), 0);
-}
-
-std::string BytecodeModule::DebugStringShort() const {
- return std::string(name());
-}
-
-StatusOr<int32_t> BytecodeModule::MapFunctionOrdinal(Function::Linkage linkage,
- int32_t ordinal) const {
- const auto& function_table = function_table_def();
- switch (linkage) {
- case Function::Linkage::kImport:
- if (ordinal < 0 || ordinal >= function_table.imports()->size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Import ordinal " << ordinal
- << " is outside the valid range [0, "
- << function_table.imports()->size() << ")";
- }
- ordinal = function_table.imports()->Get(ordinal);
- break;
- case Function::Linkage::kExport:
- if (ordinal < 0 || ordinal >= function_table.exports()->size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Export ordinal " << ordinal
- << " is outside the valid range [0, "
- << function_table.exports()->size() << ")";
- }
- ordinal = function_table.exports()->Get(ordinal);
- break;
- default:
- break;
- }
- if (ordinal < 0 || ordinal >= function_table.functions()->size()) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Function ordinal " << ordinal
- << " is outside the valid range [0, "
- << function_table.functions()->size() << ")";
- }
- return ordinal;
-}
-
-StatusOr<const Function> BytecodeModule::LookupFunctionByOrdinal(
- Function::Linkage linkage, int32_t ordinal) const {
- ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
- return Function(this, Function::Linkage::kInternal, ordinal);
-}
-
-StatusOr<const Function> BytecodeModule::LookupFunctionByName(
- Function::Linkage linkage, absl::string_view name) const {
- const auto& functions = *function_table_def().functions();
- for (int i = 0; i < functions.size(); ++i) {
- const auto* function_def = functions.Get(i);
- if (WrapString(function_def->name()) == name) {
- return LookupFunctionByOrdinal(Function::Linkage::kInternal, i);
- }
- }
- return NotFoundErrorBuilder(IREE_LOC)
- << "Function '" << name
- << "' not found in function table (or names have been stripped)";
-}
-
-StatusOr<absl::string_view> BytecodeModule::GetFunctionName(
- Function::Linkage linkage, int32_t ordinal) const {
- ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
- const auto* function_def = function_table_def().functions()->Get(ordinal);
- return WrapString(function_def->name());
-}
-
-StatusOr<const FunctionSignature> BytecodeModule::GetFunctionSignature(
- Function::Linkage linkage, int32_t ordinal) const {
- ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
- const auto* function_def = function_table_def().functions()->Get(ordinal);
- const auto* type_def = function_def->type();
- return FunctionSignature(
- type_def->inputs() ? type_def->inputs()->size() : 0,
- type_def->results() ? type_def->results()->size() : 0);
-}
-
-StatusOr<const FunctionDef*> BytecodeModule::GetFunctionDef(
- rt::Function::Linkage linkage, int32_t ordinal) const {
- ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
- const auto& function_defs = *function_table_def().functions();
- if (ordinal >= function_defs.size()) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Internal function ordinal " << ordinal
- << " out of range of table (" << function_defs.size() << ")";
- }
- return function_defs.Get(ordinal);
-}
-
-StatusOr<const MultiArchExecutableDef*>
-BytecodeModule::LookupMultiArchExecutable(int executable_ordinal) const {
- if (executable_ordinal < 0 ||
- executable_ordinal >=
- executable_table_def().multi_arch_executables()->size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Invalid multi-arch executable ordinal " << executable_ordinal;
- }
- return executable_table_def().multi_arch_executables()->Get(
- executable_ordinal);
-}
-
-// static
-Status BytecodeModule::ValidateArgType(const BufferView& arg,
- const MemRefTypeDef& expected_type) {
- RETURN_IF_ERROR(
- ValidateElementSize(arg.element_size * 8, *expected_type.element_type()));
-
- auto expected_shape = expected_type.shape();
- if (arg.shape.size() != expected_shape->size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Argument should have rank " << expected_shape->size()
- << " but has rank " << arg.shape.size();
- }
- for (int i = 0; i < expected_shape->size(); ++i) {
- auto dim_size = arg.shape[i];
- auto expected_dim_size = expected_shape->Get(i);
- if (dim_size != expected_dim_size) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Argument dimension " << i << " should have size "
- << expected_dim_size << " but has size " << dim_size;
- }
- }
- return OkStatus();
-}
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/bytecode_module.h b/iree/vm/bytecode_module.h
deleted file mode 100644
index 36d034e..0000000
--- a/iree/vm/bytecode_module.h
+++ /dev/null
@@ -1,103 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_BYTECODE_MODULE_H_
-#define IREE_VM_BYTECODE_MODULE_H_
-
-#include <memory>
-
-#include "iree/base/flatbuffer_util.h"
-#include "iree/rt/function.h"
-#include "iree/rt/module.h"
-#include "iree/schemas/executable_table_def_generated.h"
-#include "iree/schemas/function_table_def_generated.h"
-#include "iree/schemas/module_def_generated.h"
-#include "iree/vm/opcode_info.h"
-#include "iree/vm/source_map_resolver.h"
-
-namespace iree {
-namespace vm {
-
-using ModuleFile = FlatBufferFile<ModuleDef>;
-
-// A loaded bytecode module backed by a FlatBuffer.
-class BytecodeModule : public rt::Module {
- public:
- static Status ValidateStructure(const ModuleDef& module_def);
-
- ~BytecodeModule() override;
-
- const ModuleDef& def() const { return module_def_; }
- const FunctionTableDef& function_table_def() const {
- return *module_def_.function_table();
- }
- const ExecutableTableDef& executable_table_def() const {
- return *module_def_.executable_table();
- }
-
- absl::string_view name() const override {
- return WrapString(module_def_.name());
- }
-
- const rt::ModuleSignature signature() const override;
-
- rt::SourceResolver* source_resolver() const override {
- return &source_resolver_;
- }
-
- rt::Disassembler* disassembler() const override {
- return disassembler_.get();
- }
-
- std::string DebugStringShort() const override;
-
- StatusOr<const rt::Function> LookupFunctionByOrdinal(
- rt::Function::Linkage linkage, int32_t ordinal) const override;
-
- StatusOr<const rt::Function> LookupFunctionByName(
- rt::Function::Linkage linkage, absl::string_view name) const override;
-
- StatusOr<absl::string_view> GetFunctionName(rt::Function::Linkage linkage,
- int32_t ordinal) const override;
-
- StatusOr<const rt::FunctionSignature> GetFunctionSignature(
- rt::Function::Linkage linkage, int32_t ordinal) const override;
-
- StatusOr<const FunctionDef*> GetFunctionDef(rt::Function::Linkage linkage,
- int32_t ordinal) const;
-
- StatusOr<const MultiArchExecutableDef*> LookupMultiArchExecutable(
- int executable_ordinal) const;
-
- protected:
- BytecodeModule(std::unique_ptr<ModuleFile> module_file,
- OpcodeTable opcode_table);
-
- static Status ValidateArgType(const hal::BufferView& arg,
- const MemRefTypeDef& expected_type);
-
- private:
- StatusOr<int32_t> MapFunctionOrdinal(rt::Function::Linkage linkage,
- int32_t ordinal) const;
-
- std::unique_ptr<ModuleFile> module_file_;
- const ModuleDef& module_def_;
- mutable SourceMapResolver source_resolver_;
- mutable std::unique_ptr<rt::Disassembler> disassembler_;
-};
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_BYTECODE_MODULE_H_
diff --git a/iree/vm/bytecode_reader.cc b/iree/vm/bytecode_reader.cc
deleted file mode 100644
index 65e69f3..0000000
--- a/iree/vm/bytecode_reader.cc
+++ /dev/null
@@ -1,289 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/bytecode_reader.h"
-
-#include "iree/base/shape.h"
-#include "iree/base/status.h"
-#include "iree/hal/heap_buffer.h"
-#include "iree/vm/bytecode_module.h"
-
-namespace iree {
-namespace vm {
-
-namespace {
-
-using ::iree::hal::BufferView;
-using ::iree::rt::StackFrame;
-
-} // namespace
-
-StatusOr<const uint8_t*> BytecodeReader::AdvanceOffset() {
- *stack_frame_->mutable_offset() = offset();
- // TODO(benvanik): make a flag and/or remove.
- DVLOG(1) << "dispatch(" << stack_frame_->function().name() << "@" << offset()
- << "): " << int(*bytecode_pc_);
- for (int i = 0; i < registers_->buffer_views.size(); ++i) {
- DVLOG(1) << "local[" << i << "] "
- << registers_->buffer_views[i].DebugStringShort();
- }
- return bytecode_pc_++;
-}
-
-Status BytecodeReader::SkipLocals(int count) {
- size_t stride = sizeof(uint16_t) * count;
- if (bytecode_pc_ + stride >= bytecode_limit_) {
- return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode underflow";
- }
- bytecode_pc_ += stride;
- return OkStatus();
-}
-
-Status BytecodeReader::ReadShape(Shape* out_shape) {
- ASSIGN_OR_RETURN(auto shape_dims, ReadIndexList());
- *out_shape = Shape(shape_dims);
- return OkStatus();
-}
-
-StatusOr<Shape> BytecodeReader::ReadShapePieces() {
- // TODO(benvanik): rewrite to be faster (multiple offsets to walk both lists).
- ASSIGN_OR_RETURN(auto shape_dims, ReadIndexList());
- if (shape_dims.size() >= kMaxRank) {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Shapes limited to rank " << kMaxRank << " right now";
- }
- int expected_dynamic_dims = 0;
- for (int i = 0; i < shape_dims.size(); ++i) {
- if (shape_dims[i] == -1) {
- ++expected_dynamic_dims;
- }
- }
-
- Shape shape(shape_dims);
- ASSIGN_OR_RETURN(int dynamic_dims, ReadCount());
- if (dynamic_dims != expected_dynamic_dims) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Expected " << expected_dynamic_dims << " dynamic dims but only "
- << dynamic_dims << " provided";
- } else if (dynamic_dims) {
- for (int i = 0; i < shape_dims.size(); ++i) {
- if (shape_dims[i] != -1) {
- continue;
- }
- // TODO(benvanik): kill this embarrassment.
- ASSIGN_OR_RETURN(auto dims_piece, ReadSlotElements<int32_t>());
- if (dims_piece.size() != 1) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Dims piece has rank " << dims_piece.size() << "; must be 1";
- }
- shape[i] = dims_piece[0];
- }
- }
- return shape;
-}
-
-StatusOr<Shape> BytecodeReader::ReadShapePieces(size_t* out_element_count) {
- ASSIGN_OR_RETURN(auto shape, ReadShapePieces());
- *out_element_count = shape.element_count();
- return shape;
-}
-
-StatusOr<absl::Span<const int32_t>> BytecodeReader::ReadIndexList() {
- ASSIGN_OR_RETURN(int count, ReadCount());
- int stride = count * sizeof(int32_t);
- if (bytecode_pc_ + stride >= bytecode_limit_) {
- return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode underflow";
- }
- auto list = absl::Span<const int32_t>(
- reinterpret_cast<const int32_t*>(bytecode_pc_), count);
- bytecode_pc_ += stride;
- return list;
-}
-
-Status BytecodeReader::SwitchStackFrame(StackFrame* new_stack_frame) {
- // Flush old state.
- auto* old_stack_frame = stack_frame_;
- if (old_stack_frame) {
- *old_stack_frame->mutable_offset() = offset();
- }
-
- // Switch the frame. The FiberState holds the full stack, this is just the
- // current one for easy access.
- stack_frame_ = new_stack_frame;
-
- // Setup state pointers for faster dereferencing.
- const auto& function = new_stack_frame->function();
- ASSIGN_OR_RETURN(
- const auto* function_def,
- static_cast<const BytecodeModule*>(function.module())
- ->GetFunctionDef(function.linkage(), function.ordinal()));
- const auto& bytecode = *function_def->bytecode();
- bytecode_base_ = bytecode.contents()->Data();
- bytecode_limit_ = bytecode_base_ + bytecode.contents()->size();
- bytecode_pc_ = bytecode_base_ + new_stack_frame->offset();
- registers_ = new_stack_frame->mutable_registers();
- return OkStatus();
-}
-
-Status BytecodeReader::CopyInputsAndSwitchStackFrame(
- StackFrame* src_stack_frame, StackFrame* dst_stack_frame) {
- ASSIGN_OR_RETURN(size_t src_count, ReadCount());
- auto& dst_buffer_views = dst_stack_frame->mutable_registers()->buffer_views;
- for (int i = 0; i < std::min(src_count, dst_buffer_views.size()); ++i) {
- ASSIGN_OR_RETURN(auto* src_local,
- ReadLocal(src_stack_frame->mutable_registers()));
- dst_buffer_views[i] = *src_local;
- }
- return SwitchStackFrame(dst_stack_frame);
-}
-
-Status BytecodeReader::CopyResultsAndSwitchStackFrame(
- StackFrame* src_stack_frame, StackFrame* dst_stack_frame) {
- ASSIGN_OR_RETURN(int32_t src_count, ReadCount());
- // TODO(benvanik): avoid vector.
- absl::InlinedVector<BufferView*, 8> src_locals(src_count);
- for (int i = 0; i < src_count; ++i) {
- ASSIGN_OR_RETURN(src_locals[i],
- ReadLocal(src_stack_frame->mutable_registers()));
- }
- RETURN_IF_ERROR(SwitchStackFrame(dst_stack_frame));
- ASSIGN_OR_RETURN(int32_t dst_count, ReadCount());
- if (src_count != dst_count) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Src and dst value counts differ: " << src_count << " vs "
- << dst_count;
- }
- for (int i = 0; i < dst_count; ++i) {
- ASSIGN_OR_RETURN(auto* dst_local,
- ReadLocal(dst_stack_frame->mutable_registers()));
- *dst_local = *src_locals[i];
- }
- return OkStatus();
-}
-
-Status BytecodeReader::CopySlots() {
- ASSIGN_OR_RETURN(int32_t count, ReadCount());
- for (int i = 0; i < count; ++i) {
- ASSIGN_OR_RETURN(auto* src_local,
- ReadLocal(stack_frame_->mutable_registers()));
- ASSIGN_OR_RETURN(auto* dst_local,
- ReadLocal(stack_frame_->mutable_registers()));
- *dst_local = *src_local;
- }
- return OkStatus();
-}
-
-Status BytecodeReader::BranchToOffset(int32_t offset) {
- const uint8_t* new_bytecode_pc = bytecode_base_ + offset;
- if (new_bytecode_pc < bytecode_base_ || new_bytecode_pc > bytecode_limit_) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Branch target " << offset
- << " is out of bounds of the function bytecode ("
- << static_cast<size_t>(bytecode_limit_ - bytecode_base_)
- << "b total)";
- }
- bytecode_pc_ = new_bytecode_pc;
- return OkStatus();
-}
-
-StatusOr<BufferView> BytecodeReader::ReadConstant() {
- BufferView buffer_view;
-
- // Element type defines the buffer_view size (but we don't really care about
- // the data format).
- ASSIGN_OR_RETURN(auto element_type, ReadType());
- buffer_view.element_size = element_type.element_size();
-
- // Parse shape - constants always define a full shape.
- RETURN_IF_ERROR(ReadShape(&buffer_view.shape));
-
- // Read encoding to determine how the constant data is stored in the file.
- ASSIGN_OR_RETURN(auto encoding, ReadValue<ConstantEncoding>());
-
- // Get buffer for the constant data.
- switch (encoding) {
- case ConstantEncoding::kDense: {
- // Validate we have all constant data present.
- device_size_t serialized_length = buffer_view.byte_length();
- if (bytecode_pc_ + serialized_length >= bytecode_limit_) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Constant data out of bounds";
- }
-
- buffer_view.buffer = hal::HeapBuffer::Wrap(
- hal::MemoryType::kHostLocal, hal::BufferUsage::kAll, bytecode_pc_,
- serialized_length);
- bytecode_pc_ += serialized_length;
- break;
- }
- case ConstantEncoding::kSplat: {
- // Validate we have at least one element worth of data in the buffer.
- if (bytecode_pc_ + buffer_view.element_size >= bytecode_limit_) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Constant data out of bounds";
- }
-
- // TODO(benvanik): replace with fancy constant pool and such.
- // NOTE: this is not much different than if a alloc_heap+broadcast pair
- // had been in the IR.
- buffer_view.buffer = hal::HeapBuffer::Allocate(
- hal::MemoryType::kHostLocal, hal::BufferUsage::kAll,
- buffer_view.byte_length());
- switch (buffer_view.element_size) {
- case 1: {
- uint8_t value = *reinterpret_cast<const uint8_t*>(bytecode_pc_);
- RETURN_IF_ERROR(buffer_view.buffer->Fill8(value));
- break;
- }
- case 2: {
- uint16_t value = *reinterpret_cast<const uint16_t*>(bytecode_pc_);
- RETURN_IF_ERROR(buffer_view.buffer->Fill16(value));
- break;
- }
- case 4: {
- uint32_t value = *reinterpret_cast<const uint32_t*>(bytecode_pc_);
- RETURN_IF_ERROR(buffer_view.buffer->Fill32(value));
- break;
- }
- case 8: {
- // TODO(benvanik): add Fill64.
- uint64_t value = *reinterpret_cast<const uint64_t*>(bytecode_pc_);
- ASSIGN_OR_RETURN(auto mapping,
- buffer_view.buffer->MapMemory<uint64_t>(
- hal::MemoryAccess::kDiscardWrite));
- auto mapped_data = mapping.mutable_contents();
- for (int i = 0; i < mapping.size(); ++i) {
- mapped_data[i] = value;
- }
- break;
- }
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented splat element stride "
- << buffer_view.element_size;
- }
- bytecode_pc_ += buffer_view.element_size;
- break;
- }
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented constant encoding "
- << static_cast<int>(encoding);
- }
-
- return buffer_view;
-}
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/bytecode_reader.h b/iree/vm/bytecode_reader.h
deleted file mode 100644
index f2b85fd..0000000
--- a/iree/vm/bytecode_reader.h
+++ /dev/null
@@ -1,169 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_BYTECODE_READER_H_
-#define IREE_VM_BYTECODE_READER_H_
-
-#include "absl/base/attributes.h"
-#include "absl/container/inlined_vector.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/rt/context.h"
-#include "iree/rt/stack.h"
-#include "iree/rt/stack_frame.h"
-#include "iree/schemas/bytecode/bytecode_v0.h"
-#include "iree/vm/type.h"
-
-namespace iree {
-namespace vm {
-
-class BytecodeReader {
- public:
- explicit BytecodeReader(rt::Stack* stack) : stack_(stack) {}
-
- int offset() const { return static_cast<int>(bytecode_pc_ - bytecode_base_); }
-
- StatusOr<const uint8_t*> AdvanceOffset();
-
- Status SwitchStackFrame(rt::StackFrame* new_stack_frame);
- Status BranchToOffset(int32_t offset);
-
- Status CopyInputsAndSwitchStackFrame(rt::StackFrame* src_stack_frame,
- rt::StackFrame* dst_stack_frame);
- Status CopyResultsAndSwitchStackFrame(rt::StackFrame* src_stack_frame,
- rt::StackFrame* dst_stack_frame);
- Status CopySlots();
-
- StatusOr<hal::BufferView> ReadConstant();
-
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<int> ReadCount() {
- return ReadValue<uint8_t>();
- }
-
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const Type> ReadType() {
- ASSIGN_OR_RETURN(uint8_t type_index, ReadValue<uint8_t>());
- return Type::FromTypeIndex(type_index);
- }
-
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const rt::Function> ReadFunction() {
- ASSIGN_OR_RETURN(auto value, ReadValue<uint32_t>());
- const auto& module = stack_frame_->module();
- return module.LookupFunctionByOrdinal(rt::Function::Linkage::kInternal,
- value);
- }
-
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const rt::Function>
- ReadImportFunction() {
- ASSIGN_OR_RETURN(auto value, ReadValue<uint32_t>());
- const auto& module = stack_frame_->module();
- return stack_->context()->ResolveImport(&module, value);
- }
-
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<hal::BufferView*> ReadLocal(
- rt::Registers* registers) {
- ASSIGN_OR_RETURN(auto value, ReadValue<uint16_t>());
- if (value > registers->buffer_views.size()) {
- return OutOfRangeErrorBuilder(IREE_LOC)
- << "Out of bounds local access " << value << " of "
- << registers->buffer_views.size();
- }
- return ®isters->buffer_views[value];
- }
-
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<hal::BufferView*> ReadLocal() {
- return ReadLocal(registers_);
- }
-
- Status SkipLocals(int count);
-
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint8_t> ReadUint8_t() {
- return ReadValue<uint8_t>();
- }
-
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint16_t> ReadUint16_t() {
- return ReadValue<uint16_t>();
- }
-
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<int32_t> ReadInt32() {
- return ReadValue<int32_t>();
- }
-
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint32_t> ReadBlockOffset() {
- return ReadValue<uint32_t>();
- }
-
- template <typename T, size_t N = 8>
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<absl::InlinedVector<T, N>>
- ReadSlotElements() {
- ASSIGN_OR_RETURN(auto* local, ReadLocal(registers_));
- absl::InlinedVector<T, N> result(local->shape.element_count());
- if (sizeof(T) == local->element_size) {
- // Fast(ish) path: requested element size matches the actual element size.
- RETURN_IF_ERROR(
- local->buffer->ReadData(0, result.data(), result.size() * sizeof(T)));
- } else {
- // Slow path: need to convert the data.
- switch (local->element_size) {
- case 4: {
- ASSIGN_OR_RETURN(auto mapping, local->buffer->MapMemory<int32_t>(
- hal::MemoryAccess::kRead));
- for (size_t i = 0; i < result.size(); ++i) {
- result[i] = static_cast<T>(mapping[i]);
- }
- break;
- }
- case 8: {
- ASSIGN_OR_RETURN(auto mapping, local->buffer->MapMemory<int64_t>(
- hal::MemoryAccess::kRead));
- for (size_t i = 0; i < result.size(); ++i) {
- result[i] = static_cast<T>(mapping[i]);
- }
- break;
- }
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unsupported local element size: " << local->element_size;
- }
- }
- return result;
- }
-
- Status ReadShape(Shape* out_shape);
-
- StatusOr<Shape> ReadShapePieces();
- StatusOr<Shape> ReadShapePieces(size_t* out_element_count);
-
- StatusOr<absl::Span<const int32_t>> ReadIndexList();
-
- private:
- template <typename T>
- ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<T> ReadValue() {
- // TODO(benvanik): validate bounds.
- T value = *reinterpret_cast<const T*>(bytecode_pc_);
- bytecode_pc_ += sizeof(T);
- return value;
- }
-
- rt::Stack* stack_ = nullptr;
- rt::StackFrame* stack_frame_ = nullptr;
- const uint8_t* bytecode_base_ = nullptr;
- const uint8_t* bytecode_limit_ = nullptr;
- const uint8_t* bytecode_pc_ = nullptr;
- rt::Registers* registers_ = nullptr;
-};
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_BYTECODE_READER_H_
diff --git a/iree/vm/bytecode_tables_interpreter.cc b/iree/vm/bytecode_tables_interpreter.cc
deleted file mode 100644
index 23661c5..0000000
--- a/iree/vm/bytecode_tables_interpreter.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/bytecode_tables_interpreter.h"
-
-#include "iree/schemas/bytecode/interpreter_bytecode_v0.h"
-
-namespace iree {
-namespace vm {
-
-namespace {
-
-// Info table mapping 1:1 with bytecode ops.
-//
-// Note that we ensure the table is 256 elements long exactly to make sure
-// that unused opcodes are handled gracefully.
-static const OpcodeInfo kInfoTable[256] = {
-#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \
- OpcodeInfo{ \
- name, \
- flags, \
- {operand_encodings}, \
- },
- IREE_INTERPRETER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)
-#undef DECLARE_INFO
-};
-
-} // namespace
-
-OpcodeTable interpreter_opcode_table() { return kInfoTable; }
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/bytecode_tables_interpreter.h b/iree/vm/bytecode_tables_interpreter.h
deleted file mode 100644
index ef53902..0000000
--- a/iree/vm/bytecode_tables_interpreter.h
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_BYTECODE_TABLES_INTERPRETER_H_
-#define IREE_VM_BYTECODE_TABLES_INTERPRETER_H_
-
-#include "iree/vm/opcode_info.h"
-
-namespace iree {
-namespace vm {
-
-OpcodeTable interpreter_opcode_table();
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_BYTECODE_TABLES_INTERPRETER_H_
diff --git a/iree/vm/bytecode_tables_sequencer.cc b/iree/vm/bytecode_tables_sequencer.cc
deleted file mode 100644
index 6a64335..0000000
--- a/iree/vm/bytecode_tables_sequencer.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/bytecode_tables_sequencer.h"
-
-#include "iree/schemas/bytecode/sequencer_bytecode_v0.h"
-
-namespace iree {
-namespace vm {
-
-namespace {
-
-// Info table mapping 1:1 with bytecode ops.
-//
-// Note that we ensure the table is 256 elements long exactly to make sure
-// that unused opcodes are handled gracefully.
-static const OpcodeInfo kInfoTable[256] = {
-#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \
- OpcodeInfo{ \
- name, \
- flags, \
- {operand_encodings}, \
- },
- IREE_SEQUENCER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)
-#undef DECLARE_INFO
-};
-
-} // namespace
-
-OpcodeTable sequencer_opcode_table() { return kInfoTable; }
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/bytecode_tables_sequencer.h b/iree/vm/bytecode_tables_sequencer.h
deleted file mode 100644
index 7e690c5..0000000
--- a/iree/vm/bytecode_tables_sequencer.h
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_BYTECODE_TABLES_SEQUENCER_H_
-#define IREE_VM_BYTECODE_TABLES_SEQUENCER_H_
-
-#include "iree/vm/opcode_info.h"
-
-namespace iree {
-namespace vm {
-
-OpcodeTable sequencer_opcode_table();
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_BYTECODE_TABLES_SEQUENCER_H_
diff --git a/iree/vm/bytecode_util.cc b/iree/vm/bytecode_util.cc
deleted file mode 100644
index 9307773..0000000
--- a/iree/vm/bytecode_util.cc
+++ /dev/null
@@ -1,43 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/bytecode_util.h"
-
-namespace iree {
-namespace vm {
-
-absl::string_view PredicateToString(CmpIPredicate p) {
-#define PRED(index, name, str, ...) \
- case CmpIPredicate::name: \
- return str;
- switch (p) {
- IREE_CMPI_PREDICATE_LIST(PRED)
-#undef PRED
- }
- return "<unknown>";
-}
-
-absl::string_view PredicateToString(CmpFPredicate p) {
-#define PRED(index, name, str, ...) \
- case CmpFPredicate::name: \
- return str;
- switch (p) {
- IREE_CMPF_PREDICATE_LIST(PRED)
-#undef PRED
- }
- return "<unknown>";
-}
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/bytecode_util.h b/iree/vm/bytecode_util.h
deleted file mode 100644
index c663570..0000000
--- a/iree/vm/bytecode_util.h
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_BYTECODE_UTIL_H_
-#define IREE_VM_BYTECODE_UTIL_H_
-
-#include "absl/strings/string_view.h"
-#include "iree/schemas/bytecode/bytecode_v0.h"
-
-namespace iree {
-namespace vm {
-
-absl::string_view PredicateToString(CmpIPredicate predicate);
-
-absl::string_view PredicateToString(CmpFPredicate predicate);
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_BYTECODE_UTIL_H_
diff --git a/iree/vm/bytecode_validator.cc b/iree/vm/bytecode_validator.cc
deleted file mode 100644
index 968b193..0000000
--- a/iree/vm/bytecode_validator.cc
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/bytecode_validator.h"
-
-namespace iree {
-namespace vm {
-
-// static
-Status BytecodeValidator::Validate(const BytecodeModule& module,
- const BytecodeDef& bytecode_def) {
- // TODO(benvanik): validate bytecode.
- return OkStatus();
-}
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/bytecode_validator.h b/iree/vm/bytecode_validator.h
deleted file mode 100644
index 429c754..0000000
--- a/iree/vm/bytecode_validator.h
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_BYTECODE_VALIDATOR_H_
-#define IREE_VM_BYTECODE_VALIDATOR_H_
-
-#include "iree/base/status.h"
-#include "iree/schemas/bytecode_def_generated.h"
-#include "iree/vm/bytecode_module.h"
-
-namespace iree {
-namespace vm {
-
-// Validates bytecode such that success indicates the bytecode does not
-// reference undefined types, functions, or required imports and all imports can
-// be resolved with matching signatures.
-class BytecodeValidator {
- public:
- static Status Validate(const BytecodeModule& module,
- const BytecodeDef& bytecode_def);
-};
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_BYTECODE_VALIDATOR_H_
diff --git a/iree/vm/opcode_info.h b/iree/vm/opcode_info.h
deleted file mode 100644
index 751a6e0..0000000
--- a/iree/vm/opcode_info.h
+++ /dev/null
@@ -1,45 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_OPCODE_INFO_H_
-#define IREE_VM_OPCODE_INFO_H_
-
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "absl/types/span.h"
-#include "iree/schemas/bytecode/bytecode_v0.h"
-
-namespace iree {
-namespace vm {
-
-struct OpcodeInfo {
- const char* mnemonic;
- OpcodeFlagBitfield flag;
- union {
- const char operands_value[8];
- const OperandEncoding operands[8];
- };
-};
-
-using OpcodeTable = absl::Span<const OpcodeInfo>;
-
-template <typename T>
-inline const OpcodeInfo& GetOpcodeInfo(OpcodeTable opcode_table, T opcode) {
- return opcode_table[static_cast<uint8_t>(opcode)];
-}
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_OPCODE_INFO_H_
diff --git a/iree/vm/sequencer_dispatch.cc b/iree/vm/sequencer_dispatch.cc
deleted file mode 100644
index ce5a408..0000000
--- a/iree/vm/sequencer_dispatch.cc
+++ /dev/null
@@ -1,561 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Implements a full bytecode dispatch system for sequencer ops.
-// TODO(benvanik): rework to be async against CommandBuffers.
-
-#include "iree/vm/sequencer_dispatch.h"
-
-#include <algorithm>
-
-#include "absl/base/attributes.h"
-#include "absl/container/inlined_vector.h"
-#include "absl/strings/str_join.h"
-#include "absl/time/time.h"
-#include "absl/types/span.h"
-#include "iree/base/logging.h"
-#include "iree/base/memory.h"
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/command_queue.h"
-#include "iree/hal/device.h"
-#include "iree/hal/heap_buffer.h"
-#include "iree/schemas/bytecode/sequencer_bytecode_v0.h"
-#include "iree/vm/bytecode_module.h"
-#include "iree/vm/bytecode_reader.h"
-#include "iree/vm/bytecode_tables_sequencer.h"
-#include "iree/vm/bytecode_util.h"
-#include "iree/vm/opcode_info.h"
-
-namespace iree {
-namespace vm {
-
-namespace {
-
-using ::iree::hal::Buffer;
-using ::iree::hal::BufferView;
-
-// TODO(benvanik): remove (this should happen via predication).
-bool BufferViewIsTrue(const BufferView& buffer_view) {
- if (buffer_view.element_size == 0 || !buffer_view.buffer ||
- buffer_view.byte_length() == 0) {
- return false;
- }
- // TODO(benvanik): map more efficiently (based on element size?).
- auto mapping =
- buffer_view.buffer->MapMemory<uint8_t>(hal::MemoryAccess::kRead);
- if (!mapping.ok()) {
- return false;
- }
- for (uint8_t value : mapping.ValueOrDie().contents()) {
- if (value) return true;
- }
- return false;
-}
-
-// TODO(benvanik): insert fence callbacks and wait on fence.
-Status CallExternalFunction(rt::Stack* stack, const rt::Function& function) {
- // Marshal inputs and outputs.
- const auto* stack_frame = stack->current_frame();
- auto buffer_views = absl::MakeSpan(stack_frame->registers().buffer_views);
- absl::InlinedVector<hal::BufferView, 8> arguments(
- buffer_views.begin(),
- buffer_views.begin() + function.signature().argument_count());
- absl::InlinedVector<hal::BufferView, 8> results(
- buffer_views.begin() + arguments.size(), buffer_views.end());
- return function.module()->Execute(stack, function, std::move(arguments),
- &results);
-}
-
-// Pretty prints an array, e.g. [1, 2, 3, 4]
-inline std::string PrettyPrint(absl::Span<const int32_t> arr) {
- return "[" + absl::StrJoin(arr, ",") + "]";
-}
-
-// Calculates the byte offset into a buffer corresponding to the indices in the
-// given shape.
-StatusOr<device_size_t> CalculateOffset(absl::Span<const int32_t> indices,
- Shape shape, uint8_t element_size) {
- if (shape.empty() || indices.size() > shape.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Indices " << PrettyPrint(indices) << " out of bounds of shape "
- << PrettyPrint(shape.subspan());
- }
- device_size_t offset = 0;
- for (int i = 0; i < indices.size(); ++i) {
- if (indices[i] >= shape[i]) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Indices[" << i << "]=" << indices[i]
- << " out of bounds of shape " << PrettyPrint(shape.subspan());
- }
- device_size_t axis_offset = indices[i];
- for (int j = i + 1; j < shape.size(); ++j) {
- axis_offset *= shape[j];
- }
- offset += axis_offset;
- }
- offset *= element_size;
- return offset;
-}
-
-} // namespace
-
-Status DispatchSequence(const hal::DevicePlacement& placement, rt::Stack* stack,
- rt::StackFrame* entry_stack_frame,
- absl::Span<BufferView> entry_results) {
- // Dispatch table mapping 1:1 with bytecode ops.
- // Each entry is a label within this function that can be used for computed
- // goto. You can find more information on computed goto here:
- // https://eli.thegreenplace.net/2012/07/12/computed-goto-for-efficient-dispatch-tables
- //
- // Note that we ensure the table is 256 elements long exactly to make sure
- // that unused opcodes are handled gracefully.
- static const void* kDispatchTable[256] = {
-#define DECLARE_DISPATCH(ordinal, name, ...) &&_dispatch_##name,
-#define DECLARE_DISPATCH_RESERVED(ordinal, name, ...) &&_dispatch_unhandled,
- IREE_SEQUENCER_OPCODE_LIST(DECLARE_DISPATCH, DECLARE_DISPATCH_RESERVED)
-#undef DECLARE_DISPATCH
-#undef DECLARE_DISPATCH_RESERVED
- };
-
- // Primary dispatch state. This is our 'native stack frame' and really just
- // enough to make dereferencing common addresses (like the current offset)
- // faster. You can think of this like CPU state (like PC).
- //
- // We hope that LLVM decides to keep these in registers (as they are touched
- // for every instruction executed). The stack_frame will change as we call
- // into different functions.
- BytecodeReader reader(stack);
- RETURN_IF_ERROR(reader.SwitchStackFrame(entry_stack_frame));
-
-#define DISPATCH_NEXT() \
- { \
- uint8_t opcode = *reader.AdvanceOffset().ValueOrDie(); \
- DVLOG(1) << "Sequencer dispatching op code: " \
- << GetOpcodeInfo(sequencer_opcode_table(), opcode).mnemonic; \
- goto* kDispatchTable[opcode]; \
- }
-
-#define DISPATCH_CORE_OPCODE(opcode, body) \
- _dispatch_##opcode : {body} DISPATCH_NEXT()
-
- DISPATCH_NEXT();
-
- DISPATCH_CORE_OPCODE(kConstant, {
- ASSIGN_OR_RETURN(auto value, reader.ReadConstant());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- // TODO(b/139121143): until we have full command buffers we need to do this.
- ASSIGN_OR_RETURN(value.buffer,
- placement.device->allocator()->AllocateConstant(
- hal::BufferUsage::kConstant | hal::BufferUsage::kAll,
- std::move(value.buffer)));
- *dst_local = std::move(value);
- });
-
- DISPATCH_CORE_OPCODE(kCall, {
- auto* old_stack_frame = stack->current_frame();
- ASSIGN_OR_RETURN(const auto& target_function, reader.ReadFunction());
- // TODO(benvanik): rework register storage interface.
- ASSIGN_OR_RETURN(
- const auto* function_def,
- static_cast<const BytecodeModule*>(target_function.module())
- ->GetFunctionDef(target_function.linkage(),
- target_function.ordinal()));
- ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function));
- new_stack_frame->mutable_registers()->buffer_views.resize(
- function_def->bytecode()->local_count());
- RETURN_IF_ERROR(
- reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame));
- DVLOG(1) << "Call; stack now: " << stack->DebugString();
- });
-
- DISPATCH_CORE_OPCODE(kCallImport, {
- auto* old_stack_frame = stack->current_frame();
- ASSIGN_OR_RETURN(const auto& target_function, reader.ReadImportFunction());
- ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function));
- // TODO(benvanik): rework register storage interface.
- const auto& signature = target_function.signature();
- new_stack_frame->mutable_registers()->buffer_views.resize(
- signature.argument_count() + signature.result_count());
- RETURN_IF_ERROR(
- reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame));
- DVLOG(1) << "Call native import; stack now: " << stack->DebugString();
- RETURN_IF_ERROR(CallExternalFunction(stack, target_function));
- RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame,
- new_stack_frame));
- RETURN_IF_ERROR(stack->PopFrame());
- DVLOG(1) << "Return from native; stack now: " << stack->DebugString();
- });
-
- DISPATCH_CORE_OPCODE(kCallIndirect, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented call_indirect";
- });
-
- DISPATCH_CORE_OPCODE(kReturn, {
- auto* old_stack_frame = stack->current_frame();
- auto* new_stack_frame = stack->caller_frame();
- if (old_stack_frame == entry_stack_frame) {
- // Returning from entry function. Marshal results from the return stmt.
- ASSIGN_OR_RETURN(int32_t src_count, reader.ReadCount());
- for (int i = 0; i < src_count; ++i) {
- ASSIGN_OR_RETURN(
- auto* src_local,
- reader.ReadLocal(old_stack_frame->mutable_registers()));
- entry_results[i] = std::move(*src_local);
- }
- DVLOG(1) << "Returning to entry";
- return OkStatus();
- } else if (!new_stack_frame) {
- return FailedPreconditionErrorBuilder(IREE_LOC) << "Stack underflow";
- }
- RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame,
- new_stack_frame));
- RETURN_IF_ERROR(stack->PopFrame());
- DVLOG(1) << "Return; stack now: " << stack->DebugString();
- });
-
- DISPATCH_CORE_OPCODE(kBranch, {
- ASSIGN_OR_RETURN(int32_t offset, reader.ReadBlockOffset());
- RETURN_IF_ERROR(reader.CopySlots());
- RETURN_IF_ERROR(reader.BranchToOffset(offset));
- });
-
- DISPATCH_CORE_OPCODE(kCondBranch, {
- // Evaluate condition first so we can do the copies as we read them for
- // which side of the branch we take.
- ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal());
- bool cond_value = BufferViewIsTrue(*cond_local);
- ASSIGN_OR_RETURN(int32_t true_offset, reader.ReadBlockOffset());
-
- if (cond_value) {
- RETURN_IF_ERROR(reader.CopySlots());
- RETURN_IF_ERROR(reader.BranchToOffset(true_offset));
- } else {
- ASSIGN_OR_RETURN(int32_t true_op_count, reader.ReadCount());
- RETURN_IF_ERROR(reader.SkipLocals(2 * true_op_count));
- ASSIGN_OR_RETURN(int32_t false_offset, reader.ReadBlockOffset());
-
- RETURN_IF_ERROR(reader.CopySlots());
- RETURN_IF_ERROR(reader.BranchToOffset(false_offset));
- }
- });
-
- DISPATCH_CORE_OPCODE(kDynamicDispatch, {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented dynamic_dispatch";
- });
-
- DISPATCH_CORE_OPCODE(kStaticDispatch, {
- // TODO(benvanik): the real sequencer :)
- ASSIGN_OR_RETURN(auto dispatch_ordinal, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto export_ordinal, reader.ReadUint16_t());
- ASSIGN_OR_RETURN(
- const auto* multi_arch_executable_def,
- static_cast<const BytecodeModule&>(stack->current_frame()->module())
- .LookupMultiArchExecutable(dispatch_ordinal));
- if (export_ordinal >= multi_arch_executable_def->entry_point_count()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Invalid executable export ordinal " << export_ordinal;
- }
- auto* executable_def = multi_arch_executable_def->executables()->Get(0);
- hal::ExecutableSpec executable_spec;
- executable_spec.format = executable_def->format();
- executable_spec.executable_data = absl::Span<const uint8_t>(
- executable_def->contents()->data(), executable_def->contents()->size());
- auto executable_cache = placement.device->CreateExecutableCache();
- ref_ptr<hal::Executable> executable;
- for (auto* executable_def : *multi_arch_executable_def->executables()) {
- if (!executable_cache->CanPrepareFormat(executable_def->format())) {
- continue;
- }
- hal::ExecutableSpec executable_spec;
- executable_spec.format = executable_def->format();
- executable_spec.executable_data =
- absl::Span<const uint8_t>(executable_def->contents()->data(),
- executable_def->contents()->size());
- ASSIGN_OR_RETURN(executable,
- executable_cache->PrepareExecutable(
- hal::ExecutableCachingMode::kDefault |
- hal::ExecutableCachingMode::kAliasProvidedData,
- executable_spec),
- _.LogError());
- break;
- }
- if (!executable) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "No executable found for the current driver";
- }
-
- ASSIGN_OR_RETURN(int workload_x, reader.ReadInt32());
- ASSIGN_OR_RETURN(int workload_y, reader.ReadInt32());
- ASSIGN_OR_RETURN(int workload_z, reader.ReadInt32());
-
- std::vector<hal::BufferBinding> bindings;
- ASSIGN_OR_RETURN(int input_count, reader.ReadCount());
- for (int i = 0; i < input_count; ++i) {
- ASSIGN_OR_RETURN(auto* input_local, reader.ReadLocal());
- bindings.push_back(hal::BufferBinding(
- input_local->buffer->allowed_access() & hal::MemoryAccess::kAll,
- *input_local));
- }
- ASSIGN_OR_RETURN(int output_count, reader.ReadCount());
- for (int i = 0; i < output_count; ++i) {
- ASSIGN_OR_RETURN(auto* output_local, reader.ReadLocal());
- bindings.push_back(
- hal::BufferBinding(hal::MemoryAccess::kWrite, *output_local));
- }
- ASSIGN_OR_RETURN(int result_count, reader.ReadCount());
- CHECK_EQ(0, result_count) << "Results not yet implemented";
-
- ASSIGN_OR_RETURN(
- auto cmd,
- placement.device->CreateCommandBuffer(
- hal::CommandBufferMode::kOneShot,
- hal::CommandCategory::kTransfer | hal::CommandCategory::kDispatch),
- _.LogError());
- RETURN_IF_ERROR(cmd->Begin());
- hal::DispatchRequest dispatch_request;
- dispatch_request.executable = executable.get();
- dispatch_request.entry_point = export_ordinal;
- dispatch_request.workload[0] = workload_x;
- dispatch_request.workload[1] = workload_y;
- dispatch_request.workload[2] = workload_z;
- dispatch_request.bindings = bindings;
- RETURN_IF_ERROR(cmd->Dispatch(dispatch_request));
- RETURN_IF_ERROR(cmd->End());
- auto* cmd_ptr = cmd.get();
-
- auto* queue = placement.device->dispatch_queues().front();
- hal::SubmissionBatch batch;
- batch.command_buffers = absl::MakeConstSpan(&cmd_ptr, 1);
- ASSIGN_OR_RETURN(auto fence, placement.device->CreateFence(0u));
- RETURN_IF_ERROR(queue->Submit(batch, {fence.get(), 1u}));
- RETURN_IF_ERROR(placement.device->WaitAllFences({{fence.get(), 1u}},
- absl::InfiniteFuture()));
- });
-
- DISPATCH_CORE_OPCODE(kAllocStatic, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_static";
- });
-
- DISPATCH_CORE_OPCODE(kAllocStack, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_stack";
- });
-
- DISPATCH_CORE_OPCODE(kAllocStackInit, {
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unimplemented alloc_stack_init";
- });
-
- DISPATCH_CORE_OPCODE(kAllocHeap, {
- ASSIGN_OR_RETURN(auto heap_type, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto type, reader.ReadType());
- size_t element_size = type.element_size();
-
- // TODO(benvanik): more efficient reading and storage.
- size_t element_count = 0;
- ASSIGN_OR_RETURN(auto shape, reader.ReadShapePieces(&element_count));
- size_t allocation_size = element_size * element_count;
-
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- dst_local->element_size = element_size;
- dst_local->shape = shape;
-
- // TODO(benvanik): pick an allocator and use that instead.
- CHECK_EQ(heap_type, 0);
- auto* allocator = placement.device->allocator();
- ASSIGN_OR_RETURN(
- dst_local->buffer,
- allocator->Allocate(
- hal::MemoryType::kHostLocal | hal::MemoryType::kDeviceVisible,
- hal::BufferUsage::kAll, allocation_size));
- });
-
- DISPATCH_CORE_OPCODE(kDiscard, {
- // NOTE: if we were an encoder we would actually discard the buffer.
- ASSIGN_OR_RETURN(auto* local, reader.ReadLocal());
- *local = {};
- });
-
- DISPATCH_CORE_OPCODE(kComputeRange, {
- ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto element_size, reader.ReadUint8_t());
- ASSIGN_OR_RETURN(auto indices, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_offset_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_length_local, reader.ReadLocal());
-
- Shape shape(shape_data);
- ASSIGN_OR_RETURN(device_size_t dst_offset,
- CalculateOffset(indices, shape, element_size));
- RETURN_IF_ERROR(
- dst_offset_local->buffer->WriteData(0, &dst_offset, sizeof(int32_t)));
-
- // A buffer range can only be computed for contiguous memory. To ensure that
- // this only requests such, we validate that the offset in the buffer
- // between the start and end indices is the same as the requested size.
- device_size_t dst_length = element_size;
- for (int i = 0; i < lengths.size(); ++i) {
- dst_length *= lengths[i];
- indices[i] += lengths[i] - 1;
- }
- ASSIGN_OR_RETURN(auto end_offset,
- CalculateOffset(indices, shape, element_size));
- auto offset_based_length = end_offset - dst_offset + element_size;
- if (dst_length != offset_based_length) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Cannot compute range for non-contiguous region of memory;"
- << " shape: " << PrettyPrint(shape.subspan())
- << " indices: " << PrettyPrint(indices)
- << " lengths: " << PrettyPrint(lengths);
- }
- RETURN_IF_ERROR(
- dst_length_local->buffer->WriteData(0, &dst_length, sizeof(int32_t)));
- });
-
- DISPATCH_CORE_OPCODE(kShape, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- RETURN_IF_ERROR(dst_local->buffer->WriteData(
- 0, src_local->shape.subspan().data(),
- src_local->shape.subspan().size() * sizeof(int32_t)));
- });
-
- DISPATCH_CORE_OPCODE(kLength, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- int32_t length = src_local->shape.element_count();
- RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &length, sizeof(int32_t)));
- });
-
- DISPATCH_CORE_OPCODE(kDynamicSlice, {
- // TODO(b/139299169): implement indirect copies to avoid CPU readback.
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented dynamic_slice";
- });
-
- DISPATCH_CORE_OPCODE(kStaticSlice, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto offset, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto length, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto type, reader.ReadType());
- ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- Shape new_shape = Shape{shape_data};
- if (new_shape.element_count() * type.element_size() != length) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "New element count " << new_shape.element_count()
- << " != length slice " << length;
- }
- ASSIGN_OR_RETURN(dst_local->buffer,
- Buffer::Subspan(src_local->buffer, offset, length));
- dst_local->shape = new_shape;
- dst_local->element_size = type.element_size();
- });
-
- DISPATCH_CORE_OPCODE(kDynamicCopy, {
- // TODO(b/139299169): implement indirect copies to avoid CPU readback.
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto src_offset_span, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dst_offset_span, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto length_span, reader.ReadSlotElements<int32_t>());
- RETURN_IF_ERROR(dst_local->buffer->CopyData(
- dst_offset_span.front(), src_local->buffer.get(),
- src_offset_span.front(), length_span.front()));
- });
-
- DISPATCH_CORE_OPCODE(kStaticCopy, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto src_offset, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dst_offset, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto length, reader.ReadInt32());
- RETURN_IF_ERROR(dst_local->buffer->CopyData(
- dst_offset, src_local->buffer.get(), src_offset, length));
- });
-
- DISPATCH_CORE_OPCODE(kDynamicFill, {
- // TODO(b/139299169): implement indirect fills to avoid CPU readback.
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented dynamic_fill";
- });
-
- DISPATCH_CORE_OPCODE(kStaticFill, {
- ASSIGN_OR_RETURN(auto value, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto dst_offset, reader.ReadInt32());
- ASSIGN_OR_RETURN(auto length, reader.ReadInt32());
- RETURN_IF_ERROR(dst_local->buffer->Fill32(dst_offset, length, value));
- });
-
- DISPATCH_CORE_OPCODE(kClone, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- dst_local->element_size = src_local->element_size;
- dst_local->shape = src_local->shape;
- ASSIGN_OR_RETURN(dst_local->buffer, placement.device->allocator()->Allocate(
- src_local->buffer->memory_type(),
- src_local->buffer->usage(),
- src_local->buffer->byte_length()));
- RETURN_IF_ERROR(dst_local->buffer->CopyData(0, src_local->buffer.get()));
- });
-
- DISPATCH_CORE_OPCODE(kAssign, {
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- *dst_local = *src_local;
- });
-
- DISPATCH_CORE_OPCODE(kCondAssign, {
- ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- *dst_local = BufferViewIsTrue(*cond_local) ? *lhs_local : *rhs_local;
- });
-
- DISPATCH_CORE_OPCODE(kReshape, {
- // TODO(benvanik): more logic required if strides differ.
- ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
- ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
- ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
- Shape new_shape = Shape{shape_data};
- if (src_local->shape.element_count() != new_shape.element_count()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "New element count " << new_shape.element_count()
- << " != source element count " << src_local->shape.element_count();
- }
- dst_local->shape = new_shape;
- dst_local->buffer = add_ref(src_local->buffer);
- dst_local->element_size = src_local->element_size;
- });
-
- DISPATCH_CORE_OPCODE(kTrace, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented trace";
- });
-
- DISPATCH_CORE_OPCODE(kBreak, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented break";
- });
-
- DISPATCH_CORE_OPCODE(kCondBreak, {
- return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented cond_break";
- });
-
-_dispatch_unhandled:
- // TODO(benvanik): better tracing.
- return UnimplementedErrorBuilder(IREE_LOC) << "Unknown dispatch opcode";
-}
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/sequencer_dispatch.h b/iree/vm/sequencer_dispatch.h
deleted file mode 100644
index 0251c17..0000000
--- a/iree/vm/sequencer_dispatch.h
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_SEQUENCER_DISPATCH_H_
-#define IREE_VM_SEQUENCER_DISPATCH_H_
-
-#include "iree/base/status.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/hal/device_placement.h"
-#include "iree/rt/stack.h"
-#include "iree/rt/stack_frame.h"
-
-namespace iree {
-namespace vm {
-
-// TODO(benvanik): API that supports yielding.
-Status DispatchSequence(const hal::DevicePlacement& placement, rt::Stack* stack,
- rt::StackFrame* entry_stack_frame,
- absl::Span<hal::BufferView> entry_results);
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_SEQUENCER_DISPATCH_H_
diff --git a/iree/vm/sequencer_module.cc b/iree/vm/sequencer_module.cc
deleted file mode 100644
index c3cf58f..0000000
--- a/iree/vm/sequencer_module.cc
+++ /dev/null
@@ -1,112 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/sequencer_module.h"
-
-#include "absl/memory/memory.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/buffer_view.h"
-#include "iree/rt/context.h"
-#include "iree/rt/instance.h"
-#include "iree/vm/bytecode_tables_sequencer.h"
-#include "iree/vm/sequencer_dispatch.h"
-
-namespace iree {
-namespace vm {
-
-namespace {
-
-using ::iree::hal::BufferView;
-using ::iree::rt::Function;
-using ::iree::rt::Module;
-
-} // namespace
-
-// static
-StatusOr<ref_ptr<rt::Module>> SequencerModule::FromDef(
- const ModuleDef& module_def) {
- ASSIGN_OR_RETURN(auto module_file, ModuleFile::Create(&module_def, []() {}));
- return FromFile(std::move(module_file));
-}
-
-// static
-StatusOr<ref_ptr<rt::Module>> SequencerModule::FromFile(
- std::unique_ptr<ModuleFile> module_file) {
- if (module_file->root() == nullptr) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No root ModuleDef present";
- }
- const auto& module_def = *module_file->root();
-
- // Validates the structure of the module (but not bytecode).
- // This ensures we don't have flatbuffer vectors will null entries, etc.
- RETURN_IF_ERROR(BytecodeModule::ValidateStructure(module_def));
-
- auto module = assign_ref(new SequencerModule(std::move(module_file)));
-
- // TODO(benvanik): validate internals here? or make explicit?
-
- return {std::move(module)};
-}
-
-SequencerModule::SequencerModule(std::unique_ptr<ModuleFile> module_file)
- : BytecodeModule(std::move(module_file), sequencer_opcode_table()) {}
-
-SequencerModule::~SequencerModule() = default;
-
-Status SequencerModule::Execute(
- rt::Stack* stack, const Function function,
- absl::InlinedVector<hal::BufferView, 8> arguments,
- absl::InlinedVector<hal::BufferView, 8>* results) const {
- IREE_TRACE_SCOPE0("SequencerModule::Execute");
-
- // Push stack frame for the function we are calling.
- ASSIGN_OR_RETURN(auto* callee_stack_frame, stack->PushFrame(function));
-
- // TODO(benvanik): rework register storage interface.
- ASSIGN_OR_RETURN(const auto* function_def,
- GetFunctionDef(function.linkage(), function.ordinal()));
- auto* registers = callee_stack_frame->mutable_registers();
- registers->buffer_views.resize(function_def->bytecode()->local_count());
-
- // Marshal input arguments.
- for (int i = 0; i < arguments.size(); ++i) {
- auto arg = arguments[i];
- auto expected_arg_type = function_def->type()->inputs()->Get(i);
- RETURN_IF_ERROR(BytecodeModule::ValidateArgType(
- arg, *expected_arg_type->type_union_as_MemRefTypeDef()))
- << "Function " << function.name() << " argument " << i;
- registers->buffer_views[i] = std::move(arg);
- }
-
- // TODO(benvanik): change to:
- // get command queue (any command queue)
- // make command buffer
- // record dispatch
- // submit
- // wait on fence
- ASSIGN_OR_RETURN(
- auto placement,
- stack->context()->instance()->device_manager()->ResolvePlacement({}));
- RETURN_IF_ERROR(DispatchSequence(placement, stack, callee_stack_frame,
- absl::MakeSpan(*results)));
-
- // Pop the callee frame to balance out the stack.
- RETURN_IF_ERROR(stack->PopFrame());
-
- return OkStatus();
-}
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/sequencer_module.h b/iree/vm/sequencer_module.h
deleted file mode 100644
index b9bb176..0000000
--- a/iree/vm/sequencer_module.h
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_SEQUENCER_MODULE_H_
-#define IREE_VM_SEQUENCER_MODULE_H_
-
-#include <memory>
-
-#include "iree/vm/bytecode_module.h"
-
-namespace iree {
-namespace vm {
-
-// A module using the sequencer bytecode ops.
-class SequencerModule final : public BytecodeModule {
- public:
- static StatusOr<ref_ptr<rt::Module>> FromDef(const ModuleDef& module_def);
- static StatusOr<ref_ptr<rt::Module>> FromFile(
- std::unique_ptr<ModuleFile> module_file);
-
- ~SequencerModule() override;
-
- Status Execute(
- rt::Stack* stack, const rt::Function function,
- absl::InlinedVector<hal::BufferView, 8> arguments,
- absl::InlinedVector<hal::BufferView, 8>* results) const override;
-
- private:
- explicit SequencerModule(std::unique_ptr<ModuleFile> module_file);
-};
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_SEQUENCER_MODULE_H_
diff --git a/iree/vm/source_map_resolver.cc b/iree/vm/source_map_resolver.cc
deleted file mode 100644
index 96025e4..0000000
--- a/iree/vm/source_map_resolver.cc
+++ /dev/null
@@ -1,194 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/source_map_resolver.h"
-
-#include "iree/base/flatbuffer_util.h"
-#include "iree/base/status.h"
-#include "iree/schemas/source_map_def_generated.h"
-
-namespace iree {
-namespace vm {
-
-namespace {
-
-Status PrintLocation(const SourceMapResolver& source_map,
- const FunctionSourceMapDef& function_source_map,
- const LocationDef& location, std::ostream* stream);
-
-Status PrintFileLocation(const SourceMapResolver& source_map,
- const FunctionSourceMapDef& function_source_map,
- const FileLocationDef& location,
- std::ostream* stream) {
- ASSIGN_OR_RETURN(auto filename,
- source_map.GetUniqueString(location.filename()));
- *stream << filename << ":" << location.line() << ":" << location.column();
- return OkStatus();
-}
-
-Status PrintNameLocation(const SourceMapResolver& source_map,
- const FunctionSourceMapDef& function_source_map,
- const NameLocationDef& location,
- std::ostream* stream) {
- ASSIGN_OR_RETURN(auto name, source_map.GetUniqueString(location.name()));
- *stream << "\"" << name << "\"";
- return OkStatus();
-}
-
-Status PrintCallSiteLocation(const SourceMapResolver& source_map,
- const FunctionSourceMapDef& function_source_map,
- const CallSiteLocationDef& location,
- std::ostream* stream) {
- *stream << "(callsites todo)";
- return OkStatus();
-}
-
-Status PrintFusedLocation(const SourceMapResolver& source_map,
- const FunctionSourceMapDef& function_source_map,
- const FusedLocationDef& location,
- std::ostream* stream) {
- *stream << "fused[";
- if (location.locations()) {
- for (int i = 0; i < location.locations()->size(); ++i) {
- if (i > 0) *stream << ", ";
- int location_ordinal = location.locations()->Get(i);
- const auto& child_location =
- *function_source_map.location_table()->Get(location_ordinal);
- RETURN_IF_ERROR(PrintLocation(source_map, function_source_map,
- child_location, stream));
- }
- }
- *stream << "]";
- return OkStatus();
-}
-
-Status PrintLocation(const SourceMapResolver& source_map,
- const FunctionSourceMapDef& function_source_map,
- const LocationDef& location, std::ostream* stream) {
- switch (location.location_union_type()) {
- case LocationDefUnion::FileLocationDef:
- return PrintFileLocation(source_map, function_source_map,
- *location.location_union_as_FileLocationDef(),
- stream);
- case LocationDefUnion::NameLocationDef:
- return PrintNameLocation(source_map, function_source_map,
- *location.location_union_as_NameLocationDef(),
- stream);
- case LocationDefUnion::CallSiteLocationDef:
- return PrintCallSiteLocation(
- source_map, function_source_map,
- *location.location_union_as_CallSiteLocationDef(), stream);
- case LocationDefUnion::FusedLocationDef:
- return PrintFusedLocation(source_map, function_source_map,
- *location.location_union_as_FusedLocationDef(),
- stream);
- default:
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unhandled location type "
- << static_cast<int>(location.location_union_type());
- }
-}
-
-} // namespace
-
-// static
-SourceMapResolver SourceMapResolver::FromModule(const ModuleDef& module_def) {
- if (module_def.source_map()) {
- return SourceMapResolver{*module_def.source_map()};
- }
- return {};
-}
-
-StatusOr<absl::string_view> SourceMapResolver::GetUniqueString(
- int string_index) const {
- if (empty()) {
- return NotFoundErrorBuilder(IREE_LOC) << "No source map present";
- }
- const auto* string_table = source_map_def_->string_table();
- if (string_table && string_table->size() > string_index) {
- return WrapString(string_table->Get(string_index));
- }
- return NotFoundErrorBuilder(IREE_LOC)
- << "String index " << string_index << " not present in string table";
-}
-
-StatusOr<const FunctionSourceMapDef*> SourceMapResolver::GetFunctionSourceMap(
- int function_ordinal) const {
- if (empty()) {
- return NotFoundErrorBuilder(IREE_LOC) << "No source map present";
- }
- const auto* function_table = source_map_def_->function_table();
- if (function_table && function_table->size() > function_ordinal) {
- const auto* function_source_map = function_table->Get(function_ordinal);
- if (function_source_map && function_source_map->location_table() &&
- function_source_map->bytecode_map()) {
- return function_source_map;
- }
- }
- return NotFoundErrorBuilder(IREE_LOC)
- << "Function ordinal " << function_ordinal
- << " source map not present in function table";
-}
-
-absl::optional<rt::SourceLocation> SourceMapResolver::ResolveFunctionOffset(
- const rt::Function& function, rt::SourceOffset offset) {
- if (empty()) return absl::nullopt;
- auto function_source_map_or = GetFunctionSourceMap(function.ordinal());
- if (!function_source_map_or.ok()) {
- return absl::nullopt;
- }
- const auto* function_source_map = function_source_map_or.ValueOrDie();
- const auto* bytecode_map = function_source_map->bytecode_map();
- if (!bytecode_map) return absl::nullopt;
-
- // TODO(benvanik): allow fuzzy offset matching/table sparsity.
- int location_ordinal = -1;
- for (const auto* map_loc : *bytecode_map) {
- if (map_loc->offset() == offset) {
- location_ordinal = map_loc->location();
- break;
- }
- }
- if (location_ordinal == -1) {
- return absl::nullopt;
- }
-
- return rt::SourceLocation(this,
- {
- reinterpret_cast<uint64_t>(function_source_map),
- static_cast<uint64_t>(location_ordinal),
- });
-}
-
-void SourceMapResolver::PrintSourceLocation(
- rt::SourceResolverArgs resolver_args, std::ostream* stream) const {
- if (empty()) {
- *stream << "<unknown>";
- return;
- }
-
- auto* function_source_map =
- reinterpret_cast<FunctionSourceMapDef*>(resolver_args[0]);
- int location_ordinal = static_cast<int>(resolver_args[1]);
-
- const auto& location =
- *function_source_map->location_table()->Get(location_ordinal);
- auto status = PrintLocation(*this, *function_source_map, location, stream);
- if (!status.ok()) {
- *stream << status;
- }
-}
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/source_map_resolver.h b/iree/vm/source_map_resolver.h
deleted file mode 100644
index 5c8f7c2..0000000
--- a/iree/vm/source_map_resolver.h
+++ /dev/null
@@ -1,57 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_SOURCE_MAP_RESOLVER_H_
-#define IREE_VM_SOURCE_MAP_RESOLVER_H_
-
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "iree/base/status.h"
-#include "iree/rt/source_resolver.h"
-#include "iree/schemas/module_def_generated.h"
-#include "iree/schemas/source_map_def_generated.h"
-
-namespace iree {
-namespace vm {
-
-class SourceMapResolver final : public rt::SourceResolver {
- public:
- static SourceMapResolver FromModule(const ModuleDef& module_def);
-
- SourceMapResolver() = default;
- explicit SourceMapResolver(const SourceMapDef& source_map_def)
- : source_map_def_(&source_map_def) {}
-
- bool empty() const { return source_map_def_ == nullptr; }
- const SourceMapDef* def() const { return source_map_def_; }
-
- StatusOr<absl::string_view> GetUniqueString(int string_index) const;
-
- StatusOr<const FunctionSourceMapDef*> GetFunctionSourceMap(
- int function_ordinal) const;
-
- absl::optional<rt::SourceLocation> ResolveFunctionOffset(
- const rt::Function& function, rt::SourceOffset offset) override;
-
- void PrintSourceLocation(rt::SourceResolverArgs resolver_args,
- std::ostream* stream) const override;
-
- private:
- const SourceMapDef* source_map_def_ = nullptr;
-};
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_SOURCE_MAP_RESOLVER_H_
diff --git a/iree/vm/type.cc b/iree/vm/type.cc
deleted file mode 100644
index 5b2c372..0000000
--- a/iree/vm/type.cc
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/vm/type.h"
-
-#include "iree/base/status.h"
-
-namespace iree {
-namespace vm {
-
-// static
-StatusOr<const Type> Type::FromTypeIndex(uint8_t type_index) {
- // Currently we only support the builtin types.
- if (type_index == static_cast<uint8_t>(BuiltinType::kOpaque)) {
- return Type(type_index);
- } else if (type_index < kBuiltinTypeCount) {
- return Type(type_index);
- }
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Type index " << static_cast<int>(type_index) << " not supported";
-}
-
-// static
-const Type Type::FromBuiltin(BuiltinType type) {
- return Type(static_cast<uint8_t>(type));
-}
-
-std::string Type::DebugString() const {
- switch (type_index_) {
-#define TYPE_NAME(index, name, str, size) \
- case index: \
- return str;
- IREE_TYPE_LIST(TYPE_NAME)
-#undef TYPE_NAME
- default:
- return "<invalid>";
- }
-}
-
-size_t Type::element_size() const {
- switch (type_index_) {
-#define TYPE_SIZE(index, name, str, size) \
- case index: \
- return size;
- IREE_TYPE_LIST(TYPE_SIZE)
-#undef TYPE_SIZE
- default:
- return 0;
- }
-}
-
-} // namespace vm
-} // namespace iree
diff --git a/iree/vm/type.h b/iree/vm/type.h
deleted file mode 100644
index b8ab9bd..0000000
--- a/iree/vm/type.h
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_VM_TYPE_H_
-#define IREE_VM_TYPE_H_
-
-#include "iree/base/status.h"
-#include "iree/schemas/bytecode/bytecode_v0.h"
-#include "iree/schemas/type_def_generated.h"
-
-namespace iree {
-namespace vm {
-
-class Type {
- public:
- static StatusOr<const Type> FromTypeIndex(uint8_t type_index);
- static const Type FromBuiltin(BuiltinType type);
-
- std::string DebugString() const;
-
- uint8_t type_index() const { return type_index_; }
-
- bool is_opaque() const {
- return type_index_ == static_cast<uint8_t>(BuiltinType::kOpaque);
- }
- bool is_builtin() const { return !is_opaque(); }
- BuiltinType builtin_type() const {
- DCHECK(is_builtin());
- return static_cast<BuiltinType>(type_index_);
- }
-
- size_t element_size() const;
-
- private:
- explicit Type(uint8_t type_index) : type_index_(type_index) {}
-
- uint8_t type_index_;
-};
-
-inline bool operator==(const Type& a, const Type& b) {
- return a.type_index() == b.type_index();
-}
-
-inline bool operator!=(const Type& a, const Type& b) { return !(a == b); }
-
-inline std::ostream& operator<<(std::ostream& stream, const Type& type) {
- stream << type.DebugString();
- return stream;
-}
-
-} // namespace vm
-} // namespace iree
-
-#endif // IREE_VM_TYPE_H_
diff --git a/rt/BUILD b/rt/BUILD
new file mode 100644
index 0000000..e06308c
--- /dev/null
+++ b/rt/BUILD
@@ -0,0 +1,83 @@
+# Runtime API for interacting with IREE modules and invoking functions.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "api",
+ srcs = ["api.cc"],
+ hdrs = ["api.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":api_hdrs",
+ ":rt",
+ "///base:api",
+ "///base:api_util",
+ "///base:tracing",
+ "///hal:api",
+ "///hal:buffer_view",
+ "///hal:driver_registry",
+ "///rt/debug:debug_server_interface",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "api_hdrs",
+ hdrs = ["api.h"],
+ deps = [
+ "///base:api_hdrs",
+ "///hal:api_hdrs",
+ ],
+)
+
+cc_library(
+ name = "rt",
+ srcs = [
+ "context.cc",
+ "function.cc",
+ "instance.cc",
+ "invocation.cc",
+ "module_printer.cc",
+ "source_location.cc",
+ "stack.cc",
+ "stack_frame.cc",
+ "stack_trace.cc",
+ ],
+ hdrs = [
+ "context.h",
+ "disassembler.h",
+ "function.h",
+ "function_signature.h",
+ "instance.h",
+ "invocation.h",
+ "module.h",
+ "module_printer.h",
+ "module_signature.h",
+ "policy.h",
+ "source_location.h",
+ "source_resolver.h",
+ "stack.h",
+ "stack_frame.h",
+ "stack_trace.h",
+ ],
+ deps = [
+ "///base:bitfield",
+ "///base:intrusive_list",
+ "///base:ref_ptr",
+ "///base:status",
+ "///base:tracing",
+ "///hal:buffer_view",
+ "///hal:device_manager",
+ "///rt/debug:debug_server_interface",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
diff --git a/iree/rt/CMakeLists.txt b/rt/CMakeLists.txt
similarity index 100%
rename from iree/rt/CMakeLists.txt
rename to rt/CMakeLists.txt
diff --git a/rt/api.cc b/rt/api.cc
new file mode 100644
index 0000000..c99f733
--- /dev/null
+++ b/rt/api.cc
@@ -0,0 +1,788 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/api.h"
+
+#include "absl/time/time.h"
+#include "base/api.h"
+#include "base/api_util.h"
+#include "base/tracing.h"
+#include "hal/api.h"
+#include "hal/api_detail.h"
+#include "hal/buffer_view.h"
+#include "hal/driver_registry.h"
+#include "rt/context.h"
+#include "rt/debug/debug_server.h"
+#include "rt/function.h"
+#include "rt/instance.h"
+#include "rt/invocation.h"
+#include "rt/module.h"
+#include "rt/policy.h"
+
+namespace iree {
+namespace rt {
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Instance
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create(
+ iree_allocator_t allocator, iree_rt_instance_t** out_instance) {
+ IREE_TRACE_SCOPE0("iree_rt_instance_create");
+
+ if (!out_instance) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_instance = nullptr;
+
+ auto instance = make_ref<Instance>();
+ *out_instance = reinterpret_cast<iree_rt_instance_t*>(instance.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_instance_retain(iree_rt_instance_t* instance) {
+ IREE_TRACE_SCOPE0("iree_rt_instance_retain");
+ auto* handle = reinterpret_cast<Instance*>(instance);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_instance_release(iree_rt_instance_t* instance) {
+ IREE_TRACE_SCOPE0("iree_rt_instance_release");
+ auto* handle = reinterpret_cast<Instance*>(instance);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_register_driver_ex(
+ iree_rt_instance_t* instance, iree_string_view_t driver_name) {
+ IREE_TRACE_SCOPE0("iree_rt_instance_register_driver_ex");
+ auto* handle = reinterpret_cast<Instance*>(instance);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ IREE_API_ASSIGN_OR_RETURN(
+ auto driver, hal::DriverRegistry::shared_registry()->Create(
+ absl::string_view{driver_name.data, driver_name.size}));
+ IREE_API_ASSIGN_OR_RETURN(auto available_devices,
+ driver->EnumerateAvailableDevices());
+ for (const auto& device_info : available_devices) {
+ LOG(INFO) << " Device: " << device_info.name();
+ }
+ LOG(INFO) << "Creating default device...";
+ IREE_API_ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
+ IREE_API_RETURN_IF_ERROR(handle->device_manager()->RegisterDevice(device));
+ LOG(INFO) << "Successfully created device '" << device->info().name() << "'";
+
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Module
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class ExternalModule final : public Module {
+ public:
+ ExternalModule(iree_rt_external_module_t impl, iree_allocator_t allocator)
+ : impl_(impl), allocator_(allocator) {
+ IREE_TRACE_SCOPE0("ExternalModule::ctor");
+ }
+
+ ~ExternalModule() override {
+ IREE_TRACE_SCOPE0("ExternalModule::dtor");
+ impl_.destroy(impl_.self);
+ std::memset(&impl_, 0, sizeof(impl_));
+ }
+
+ absl::string_view name() const override {
+ auto result = impl_.name(impl_.self);
+ return absl::string_view{result.data, result.size};
+ }
+
+ const ModuleSignature signature() const override {
+ auto signature = impl_.signature(impl_.self);
+ return ModuleSignature{
+ signature.import_function_count,
+ signature.export_function_count,
+ signature.internal_function_count,
+ signature.state_slot_count,
+ };
+ }
+
+ SourceResolver* source_resolver() const override { return nullptr; }
+
+ Disassembler* disassembler() const override { return nullptr; }
+
+ std::string DebugStringShort() const override { return std::string(name()); }
+
+ StatusOr<const Function> LookupFunctionByOrdinal(
+ Function::Linkage linkage, int32_t ordinal) const override {
+ IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByOrdinal");
+ iree_rt_function_t function;
+ auto status = impl_.lookup_function_by_ordinal(
+ impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
+ &function);
+ if (status != IREE_STATUS_OK) {
+ return FromApiStatus(status, IREE_LOC);
+ }
+ return Function{reinterpret_cast<Module*>(function.module),
+ static_cast<Function::Linkage>(function.linkage),
+ function.ordinal};
+ }
+
+ StatusOr<const Function> LookupFunctionByName(
+ Function::Linkage linkage, absl::string_view name) const override {
+ IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByName");
+ iree_rt_function_t function;
+ auto status = impl_.lookup_function_by_name(
+ impl_.self, static_cast<iree_rt_function_linkage_t>(linkage),
+ iree_string_view_t{name.data(), name.size()}, &function);
+ if (status != IREE_STATUS_OK) {
+ return FromApiStatus(status, IREE_LOC);
+ }
+ return Function{reinterpret_cast<Module*>(function.module),
+ static_cast<Function::Linkage>(function.linkage),
+ function.ordinal};
+ }
+
+ StatusOr<absl::string_view> GetFunctionName(Function::Linkage linkage,
+ int32_t ordinal) const override {
+ IREE_TRACE_SCOPE0("ExternalModule::GetFunctionName");
+ iree_string_view_t name;
+ auto status = impl_.get_function_name(
+ impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
+ &name);
+ RETURN_IF_ERROR(FromApiStatus(status, IREE_LOC));
+ return absl::string_view{name.data, name.size};
+ }
+
+ StatusOr<const FunctionSignature> GetFunctionSignature(
+ Function::Linkage linkage, int32_t ordinal) const override {
+ IREE_TRACE_SCOPE0("ExternalModule::GetFunctionSignature");
+ iree_rt_function_signature_t signature;
+ auto status = impl_.get_function_signature(
+ impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
+ &signature);
+ if (status != IREE_STATUS_OK) {
+ return FromApiStatus(status, IREE_LOC);
+ }
+ return FunctionSignature{signature.argument_count, signature.result_count};
+ }
+
+ Status Execute(
+ Stack* stack, const Function function,
+ absl::InlinedVector<hal::BufferView, 8> arguments,
+ absl::InlinedVector<hal::BufferView, 8>* results) const override {
+ // TODO(benvanik): fn ptr callback to external code. Waiting on fibers.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "External calls not yet implemented";
+ }
+
+ private:
+ iree_rt_external_module_t impl_;
+ iree_allocator_t allocator_;
+};
+
+} // namespace
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external(
+ iree_rt_external_module_t impl, iree_allocator_t allocator,
+ iree_rt_module_t** out_module) {
+ IREE_TRACE_SCOPE0("iree_rt_module_create_external");
+
+ if (!out_module) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_module = nullptr;
+
+ auto module = make_ref<ExternalModule>(impl, allocator);
+ *out_module = reinterpret_cast<iree_rt_module_t*>(module.release());
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_retain(iree_rt_module_t* module) {
+ IREE_TRACE_SCOPE0("iree_rt_module_retain");
+ auto* handle = reinterpret_cast<Module*>(module);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_release(iree_rt_module_t* module) {
+ IREE_TRACE_SCOPE0("iree_rt_module_release");
+ auto* handle = reinterpret_cast<Module*>(module);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_string_view_t IREE_API_CALL
+iree_rt_module_name(const iree_rt_module_t* module) {
+ IREE_TRACE_SCOPE0("iree_rt_module_name");
+ const auto* handle = reinterpret_cast<const Module*>(module);
+ CHECK(handle) << "NULL module handle";
+ return iree_string_view_t{handle->name().data(), handle->name().size()};
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module,
+ iree_rt_function_linkage_t linkage,
+ int32_t ordinal,
+ iree_rt_function_t* out_function) {
+ IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_ordinal");
+
+ if (!out_function) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_function, 0, sizeof(*out_function));
+
+ auto* handle = reinterpret_cast<Module*>(module);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto function_or = handle->LookupFunctionByOrdinal(
+ static_cast<Function::Linkage>(linkage), ordinal);
+ if (!function_or.ok()) {
+ // Map this invalid argument to not found, per the API spec.
+ if (IsInvalidArgument(function_or.status())) {
+ return IREE_STATUS_NOT_FOUND;
+ }
+ return ToApiStatus(std::move(function_or).status());
+ }
+ auto function = *function_or;
+
+ out_function->module = module;
+ out_function->linkage =
+ static_cast<iree_rt_function_linkage_t>(function.linkage());
+ out_function->ordinal = function.ordinal();
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_lookup_function_by_name(iree_rt_module_t* module,
+ iree_rt_function_linkage_t linkage,
+ iree_string_view_t name,
+ iree_rt_function_t* out_function) {
+ IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_name");
+
+ if (!out_function) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_function, 0, sizeof(*out_function));
+
+ auto* handle = reinterpret_cast<Module*>(module);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ IREE_API_ASSIGN_OR_RETURN(
+ auto function,
+ handle->LookupFunctionByName(static_cast<Function::Linkage>(linkage),
+ absl::string_view{name.data, name.size}));
+
+ out_function->linkage =
+ static_cast<iree_rt_function_linkage_t>(function.linkage());
+ out_function->module = module;
+ out_function->linkage = linkage;
+ out_function->ordinal = function.ordinal();
+
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Function
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_string_view_t IREE_API_CALL
+iree_rt_function_name(const iree_rt_function_t* function) {
+ IREE_TRACE_SCOPE0("iree_rt_function_name");
+ CHECK(function && function->module) << "NULL function handle";
+ auto* module = reinterpret_cast<Module*>(function->module);
+ auto name_or = module->GetFunctionName(
+ static_cast<Function::Linkage>(function->linkage), function->ordinal);
+ if (!name_or.ok()) return {};
+ auto name = name_or.ValueOrDie();
+ return iree_string_view_t{name.data(), name.size()};
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_function_signature(const iree_rt_function_t* function,
+ iree_rt_function_signature_t* out_signature) {
+ IREE_TRACE_SCOPE0("iree_rt_function_signature");
+
+ if (!out_signature) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_signature, 0, sizeof(*out_signature));
+
+ if (!function || !function->module) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto* module = reinterpret_cast<Module*>(function->module);
+ IREE_API_ASSIGN_OR_RETURN(
+ auto signature, module->GetFunctionSignature(
+ static_cast<Function::Linkage>(function->linkage),
+ function->ordinal));
+ out_signature->argument_count = signature.argument_count();
+ out_signature->result_count = signature.result_count();
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Policy
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t iree_rt_policy_create(
+ iree_allocator_t allocator, iree_rt_policy_t** out_policy) {
+ IREE_TRACE_SCOPE0("iree_rt_policy_create");
+
+ if (!out_policy) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_policy = nullptr;
+
+ auto policy = make_ref<Policy>();
+
+ *out_policy = reinterpret_cast<iree_rt_policy_t*>(policy.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_policy_retain(iree_rt_policy_t* policy) {
+ IREE_TRACE_SCOPE0("iree_rt_policy_retain");
+ auto* handle = reinterpret_cast<Policy*>(policy);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_policy_release(iree_rt_policy_t* policy) {
+ IREE_TRACE_SCOPE0("iree_rt_policy_release");
+ auto* handle = reinterpret_cast<Policy*>(policy);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Context
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create(
+ iree_rt_instance_t* instance, iree_rt_policy_t* policy,
+ iree_allocator_t allocator, iree_rt_context_t** out_context) {
+ IREE_TRACE_SCOPE0("iree_rt_context_create");
+
+ if (!out_context) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_context = nullptr;
+
+ if (!instance || !policy) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto context =
+ make_ref<Context>(add_ref(reinterpret_cast<Instance*>(instance)),
+ add_ref(reinterpret_cast<Policy*>(policy)));
+
+ *out_context = reinterpret_cast<iree_rt_context_t*>(context.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_retain(iree_rt_context_t* context) {
+ IREE_TRACE_SCOPE0("iree_rt_context_retain");
+ auto* handle = reinterpret_cast<Context*>(context);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_release(iree_rt_context_t* context) {
+ IREE_TRACE_SCOPE0("iree_rt_context_release");
+ auto* handle = reinterpret_cast<Context*>(context);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT int32_t IREE_API_CALL
+iree_rt_context_id(const iree_rt_context_t* context) {
+ IREE_TRACE_SCOPE0("iree_rt_context_id");
+ const auto* handle = reinterpret_cast<const Context*>(context);
+ CHECK(handle) << "NULL context handle";
+ return handle->id();
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_register_modules(
+ iree_rt_context_t* context, iree_rt_module_t** modules,
+ iree_host_size_t module_count) {
+ IREE_TRACE_SCOPE0("iree_rt_context_register_modules");
+ auto* handle = reinterpret_cast<Context*>(context);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ if (module_count && !modules) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ for (size_t i = 0; i < module_count; ++i) {
+ auto* module = reinterpret_cast<Module*>(modules[i]);
+ if (!module) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(handle->RegisterModule(add_ref(module)));
+ }
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL
+iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context,
+ iree_string_view_t module_name) {
+ IREE_TRACE_SCOPE0("iree_rt_context_lookup_module_by_name");
+ const auto* handle = reinterpret_cast<const Context*>(context);
+ CHECK(handle) << "NULL context handle";
+ auto module_or = handle->LookupModuleByName(
+ absl::string_view{module_name.data, module_name.size});
+ if (!module_or.ok()) {
+ return nullptr;
+ }
+ return reinterpret_cast<iree_rt_module_t*>(module_or.ValueOrDie());
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function(
+ const iree_rt_context_t* context, iree_string_view_t full_name,
+ iree_rt_function_t* out_function) {
+ IREE_TRACE_SCOPE0("iree_rt_context_resolve_function");
+
+ if (!out_function) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_function, 0, sizeof(*out_function));
+
+ const auto* handle = reinterpret_cast<const Context*>(context);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto full_name_view = absl::string_view{full_name.data, full_name.size};
+ size_t last_dot = full_name_view.rfind('.');
+ if (last_dot == absl::string_view::npos) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ auto module_name = full_name_view.substr(0, last_dot);
+ auto function_name = full_name_view.substr(last_dot + 1);
+
+ iree_rt_module_t* module = iree_rt_context_lookup_module_by_name(
+ context, iree_string_view_t{module_name.data(), module_name.size()});
+ if (!module) {
+ return IREE_STATUS_NOT_FOUND;
+ }
+
+ return iree_rt_module_lookup_function_by_name(
+ module, IREE_RT_FUNCTION_LINKAGE_EXPORT,
+ iree_string_view_t{function_name.data(), function_name.size()},
+ out_function);
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_allocate_device_visible_buffer(
+ iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage,
+ iree_host_size_t allocation_size, iree_allocator_t allocator,
+ iree_hal_buffer_t** out_buffer) {
+ IREE_TRACE_SCOPE0("iree_rt_context_allocate_device_visible_buffer");
+
+ if (!out_buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_buffer, 0, sizeof(*out_buffer));
+
+ const auto* handle = reinterpret_cast<const Context*>(context);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ } else if (!allocation_size) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ // TODO(benvanik): reroute to context based on current policy.
+ auto* device_manager = handle->instance()->device_manager();
+ IREE_API_ASSIGN_OR_RETURN(auto device_placement,
+ device_manager->ResolvePlacement({}));
+ IREE_API_ASSIGN_OR_RETURN(auto buffer,
+ device_manager->AllocateDeviceVisibleBuffer(
+ static_cast<hal::BufferUsage>(buffer_usage),
+ allocation_size, {device_placement}));
+
+ *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_allocate_device_local_buffer(
+ iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage,
+ iree_host_size_t allocation_size, iree_allocator_t allocator,
+ iree_hal_buffer_t** out_buffer) {
+ IREE_TRACE_SCOPE0("iree_rt_context_allocate_device_local_buffer");
+
+ if (!out_buffer) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_buffer, 0, sizeof(*out_buffer));
+
+ const auto* handle = reinterpret_cast<const Context*>(context);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ } else if (!allocation_size) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ // TODO(benvanik): reroute to context based on current policy.
+ auto* device_manager = handle->instance()->device_manager();
+ IREE_API_ASSIGN_OR_RETURN(auto device_placement,
+ device_manager->ResolvePlacement({}));
+ IREE_API_ASSIGN_OR_RETURN(auto buffer,
+ device_manager->AllocateDeviceLocalBuffer(
+ static_cast<hal::BufferUsage>(buffer_usage),
+ allocation_size, {device_placement}));
+
+ *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release());
+
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Invocation
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create(
+ iree_rt_context_t* context, iree_rt_function_t* function,
+ iree_rt_policy_t* policy,
+ const iree_rt_invocation_dependencies_t* dependencies,
+ iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count,
+ iree_hal_buffer_view_t** results, iree_host_size_t result_count,
+ iree_allocator_t allocator, iree_rt_invocation_t** out_invocation) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_create");
+
+ if (!out_invocation) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_invocation = nullptr;
+
+ if (!context || !function || !function->module) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ } else if (dependencies &&
+ (dependencies->invocation_count && !dependencies->invocations)) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ } else if ((argument_count && !arguments) || (result_count && !results)) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ // TODO(benvanik): unwrap without needing to retain here.
+ absl::InlinedVector<ref_ptr<Invocation>, 4> dependent_invocations;
+ if (dependencies) {
+ dependent_invocations.resize(dependencies->invocation_count);
+ for (int i = 0; i < dependencies->invocation_count; ++i) {
+ dependent_invocations[i] =
+ add_ref(reinterpret_cast<Invocation*>(dependencies->invocations[i]));
+ }
+ }
+
+ // TODO(benvanik): unwrap without needing to retain here.
+ absl::InlinedVector<hal::BufferView, 8> argument_views(argument_count);
+ for (int i = 0; i < argument_count; ++i) {
+ const auto* api_buffer_view =
+ reinterpret_cast<const hal::iree_hal_buffer_view*>(arguments[i]);
+ if (!api_buffer_view) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ argument_views[i] = hal::BufferView{add_ref(api_buffer_view->impl.buffer),
+ api_buffer_view->impl.shape,
+ api_buffer_view->impl.element_size};
+ }
+
+ // TODO(benvanik): unwrap without needing to retain here.
+ absl::InlinedVector<hal::BufferView, 8> result_views(result_count);
+ for (int i = 0; i < result_count; ++i) {
+ const auto* api_buffer_view =
+ reinterpret_cast<const hal::iree_hal_buffer_view*>(results[i]);
+ if (api_buffer_view) {
+ result_views[i] = hal::BufferView{add_ref(api_buffer_view->impl.buffer),
+ api_buffer_view->impl.shape,
+ api_buffer_view->impl.element_size};
+ }
+ }
+
+ IREE_API_ASSIGN_OR_RETURN(
+ auto invocation,
+ Invocation::Create(
+ add_ref(reinterpret_cast<Context*>(context)),
+ Function{reinterpret_cast<Module*>(function->module),
+ static_cast<Function::Linkage>(function->linkage),
+ function->ordinal},
+ add_ref(reinterpret_cast<Policy*>(policy)),
+ std::move(dependent_invocations), std::move(argument_views),
+ result_views.empty()
+ ? absl::optional<absl::InlinedVector<hal::BufferView, 8>>(
+ absl::nullopt)
+ : std::move(result_views)));
+
+ *out_invocation =
+ reinterpret_cast<iree_rt_invocation_t*>(invocation.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_retain(iree_rt_invocation_t* invocation) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_retain");
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_release(iree_rt_invocation_t* invocation) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_release");
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_query_status(iree_rt_invocation_t* invocation) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_query_status");
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(handle->QueryStatus());
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results(
+ iree_rt_invocation_t* invocation, iree_host_size_t result_capacity,
+ iree_allocator_t allocator, iree_hal_buffer_view_t** out_results,
+ iree_host_size_t* out_result_count) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_consume_results");
+
+ if (!out_result_count) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_result_count = 0;
+ if (!out_results) {
+ std::memset(out_results, 0,
+ sizeof(iree_hal_buffer_view_t*) * result_capacity);
+ }
+
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ const auto& function = handle->function();
+ int32_t result_count = function.signature().result_count();
+ *out_result_count = result_count;
+ if (!out_results) {
+ return IREE_STATUS_OK;
+ } else if (result_capacity < result_count) {
+ return IREE_STATUS_OUT_OF_RANGE;
+ }
+
+ IREE_API_ASSIGN_OR_RETURN(auto results, handle->ConsumeResults());
+ iree_status_t status = IREE_STATUS_OK;
+ int i = 0;
+ for (i = 0; i < results.size(); ++i) {
+ iree_shape_t shape;
+ status = ToApiShape(results[i].shape, &shape);
+ if (status != IREE_STATUS_OK) break;
+ status = iree_hal_buffer_view_create(
+ reinterpret_cast<iree_hal_buffer_t*>(results[i].buffer.get()), shape,
+ results[i].element_size, allocator, &out_results[i]);
+ if (status != IREE_STATUS_OK) break;
+ }
+ if (status != IREE_STATUS_OK) {
+ // Release already-retained buffer views on failure.
+ for (int j = 0; j < i; ++j) {
+ iree_hal_buffer_view_release(out_results[j]);
+ out_results[j] = nullptr;
+ }
+ }
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await(
+ iree_rt_invocation_t* invocation, iree_time_t deadline) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_await");
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(handle->Await(ToAbslTime(deadline)));
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_abort(iree_rt_invocation_t* invocation) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_abort");
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(handle->Abort());
+ return IREE_STATUS_OK;
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/rt/api.h b/rt/api.h
new file mode 100644
index 0000000..5bc57c4
--- /dev/null
+++ b/rt/api.h
@@ -0,0 +1,400 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// See iree/base/api.h for documentation on the API conventions used.
+
+#ifndef IREE_RT_API_H_
+#define IREE_RT_API_H_
+
+#include <stdint.h>
+
+#include "base/api.h"
+#include "hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// Types and Enums
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_rt_instance iree_rt_instance_t;
+typedef struct iree_rt_context iree_rt_context_t;
+typedef struct iree_rt_policy iree_rt_policy_t;
+typedef struct iree_rt_module iree_rt_module_t;
+typedef struct iree_rt_invocation iree_rt_invocation_t;
+
+// Describes the type of a function reference.
+typedef enum {
+ // Function is internal to the module and may not be reflectable.
+ IREE_RT_FUNCTION_LINKAGE_INTERNAL = 0,
+ // Function is an import from another module.
+ IREE_RT_FUNCTION_LINKAGE_IMPORT = 1,
+ // Function is an export from the module.
+ IREE_RT_FUNCTION_LINKAGE_EXPORT = 2,
+} iree_rt_function_linkage_t;
+
+// A function reference that can be used with the iree_rt_function_* methods.
+// These should be treated as opaque and the accessor functions should be used
+// instead.
+typedef struct {
+ // Module the function is contained within.
+ iree_rt_module_t* module;
+ // Linkage of the function. Note that IREE_RT_FUNCTION_LINKAGE_INTERNAL
+ // functions may be missing reflection information.
+ iree_rt_function_linkage_t linkage;
+ // Ordinal within the module in the linkage scope.
+ int32_t ordinal;
+} iree_rt_function_t;
+
+// Describes the expected calling convention and arguments/results of a
+// function.
+typedef struct {
+ // Total number of arguments to the function.
+ int32_t argument_count;
+ // Total number of results from the function.
+ int32_t result_count;
+} iree_rt_function_signature_t;
+
+// Describes the imports, exports, and capabilities of a module.
+typedef struct {
+ // Total number of imported functions.
+ int32_t import_function_count;
+ // Total number of exported functions.
+ int32_t export_function_count;
+ // Total number of internal functions, if debugging info is present and they
+ // can be queried.
+ int32_t internal_function_count;
+ // Total number of state block resource slots consumed.
+ int32_t state_slot_count;
+} iree_rt_module_signature_t;
+
+// Dependency information used to order invocations.
+typedef struct {
+ // Prior invocations that must complete before the new invocation begins.
+ iree_rt_invocation_t** invocations;
+ iree_host_size_t invocation_count;
+
+ // TODO(benvanik): wait semaphores/importing.
+} iree_rt_invocation_dependencies_t;
+
+// Defines an external module that can be used to reflect and execute functions.
+// Modules must be thread-safe as lookups and executions may occur in any order
+// from any thread.
+//
+// Modules will have their resolve_imports function called upon registration
+// with a context and may use the provided resolver to find imported functions.
+typedef struct {
+ // User-defined pointer passed to all functions.
+ void* self;
+ // Destroys |self| when all references to the module have been released.
+ iree_status_t(IREE_API_PTR* destroy)(void* self);
+ // Returns the name of the module (used during resolution).
+ iree_string_view_t(IREE_API_PTR* name)(void* self);
+ // Sets |out_module_signature| to the reflected signature of the module.
+ iree_rt_module_signature_t(IREE_API_PTR* signature)(void* self);
+ // Sets |out_function| to a resolved function by ordinal, if found.
+ iree_status_t(IREE_API_PTR* lookup_function_by_ordinal)(
+ void* self, iree_rt_function_linkage_t linkage, int32_t ordinal,
+ iree_rt_function_t* out_function);
+ // Sets |out_function| to a resolved function by name, if found.
+ iree_status_t(IREE_API_PTR* lookup_function_by_name)(
+ void* self, iree_rt_function_linkage_t linkage, iree_string_view_t name,
+ iree_rt_function_t* out_function);
+ // Sets |out_name| to the name of the function with the given ordinal, if
+ // found.
+ iree_status_t(IREE_API_PTR* get_function_name)(
+ void* self, iree_rt_function_linkage_t linkage, int32_t ordinal,
+ iree_string_view_t* out_name);
+ // Sets |out_signature| to the reflected signature of the given
+ // function, if found.
+ iree_status_t(IREE_API_PTR* get_function_signature)(
+ void* self, iree_rt_function_linkage_t linkage, int32_t ordinal,
+ iree_rt_function_signature_t* out_signature);
+} iree_rt_external_module_t;
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Instance
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Creates a new instance. This should be shared with all contexts in an
+// application to ensure that resources are tracked properly and threads are
+// managed correctly.
+// |out_instance| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create(
+ iree_allocator_t allocator, iree_rt_instance_t** out_instance);
+
+// Retains the given |instance| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_instance_retain(iree_rt_instance_t* instance);
+
+// Releases the given |instance| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_instance_release(iree_rt_instance_t* instance);
+
+// TEMPORARY: until policies and placement are performed this can be used to
+// explicitly create and register drivers by name.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_register_driver_ex(
+ iree_rt_instance_t* instance, iree_string_view_t driver_name);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Module
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Creates a module with an external backing implementation.
+// The provided |external_module| definition will be used to query the module
+// state as needed. No caching occurs within the implementation to allow calls
+// to return different values per-invocation.
+//
+// |out_module| must be released by the caller.
+// iree_rt_external_module_t::destroy is called when the last reference to the
+// iree_rt_module_t is released.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external(
+ iree_rt_external_module_t impl, iree_allocator_t allocator,
+ iree_rt_module_t** out_module);
+
+// Retains the given |module| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_retain(iree_rt_module_t* module);
+
+// Releases the given |module| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_release(iree_rt_module_t* module);
+
+// Returns the name of the module.
+IREE_API_EXPORT iree_string_view_t IREE_API_CALL
+iree_rt_module_name(const iree_rt_module_t* module);
+
+// Sets |out_function| to a function with |ordinal| in the given linkage or
+// returns IREE_STATUS_NOT_FOUND. The function reference is valid for the
+// lifetime of |module|.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module,
+ iree_rt_function_linkage_t linkage,
+ int32_t ordinal,
+ iree_rt_function_t* out_function);
+
+// Sets |out_function| to a function with |name| in the given linkage or returns
+// IREE_STATUS_NOT_FOUND. The function reference is valid for the lifetime of
+// |module|.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_lookup_function_by_name(iree_rt_module_t* module,
+ iree_rt_function_linkage_t linkage,
+ iree_string_view_t name,
+ iree_rt_function_t* out_function);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Function
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Returns the name of the function as exported from the module.
+IREE_API_EXPORT iree_string_view_t IREE_API_CALL
+iree_rt_function_name(const iree_rt_function_t* function);
+
+// Sets |out_function_signature| to the reflected signature of the function.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_function_signature(const iree_rt_function_t* function,
+ iree_rt_function_signature_t* out_signature);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Policy
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// TODO(benvanik): define policies. For now they are no-ops.
+IREE_API_EXPORT iree_status_t iree_rt_policy_create(
+ iree_allocator_t allocator, iree_rt_policy_t** out_policy);
+
+// Retains the given |policy| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_policy_retain(iree_rt_policy_t* policy);
+
+// Releases the given |policy| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_policy_release(iree_rt_policy_t* policy);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Context
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Creates a new context that uses the given |instance| for device management.
+// |out_context| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create(
+ iree_rt_instance_t* instance, iree_rt_policy_t* policy,
+ iree_allocator_t allocator, iree_rt_context_t** out_context);
+
+// Retains the given |context| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_retain(iree_rt_context_t* context);
+
+// Releases the given |context| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_release(iree_rt_context_t* context);
+
+// Returns a process-unique ID for the |context|.
+IREE_API_EXPORT int32_t IREE_API_CALL
+iree_rt_context_id(const iree_rt_context_t* context);
+
+// Registers a list of modules with the context and resolves imports.
+// The modules will be retained by the context until destruction.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_register_modules(
+ iree_rt_context_t* context, iree_rt_module_t** modules,
+ iree_host_size_t module_count);
+
+// Returns a reference to the module registered with the given name or nullptr
+// if not found. The caller must retain the returned module if they want to
+// continue using it.
+IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL
+iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context,
+ iree_string_view_t module_name);
+
+// Sets |out_function| to to an exported function with the fully-qualified name
+// of |full_name| or returns IREE_STATUS_NOT_FOUND. The function reference is
+// valid for the lifetime of |context|.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function(
+ const iree_rt_context_t* context, iree_string_view_t full_name,
+ iree_rt_function_t* out_function);
+
+// Allocates a host-local buffer that is optimal for use on the host but is
+// usable by the given |device_placements| (at a possible performance
+// penalty). The buffer can be used for staging uploads to device-local
+// buffers and is useful for times when the buffer will be used more on the
+// host than the device. If a buffer never needs to be used with a device
+// prefer instead HeapBuffer::Allocate.
+//
+// Fails if it is not possible to allocate and satisfy all placements for the
+// requested |buffer_usage|.
+// |out_buffer| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_allocate_device_visible_buffer(
+ iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage,
+ iree_host_size_t allocation_size, iree_allocator_t allocator,
+ iree_hal_buffer_t** out_buffer);
+
+// Allocates a device-local buffer that is optimal for use with the given
+// |device_placements|. The buffer will not be host-visible and can only be
+// used from compatible device queues.
+//
+// Fails if it is not possible to allocate and satisfy all placements for the
+// requested |buffer_usage|.
+// |out_buffer| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_allocate_device_local_buffer(
+ iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage,
+ iree_host_size_t allocation_size, iree_allocator_t allocator,
+ iree_hal_buffer_t** out_buffer);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Invocation
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Creates a new invocation tracking object for invoking the given |function|
+// from |context|. |arguments| will be retained until the invocation is made.
+// If |dependencies| are provided then the invocation will wait until they are
+// resolved before executing. If a |policy| is provided it will override the
+// context-level policy.
+//
+// Optionally |results| may be provided with preallocated buffers that will
+// receive the outputs of the invocation. Invocation will fail if they do not
+// match expected sizes.
+//
+// Note that it's possible for the invocation to complete prior to the return of
+// this function. Any errors that occur will be set on the invocation and
+// callers should query its state prior to assuming it is in-flight.
+//
+// |out_invocation| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create(
+ iree_rt_context_t* context, iree_rt_function_t* function,
+ iree_rt_policy_t* policy,
+ const iree_rt_invocation_dependencies_t* dependencies,
+ iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count,
+ iree_hal_buffer_view_t** results, iree_host_size_t result_count,
+ iree_allocator_t allocator, iree_rt_invocation_t** out_invocation);
+
+// Retains the given |invocation| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_retain(iree_rt_invocation_t* invocation);
+
+// Releases the given |invocation| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_release(iree_rt_invocation_t* invocation);
+
+// Queries the completion status of the invocation.
+// Returns one of the following:
+// IREE_STATUS_OK: the invocation completed successfully.
+// IREE_STATUS_UNAVAILABLE: the invocation has not yet completed.
+// IREE_STATUS_CANCELLED: the invocation was cancelled internally.
+// IREE_STATUS_ABORTED: the invocation was aborted.
+// IREE_STATUS_*: an error occurred during invocation.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_query_status(iree_rt_invocation_t* invocation);
+
+// Populates |out_results| to the values of the results.
+// |result_capacity| defines the number of elements available in |out_results|
+// and |out_result_count| will be set with the actual number of results
+// available. If |result_capacity| is too small IREE_STATUS_OUT_OF_RANGE will be
+// returned wtih the required capacity in |out_result_count|. To only query the
+// required capacity |out_results| may be passed as nullptr.
+//
+// Ownership of returned results will be transferred to the caller and they must
+// be released if no longer needed.
+//
+// Returns errors as with iree_rt_invocation_query_status, for example in the
+// case of not-yet-completed or aborted invocations.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results(
+ iree_rt_invocation_t* invocation, iree_host_size_t result_capacity,
+ iree_allocator_t allocator, iree_hal_buffer_view_t** out_results,
+ iree_host_size_t* out_result_count);
+
+// Blocks the caller until the invocation completes (successfully or otherwise).
+//
+// Returns IREE_STATUS_DEADLINE_EXCEEDED if |deadline| elapses before the
+// invocation completes and otherwise returns iree_rt_invocation_query_status.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await(
+ iree_rt_invocation_t* invocation, iree_time_t deadline);
+
+// Attempts to abort the invocation if it is in-flight.
+// A no-op if the invocation has already completed.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_abort(iree_rt_invocation_t* invocation);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_RT_API_H_
diff --git a/rt/context.cc b/rt/context.cc
new file mode 100644
index 0000000..051deca
--- /dev/null
+++ b/rt/context.cc
@@ -0,0 +1,167 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/context.h"
+
+#include <atomic>
+
+#include "absl/strings/str_cat.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "rt/debug/debug_server.h"
+#include "rt/instance.h"
+#include "rt/invocation.h"
+
+namespace iree {
+namespace rt {
+
+namespace {
+
+int32_t NextUniqueContextId() {
+ static std::atomic<int32_t> next_id = {0};
+ return ++next_id;
+}
+
+} // namespace
+
+Context::Context(ref_ptr<Instance> instance, ref_ptr<Policy> policy)
+ : id_(NextUniqueContextId()),
+ instance_(std::move(instance)),
+ policy_(std::move(policy)) {
+ IREE_TRACE_SCOPE("Context::ctor", int32_t)(id_);
+ instance_->RegisterContext(this);
+}
+
+Context::~Context() {
+ IREE_TRACE_SCOPE("Context::dtor", int32_t)(id_);
+ instance_->UnregisterContext(this);
+}
+
+std::string Context::DebugStringShort() const {
+ return absl::StrCat("context_", id_);
+}
+
+Status Context::RegisterModule(ref_ptr<Module> module) {
+ IREE_TRACE_SCOPE0("Context::RegisterModule");
+
+ // Ensure no conflicts in naming - we don't support shadowing.
+ for (const auto& existing_module : modules_) {
+ if (existing_module->name() == module->name()) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Module '" << module->name()
+ << "' has already been registered in the context";
+ }
+ }
+
+ // Try resolving prior to actually registering; if we can't resolve an import
+ // then we want to fail the entire registration.
+ ASSIGN_OR_RETURN(auto import_table, ResolveImports(module.get()));
+
+ auto* debug_server = instance_->debug_server();
+ if (debug_server) {
+ CHECK_OK(debug_server->RegisterContextModule(this, module.get()));
+ }
+
+ modules_.push_back(std::move(module));
+ module_import_tables_.push_back(std::move(import_table));
+ return OkStatus();
+}
+
+StatusOr<ModuleImportTable> Context::ResolveImports(Module* module) {
+ IREE_TRACE_SCOPE0("Context::ResolveImports");
+
+ int32_t import_count = module->signature().import_function_count();
+ ModuleImportTable import_table;
+ import_table.first = module;
+ import_table.second.resize(import_count);
+
+ for (int32_t i = 0; i < import_count; ++i) {
+ ASSIGN_OR_RETURN(auto import_function_name,
+ module->GetFunctionName(Function::Linkage::kImport, i));
+ ASSIGN_OR_RETURN(import_table.second[i],
+ ResolveFunction(import_function_name));
+ }
+
+ return import_table;
+}
+
+StatusOr<Module*> Context::LookupModuleByName(
+ absl::string_view module_name) const {
+ for (const auto& module : modules_) {
+ if (module->name() == module_name) {
+ return module.get();
+ }
+ }
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "No module with the name '" << module_name
+ << "' has been registered";
+}
+
+StatusOr<const Function> Context::ResolveFunction(
+ absl::string_view full_name) const {
+ size_t last_dot = full_name.rfind('.');
+ if (last_dot == absl::string_view::npos) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "'" << full_name
+ << "' is not fully qualified (expected 'module.function')";
+ }
+ auto module_name = full_name.substr(0, last_dot);
+ auto function_name = full_name.substr(last_dot + 1);
+ ASSIGN_OR_RETURN(auto* module, LookupModuleByName(module_name));
+ return module->LookupFunctionByName(Function::Linkage::kExport,
+ function_name);
+}
+
+StatusOr<const Function> Context::ResolveImport(const Module* module,
+ int32_t ordinal) const {
+ for (const auto& import_table_ref : module_import_tables_) {
+ if (import_table_ref.first == module) {
+ const auto& import_table = import_table_ref.second;
+ if (ordinal >= import_table.size()) {
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Import ordinal " << ordinal
+ << " out of bounds of import table (" << import_table.size()
+ << ")";
+ }
+ return import_table[ordinal];
+ }
+ }
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Import ordinal " << ordinal << " not found";
+}
+
+void Context::RegisterInvocation(Invocation* invocation) {
+ {
+ absl::MutexLock lock(&invocations_mutex_);
+ invocations_.push_back(invocation);
+ }
+ auto* debug_server = instance_->debug_server();
+ if (debug_server) {
+ CHECK_OK(debug_server->RegisterInvocation(invocation));
+ }
+}
+
+void Context::UnregisterInvocation(Invocation* invocation) {
+ auto* debug_server = instance_->debug_server();
+ if (debug_server) {
+ CHECK_OK(debug_server->UnregisterInvocation(invocation));
+ }
+ {
+ absl::MutexLock lock(&invocations_mutex_);
+ invocations_.erase(invocation);
+ }
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/rt/context.h b/rt/context.h
new file mode 100644
index 0000000..ecda38b
--- /dev/null
+++ b/rt/context.h
@@ -0,0 +1,119 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_CONTEXT_H_
+#define IREE_RT_CONTEXT_H_
+
+#include <ostream>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/optional.h"
+#include "base/intrusive_list.h"
+#include "base/ref_ptr.h"
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "rt/invocation.h"
+#include "rt/module.h"
+#include "rt/policy.h"
+
+namespace iree {
+namespace rt {
+
+class Instance;
+
+using ModuleImportTable = std::pair<Module*, std::vector<Function>>;
+
+// An isolated execution context.
+// Effectively a sandbox where modules can be loaded and run with restricted
+// visibility and where they can maintain state.
+//
+// Modules have imports resolved automatically when registered by searching
+// existing modules registered within the context and load order is used for
+// resolution. For example, target-specific modules should be loaded prior to
+// generic modules that may import functions defined there and if a function is
+// not available in the target-specific modules the fallback provided by the
+// generic module will be used.
+//
+// Thread-compatible and must be externally synchronized.
+class Context final : public RefObject<Context> {
+ public:
+ Context(ref_ptr<Instance> instance, ref_ptr<Policy> policy);
+ ~Context();
+
+ // A process-unique ID for the context.
+ int32_t id() const { return id_; }
+
+ // Instance this context uses for shared resources.
+ const ref_ptr<Instance>& instance() const { return instance_; }
+
+ // A short human-readable name for the context.
+ std::string DebugStringShort() const;
+
+ // A list of modules registered with the context.
+ absl::Span<const ref_ptr<Module>> modules() const {
+ return absl::MakeConstSpan(modules_);
+ }
+
+ // Registers a new module with the context.
+ // Imports from the module will be resolved using the existing modules in the
+ // context. The module will be retained by the context until destruction.
+ Status RegisterModule(ref_ptr<Module> module);
+
+ // Looks up a module by name.
+ StatusOr<Module*> LookupModuleByName(absl::string_view module_name) const;
+
+ // Resolves an exported function by fully-qualified name. The function
+ // reference is valid for the lifetime of the context.
+ StatusOr<const Function> ResolveFunction(absl::string_view full_name) const;
+
+ // Resolves an imported function by import ordinal. The function reference is
+ // valid for the lifetime of the context.
+ StatusOr<const Function> ResolveImport(const Module* module,
+ int32_t ordinal) const;
+
+ private:
+ // Resolves imports for the given module.
+ StatusOr<ModuleImportTable> ResolveImports(Module* module);
+
+ friend class Invocation;
+ void RegisterInvocation(Invocation* invocation);
+ void UnregisterInvocation(Invocation* invocation);
+
+ int32_t id_;
+ ref_ptr<Instance> instance_;
+ ref_ptr<Policy> policy_;
+
+ absl::InlinedVector<ref_ptr<Module>, 4> modules_;
+ absl::InlinedVector<ModuleImportTable, 4> module_import_tables_;
+
+ absl::Mutex invocations_mutex_;
+ IntrusiveList<Invocation, offsetof(Invocation, context_list_link_)>
+ invocations_ ABSL_GUARDED_BY(invocations_mutex_);
+
+ friend class Instance;
+ IntrusiveListLink instance_list_link_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream, const Context& context) {
+ stream << context.DebugStringShort();
+ return stream;
+}
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_CONTEXT_H_
diff --git a/rt/debug/BUILD b/rt/debug/BUILD
new file mode 100644
index 0000000..8f12b5c
--- /dev/null
+++ b/rt/debug/BUILD
@@ -0,0 +1,192 @@
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# TODO(benvanik): re-enable debugger after refactoring.
+# cc_library(
+# name = "debug_client",
+# srcs = ["debug_client.cc"],
+# hdrs = ["debug_client.h"],
+# deps = [
+# ":debug_client_interface",
+# ":debug_client_tcp", # build-cleaner: keep
+# "@com_google_absl//absl/container:flat_hash_map",
+# "@com_google_absl//absl/strings",
+# "@com_google_absl//absl/types:optional",
+# "@com_google_absl//absl/types:span",
+# "///base:source_location",
+# "///base:status",
+# "///schemas",
+# ],
+# )
+#
+# cc_library(
+# name = "debug_client_interface",
+# hdrs = ["debug_client.h"],
+# deps = [
+# "@com_google_absl//absl/container:flat_hash_map",
+# "@com_google_absl//absl/strings",
+# "@com_google_absl//absl/types:optional",
+# "@com_google_absl//absl/types:span",
+# "///base:status",
+# "///schemas",
+# ],
+# )
+#
+# cc_library(
+# name = "debug_client_tcp",
+# srcs = ["debug_client_tcp.cc"],
+# deps = [
+# ":debug_client_interface",
+# ":debug_tcp_util",
+# "@com_google_absl//absl/container:flat_hash_map",
+# "@com_google_absl//absl/memory",
+# "@com_google_absl//absl/strings",
+# "@com_google_absl//absl/types:span",
+# "@com_github_google_flatbuffers//:flatbuffers",
+# "///base:flatbuffer_util",
+# "///base:status",
+# "///rt",
+# "///schemas",
+# ],
+# )
+#
+# cc_library(
+# name = "debug_server",
+# hdrs = ["debug_server.h"],
+# deps = [
+# ":debug_server_interface",
+# "//third_party/flatbuffers:flatbuffers",
+# "///schemas",
+# "///base:status",
+# ] + select({
+# "//:debug": [":debug_server_tcp"],
+# "//conditions:default": [":debug_server_disabled"],
+# }),
+# )
+
+cc_library(
+ name = "debug_server",
+ hdrs = ["debug_server.h"],
+ deps = [
+ ":debug_server_disabled",
+ ":debug_server_interface",
+ "///base:status",
+ ],
+)
+
+cc_library(
+ name = "debug_server_interface",
+ hdrs = ["debug_server.h"],
+ deps = ["///base:status"],
+)
+
+cc_library(
+ name = "debug_server_disabled",
+ srcs = ["debug_server_disabled.cc"],
+ deps = [
+ ":debug_server_interface",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+# TODO(benvanik): re-enable debugger after refactoring.
+# cc_library(
+# name = "debug_server_tcp",
+# srcs = ["debug_server_tcp.cc"],
+# deps = [
+# ":debug_server_interface",
+# ":debug_service",
+# ":debug_tcp_util",
+# "@com_google_absl//absl/base:core_headers",
+# "@com_google_absl//absl/memory",
+# "@com_google_absl//absl/synchronization",
+# "@com_github_google_flatbuffers//:flatbuffers",
+# "///base:status",
+# "///schemas",
+# ],
+# )
+
+cc_library(
+ name = "debug_server_flags",
+ srcs = ["debug_server_flags.cc"],
+ hdrs = ["debug_server_flags.h"],
+ deps = [
+ ":debug_server",
+ "///base:memory",
+ "///base:status",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+# TODO(benvanik): re-enable debugger after refactoring.
+# cc_library(
+# name = "debug_server_flags",
+# srcs = ["debug_server_flags.cc"],
+# hdrs = ["debug_server_flags.h"],
+# copts = select({
+# "//:debug": [
+# "-DIREE_DEBUG_EMBEDDED_APP_PRESENT=1",
+# ],
+# "//conditions:default": [],
+# }),
+# deps = [
+# ":debug_server",
+# "@com_google_absl//absl/flags:flag",
+# "@com_google_absl//absl/strings",
+# "///base:memory",
+# "///base:status",
+# ] + select({
+# "//:debug": [
+# "///tools/debugger:debug_app_embedded",
+# "//third_party/GL/native:EGL", # build-cleaner: keep
+# "//third_party/GL/native:GLESv2", # build-cleaner: keep
+# ],
+# "//conditions:default": [],
+# }),
+# )
+#
+# cc_library(
+# name = "debug_service",
+# srcs = ["debug_service.cc"],
+# hdrs = ["debug_service.h"],
+# deps = [
+# ":debug_session",
+# "@com_google_absl//absl/base:core_headers",
+# "@com_google_absl//absl/strings",
+# "@com_google_absl//absl/synchronization",
+# "@com_github_google_flatbuffers//:flatbuffers",
+# "///base:flatbuffer_util",
+# "///base:source_location",
+# "///base:status",
+# "///rt",
+# "///schemas",
+# "///schemas:reflection_data",
+# ],
+# )
+#
+# cc_library(
+# name = "debug_session",
+# srcs = ["debug_session.cc"],
+# hdrs = ["debug_session.h"],
+# deps = [
+# "@com_google_absl//absl/base:core_headers",
+# "@com_google_absl//absl/synchronization",
+# "///base:source_location",
+# "///base:status",
+# "///rt",
+# "///schemas",
+# ],
+# )
+#
+# cc_library(
+# name = "debug_tcp_util",
+# hdrs = ["debug_tcp_util.h"],
+# deps = [
+# "@com_github_google_flatbuffers//:flatbuffers",
+# "///base:status",
+# "///schemas",
+# ],
+# )
diff --git a/iree/rt/debug/CMakeLists.txt b/rt/debug/CMakeLists.txt
similarity index 100%
rename from iree/rt/debug/CMakeLists.txt
rename to rt/debug/CMakeLists.txt
diff --git a/rt/debug/debug_adapter.h b/rt/debug/debug_adapter.h
new file mode 100644
index 0000000..f9c149f
--- /dev/null
+++ b/rt/debug/debug_adapter.h
@@ -0,0 +1,106 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_DEBUG_ADAPTER_H_
+#define IREE_RT_DEBUG_ADAPTER_H_
+
+#include <functional>
+
+#include "base/status.h"
+#include "rt/invocation.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+struct StepTarget {
+ // TODO(benvanik): step target info (matching RPC message).
+ // module / function / offset
+ // relative to current: once, out, return, etc
+};
+
+// TODO(benvanik): move to fiber base.
+// Interface for debugging invocations.
+// This is only accessible in debug builds where such features are compiled in.
+class DebugAdapter {
+ public:
+ // Called when an invocation completes suspending (in response to a Suspend or
+ // Step request). The |suspend_status| will indicate if the suspension was
+ // successful.
+ using SuspendCallback = std::function<void(Status suspend_status)>;
+
+ // Returns true if the invocation is suspended.
+ // This only returns true if the invocation has been requested to suspend with
+ // Suspend and the runtime has acked the suspend. Once suspended (and until
+ // resumed) invocation state will not change and may be observed from any
+ // thread.
+ //
+ // Safe to call from any thread.
+ bool IsSuspended(Invocation* invocation);
+
+ // Suspends the invocation at the next possible chance.
+ //
+ // Fibers have a suspension depth and each call to Suspend must be matched
+ // with a call to Resume. Fibers will only resume excution when all prior
+ // Suspend calls have their matching Resume called.
+ //
+ // Optionally callers may provide a |suspend_callback| that will be called
+ // from a random thread when the invocation is suspended (or fails to
+ // suspend).
+ //
+ // Safe to call from any thread.
+ // Returns StatusCode::kUnavailable if debugging is not supported.
+ Status Suspend(ref_ptr<Invocation> invocation,
+ SuspendCallback suspend_callback = nullptr);
+
+ // Resumes the invocation if it is suspended (or cancels a pending suspend).
+ // This may wake threads if they are currently waiting on the invocation to
+ // execute.
+ //
+ // Safe to call from any thread.
+ // Returns StatusCode::kUnavailable if debugging is not supported.
+ Status Resume(Invocation* invocation);
+
+ // Steps invocation execution.
+ // This will attempt to resume the invocation and will complete
+ // asynchronously. Upon returning the invocation should be assumed resumed and
+ // callers must query is_suspended to wait until the invocation suspends
+ // again. Optionally callers may provide a |suspend_callback| that will be
+ // called from a random thread when the invocation is suspended (or fails to
+ // suspend).
+ //
+ // Safe to call from any thread while the invocation is suspended.
+ // Returns StatusCode::kUnavailable if debugging is not supported and
+ // StatusCode::kFailedPrecondition if the invocation is not suspended.
+ Status Step(ref_ptr<Invocation> invocation, StepTarget step_target,
+ SuspendCallback suspend_callback = nullptr);
+
+ // Returns a call stack that can be used to query and manipulate the
+ // invocation state. The behaviors supported depend on the stack frames and
+ // the backend support and may be conditionally enabled via compile-time or
+ // run-time flags.
+ //
+ // Safe to call from any thread while the invocation is suspended.
+ // Returns StatusCode::kUnavailable if debugging is not supported and
+ // StatusCode::kFailedPrecondition if the invocation is not suspended.
+ // The returned stack will be invalidated when the invocation is stepped or
+ // resumed.
+ StatusOr<Stack> CaptureStack(Invocation* invocation);
+};
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_DEBUG_ADAPTER_H_
diff --git a/rt/debug/debug_client.cc b/rt/debug/debug_client.cc
new file mode 100644
index 0000000..8ee2cbf
--- /dev/null
+++ b/rt/debug/debug_client.cc
@@ -0,0 +1,64 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/debug/debug_client.h"
+
+#include "base/source_location.h"
+#include "base/status.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+Status DebugClient::GetFunction(
+ std::string module_name, std::string function_name,
+ std::function<void(StatusOr<RemoteFunction*> function)> callback) {
+ return ResolveFunction(
+ module_name, function_name,
+ [this, module_name, callback](StatusOr<int> function_ordinal) {
+ if (!function_ordinal.ok()) {
+ callback(function_ordinal.status());
+ return;
+ }
+ auto status =
+ GetFunction(module_name, function_ordinal.ValueOrDie(), callback);
+ if (!status.ok()) {
+ callback(std::move(status));
+ }
+ });
+}
+
+Status DebugClient::StepInvocationOver(const RemoteInvocation& invocation,
+ std::function<void()> callback) {
+ // TODO(benvanik): implement bytecode stepping search.
+ // int bytecode_offset = 0;
+ // return StepInvocationToOffset(invocation, bytecode_offset,
+ // std::move(callback));
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "StepInvocationOver not yet implemented";
+}
+
+Status DebugClient::StepInvocationOut(const RemoteInvocation& invocation,
+ std::function<void()> callback) {
+ // TODO(benvanik): implement bytecode stepping search.
+ // int bytecode_offset = 0;
+ // return StepInvocationToOffset(invocation, bytecode_offset,
+ // std::move(callback));
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "StepInvocationOut not yet implemented";
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/rt/debug/debug_client.h b/rt/debug/debug_client.h
new file mode 100644
index 0000000..ac074b8
--- /dev/null
+++ b/rt/debug/debug_client.h
@@ -0,0 +1,286 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_DEBUG_DEBUG_CLIENT_H_
+#define IREE_RT_DEBUG_DEBUG_CLIENT_H_
+
+#include <functional>
+#include <memory>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "schemas/debug_service_generated.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+// Remote breakpoint currently active on the server.
+class RemoteBreakpoint {
+ public:
+ enum class Type {
+ kBytecodeFunction = 0,
+ kNativeFunction = 1,
+ };
+
+ virtual ~RemoteBreakpoint() = default;
+
+ int id() const { return id_; }
+ Type type() const { return type_; }
+
+ virtual const std::string& module_name() const = 0;
+ virtual const std::string& function_name() const = 0;
+ virtual int function_ordinal() const = 0;
+ virtual int bytecode_offset() const = 0;
+
+ protected:
+ explicit RemoteBreakpoint(int id, Type type) : id_(id), type_(type) {}
+
+ private:
+ int id_;
+ Type type_;
+};
+
+class RemoteModule;
+
+class RemoteFunction {
+ public:
+ virtual ~RemoteFunction() = default;
+
+ RemoteModule* module() const { return module_; }
+ int ordinal() const { return function_ordinal_; }
+ virtual const std::string& name() const = 0;
+
+ virtual const FunctionDef& def() = 0;
+
+ virtual bool is_loaded() const = 0;
+ virtual bool CheckLoadedOrRequest() = 0;
+
+ using LoadCallback = std::function<void(StatusOr<RemoteFunction*>)>;
+ virtual void WhenLoaded(LoadCallback callback) = 0;
+
+ virtual const BytecodeDef* bytecode() = 0;
+
+ protected:
+ RemoteFunction(RemoteModule* module, int function_ordinal)
+ : module_(module), function_ordinal_(function_ordinal) {}
+
+ RemoteModule* module_;
+ int function_ordinal_;
+};
+
+class RemoteModule {
+ public:
+ virtual ~RemoteModule() = default;
+
+ int context_id() const { return context_id_; }
+ const std::string& name() const { return name_; }
+
+ virtual const ModuleDef& def() = 0;
+
+ virtual bool is_loaded() const = 0;
+ virtual bool CheckLoadedOrRequest() = 0;
+
+ using LoadCallback = std::function<void(StatusOr<RemoteModule*>)>;
+ virtual void WhenLoaded(LoadCallback callback) = 0;
+
+ virtual absl::Span<RemoteFunction*> functions() = 0;
+
+ protected:
+ RemoteModule(int context_id, std::string name)
+ : context_id_(context_id), name_(std::move(name)) {}
+
+ private:
+ int context_id_;
+ std::string name_;
+};
+
+class RemoteContext {
+ public:
+ virtual ~RemoteContext() = default;
+
+ int id() const { return id_; }
+
+ virtual absl::Span<RemoteModule* const> modules() const = 0;
+
+ protected:
+ explicit RemoteContext(int id) : id_(id) {}
+
+ private:
+ int id_;
+};
+
+class RemoteInvocation {
+ public:
+ virtual ~RemoteInvocation() = default;
+
+ int id() const { return id_; }
+ const std::string& name() const { return name_; }
+
+ virtual const rpc::InvocationDefT& def() const = 0;
+
+ protected:
+ explicit RemoteInvocation(int id)
+ : id_(id), name_(absl::StrCat("Invocation ", id)) {}
+
+ private:
+ int id_;
+ std::string name_;
+};
+
+// Debugger RPC server client.
+// Statefully tracks a DebugServer to provide common client operations and
+// memoized queries.
+//
+// Thread-compatible. Do not use the client from multiple threads concurrently.
+// All remote updates of local state are performed by the Poll function. See
+// Poll for more details.
+class DebugClient {
+ public:
+ // Debug event listener interface.
+ // Event methods will be called from within Poll calls (so on that thread).
+ //
+ // When the server posts an event it will mark the client as unready and
+ // suspend execution of all invocations until MakeReady is used to indicate
+ // that the client is ready for the server to resume. Each event needs a
+ // matching MakeReady ack.
+ //
+ // Listeners can defer acking if they need to perform additional queries or
+ // state changes to the server or wait for user interaction. Multiple events
+ // may come in while unready if there was a series of events pending on the
+ // server.
+ class Listener {
+ public:
+ virtual ~Listener() = default;
+
+ // Signals that a context has been registered on the server.
+ virtual Status OnContextRegistered(const RemoteContext& context) = 0;
+ virtual Status OnContextUnregistered(const RemoteContext& context) = 0;
+
+ // Signals that a module has been loaded into a context on the server.
+ virtual Status OnModuleLoaded(const RemoteContext& context,
+ const RemoteModule& module) = 0;
+
+ // Signals that a invocation has been registered on the server.
+ virtual Status OnInvocationRegistered(
+ const RemoteInvocation& invocation) = 0;
+ virtual Status OnInvocationUnregistered(
+ const RemoteInvocation& invocation) = 0;
+
+ // Signals that a breakpoint has been hit by a invocation on the server.
+ virtual Status OnBreakpointHit(const RemoteBreakpoint& breakpoint,
+ const RemoteInvocation& invocation) = 0;
+ };
+
+ // Connects to a remote debug service at the provided IP:port.
+ // The specified |listener| will receive async event notifications.
+ static StatusOr<std::unique_ptr<DebugClient>> Connect(
+ absl::string_view service_address, Listener* listener);
+
+ virtual ~DebugClient() = default;
+
+ // Returns true if the client is connected to a service.
+ // virtual bool is_connected() const = 0;
+
+ // A list of all contexts registered with the server.
+ virtual absl::Span<RemoteContext* const> contexts() const = 0;
+
+ // A list of all invocations registered with the server.
+ virtual absl::Span<RemoteInvocation* const> invocations() const = 0;
+
+ // A list of all breakpoints registered with the server.
+ virtual absl::Span<RemoteBreakpoint* const> breakpoints() const = 0;
+
+ // Resolves a function to a module ordinal.
+ // This will occur asynchronously and the |callback| will be issued on the
+ // polling thread.
+ virtual Status ResolveFunction(
+ std::string module_name, std::string function_name,
+ std::function<void(StatusOr<int> function_ordinal)> callback) = 0;
+
+ // Gets a function body instance.
+ // The provided |callback| will be issued on the polling thread when the
+ // function is available.
+ virtual Status GetFunction(
+ std::string module_name, int function_ordinal,
+ std::function<void(StatusOr<RemoteFunction*> function)> callback) = 0;
+ Status GetFunction(
+ std::string module_name, std::string function_name,
+ std::function<void(StatusOr<RemoteFunction*> function)> callback);
+
+ // Adds a breakpoint for the given module:function:offset.
+ // The breakpoint will apply to all contexts with the module loaded.
+ virtual Status AddFunctionBreakpoint(
+ std::string module_name, std::string function_name, int offset,
+ std::function<void(const RemoteBreakpoint& breakpoint)> callback =
+ nullptr) = 0;
+
+ // Removes a breakpoint from the server.
+ virtual Status RemoveBreakpoint(const RemoteBreakpoint& breakpoint) = 0;
+
+ // Notifies the server that the debug session is ready to continue.
+ // This must be called once on connection to and in acknowledgement to any
+ // events posted by the server (read: any call to the Listener::On* methods).
+ virtual Status MakeReady() = 0;
+
+ // Suspends all invocations running on the server.
+ virtual Status SuspendAllInvocations() = 0;
+
+ // Resumes all invocations running on the server.
+ virtual Status ResumeAllInvocations() = 0;
+
+ // Suspends a list of invocations running on the server. Invocations not in
+ // the provided list will not be suspended, such as new invocations created
+ // while the request is pending.
+ virtual Status SuspendInvocations(
+ absl::Span<RemoteInvocation*> invocations) = 0;
+
+ // Resumes a list of invocations running on the server.
+ virtual Status ResumeInvocations(
+ absl::Span<RemoteInvocation*> invocations) = 0;
+
+ // Steps a invocation one bytecode operation.
+ virtual Status StepInvocation(const RemoteInvocation& invocation,
+ std::function<void()> callback) = 0;
+ // Steps a invocation over one bytecode operation, not stopping until it
+ // completes.
+ Status StepInvocationOver(const RemoteInvocation& invocation,
+ std::function<void()> callback);
+ // Steps a invocation out of the current block.
+ Status StepInvocationOut(const RemoteInvocation& invocation,
+ std::function<void()> callback);
+ // Steps a invocation to a specific bytecode offset within the current
+ // function.
+ virtual Status StepInvocationToOffset(const RemoteInvocation& invocation,
+ int bytecode_offset,
+ std::function<void()> callback) = 0;
+
+ // TODO(benvanik): profiling modes.
+
+ // Polls for the current state of the debug service and processes incoming
+ // responses. Must be called as frequently as the UI is desired to update.
+ // Returns CancelledError when the service is being shutdown/disconnected.
+ //
+ // Events on the Listener will be called from within this method.
+ virtual Status Poll() = 0;
+};
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_DEBUG_DEBUG_CLIENT_H_
diff --git a/rt/debug/debug_client_tcp.cc b/rt/debug/debug_client_tcp.cc
new file mode 100644
index 0000000..14f90f7
--- /dev/null
+++ b/rt/debug/debug_client_tcp.cc
@@ -0,0 +1,1127 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netdb.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <cstring>
+#include <queue>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+#include "absl/types/span.h"
+#include "base/flatbuffer_util.h"
+#include "base/status.h"
+#include "flatbuffers/base.h"
+#include "flatbuffers/flatbuffers.h"
+#include "rt/debug/debug_client.h"
+#include "rt/debug/debug_tcp_util.h"
+#include "rt/module.h"
+#include "schemas/debug_service_generated.h"
+#include "schemas/module_def_generated.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+namespace {
+
+using ::flatbuffers::FlatBufferBuilder;
+
+// Parses a host:port address, with support for the RFC 3986 IPv6 [host]:port
+// format. Returns a pair of (hostname, port), with port being 0 if none was
+// specified.
+//
+// Parses:
+// foo (port 0) / foo:123
+// 1.2.3.4 (port 0) / 1.2.3.4:123
+// [foo] (port 0) / [foo]:123
+// [::1] (port 0) / [::1]:123
+StatusOr<std::pair<std::string, int>> ParseAddress(absl::string_view address) {
+ address = absl::StripAsciiWhitespace(address);
+ absl::string_view hostname;
+ absl::string_view port_str;
+ size_t bracket_loc = address.find_last_of(']');
+ if (bracket_loc != std::string::npos) {
+ // Has at least a ]. Let's assume it's mostly right.
+ if (address.find('[') != 0) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Mismatched brackets in address: " << address;
+ }
+ hostname = address.substr(1, bracket_loc - 1);
+ port_str = address.substr(bracket_loc + 1);
+ if (port_str.find(':') == 0) {
+ port_str.remove_prefix(1);
+ }
+ } else {
+ size_t colon_loc = address.find_last_of(':');
+ if (colon_loc != std::string::npos) {
+ hostname = address.substr(0, colon_loc);
+ port_str = address.substr(colon_loc + 1);
+ } else {
+ hostname = address;
+ port_str = "";
+ }
+ }
+ int port = 0;
+ if (!port_str.empty() && !absl::SimpleAtoi(port_str, &port)) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Unable to parse port '" << port_str << "' from " << address;
+ }
+ return std::make_pair(std::string(hostname), port);
+}
+
+class TcpDebugClient final : public DebugClient {
+ public:
+ class TcpRemoteBreakpoint : public RemoteBreakpoint {
+ public:
+ TcpRemoteBreakpoint(int id, Type type, TcpDebugClient* client)
+ : RemoteBreakpoint(id, type) {}
+
+ const std::string& module_name() const override { return def_.module_name; }
+ const std::string& function_name() const override {
+ return def_.function_name;
+ }
+ int function_ordinal() const override { return def_.function_ordinal; }
+ int bytecode_offset() const override { return def_.bytecode_offset; }
+
+ Status MergeFrom(const rpc::BreakpointDef& breakpoint_def) {
+ breakpoint_def.UnPackTo(&def_);
+ return OkStatus();
+ }
+
+ private:
+ rpc::BreakpointDefT def_;
+ };
+
+ class TcpRemoteFunction final : public RemoteFunction {
+ public:
+ TcpRemoteFunction(RemoteModule* module, int function_ordinal,
+ const FunctionDef* function_def, TcpDebugClient* client)
+ : RemoteFunction(module, function_ordinal),
+ def_(function_def),
+ client_(client) {
+ name_ = def_->name() ? std::string(WrapString(def_->name())) : "";
+ }
+
+ const std::string& name() const override { return name_; }
+
+ const FunctionDef& def() override { return *def_; }
+
+ bool is_loaded() const override {
+ return contents_.flatbuffers_buffer.size() > 0;
+ }
+
+ bool CheckLoadedOrRequest() override {
+ if (!is_loaded()) {
+ DemandContents();
+ }
+ return is_loaded();
+ }
+
+ void WhenLoaded(LoadCallback callback) override {
+ if (is_loaded()) {
+ callback(this);
+ return;
+ }
+ load_callbacks_.push_back(std::move(callback));
+ }
+
+ const BytecodeDef* bytecode() override {
+ CHECK(is_loaded());
+ return contents_.bytecode_def;
+ }
+
+ private:
+ void DemandContents() {
+ if (!has_requested_contents_) {
+ VLOG(2) << "Client " << client_->fd() << ": GetFunction("
+ << module()->context_id() << ", " << module()->name() << ", "
+ << ordinal() << ")";
+ FlatBufferBuilder fbb;
+ rpc::GetFunctionRequestT request;
+ request.session_id = client_->session_id();
+ request.context_id = module()->context_id();
+ request.module_name = module()->name();
+ request.function_ordinal = ordinal();
+ auto status =
+ client_->IssueRequest<rpc::GetFunctionRequest,
+ rpc::ResponseUnion::GetFunctionResponse>(
+ rpc::GetFunctionRequest::Pack(fbb, &request), std::move(fbb),
+ [this](Status status,
+ const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ const auto& response =
+ *response_union.message_as_GetFunctionResponse();
+ VLOG(2) << "Client " << client_->fd() << ": GetFunction("
+ << module()->context_id() << ", " << module()->name()
+ << ", " << ordinal() << ") = ...";
+ RETURN_IF_ERROR(MergeFrom(response));
+ for (auto& callback : load_callbacks_) {
+ callback(this);
+ }
+ load_callbacks_.clear();
+ return OkStatus();
+ });
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to request module: " << status;
+ return;
+ }
+ has_requested_contents_ = true;
+ }
+ }
+
+ Status MergeFrom(const rpc::GetFunctionResponse& response) {
+ // Clone and retain the contents.
+ // TODO(benvanik): find a way to steal to avoid the reserialization.
+ BytecodeDefT bytecode_def_storage;
+ response.bytecode()->UnPackTo(&bytecode_def_storage);
+ ::flatbuffers::FlatBufferBuilder fbb;
+ fbb.Finish(response.bytecode()->Pack(fbb, &bytecode_def_storage));
+ contents_.flatbuffers_buffer = fbb.Release();
+ contents_.bytecode_def = ::flatbuffers::GetRoot<BytecodeDef>(
+ contents_.flatbuffers_buffer.data());
+ return OkStatus();
+ }
+
+ const FunctionDef* def_;
+ TcpDebugClient* client_;
+ std::string name_;
+ bool has_requested_contents_ = false;
+ std::vector<LoadCallback> load_callbacks_;
+ struct {
+ ::flatbuffers::DetachedBuffer flatbuffers_buffer;
+ const BytecodeDef* bytecode_def = nullptr;
+ } contents_;
+ };
+
+ class TcpRemoteModule final : public RemoteModule {
+ public:
+ TcpRemoteModule(int context_id, std::string module_name,
+ TcpDebugClient* client)
+ : RemoteModule(context_id, std::move(module_name)), client_(client) {}
+
+ const ModuleDef& def() override {
+ CHECK(is_loaded());
+ return *module_file_->root();
+ }
+
+ bool is_loaded() const override { return module_file_ != nullptr; }
+
+ bool CheckLoadedOrRequest() override {
+ if (!is_loaded()) {
+ DemandModuleDef();
+ }
+ return is_loaded();
+ }
+
+ void WhenLoaded(LoadCallback callback) override {
+ if (is_loaded()) {
+ callback(this);
+ return;
+ }
+ load_callbacks_.push_back(std::move(callback));
+ }
+
+ absl::Span<RemoteFunction*> functions() override {
+ auto* module_def = DemandModuleDef();
+ if (!module_def) return {};
+ return {reinterpret_cast<RemoteFunction**>(functions_.data()),
+ functions_.size()};
+ }
+
+ private:
+ const ModuleDef* DemandModuleDef() {
+ if (module_file_) {
+ return module_file_->root();
+ }
+ if (!has_requested_module_def_) {
+ VLOG(2) << "Client " << client_->fd() << ": GetModule(" << context_id()
+ << ", " << name() << ")";
+ FlatBufferBuilder fbb;
+ rpc::GetModuleRequestT request;
+ request.session_id = client_->session_id();
+ request.context_id = context_id();
+ request.module_name = name();
+ auto status =
+ client_->IssueRequest<rpc::GetModuleRequest,
+ rpc::ResponseUnion::GetModuleResponse>(
+ rpc::GetModuleRequest::Pack(fbb, &request), std::move(fbb),
+ [this](Status status,
+ const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ const auto& response =
+ *response_union.message_as_GetModuleResponse();
+ VLOG(2) << "Client " << client_->fd() << ": GetModule("
+ << context_id() << ", " << name() << ") = ...";
+ RETURN_IF_ERROR(MergeFrom(response));
+ for (auto& callback : load_callbacks_) {
+ callback(this);
+ }
+ load_callbacks_.clear();
+ return OkStatus();
+ });
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to request module: " << status;
+ return nullptr;
+ }
+ has_requested_module_def_ = true;
+ }
+ return nullptr;
+ }
+
+ Status MergeFrom(const rpc::GetModuleResponse& response) {
+ // Clone and retain the module.
+ // TODO(benvanik): find a way to steal to avoid the reserialization.
+ ModuleDefT module_def_storage;
+ response.module_()->UnPackTo(&module_def_storage);
+ FlatBufferBuilder fbb;
+ auto module_offs = response.module_()->Pack(fbb, &module_def_storage);
+ FinishModuleDefBuffer(fbb, module_offs);
+ ASSIGN_OR_RETURN(auto module_file,
+ ModuleFile::CreateWithBackingBuffer(fbb.Release()));
+
+ const auto& module_def = module_file->root();
+ const auto& function_table = *module_def->function_table();
+ functions_.reserve(function_table.functions()->size());
+ for (int i = 0; i < function_table.functions()->size(); ++i) {
+ const auto* function_def = function_table.functions()->Get(i);
+ functions_.push_back(absl::make_unique<TcpRemoteFunction>(
+ this, i, function_def, client_));
+ }
+
+ module_file_ = std::move(module_file);
+ return OkStatus();
+ }
+
+ TcpDebugClient* client_;
+ bool has_requested_module_def_ = false;
+ std::vector<LoadCallback> load_callbacks_;
+ std::unique_ptr<ModuleFile> module_file_;
+ std::vector<std::unique_ptr<RemoteFunction>> functions_;
+ };
+
+ class TcpRemoteContext final : public RemoteContext {
+ public:
+ TcpRemoteContext(int context_id, TcpDebugClient* client)
+ : RemoteContext(context_id), client_(client) {}
+
+ absl::Span<RemoteModule* const> modules() const override {
+ return absl::MakeConstSpan(modules_);
+ }
+
+ Status AddModule(std::unique_ptr<TcpRemoteModule> module) {
+ modules_.push_back(module.get());
+ module_map_.insert({module->name(), std::move(module)});
+ return OkStatus();
+ }
+
+ Status MergeFrom(const rpc::ContextDef& context_def) { return OkStatus(); }
+
+ private:
+ TcpDebugClient* client_;
+ std::vector<RemoteModule*> modules_;
+ absl::flat_hash_map<std::string, std::unique_ptr<TcpRemoteModule>>
+ module_map_;
+ };
+
+ class TcpRemoteInvocation final : public RemoteInvocation {
+ public:
+ TcpRemoteInvocation(int invocation_id, TcpDebugClient* client)
+ : RemoteInvocation(invocation_id), client_(client) {}
+
+ const rpc::InvocationDefT& def() const override { return def_; }
+
+ Status MergeFrom(const rpc::InvocationDef& invocation_def) {
+ invocation_def.UnPackTo(&def_);
+ return OkStatus();
+ }
+
+ private:
+ TcpDebugClient* client_;
+ rpc::InvocationDefT def_;
+ };
+
+ static StatusOr<std::unique_ptr<TcpDebugClient>> Create(int fd,
+ Listener* listener) {
+ VLOG(2) << "Client " << fd << ": Setting up socket options...";
+ // Disable Nagel's algorithm to ensure we have low latency.
+ RETURN_IF_ERROR(tcp::ToggleSocketNagelsAlgorithm(fd, false));
+ // Enable keepalive assuming the client is local and this high freq is ok.
+ RETURN_IF_ERROR(tcp::ToggleSocketLocalKeepalive(fd, true));
+ // Linger around for a bit to flush all data.
+ RETURN_IF_ERROR(tcp::ToggleSocketLinger(fd, true));
+ // Disable blocking as we are poll based.
+ RETURN_IF_ERROR(tcp::ToggleSocketBlocking(fd, false));
+
+ auto client = absl::make_unique<TcpDebugClient>(fd, listener);
+ RETURN_IF_ERROR(client->Refresh());
+ return client;
+ }
+
+ TcpDebugClient(int fd, Listener* listener) : fd_(fd), listener_(listener) {}
+
+ ~TcpDebugClient() override {
+ VLOG(2) << "Client " << fd_ << ": Shutting down session socket...";
+ ::shutdown(fd_, SHUT_WR);
+ VLOG(2) << "Client " << fd_ << ": Closing session socket...";
+ ::close(fd_);
+ VLOG(2) << "Client " << fd_ << ": Closed session socket!";
+ fd_ = -1;
+ }
+
+ int fd() const { return fd_; }
+ int session_id() const { return session_id_; }
+
+ absl::Span<RemoteContext* const> contexts() const override {
+ return absl::MakeConstSpan(contexts_);
+ }
+
+ absl::Span<RemoteInvocation* const> invocations() const override {
+ return absl::MakeConstSpan(invocations_);
+ }
+
+ absl::Span<RemoteBreakpoint* const> breakpoints() const override {
+ return absl::MakeConstSpan(breakpoints_);
+ }
+
+ // Writes the given typed request message to the given fd by wrapping it in
+ // a size-prefixed rpc::Request union.
+ //
+ // Example:
+ // FlatBufferBuilder fbb;
+ // rpc::SuspendInvocationRequestBuilder request(fbb);
+ // RETURN_IF_ERROR(WriteRequest(fd_, request.Finish(), std::move(fbb)));
+ template <typename T>
+ Status WriteRequest(int fd, ::flatbuffers::Offset<T> request_offs,
+ FlatBufferBuilder fbb) {
+ rpc::RequestBuilder request_builder(fbb);
+ request_builder.add_message_type(rpc::RequestUnionTraits<T>::enum_value);
+ request_builder.add_message(request_offs.Union());
+ fbb.FinishSizePrefixed(request_builder.Finish());
+ auto write_status = tcp::WriteBuffer(fd, fbb.Release());
+ if (shutdown_pending_ && IsUnavailable(write_status)) {
+ return OkStatus();
+ }
+ return write_status;
+ }
+
+ Status ResolveFunction(
+ std::string module_name, std::string function_name,
+ std::function<void(StatusOr<int> function_ordinal)> callback) override {
+ VLOG(2) << "Client " << fd_ << ": ResolveFunction(" << module_name << ", "
+ << function_name << ")";
+ FlatBufferBuilder fbb;
+ rpc::ResolveFunctionRequestT request;
+ request.session_id = session_id_;
+ request.module_name = module_name;
+ request.function_name = function_name;
+ return IssueRequest<rpc::ResolveFunctionRequest,
+ rpc::ResponseUnion::ResolveFunctionResponse>(
+ rpc::ResolveFunctionRequest::Pack(fbb, &request), std::move(fbb),
+ [this, module_name, function_name, callback](
+ Status status, const rpc::Response& response_union) -> Status {
+ if (status.ok()) {
+ const auto& response =
+ *response_union.message_as_ResolveFunctionResponse();
+ VLOG(2) << "Client " << fd_ << ": ResolveFunction(" << module_name
+ << ", " << function_name
+ << ") = " << response.function_ordinal();
+ callback(response.function_ordinal());
+ } else {
+ callback(std::move(status));
+ }
+ return OkStatus();
+ });
+ }
+
+ Status GetFunction(std::string module_name, int function_ordinal,
+ std::function<void(StatusOr<RemoteFunction*> function)>
+ callback) override {
+ // See if we have the module already. If not, we'll fetch it first.
+ RemoteModule* target_module = nullptr;
+ for (auto* context : contexts_) {
+ for (auto* module : context->modules()) {
+ if (module->name() == module_name) {
+ target_module = module;
+ break;
+ }
+ }
+ if (target_module) break;
+ }
+ if (!target_module) {
+ // TODO(benvanik): fetch contexts first.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Demand fetch contexts not yet implemented";
+ }
+ // Found at least one module with the right name.
+ if (target_module->is_loaded()) {
+ callback(target_module->functions()[function_ordinal]);
+ return OkStatus();
+ } else {
+ // Wait until the module completes loading.
+ target_module->WhenLoaded(
+ [callback, function_ordinal](StatusOr<RemoteModule*> module_or) {
+ if (!module_or.ok()) {
+ callback(module_or.status());
+ return;
+ }
+ callback(module_or.ValueOrDie()->functions()[function_ordinal]);
+ });
+ return OkStatus();
+ }
+ }
+
+ Status AddFunctionBreakpoint(
+ std::string module_name, std::string function_name, int offset,
+ std::function<void(const RemoteBreakpoint& breakpoint)> callback)
+ override {
+ VLOG(2) << "Client " << fd_ << ": AddFunctionBreakpoint(" << module_name
+ << ", " << function_name << ", " << offset << ")";
+ FlatBufferBuilder fbb;
+
+ auto breakpoint = absl::make_unique<rpc::BreakpointDefT>();
+ breakpoint->module_name = module_name;
+ breakpoint->function_name = function_name;
+ breakpoint->function_ordinal = -1;
+ breakpoint->bytecode_offset = offset;
+ rpc::AddBreakpointRequestT request;
+ request.session_id = session_id_;
+ request.breakpoint = std::move(breakpoint);
+ return IssueRequest<rpc::AddBreakpointRequest,
+ rpc::ResponseUnion::AddBreakpointResponse>(
+ rpc::AddBreakpointRequest::Pack(fbb, &request), std::move(fbb),
+ [this, callback](Status status,
+ const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ const auto& response =
+ *response_union.message_as_AddBreakpointResponse();
+ RETURN_IF_ERROR(RegisterBreakpoint(*response.breakpoint()));
+ if (callback) {
+ ASSIGN_OR_RETURN(
+ auto breakpoint,
+ GetBreakpoint(response.breakpoint()->breakpoint_id()));
+ callback(*breakpoint);
+ }
+ return OkStatus();
+ });
+ }
+
+ Status RemoveBreakpoint(const RemoteBreakpoint& breakpoint) override {
+ VLOG(2) << "Client " << fd_ << ": RemoveBreakpoint(" << breakpoint.id()
+ << ")";
+ int breakpoint_id = breakpoint.id();
+ ASSIGN_OR_RETURN(auto* breakpoint_ptr, GetBreakpoint(breakpoint_id));
+ RETURN_IF_ERROR(UnregisterBreakpoint(breakpoint_ptr));
+ FlatBufferBuilder fbb;
+ rpc::RemoveBreakpointRequestBuilder request(fbb);
+ request.add_session_id(session_id_);
+ request.add_breakpoint_id(breakpoint_id);
+ return IssueRequest<rpc::RemoveBreakpointRequest,
+ rpc::ResponseUnion::RemoveBreakpointResponse>(
+ request.Finish(), std::move(fbb),
+ [](Status status, const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ // No non-error status.
+ return OkStatus();
+ });
+ }
+
+ Status MakeReady() override {
+ FlatBufferBuilder fbb;
+ rpc::MakeReadyRequestBuilder request(fbb);
+ request.add_session_id(session_id_);
+ return IssueRequest<rpc::MakeReadyRequest,
+ rpc::ResponseUnion::MakeReadyResponse>(
+ request.Finish(), std::move(fbb),
+ [](Status status, const rpc::Response& response_union) {
+ return status;
+ });
+ }
+
+ Status SuspendAllInvocations() override {
+ VLOG(2) << "Client " << fd_ << ": SuspendAllInvocations()";
+ FlatBufferBuilder fbb;
+ rpc::SuspendInvocationsRequestBuilder request(fbb);
+ request.add_session_id(session_id_);
+ return IssueRequest<rpc::SuspendInvocationsRequest,
+ rpc::ResponseUnion::SuspendInvocationsResponse>(
+ request.Finish(), std::move(fbb),
+ [this](Status status, const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ return RefreshInvocations();
+ });
+ }
+
+ Status ResumeAllInvocations() override {
+ VLOG(2) << "Client " << fd_ << ": ResumeAllInvocations()";
+ FlatBufferBuilder fbb;
+ rpc::ResumeInvocationsRequestBuilder request(fbb);
+ request.add_session_id(session_id_);
+ return IssueRequest<rpc::ResumeInvocationsRequest,
+ rpc::ResponseUnion::ResumeInvocationsResponse>(
+ request.Finish(), std::move(fbb),
+ [this](Status status, const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ return RefreshInvocations();
+ });
+ }
+
+ Status SuspendInvocations(
+ absl::Span<RemoteInvocation*> invocations) override {
+ VLOG(2) << "Client " << fd_ << ": SuspendInvocations(...)";
+ FlatBufferBuilder fbb;
+ auto invocation_ids_offs = fbb.CreateVector<int32_t>(
+ invocations.size(),
+ [&invocations](size_t i) { return invocations[i]->id(); });
+ rpc::SuspendInvocationsRequestBuilder request(fbb);
+ request.add_session_id(session_id_);
+ request.add_invocation_ids(invocation_ids_offs);
+ return IssueRequest<rpc::SuspendInvocationsRequest,
+ rpc::ResponseUnion::SuspendInvocationsResponse>(
+ request.Finish(), std::move(fbb),
+ [this](Status status, const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ return RefreshInvocations();
+ });
+ }
+
+ Status ResumeInvocations(absl::Span<RemoteInvocation*> invocations) override {
+ VLOG(2) << "Client " << fd_ << ": ResumeInvocations(...)";
+ FlatBufferBuilder fbb;
+ auto invocation_ids_offs = fbb.CreateVector<int32_t>(
+ invocations.size(),
+ [&invocations](size_t i) { return invocations[i]->id(); });
+ rpc::ResumeInvocationsRequestBuilder request(fbb);
+ request.add_session_id(session_id_);
+ request.add_invocation_ids(invocation_ids_offs);
+ return IssueRequest<rpc::ResumeInvocationsRequest,
+ rpc::ResponseUnion::ResumeInvocationsResponse>(
+ request.Finish(), std::move(fbb),
+ [this](Status status, const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ return RefreshInvocations();
+ });
+ }
+
+ Status StepInvocation(const RemoteInvocation& invocation,
+ std::function<void()> callback) override {
+ int step_id = next_step_id_++;
+ VLOG(2) << "Client " << fd_ << ": StepInvocation(" << invocation.id()
+ << ") as step_id=" << step_id;
+ rpc::StepInvocationRequestT step_request;
+ step_request.step_id = step_id;
+ step_request.invocation_id = invocation.id();
+ step_request.step_mode = rpc::StepMode::STEP_ONCE;
+ return StepInvocation(&step_request, std::move(callback));
+ }
+
+ Status StepInvocationToOffset(const RemoteInvocation& invocation,
+ int bytecode_offset,
+ std::function<void()> callback) override {
+ int step_id = next_step_id_++;
+ VLOG(2) << "Client " << fd_ << ": StepInvocationToOffset("
+ << invocation.id() << ", " << bytecode_offset
+ << ") as step_id=" << step_id;
+ rpc::StepInvocationRequestT step_request;
+ step_request.step_id = step_id;
+ step_request.invocation_id = invocation.id();
+ step_request.step_mode = rpc::StepMode::STEP_TO_OFFSET;
+ step_request.bytecode_offset = bytecode_offset;
+ return StepInvocation(&step_request, std::move(callback));
+ }
+
+ Status Poll() override {
+ while (true) {
+ // If nothing awaiting then return immediately.
+ if (!tcp::CanReadBuffer(fd_)) {
+ break;
+ }
+
+ // Read the pending response and dispatch.
+ auto packet_buffer_or = tcp::ReadBuffer<rpc::ServicePacket>(fd_);
+ if (!packet_buffer_or.ok()) {
+ if (shutdown_pending_ && IsUnavailable(packet_buffer_or.status())) {
+ // This is a graceful close.
+ return CancelledErrorBuilder(IREE_LOC) << "Service shutdown";
+ }
+ return packet_buffer_or.status();
+ }
+ const auto& packet = packet_buffer_or.ValueOrDie().GetRoot();
+ if (packet.response()) {
+ RETURN_IF_ERROR(DispatchResponse(*packet.response()));
+ }
+ if (packet.event()) {
+ RETURN_IF_ERROR(DispatchEvent(packet));
+ }
+ }
+ return OkStatus();
+ }
+
+ using ResponseCallback =
+ std::function<Status(Status status, const rpc::Response& response)>;
+
+ template <typename T, rpc::ResponseUnion response_type>
+ Status IssueRequest(::flatbuffers::Offset<T> request_offs,
+ FlatBufferBuilder fbb, ResponseCallback callback) {
+ RETURN_IF_ERROR(WriteRequest(fd_, request_offs, std::move(fbb)));
+ pending_responses_.push({response_type, std::move(callback)});
+ return OkStatus();
+ }
+
+ private:
+ Status Refresh() {
+ RETURN_IF_ERROR(RefreshContexts());
+ RETURN_IF_ERROR(RefreshInvocations());
+ RETURN_IF_ERROR(RefreshBreakpoints());
+ return OkStatus();
+ }
+
+ Status RefreshContexts() {
+ VLOG(2) << "Request contexts refresh...";
+ FlatBufferBuilder fbb;
+ rpc::ListContextsRequestBuilder request(fbb);
+ request.add_session_id(session_id_);
+ return IssueRequest<rpc::ListContextsRequest,
+ rpc::ResponseUnion::ListContextsResponse>(
+ request.Finish(), std::move(fbb),
+ [this](Status status, const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ VLOG(2) << "Refreshing contexts...";
+ const auto& response =
+ *response_union.message_as_ListContextsResponse();
+ for (auto* context_def : *response.contexts()) {
+ auto context_or = GetContext(context_def->context_id());
+ if (!context_or.ok()) {
+ // Not found; add new.
+ RETURN_IF_ERROR(RegisterContext(context_def->context_id()));
+ context_or = GetContext(context_def->context_id());
+ }
+ RETURN_IF_ERROR(context_or.status());
+ RETURN_IF_ERROR(context_or.ValueOrDie()->MergeFrom(*context_def));
+ }
+ VLOG(2) << "Refreshed contexts!";
+ return OkStatus();
+ });
+ }
+
+ Status RefreshInvocations() {
+ VLOG(2) << "Request invocation states refresh...";
+ FlatBufferBuilder fbb;
+ rpc::ListInvocationsRequestBuilder request(fbb);
+ request.add_session_id(session_id_);
+ return IssueRequest<rpc::ListInvocationsRequest,
+ rpc::ResponseUnion::ListInvocationsResponse>(
+ request.Finish(), std::move(fbb),
+ [this](Status status, const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ VLOG(2) << "Refreshing invocation states...";
+ const auto& response =
+ *response_union.message_as_ListInvocationsResponse();
+ for (auto* invocation_def : *response.invocations()) {
+ auto invocation_or = GetInvocation(invocation_def->invocation_id());
+ if (!invocation_or.ok()) {
+ // Not found; add new.
+ RETURN_IF_ERROR(
+ RegisterInvocation(invocation_def->invocation_id()));
+ invocation_or = GetInvocation(invocation_def->invocation_id());
+ }
+ RETURN_IF_ERROR(invocation_or.status());
+ RETURN_IF_ERROR(
+ invocation_or.ValueOrDie()->MergeFrom(*invocation_def));
+ }
+ // TODO(benvanik): handle removals/deaths.
+ VLOG(2) << "Refreshed invocation states!";
+ return OkStatus();
+ });
+ }
+
+ Status RefreshBreakpoints() {
+ VLOG(2) << "Requesting breakpoint refresh...";
+ FlatBufferBuilder fbb;
+ rpc::ListBreakpointsRequestBuilder request(fbb);
+ request.add_session_id(session_id_);
+ return IssueRequest<rpc::ListBreakpointsRequest,
+ rpc::ResponseUnion::ListBreakpointsResponse>(
+ request.Finish(), std::move(fbb),
+ [this](Status status, const rpc::Response& response_union) -> Status {
+ if (!status.ok()) return status;
+ VLOG(2) << "Refreshing breakpoints...";
+ const auto& response =
+ *response_union.message_as_ListBreakpointsResponse();
+ for (auto* breakpoint_def : *response.breakpoints()) {
+ auto breakpoint_or = GetBreakpoint(breakpoint_def->breakpoint_id());
+ if (!breakpoint_or.ok()) {
+ // Not found; add new.
+ RETURN_IF_ERROR(RegisterBreakpoint(*breakpoint_def));
+ breakpoint_or = GetBreakpoint(breakpoint_def->breakpoint_id());
+ }
+ RETURN_IF_ERROR(breakpoint_or.status());
+ RETURN_IF_ERROR(
+ breakpoint_or.ValueOrDie()->MergeFrom(*breakpoint_def));
+ }
+ // TODO(benvanik): handle removals/deaths.
+ VLOG(2) << "Refreshed breakpoints!";
+ return OkStatus();
+ });
+ }
+
+ Status DispatchResponse(const rpc::Response& response) {
+ if (pending_responses_.empty()) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Response received but no request is pending";
+ }
+ auto type_callback = std::move(pending_responses_.front());
+ pending_responses_.pop();
+
+ if (response.status()) {
+ const auto& status = *response.status();
+ Status client_status =
+ StatusBuilder(static_cast<StatusCode>(status.code()), IREE_LOC)
+ << "Server request failed: " << WrapString(status.message());
+ return type_callback.second(std::move(client_status), response);
+ }
+
+ if (!response.message()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Response contains no message body";
+ }
+
+ if (response.message_type() != type_callback.first) {
+ return DataLossErrorBuilder(IREE_LOC)
+ << "Out of order response (mismatch pending)";
+ }
+ return type_callback.second(OkStatus(), response);
+ }
+
+ Status DispatchEvent(const rpc::ServicePacket& packet) {
+ switch (packet.event_type()) {
+#define DISPATCH_EVENT(event_name) \
+ case rpc::EventUnion::event_name##Event: { \
+ VLOG(2) << "EVENT: " << #event_name; \
+ return On##event_name(*packet.event_as_##event_name##Event()); \
+ }
+ DISPATCH_EVENT(ServiceShutdown);
+ DISPATCH_EVENT(ContextRegistered);
+ DISPATCH_EVENT(ContextUnregistered);
+ DISPATCH_EVENT(ModuleLoaded);
+ DISPATCH_EVENT(InvocationRegistered);
+ DISPATCH_EVENT(InvocationUnregistered);
+ DISPATCH_EVENT(BreakpointResolved);
+ DISPATCH_EVENT(BreakpointHit);
+ DISPATCH_EVENT(StepCompleted);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented debug service event: "
+ << static_cast<int>(packet.event_type());
+ }
+ }
+
+ StatusOr<TcpRemoteContext*> GetContext(int context_id) {
+ auto it = context_map_.find(context_id);
+ if (it == context_map_.end()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "Context was never registered";
+ }
+ return it->second.get();
+ }
+
+ Status OnServiceShutdown(const rpc::ServiceShutdownEvent& event) {
+ LOG(INFO) << "Service is shutting down; setting pending shutdown flag";
+ shutdown_pending_ = true;
+ return OkStatus();
+ }
+
+ Status RegisterContext(int context_id) {
+ auto context = absl::make_unique<TcpRemoteContext>(context_id, this);
+ VLOG(2) << "RegisterContext(" << context_id << ")";
+ auto context_ptr = context.get();
+ context_map_.insert({context_id, std::move(context)});
+ contexts_.push_back(context_ptr);
+ return listener_->OnContextRegistered(*context_ptr);
+ }
+
+ Status OnContextRegistered(const rpc::ContextRegisteredEvent& event) {
+ VLOG(2) << "OnContextRegistered(" << event.context_id() << ")";
+ auto it = context_map_.find(event.context_id());
+ if (it != context_map_.end()) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Context already registered";
+ }
+ return RegisterContext(event.context_id());
+ }
+
+ Status OnContextUnregistered(const rpc::ContextUnregisteredEvent& event) {
+ VLOG(2) << "OnContextUnregistered(" << event.context_id() << ")";
+ auto it = context_map_.find(event.context_id());
+ if (it == context_map_.end()) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Context was never registered";
+ }
+ auto context = std::move(it->second);
+ context_map_.erase(it);
+ auto list_it = std::find(contexts_.begin(), contexts_.end(), context.get());
+ contexts_.erase(list_it);
+ return listener_->OnContextUnregistered(*context);
+ }
+
+ Status OnModuleLoaded(const rpc::ModuleLoadedEvent& event) {
+ VLOG(2) << "OnModuleLoaded(" << event.context_id() << ", "
+ << WrapString(event.module_name()) << ")";
+ ASSIGN_OR_RETURN(auto* context, GetContext(event.context_id()));
+ auto module_name = WrapString(event.module_name());
+ auto module = absl::make_unique<TcpRemoteModule>(
+ event.context_id(), std::string(module_name), this);
+ auto* module_ptr = module.get();
+ RETURN_IF_ERROR(context->AddModule(std::move(module)));
+ return listener_->OnModuleLoaded(*context, *module_ptr);
+ }
+
+ StatusOr<TcpRemoteInvocation*> GetInvocation(int invocation_id) {
+ auto it = invocation_map_.find(invocation_id);
+ if (it == invocation_map_.end()) {
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Invocation was never registered";
+ }
+ return it->second.get();
+ }
+
+ Status RegisterInvocation(int invocation_id) {
+ VLOG(2) << "RegisterInvocation(" << invocation_id << ")";
+ auto invocation =
+ absl::make_unique<TcpRemoteInvocation>(invocation_id, this);
+ auto invocation_ptr = invocation.get();
+ invocation_map_.insert({invocation_id, std::move(invocation)});
+ invocations_.push_back(invocation_ptr);
+ RETURN_IF_ERROR(RefreshInvocations());
+ return listener_->OnInvocationRegistered(*invocation_ptr);
+ }
+
+ Status OnInvocationRegistered(const rpc::InvocationRegisteredEvent& event) {
+ VLOG(2) << "OnInvocationRegistered(" << event.invocation_id() << ")";
+ auto it = invocation_map_.find(event.invocation_id());
+ if (it != invocation_map_.end()) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Invocation already registered";
+ }
+ return RegisterInvocation(event.invocation_id());
+ }
+
+ Status OnInvocationUnregistered(
+ const rpc::InvocationUnregisteredEvent& event) {
+ VLOG(2) << "OnInvocationUnregistered(" << event.invocation_id() << ")";
+ auto it = invocation_map_.find(event.invocation_id());
+ if (it == invocation_map_.end()) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Invocation was never registered";
+ }
+ auto invocation = std::move(it->second);
+ invocation_map_.erase(it);
+ auto list_it =
+ std::find(invocations_.begin(), invocations_.end(), invocation.get());
+ invocations_.erase(list_it);
+ return listener_->OnInvocationUnregistered(*invocation);
+ }
+
+ StatusOr<TcpRemoteBreakpoint*> GetBreakpoint(int breakpoint_id) {
+ auto it = breakpoint_map_.find(breakpoint_id);
+ if (it == breakpoint_map_.end()) {
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Breakpoint " << breakpoint_id << " was never registered";
+ }
+ return it->second.get();
+ }
+
+ Status RegisterBreakpoint(const rpc::BreakpointDef& breakpoint_def) {
+ auto it = breakpoint_map_.find(breakpoint_def.breakpoint_id());
+ if (it != breakpoint_map_.end()) {
+ VLOG(2) << "RegisterBreakpoint(" << breakpoint_def.breakpoint_id()
+ << ") (update)";
+ return it->second->MergeFrom(breakpoint_def);
+ }
+
+ VLOG(2) << "RegisterBreakpoint(" << breakpoint_def.breakpoint_id() << ")";
+ auto breakpoint = absl::make_unique<TcpRemoteBreakpoint>(
+ breakpoint_def.breakpoint_id(),
+ static_cast<RemoteBreakpoint::Type>(breakpoint_def.breakpoint_type()),
+ this);
+ RETURN_IF_ERROR(breakpoint->MergeFrom(breakpoint_def));
+ breakpoints_.push_back(breakpoint.get());
+ breakpoint_map_.insert({breakpoint->id(), std::move(breakpoint)});
+ return OkStatus();
+ }
+
+ Status UnregisterBreakpoint(RemoteBreakpoint* breakpoint) {
+ VLOG(2) << "UnregisterBreakpoint(" << breakpoint->id() << ")";
+ auto it = breakpoint_map_.find(breakpoint->id());
+ if (it == breakpoint_map_.end()) {
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Breakpoint was never registered";
+ }
+ breakpoint_map_.erase(it);
+ auto list_it =
+ std::find(breakpoints_.begin(), breakpoints_.end(), breakpoint);
+ breakpoints_.erase(list_it);
+ return OkStatus();
+ }
+
+ Status OnBreakpointResolved(const rpc::BreakpointResolvedEvent& event) {
+ VLOG(2) << "OnBreakpointResolved(" << event.breakpoint()->breakpoint_id()
+ << ")";
+ auto it = breakpoint_map_.find(event.breakpoint()->breakpoint_id());
+ if (it == breakpoint_map_.end()) {
+ RETURN_IF_ERROR(RegisterBreakpoint(*event.breakpoint()));
+ } else {
+ RETURN_IF_ERROR(it->second->MergeFrom(*event.breakpoint()));
+ }
+ return OkStatus();
+ }
+
+ Status OnBreakpointHit(const rpc::BreakpointHitEvent& event) {
+ VLOG(2) << "OnBreakpointHit(" << event.breakpoint_id() << ")";
+ ASSIGN_OR_RETURN(auto* breakpoint, GetBreakpoint(event.breakpoint_id()));
+ auto* invocation_def = event.invocation();
+ auto invocation_or = GetInvocation(invocation_def->invocation_id());
+ if (!invocation_or.ok()) {
+ // Not found; add new.
+ RETURN_IF_ERROR(RegisterInvocation(invocation_def->invocation_id()));
+ invocation_or = GetInvocation(invocation_def->invocation_id());
+ }
+ RETURN_IF_ERROR(invocation_or.status());
+ RETURN_IF_ERROR(invocation_or.ValueOrDie()->MergeFrom(*invocation_def));
+ return listener_->OnBreakpointHit(*breakpoint, *invocation_or.ValueOrDie());
+ }
+
+ Status StepInvocation(rpc::StepInvocationRequestT* step_request,
+ std::function<void()> callback) {
+ FlatBufferBuilder fbb;
+ auto status = IssueRequest<rpc::StepInvocationRequest,
+ rpc::ResponseUnion::StepInvocationResponse>(
+ rpc::StepInvocationRequest::Pack(fbb, step_request), std::move(fbb),
+ [](Status status, const rpc::Response& response_union) -> Status {
+ return status;
+ });
+ RETURN_IF_ERROR(status);
+ pending_step_callbacks_[step_request->step_id] = std::move(callback);
+ return OkStatus();
+ }
+
+ Status OnStepCompleted(const rpc::StepCompletedEvent& event) {
+ VLOG(2) << "OnStepCompleted(" << event.step_id() << ")";
+
+ // Update all invocation states that are contained.
+ // This may only be a subset of relevant states.
+ for (auto* invocation_def : *event.invocations()) {
+ ASSIGN_OR_RETURN(auto invocation,
+ GetInvocation(invocation_def->invocation_id()));
+ RETURN_IF_ERROR(invocation->MergeFrom(*invocation_def));
+ }
+
+ // Dispatch step callback. Note that it may have been cancelled and that's
+ // ok. We'll just make ready to resume execution.
+ auto it = pending_step_callbacks_.find(event.step_id());
+ if (it != pending_step_callbacks_.end()) {
+ it->second();
+ pending_step_callbacks_.erase(it);
+ } else {
+ LOG(WARNING) << "Step " << event.step_id()
+ << " not found; was cancelled?";
+ RETURN_IF_ERROR(MakeReady());
+ }
+ return OkStatus();
+ }
+
+ int session_id_ = 123;
+
+ int fd_ = -1;
+ Listener* listener_;
+ bool shutdown_pending_ = false;
+ std::queue<std::pair<rpc::ResponseUnion, ResponseCallback>>
+ pending_responses_;
+
+ std::vector<RemoteContext*> contexts_;
+ absl::flat_hash_map<int, std::unique_ptr<TcpRemoteContext>> context_map_;
+ std::vector<RemoteInvocation*> invocations_;
+ absl::flat_hash_map<int, std::unique_ptr<TcpRemoteInvocation>>
+ invocation_map_;
+ std::vector<RemoteBreakpoint*> breakpoints_;
+ absl::flat_hash_map<int, std::unique_ptr<TcpRemoteBreakpoint>>
+ breakpoint_map_;
+
+ int next_step_id_ = 1;
+ absl::flat_hash_map<int, std::function<void()>> pending_step_callbacks_;
+};
+
+} // namespace
+
+// static
+StatusOr<std::unique_ptr<DebugClient>> DebugClient::Connect(
+ absl::string_view service_address, Listener* listener) {
+ // Parse address into hostname and port.
+ ASSIGN_OR_RETURN(auto hostname_port, ParseAddress(service_address));
+ if (hostname_port.second == 0) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "No port specified in service address; port must match the "
+ "server: "
+ << service_address;
+ }
+
+ // Attempt to resolve the address.
+ // Note that if we only wanted local debugging we could remove the dep on
+ // getaddrinfo/having a valid DNS setup.
+ addrinfo hints = {0};
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ addrinfo* resolved_address = nullptr;
+ auto port_str = std::to_string(hostname_port.second);
+ int getaddrinfo_ret = ::getaddrinfo(
+ hostname_port.first.c_str(), port_str.c_str(), &hints, &resolved_address);
+ if (getaddrinfo_ret != 0) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Unable to resolve debug service address for " << service_address
+ << ": (" << getaddrinfo_ret << ") "
+ << ::gai_strerror(getaddrinfo_ret);
+ }
+
+ // Attempt to connect with each address returned from the query.
+ int fd = -1;
+ for (addrinfo* rp = resolved_address; rp != nullptr; rp = rp->ai_next) {
+ fd = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
+ if (fd == -1) continue;
+ if (::connect(fd, rp->ai_addr, rp->ai_addrlen) == 0) {
+ break; // Success!
+ }
+ ::close(fd);
+ fd = -1;
+ }
+ ::freeaddrinfo(resolved_address);
+ if (fd == -1) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Unable to connect to " << service_address << " on any address: ("
+ << errno << ") " << ::strerror(errno);
+ }
+
+ LOG(INFO) << "Connected to debug service at " << service_address;
+
+ return TcpDebugClient::Create(fd, listener);
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/rt/debug/debug_server.h b/rt/debug/debug_server.h
new file mode 100644
index 0000000..6fd416d
--- /dev/null
+++ b/rt/debug/debug_server.h
@@ -0,0 +1,88 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_DEBUG_DEBUG_SERVER_H_
+#define IREE_RT_DEBUG_DEBUG_SERVER_H_
+
+#include "base/status.h"
+
+namespace iree {
+namespace rt {
+class Context;
+class Instance;
+class Invocation;
+class Module;
+} // namespace rt
+} // namespace iree
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+// Runtime debugging server.
+// Enabled only when compiled in (by defining IREE_DEBUG=1), this provides an
+// RPC server that allows debuggers to attach, query, and manipulate contexts.
+// This interface is used by various parts of the runtime such as dispatch to
+// query the current debug state and signal events.
+//
+// Thread-safe. Contexts may be registered and unregistered from any thread.
+class DebugServer {
+ public:
+ // Creates a new debug service listening on the provided |port|.
+ // Even when disabled the device can still be created however it will not
+ // perform any actual operations and act as if the debugger is not attached.
+ static StatusOr<std::unique_ptr<DebugServer>> Create(int listen_port);
+
+ // TODO(benvanik): ensure this gets optimized out when disabled.
+ // Seems to be the case: https://gcc.godbolt.org/z/0zf-L4
+ virtual ~DebugServer() = default;
+
+ // Attaches a callback that will be made when the debug server is shutting
+ // down. This can be used to keep resources alive that require the debugger.
+ // The callback will be made from a random thread.
+ virtual void AtExit(std::function<void()> callback) = 0;
+
+ // Blocks the caller until a client session connects and resumes all fibers.
+ // Returns AbortedError if a session connects/is connected but disconnects
+ // during the wait.
+ virtual Status WaitUntilSessionReady() = 0;
+
+ protected:
+ friend class ::iree::rt::Instance;
+
+ // Registers a context with the debug service.
+ // Ownership remains with the caller and UnregisterContext must be called
+ // prior to the context being destroyed.
+ virtual Status RegisterContext(Context* context) = 0;
+ virtual Status UnregisterContext(Context* context) = 0;
+
+ friend class ::iree::rt::Context;
+
+ // Registers a new module linked into an existing Context.
+ virtual Status RegisterContextModule(Context* context, Module* module) = 0;
+
+ friend class ::iree::rt::Invocation;
+
+ // Registers an invocation with the debug service.
+ // Ownership remains with the caller and UnregisterInvocation must be called
+ // prior to the fiber state being destroyed.
+ virtual Status RegisterInvocation(Invocation* invocation) = 0;
+ virtual Status UnregisterInvocation(Invocation* invocation) = 0;
+};
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_DEBUG_DEBUG_SERVER_H_
diff --git a/rt/debug/debug_server_disabled.cc b/rt/debug/debug_server_disabled.cc
new file mode 100644
index 0000000..28306a9
--- /dev/null
+++ b/rt/debug/debug_server_disabled.cc
@@ -0,0 +1,28 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/debug/debug_server.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+// static
+StatusOr<std::unique_ptr<DebugServer>> DebugServer::Create(int listen_port) {
+ return {nullptr};
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/rt/debug/debug_server_flags.cc b/rt/debug/debug_server_flags.cc
new file mode 100644
index 0000000..4455bf7
--- /dev/null
+++ b/rt/debug/debug_server_flags.cc
@@ -0,0 +1,80 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/debug/debug_server_flags.h"
+
+#include "absl/flags/flag.h"
+#include "absl/strings/str_cat.h"
+#include "base/memory.h"
+#include "base/status.h"
+
+#if defined(IREE_DEBUG_EMBEDDED_APP_PRESENT)
+#include "tools/debugger/debug_app_embedded.h"
+#endif // IREE_DEBUG_EMBEDDED_APP_PRESENT
+
+ABSL_FLAG(int32_t, iree_debug_service_port, 6000,
+ "TCP port to listen for debug service connections.");
+ABSL_FLAG(bool, iree_wait_for_debugger, false,
+ "Waits until a debugger connects to continue startup.");
+ABSL_FLAG(bool, iree_attach_debugger, false, "Attaches a debugger at startup.");
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+StatusOr<std::unique_ptr<DebugServer>> CreateDebugServerFromFlags() {
+ // Create the server based on whatever version is compiled in.
+ // Note that this will return nullptr if no server is available.
+ ASSIGN_OR_RETURN(
+ auto debug_server,
+ DebugServer::Create(absl::GetFlag(FLAGS_iree_debug_service_port)));
+ if (!debug_server) {
+ return nullptr;
+ }
+
+#if defined(IREE_DEBUG_EMBEDDED_APP_PRESENT)
+ // If the embedded debug UI is present then we can launch that now.
+ std::unique_ptr<EmbeddedDebugger> debugger;
+ if (absl::GetFlag(FLAGS_iree_attach_debugger)) {
+ LOG(INFO) << "Attaching debugger at startup...";
+ ASSIGN_OR_RETURN(
+ debugger,
+ AttachDebugger(absl::StrCat(
+ "localhost:", absl::GetFlag(FLAGS_iree_debug_service_port))));
+ RETURN_IF_ERROR(debug_server->WaitUntilSessionReady());
+ LOG(INFO) << "Debugger attached";
+ // TODO(benvanik): C++14 to avoid this.
+ auto debugger_baton = IreeMoveToLambda(debugger);
+ debug_server->AtExit([debugger_baton]() { debugger_baton.value.reset(); });
+ }
+#else
+ if (absl::GetFlag(FLAGS_iree_attach_debugger)) {
+ LOG(WARNING) << "--iree_attach_debugger specified but no embedded debugger "
+ "is present. Build with --define=IREE_DEBUG=1.";
+ }
+#endif // IREE_DEBUG_EMBEDDED_APP_PRESENT
+
+ // Wait for a debugger to connect.
+ if (absl::GetFlag(FLAGS_iree_wait_for_debugger)) {
+ LOG(INFO) << "Waiting for a debugger to connect...";
+ RETURN_IF_ERROR(debug_server->WaitUntilSessionReady());
+ LOG(INFO) << "Debugger ready, resuming...";
+ }
+
+ return std::move(debug_server);
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/rt/debug/debug_server_flags.h b/rt/debug/debug_server_flags.h
new file mode 100644
index 0000000..13dc1b2
--- /dev/null
+++ b/rt/debug/debug_server_flags.h
@@ -0,0 +1,33 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_
+#define IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_
+
+#include "base/status.h"
+#include "rt/debug/debug_server.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+// Creates a debug server based on the current --iree_* debug flags.
+// Returns nullptr if no server is compiled in or the flags are not set.
+StatusOr<std::unique_ptr<DebugServer>> CreateDebugServerFromFlags();
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_
diff --git a/rt/debug/debug_server_tcp.cc b/rt/debug/debug_server_tcp.cc
new file mode 100644
index 0000000..a4bf868
--- /dev/null
+++ b/rt/debug/debug_server_tcp.cc
@@ -0,0 +1,459 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <cerrno>
+#include <exception>
+#include <thread> // NOLINT
+
+#include "absl/base/thread_annotations.h"
+#include "absl/memory/memory.h"
+#include "absl/synchronization/mutex.h"
+#include "base/status.h"
+#include "flatbuffers/flatbuffers.h"
+#include "rt/debug/debug_server.h"
+#include "rt/debug/debug_service.h"
+#include "rt/debug/debug_tcp_util.h"
+#include "schemas/debug_service_generated.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+namespace {
+
+// Writes the given typed response message to the given fd by wrapping it in
+// a size-prefixed rpc::Request union.
+//
+// Example:
+// ::flatbuffers::FlatBufferBuilder fbb;
+// rpc::SuspendInvocationResponseBuilder response(fbb);
+// RETURN_IF_ERROR(WriteResponse(fd_, response.Finish(), std::move(fbb)));
+template <typename T>
+Status WriteResponse(int fd, ::flatbuffers::Offset<T> message_offs,
+ ::flatbuffers::FlatBufferBuilder fbb) {
+ rpc::ResponseBuilder response_builder(fbb);
+ response_builder.add_message_type(rpc::ResponseUnionTraits<T>::enum_value);
+ response_builder.add_message(message_offs.Union());
+ auto response_offs = response_builder.Finish();
+ rpc::ServicePacketBuilder packet_builder(fbb);
+ packet_builder.add_response(response_offs);
+ fbb.FinishSizePrefixed(packet_builder.Finish());
+ return tcp::WriteBuffer(fd, fbb.Release());
+}
+
+class TcpDebugSession : public DebugSession {
+ public:
+ using ClosedCallback =
+ std::function<void(TcpDebugSession* session, Status status)>;
+
+ static StatusOr<std::unique_ptr<TcpDebugSession>> Accept(
+ DebugService* debug_service, int client_fd,
+ ClosedCallback closed_callback) {
+ VLOG(2) << "Client " << client_fd << ": Setting up socket options...";
+ // Disable Nagel's algorithm to ensure we have low latency.
+ RETURN_IF_ERROR(tcp::ToggleSocketNagelsAlgorithm(client_fd, false));
+ // Enable keepalive assuming the client is local and this high freq is ok.
+ RETURN_IF_ERROR(tcp::ToggleSocketLocalKeepalive(client_fd, true));
+ // Linger around for a bit to flush all data.
+ RETURN_IF_ERROR(tcp::ToggleSocketLinger(client_fd, true));
+
+ return absl::make_unique<TcpDebugSession>(debug_service, client_fd,
+ std::move(closed_callback));
+ }
+
+ TcpDebugSession(DebugService* debug_service, int client_fd,
+ ClosedCallback closed_callback)
+ : debug_service_(debug_service),
+ client_fd_(client_fd),
+ closed_callback_(std::move(closed_callback)) {
+ CHECK_OK(debug_service_->RegisterDebugSession(this));
+ session_thread_ = std::thread([this]() { SessionThread(); });
+ }
+
+ ~TcpDebugSession() override {
+ CHECK_OK(debug_service_->UnregisterDebugSession(this));
+ VLOG(2) << "Client " << client_fd_ << ": Shutting down session socket...";
+ ::shutdown(client_fd_, SHUT_RD);
+ if (session_thread_.joinable() &&
+ session_thread_.get_id() != std::this_thread::get_id()) {
+ VLOG(2) << "Client " << client_fd_ << ": Joining socket thread...";
+ session_thread_.join();
+ VLOG(2) << "Client " << client_fd_ << ": Joined socket thread!";
+ } else {
+ VLOG(2) << "Client " << client_fd_ << ": Detaching socket thread...";
+ session_thread_.detach();
+ }
+ VLOG(2) << "Client " << client_fd_ << ": Closing session socket...";
+ ::close(client_fd_);
+ VLOG(2) << "Client " << client_fd_ << ": Closed session socket!";
+ client_fd_ = -1;
+ }
+
+ Status OnServiceShutdown() {
+ VLOG(2) << "Client " << client_fd_ << ": Post OnServiceShutdown()";
+ ::flatbuffers::FlatBufferBuilder fbb;
+ rpc::ServiceShutdownEventBuilder event(fbb);
+ return PostEvent(event.Finish(), std::move(fbb));
+ }
+
+ Status OnContextRegistered(Context* context) override {
+ VLOG(2) << "Client " << client_fd_ << ": Post OnContextRegistered("
+ << context->id() << ")";
+ ::flatbuffers::FlatBufferBuilder fbb;
+ rpc::ContextRegisteredEventBuilder event(fbb);
+ event.add_context_id(context->id());
+ return PostEvent(event.Finish(), std::move(fbb));
+ }
+ Status OnContextUnregistered(Context* context) override {
+ VLOG(2) << "Client " << client_fd_ << ": Post OnContextUnregistered("
+ << context->id() << ")";
+ ::flatbuffers::FlatBufferBuilder fbb;
+ rpc::ContextUnregisteredEventBuilder event(fbb);
+ event.add_context_id(context->id());
+ return PostEvent(event.Finish(), std::move(fbb));
+ }
+
+ Status OnModuleLoaded(Context* context, Module* module) override {
+ VLOG(2) << "Client " << client_fd_ << ": Post OnModuleLoaded("
+ << context->id() << ", " << module->name() << ")";
+ ::flatbuffers::FlatBufferBuilder fbb;
+ auto module_name_offs =
+ fbb.CreateString(module->name().data(), module->name().size());
+ rpc::ModuleLoadedEventBuilder event(fbb);
+ event.add_context_id(context->id());
+ event.add_module_name(module_name_offs);
+ return PostEvent(event.Finish(), std::move(fbb));
+ }
+
+ Status OnInvocationRegistered(Invocation* invocation) override {
+ VLOG(2) << "Client " << client_fd_ << ": Post OnInvocationRegistered("
+ << invocation->id() << ")";
+ ::flatbuffers::FlatBufferBuilder fbb;
+ rpc::InvocationRegisteredEventBuilder event(fbb);
+ event.add_invocation_id(invocation->id());
+ return PostEvent(event.Finish(), std::move(fbb));
+ }
+ Status OnInvocationUnregistered(Invocation* invocation) override {
+ VLOG(2) << "Client " << client_fd_ << ": Post OnInvocationUnregistered("
+ << invocation->id() << ")";
+ ::flatbuffers::FlatBufferBuilder fbb;
+ rpc::InvocationUnregisteredEventBuilder event(fbb);
+ event.add_invocation_id(invocation->id());
+ return PostEvent(event.Finish(), std::move(fbb));
+ }
+
+ Status OnBreakpointResolved(const rpc::BreakpointDefT& breakpoint,
+ Context* context) override {
+ VLOG(2) << "Client " << client_fd_ << ": Post OnBreakpointResolved("
+ << breakpoint.breakpoint_id << ", " << context->id() << ", "
+ << breakpoint.function_ordinal << ")";
+ rpc::BreakpointResolvedEventT event;
+ event.breakpoint = absl::make_unique<rpc::BreakpointDefT>();
+ *event.breakpoint = breakpoint;
+ event.context_id = context->id();
+ ::flatbuffers::FlatBufferBuilder fbb;
+ return PostEvent(rpc::BreakpointResolvedEvent::Pack(fbb, &event),
+ std::move(fbb));
+ }
+
+ Status OnBreakpointHit(int breakpoint_id,
+ const Invocation& invocation) override {
+ VLOG(2) << "Client " << client_fd_ << ": Post OnBreakpointHit("
+ << breakpoint_id << ", " << invocation.id() << ")";
+ ::flatbuffers::FlatBufferBuilder fbb;
+ ASSIGN_OR_RETURN(auto invocation_offs,
+ debug_service_->SerializeInvocation(invocation, &fbb));
+ rpc::BreakpointHitEventBuilder event(fbb);
+ event.add_breakpoint_id(breakpoint_id);
+ event.add_invocation(invocation_offs);
+ return PostEvent(event.Finish(), std::move(fbb));
+ }
+
+ private:
+ void SessionThread() {
+ VLOG(2) << "Client " << client_fd_ << ": Thread entry";
+ Status session_status = OkStatus();
+ while (session_status.ok()) {
+ auto buffer_or = tcp::ReadBuffer<rpc::Request>(client_fd_);
+ if (!buffer_or.ok()) {
+ if (IsCancelled(buffer_or.status())) {
+ // Graceful shutdown.
+ VLOG(2) << "Client " << client_fd_ << ": Graceful shutdown requested";
+ break;
+ }
+ // Error reading.
+ session_status = std::move(buffer_or).status();
+ LOG(ERROR) << "Client " << client_fd_
+ << ": Error reading request buffer: " << session_status;
+ break;
+ }
+ auto request_buffer = std::move(buffer_or).ValueOrDie();
+ session_status = DispatchRequest(request_buffer.GetRoot());
+ if (!session_status.ok()) {
+ LOG(ERROR) << "Client " << client_fd_
+ << ": Error dispatching request: " << session_status;
+ break;
+ }
+ }
+ VLOG(2) << "Client " << client_fd_ << ": Thread exit";
+ AbortSession(session_status);
+ }
+
+ void AbortSession(Status status) {
+ if (status.ok()) {
+ VLOG(2) << "Debug client disconnected";
+ } else {
+ LOG(ERROR) << "Debug session aborted; " << status;
+ ::flatbuffers::FlatBufferBuilder fbb;
+ auto message_offs =
+ fbb.CreateString(status.message().data(), status.message().size());
+ rpc::StatusBuilder status_builder(fbb);
+ status_builder.add_code(static_cast<int>(status.code()));
+ status_builder.add_message(message_offs);
+ auto status_offs = status_builder.Finish();
+ rpc::ResponseBuilder response(fbb);
+ response.add_status(status_offs);
+ fbb.FinishSizePrefixed(response.Finish());
+ tcp::WriteBuffer(client_fd_, fbb.Release()).IgnoreError();
+ }
+ closed_callback_(this, std::move(status));
+ }
+
+ template <typename T>
+ Status PostEvent(::flatbuffers::Offset<T> event_offs,
+ ::flatbuffers::FlatBufferBuilder fbb) {
+ rpc::ServicePacketBuilder packet_builder(fbb);
+ packet_builder.add_event_type(rpc::EventUnionTraits<T>::enum_value);
+ packet_builder.add_event(event_offs.Union());
+ fbb.FinishSizePrefixed(packet_builder.Finish());
+ return tcp::WriteBuffer(client_fd_, fbb.Release());
+ }
+
+ Status DispatchRequest(const rpc::Request& request) {
+ ::flatbuffers::FlatBufferBuilder fbb;
+ switch (request.message_type()) {
+#define DISPATCH_REQUEST(method_name) \
+ case rpc::RequestUnion::method_name##Request: { \
+ VLOG(2) << "Client " << client_fd_ \
+ << ": DispatchRequest(" #method_name ")..."; \
+ ASSIGN_OR_RETURN(auto response_offs, \
+ debug_service_->method_name( \
+ *request.message_as_##method_name##Request(), &fbb)); \
+ return WriteResponse(client_fd_, response_offs, std::move(fbb)); \
+ }
+ DISPATCH_REQUEST(MakeReady);
+ DISPATCH_REQUEST(GetStatus);
+ DISPATCH_REQUEST(ListContexts);
+ DISPATCH_REQUEST(GetModule);
+ DISPATCH_REQUEST(GetFunction);
+ DISPATCH_REQUEST(ListInvocations);
+ DISPATCH_REQUEST(SuspendInvocations);
+ DISPATCH_REQUEST(ResumeInvocations);
+ DISPATCH_REQUEST(StepInvocation);
+ DISPATCH_REQUEST(GetInvocationLocal);
+ DISPATCH_REQUEST(SetInvocationLocal);
+ DISPATCH_REQUEST(ListBreakpoints);
+ DISPATCH_REQUEST(AddBreakpoint);
+ DISPATCH_REQUEST(RemoveBreakpoint);
+ DISPATCH_REQUEST(StartProfiling);
+ DISPATCH_REQUEST(StopProfiling);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented debug service request: "
+ << static_cast<int>(request.message_type());
+ }
+ }
+
+ DebugService* debug_service_;
+ int client_fd_;
+ ClosedCallback closed_callback_;
+ std::thread session_thread_;
+};
+
+class TcpDebugServer final : public DebugServer {
+ public:
+ static StatusOr<std::unique_ptr<TcpDebugServer>> Listen(int port) {
+ // We support both IPv4 and IPv6 by using the IN6ADDR_ANY. This requires
+ // that we setup the socket as INET6 and enable reuse (so the same port can
+ // be bound for both IPv4 and IPv6).
+ int listen_fd = ::socket(AF_INET6, SOCK_STREAM, 0);
+ RETURN_IF_ERROR(tcp::ToggleSocketAddressReuse(listen_fd, true));
+
+ struct sockaddr_in6 socket_addr = {0};
+ socket_addr.sin6_family = AF_INET6;
+ socket_addr.sin6_port = htons(port);
+ socket_addr.sin6_addr = in6addr_any;
+ if (::bind(listen_fd, reinterpret_cast<struct sockaddr*>(&socket_addr),
+ sizeof(socket_addr)) < 0) {
+ return AlreadyExistsErrorBuilder(IREE_LOC)
+ << "Unable to bind socket to port " << port << ": (" << errno
+ << ") " << ::strerror(errno);
+ }
+ if (::listen(listen_fd, 1)) {
+ ::close(listen_fd);
+ return AlreadyExistsErrorBuilder(IREE_LOC)
+ << "Unable to listen on port " << port << ": (" << errno << ") "
+ << ::strerror(errno);
+ }
+ return absl::make_unique<TcpDebugServer>(listen_fd);
+ }
+
+ TcpDebugServer(int listen_fd) : listen_fd_(listen_fd) {
+ server_thread_ = std::thread([this]() { ListenThread(); });
+ }
+
+ ~TcpDebugServer() ABSL_LOCKS_EXCLUDED(mutex_) override {
+ absl::ReleasableMutexLock lock(&mutex_);
+ LOG(INFO) << "Shutting down debug server...";
+
+ // Notify all sessions.
+ for (auto& session : sessions_) {
+ session->OnServiceShutdown().IgnoreError();
+ }
+
+ // Shut down listen socket first so that we can't accept new connections.
+ VLOG(2) << "Shutting down listen socket...";
+ ::shutdown(listen_fd_, SHUT_RDWR);
+ if (server_thread_.joinable()) {
+ VLOG(2) << "Joining listen thread...";
+ server_thread_.join();
+ VLOG(2) << "Joined listen thread!";
+ }
+ VLOG(2) << "Closing listen socket...";
+ ::close(listen_fd_);
+ listen_fd_ = -1;
+ VLOG(2) << "Closed listen socket!";
+
+ // Kill all active sessions. Note that we must do this outside of our lock.
+ std::vector<std::unique_ptr<TcpDebugSession>> sessions =
+ std::move(sessions_);
+ std::vector<std::function<void()>> at_exit_callbacks =
+ std::move(at_exit_callbacks_);
+ lock.Release();
+ VLOG(2) << "Clearing live sessions...";
+ sessions.clear();
+ VLOG(2) << "Calling AtExit callbacks...";
+ for (auto& callback : at_exit_callbacks) {
+ callback();
+ }
+ LOG(INFO) << "Debug server shutdown!";
+ }
+
+ DebugService* debug_service() { return &debug_service_; }
+
+ Status AcceptNewSession(int client_fd) {
+ LOG(INFO) << "Accepting new client session as " << client_fd;
+ ASSIGN_OR_RETURN(auto session,
+ TcpDebugSession::Accept(
+ &debug_service_, client_fd,
+ [this](TcpDebugSession* session, Status status) {
+ absl::MutexLock lock(&mutex_);
+ for (auto it = sessions_.begin();
+ it != sessions_.end(); ++it) {
+ if (it->get() == session) {
+ sessions_.erase(it);
+ break;
+ }
+ }
+ return OkStatus();
+ }));
+
+ absl::MutexLock lock(&mutex_);
+ sessions_.push_back(std::move(session));
+ return OkStatus();
+ }
+
+ void AtExit(std::function<void()> callback) override {
+ absl::MutexLock lock(&mutex_);
+ at_exit_callbacks_.push_back(std::move(callback));
+ }
+
+ Status WaitUntilSessionReady() override {
+ return debug_service_.WaitUntilAllSessionsReady();
+ }
+
+ protected:
+ Status RegisterContext(Context* context) override {
+ return debug_service_.RegisterContext(context);
+ }
+ Status UnregisterContext(Context* context) override {
+ return debug_service_.UnregisterContext(context);
+ }
+ Status RegisterContextModule(Context* context, Module* module) override {
+ return debug_service_.RegisterContextModule(context, module);
+ }
+ Status RegisterInvocation(Invocation* invocation) override {
+ return debug_service_.RegisterInvocation(invocation);
+ }
+ Status UnregisterInvocation(Invocation* invocation) override {
+ return debug_service_.UnregisterInvocation(invocation);
+ }
+
+ private:
+ void ListenThread() {
+ VLOG(2) << "Listen thread entry";
+ while (true) {
+ struct sockaddr_in accept_socket_addr;
+ socklen_t accept_socket_addr_length = sizeof(accept_socket_addr);
+ int accepted_fd = ::accept(
+ listen_fd_, reinterpret_cast<struct sockaddr*>(&accept_socket_addr),
+ &accept_socket_addr_length);
+ if (accepted_fd < 0) {
+ if (errno == EINVAL) {
+ // Shutting down gracefully.
+ break;
+ }
+ // We may be able to recover from some of these cases, but... shrug.
+ LOG(FATAL) << "Failed to accept client socket: (" << errno << ") "
+ << ::strerror(errno);
+ break;
+ }
+ auto accept_status = AcceptNewSession(accepted_fd);
+ if (!accept_status.ok()) {
+ LOG(ERROR) << "Failed to accept incoming debug client: "
+ << accept_status;
+ }
+ }
+ VLOG(2) << "Listen thread exit";
+ }
+
+ int listen_fd_;
+ std::thread server_thread_;
+
+ absl::Mutex mutex_;
+ std::vector<std::unique_ptr<TcpDebugSession>> sessions_
+ ABSL_GUARDED_BY(mutex_);
+ std::vector<std::function<void()>> at_exit_callbacks_ ABSL_GUARDED_BY(mutex_);
+
+ DebugService debug_service_;
+};
+
+} // namespace
+
+// static
+StatusOr<std::unique_ptr<DebugServer>> DebugServer::Create(int listen_port) {
+ ASSIGN_OR_RETURN(auto debug_server, TcpDebugServer::Listen(listen_port));
+ LOG(INFO) << "Debug server listening on localhost:" << listen_port;
+ return debug_server;
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/rt/debug/debug_service.cc b/rt/debug/debug_service.cc
new file mode 100644
index 0000000..5a21ded
--- /dev/null
+++ b/rt/debug/debug_service.cc
@@ -0,0 +1,850 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/debug/debug_service.h"
+
+#include <algorithm>
+#include <memory>
+
+#include "absl/strings/str_join.h"
+#include "absl/synchronization/mutex.h"
+#include "base/flatbuffer_util.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+#include "rt/instance.h"
+#include "schemas/debug_service_generated.h"
+#include "schemas/reflection_data.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+namespace {
+
+using ::flatbuffers::FlatBufferBuilder;
+using ::flatbuffers::Offset;
+using ::iree::hal::BufferView;
+
+int32_t NextUniqueBreakpointId() {
+ static std::atomic<int32_t> next_id = 0;
+ return ++next_id;
+}
+
+// Gets an embedded flatbuffers reflection schema.
+const ::reflection::Schema& GetSchema(const char* schema_name) {
+ for (const auto* file_toc = schemas::reflection_data_create();
+ file_toc != nullptr; ++file_toc) {
+ if (std::strcmp(file_toc->name, schema_name) == 0) {
+ return *::reflection::GetSchema(file_toc->data);
+ }
+ }
+ LOG(FATAL) << "FlatBuffer schema '" << schema_name
+ << "' not found in binary; ensure it is in :reflection_data";
+}
+
+// Recursively copies a flatbuffer table, returning the root offset in |fbb|.
+template <typename T>
+StatusOr<Offset<T>> DeepCopyTable(const char* schema_name, const T& table_def,
+ FlatBufferBuilder* fbb) {
+ const auto* root_table =
+ reinterpret_cast<const ::flatbuffers::Table*>(std::addressof(table_def));
+ const auto& schema = GetSchema(schema_name);
+ return {::flatbuffers::CopyTable(*fbb, schema, *schema.root_table(),
+ *root_table,
+ /*use_string_pooling=*/false)
+ .o};
+}
+
+// Serializes a buffer_view value, optionally including the entire buffer
+// contents.
+StatusOr<Offset<rpc::BufferViewDef>> SerializeBufferView(
+ const BufferView& buffer_view, bool include_buffer_contents,
+ FlatBufferBuilder* fbb) {
+ auto shape_offs = fbb->CreateVector(buffer_view.shape.subspan().data(),
+ buffer_view.shape.subspan().size());
+ rpc::BufferViewDefBuilder value(*fbb);
+ value.add_is_valid(buffer_view.buffer != nullptr);
+ value.add_shape(shape_offs);
+ value.add_element_size(buffer_view.element_size);
+ if (include_buffer_contents) {
+ // TODO(benvanik): add buffer data.
+ }
+ return value.Finish();
+}
+
+// Serializes a stack frame.
+StatusOr<Offset<rpc::StackFrameDef>> SerializeStackFrame(
+ const StackFrame& stack_frame, FlatBufferBuilder* fbb) {
+ ASSIGN_OR_RETURN(int function_ordinal,
+ stack_frame.module().function_table().LookupFunctionOrdinal(
+ stack_frame.function()));
+ auto module_name_offs = fbb->CreateString(stack_frame.module().name().data(),
+ stack_frame.module().name().size());
+ std::vector<Offset<rpc::BufferViewDef>> local_offs_list;
+ for (const auto& local : stack_frame.locals()) {
+ ASSIGN_OR_RETURN(
+ auto local_offs,
+ SerializeBufferView(local, /*include_buffer_contents=*/false, fbb));
+ local_offs_list.push_back(local_offs);
+ }
+ auto locals_offs = fbb->CreateVector(local_offs_list);
+ rpc::StackFrameDefBuilder sfb(*fbb);
+ sfb.add_module_name(module_name_offs);
+ sfb.add_function_ordinal(function_ordinal);
+ sfb.add_offset(stack_frame.offset());
+ sfb.add_locals(locals_offs);
+ return sfb.Finish();
+}
+
+// Resolves a local from a invocation:frame:local_index to a BufferView.
+StatusOr<BufferView*> ResolveInvocationLocal(Invocation* invocation,
+ int frame_index, int local_index) {
+ auto frames = invocation->mutable_stack()->mutable_frames();
+ if (frame_index < 0 || frame_index > frames.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Frame index " << frame_index << " out of bounds ("
+ << frames.size() << ")";
+ }
+ auto locals = frames[frame_index].mutable_locals();
+ if (local_index < 0 || local_index > locals.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Local index " << local_index << " out of bounds ("
+ << locals.size() << ")";
+ }
+ return &locals[local_index];
+}
+
+// Suspends a set of invocations and blocks until all have been suspended (or
+// one or more fails to suspend). This works only when the caller is *not* one
+// of the threads executing a invocation in |invocations| (this normally
+// shouldn't happen, but may if we support eval()-like semantics).
+Status SuspendInvocationsAndWait(absl::Span<Invocation*> invocations) {
+ absl::Mutex suspend_mutex;
+ Status one_suspend_status = OkStatus();
+ std::list<int> pending_suspend_ids;
+ for (auto* invocation : invocations) {
+ pending_suspend_ids.push_back(invocation->id());
+ }
+ for (auto* invocation : invocations) {
+ auto suspend_callback = [&, invocation](Status suspend_status) {
+ absl::MutexLock lock(&suspend_mutex);
+ auto it = std::find(pending_suspend_ids.begin(),
+ pending_suspend_ids.end(), invocation->id());
+ CHECK(it != pending_suspend_ids.end());
+ pending_suspend_ids.erase(it);
+ if (!suspend_status.ok()) {
+ one_suspend_status = std::move(suspend_status);
+ }
+ };
+ RETURN_IF_ERROR(invocation->Suspend(suspend_callback));
+ }
+ suspend_mutex.LockWhen(absl::Condition(
+ +[](std::list<int>* pending_suspend_ids) {
+ return pending_suspend_ids->empty();
+ },
+ &pending_suspend_ids));
+ suspend_mutex.Unlock();
+ return one_suspend_status;
+}
+
+} // namespace
+
+Status DebugService::SuspendAllInvocations() {
+ VLOG(2) << "SuspendAllInvocations";
+ for (auto* invocation : invocations_) {
+ RETURN_IF_ERROR(invocation->Suspend());
+ }
+ return OkStatus();
+}
+
+Status DebugService::ResumeAllInvocations() {
+ VLOG(2) << "ResumeAllInvocations";
+ for (auto* invocation : invocations_) {
+ RETURN_IF_ERROR(invocation->Resume());
+ }
+ return OkStatus();
+}
+
+Status DebugService::RegisterContext(Context* context) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(2) << "RegisterContext(" << context->id() << ")";
+ RETURN_IF_ERROR(SuspendAllInvocations());
+ RETURN_IF_ERROR(UnreadyAllSessions());
+ contexts_.push_back(context);
+ for (auto* session : sessions_) {
+ RETURN_IF_ERROR(session->OnContextRegistered(context));
+ }
+ RETURN_IF_ERROR(ResumeAllInvocations());
+ return OkStatus();
+}
+
+Status DebugService::UnregisterContext(Context* context) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(2) << "UnregisterContext(" << context->id() << ")";
+ auto it = std::find(contexts_.begin(), contexts_.end(), context);
+ if (it == contexts_.end()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "Context not registered";
+ }
+ RETURN_IF_ERROR(SuspendAllInvocations());
+ RETURN_IF_ERROR(UnreadyAllSessions());
+ for (auto* session : sessions_) {
+ RETURN_IF_ERROR(session->OnContextUnregistered(context));
+ }
+ contexts_.erase(it);
+ RETURN_IF_ERROR(ResumeAllInvocations());
+ return OkStatus();
+}
+
+StatusOr<Context*> DebugService::GetContext(int context_id) const {
+ for (auto* context : contexts_) {
+ if (context->id() == context_id) {
+ return context;
+ }
+ }
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Context with ID " << context_id
+ << " not registered with the debug service";
+}
+
+Status DebugService::RegisterContextModule(Context* context, Module* module) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(2) << "RegisterContextModule(" << context->id() << ", " << module->name()
+ << ")";
+ RETURN_IF_ERROR(SuspendAllInvocations());
+ RETURN_IF_ERROR(UnreadyAllSessions());
+ RETURN_IF_ERROR(RegisterModuleBreakpoints(context, module));
+ for (auto* session : sessions_) {
+ RETURN_IF_ERROR(session->OnModuleLoaded(context, module));
+ }
+ RETURN_IF_ERROR(ResumeAllInvocations());
+ return OkStatus();
+}
+
+StatusOr<Module*> DebugService::GetModule(int context_id,
+ absl::string_view module_name) const {
+ ASSIGN_OR_RETURN(auto* context, GetContext(context_id));
+ for (const auto& module : context->modules()) {
+ if (module->name() == module_name) {
+ return module.get();
+ }
+ }
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Module '" << module_name << "' not found on context "
+ << context_id;
+}
+
+Status DebugService::RegisterInvocation(Invocation* invocation) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(2) << "RegisterInvocation(" << invocation->id() << ")";
+ RETURN_IF_ERROR(SuspendAllInvocations());
+ RETURN_IF_ERROR(UnreadyAllSessions());
+ invocations_.push_back(invocation);
+ if (sessions_unready_) {
+ // Suspend immediately as a debugger is not yet read.
+ RETURN_IF_ERROR(invocation->Suspend());
+ }
+ for (auto* session : sessions_) {
+ RETURN_IF_ERROR(session->OnInvocationRegistered(invocation));
+ }
+ RETURN_IF_ERROR(ResumeAllInvocations());
+ return OkStatus();
+}
+
+Status DebugService::UnregisterInvocation(Invocation* invocation) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(2) << "UnregisterInvocation(" << invocation->id() << ")";
+ auto it = std::find(invocations_.begin(), invocations_.end(), invocation);
+ if (it == invocations_.end()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "Invocation state not registered";
+ }
+ RETURN_IF_ERROR(SuspendAllInvocations());
+ RETURN_IF_ERROR(UnreadyAllSessions());
+ for (auto* session : sessions_) {
+ RETURN_IF_ERROR(session->OnInvocationUnregistered(invocation));
+ }
+ invocations_.erase(it);
+ RETURN_IF_ERROR(ResumeAllInvocations());
+ return OkStatus();
+}
+
+StatusOr<Invocation*> DebugService::GetInvocation(int invocation_id) const {
+ for (auto* invocation : invocations_) {
+ if (invocation->id() == invocation_id) {
+ return invocation;
+ }
+ }
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Invocation state with ID " << invocation_id
+ << " not registered with the debug service";
+}
+
+StatusOr<Offset<rpc::InvocationDef>> DebugService::SerializeInvocation(
+ const Invocation& invocation, FlatBufferBuilder* fbb) {
+ std::vector<Offset<rpc::StackFrameDef>> frame_offs_list;
+ for (const auto& frame : invocation.stack().frames()) {
+ ASSIGN_OR_RETURN(auto frame_offs, SerializeStackFrame(frame, fbb));
+ frame_offs_list.push_back(frame_offs);
+ }
+ auto frames_offs = fbb->CreateVector(frame_offs_list);
+ rpc::InvocationDefBuilder fsb(*fbb);
+ fsb.add_invocation_id(invocation.id());
+ fsb.add_frames(frames_offs);
+ return fsb.Finish();
+}
+
+Status DebugService::RegisterDebugSession(DebugSession* session) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(2) << "RegisterDebugSession(" << session->id() << ")";
+ sessions_.push_back(session);
+ if (session->is_ready()) {
+ ++sessions_ready_;
+ } else {
+ // Immediately suspend all invocations until the session readies up (or
+ // disconnects).
+ ++sessions_unready_;
+ RETURN_IF_ERROR(SuspendAllInvocations());
+ }
+ return OkStatus();
+}
+
+Status DebugService::UnregisterDebugSession(DebugSession* session) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(2) << "UnregisterDebugSession(" << session->id() << ")";
+ auto it = std::find(sessions_.begin(), sessions_.end(), session);
+ if (it == sessions_.end()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "Session not registered";
+ }
+ sessions_.erase(it);
+ if (session->is_ready()) {
+ --sessions_ready_;
+ } else {
+ // If the session never readied up then we still have all invocations
+ // suspended waiting for it. We should resume so that we don't block
+ // forever.
+ --sessions_unready_;
+ RETURN_IF_ERROR(ResumeAllInvocations());
+ }
+ return OkStatus();
+}
+
+Status DebugService::WaitUntilAllSessionsReady() {
+ VLOG(1) << "Waiting until all sessions are ready...";
+ struct CondState {
+ DebugService* service;
+ bool had_sessions;
+ bool consider_aborted;
+ } cond_state;
+ {
+ absl::MutexLock lock(&mutex_);
+ cond_state.service = this;
+ cond_state.had_sessions = !sessions_.empty();
+ cond_state.consider_aborted = false;
+ }
+ mutex_.LockWhen(absl::Condition(
+ +[](CondState* cond_state) {
+ cond_state->service->mutex_.AssertHeld();
+ if (cond_state->service->sessions_ready_ > 0) {
+ // One or more sessions are ready.
+ return true;
+ }
+ if (cond_state->service->sessions_unready_ > 0) {
+ // One or more sessions are connected but not yet ready.
+ cond_state->had_sessions = true;
+ return false;
+ }
+ if (cond_state->had_sessions &&
+ cond_state->service->sessions_.empty()) {
+ // We had sessions but now we don't, consider this an error and bail.
+ // This can happen when a session connects but never readies up.
+ cond_state->consider_aborted = true;
+ return true;
+ }
+ return false;
+ },
+ &cond_state));
+ mutex_.Unlock();
+ if (cond_state.consider_aborted) {
+ return AbortedErrorBuilder(IREE_LOC)
+ << "At least one session connected but never readied up";
+ }
+ VLOG(1) << "Sessions ready, resuming";
+ return OkStatus();
+}
+
+StatusOr<Offset<rpc::MakeReadyResponse>> DebugService::MakeReady(
+ const rpc::MakeReadyRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: MakeReady()";
+ // TODO(benvanik): support more than one session.
+ CHECK_LE(sessions_.size(), 1) << "Only one session is currently supported";
+ if (!sessions_.empty()) {
+ RETURN_IF_ERROR(sessions_[0]->OnReady());
+ }
+ sessions_ready_ = 0;
+ sessions_unready_ = 0;
+ for (auto* session : sessions_) {
+ sessions_ready_ += session->is_ready() ? 1 : 0;
+ sessions_unready_ += session->is_ready() ? 0 : 1;
+ }
+ rpc::MakeReadyResponseBuilder response(*fbb);
+ return response.Finish();
+}
+
+Status DebugService::UnreadyAllSessions() {
+ for (auto* session : sessions_) {
+ RETURN_IF_ERROR(session->OnUnready());
+ }
+ sessions_ready_ = 0;
+ sessions_unready_ = sessions_.size();
+ return OkStatus();
+}
+
+StatusOr<Offset<rpc::GetStatusResponse>> DebugService::GetStatus(
+ const rpc::GetStatusRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: GetStatus()";
+ rpc::GetStatusResponseBuilder response(*fbb);
+ response.add_protocol(0);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::ListContextsResponse>> DebugService::ListContexts(
+ const rpc::ListContextsRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: ListContexts()";
+ std::vector<Offset<rpc::ContextDef>> context_offs;
+ for (auto* context : contexts_) {
+ std::vector<Offset<rpc::NativeFunctionDef>> native_function_offs_list;
+ for (const auto& pair : context->native_functions()) {
+ auto name_offs = fbb->CreateString(pair.first);
+ rpc::NativeFunctionDefBuilder native_function(*fbb);
+ native_function.add_name(name_offs);
+ native_function_offs_list.push_back(native_function.Finish());
+ }
+ auto native_functions_offs = fbb->CreateVector(native_function_offs_list);
+
+ std::vector<std::string> module_names;
+ for (const auto& module : context->modules()) {
+ module_names.push_back(std::string(module->name()));
+ }
+ auto module_names_offs = fbb->CreateVectorOfStrings(module_names);
+
+ rpc::ContextDefBuilder context_def(*fbb);
+ context_def.add_context_id(context->id());
+ context_def.add_native_functions(native_functions_offs);
+ context_def.add_module_names(module_names_offs);
+ context_offs.push_back(context_def.Finish());
+ }
+
+ auto contexts_offs = fbb->CreateVector(context_offs);
+ rpc::ListContextsResponseBuilder response(*fbb);
+ response.add_contexts(contexts_offs);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::GetModuleResponse>> DebugService::GetModule(
+ const rpc::GetModuleRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: GetModule(" << request.context_id() << ", "
+ << WrapString(request.module_name()) << ")";
+ ASSIGN_OR_RETURN(auto* module, GetModule(request.context_id(),
+ WrapString(request.module_name())));
+ // TODO(benvanik): find a way to do this without possibly duping all memory.
+ // I suspect that when we make constants poolable then there's only one
+ // place to kill and there may be magic we could use to do that during a
+ // reflection pass.
+ ModuleDefT module_t;
+ module->def().UnPackTo(&module_t);
+ for (auto& function : module_t.function_table->functions) {
+ function->bytecode->contents.clear();
+ }
+ auto trimmed_module_offs = ModuleDef::Pack(*fbb, &module_t);
+ rpc::GetModuleResponseBuilder response(*fbb);
+ response.add_module_(trimmed_module_offs);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::GetFunctionResponse>> DebugService::GetFunction(
+ const rpc::GetFunctionRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: GetFunction(" << WrapString(request.module_name()) << ", "
+ << request.function_ordinal() << ")";
+ ASSIGN_OR_RETURN(auto* module, GetModule(request.context_id(),
+ WrapString(request.module_name())));
+ ASSIGN_OR_RETURN(auto& function, module->function_table().LookupFunction(
+ request.function_ordinal()));
+ Offset<BytecodeDef> bytecode_offs;
+ if (function.def().bytecode()) {
+ ASSIGN_OR_RETURN(
+ bytecode_offs,
+ DeepCopyTable("bytecode_def.bfbs", *function.def().bytecode(), fbb));
+ }
+ rpc::GetFunctionResponseBuilder response(*fbb);
+ response.add_bytecode(bytecode_offs);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::ResolveFunctionResponse>> DebugService::ResolveFunction(
+ const rpc::ResolveFunctionRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: ResolveFunction(" << WrapString(request.module_name())
+ << ", " << WrapString(request.function_name()) << ")";
+ std::vector<int32_t> context_ids;
+ auto context_ids_offs = fbb->CreateVector(context_ids);
+ int function_ordinal = -1;
+ for (auto* context : contexts_) {
+ for (const auto& module : context->modules()) {
+ if (module->name() == WrapString(request.module_name())) {
+ ASSIGN_OR_RETURN(function_ordinal,
+ module->function_table().LookupFunctionOrdinalByName(
+ WrapString(request.function_name())));
+ context_ids.push_back(context->id());
+ break;
+ }
+ }
+ }
+ rpc::ResolveFunctionResponseBuilder response(*fbb);
+ response.add_context_ids(context_ids_offs);
+ response.add_function_ordinal(function_ordinal);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::ListInvocationsResponse>> DebugService::ListInvocations(
+ const rpc::ListInvocationsRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: ListInvocations()";
+ std::vector<Offset<rpc::InvocationDef>> invocation_offsets;
+ for (auto* invocation : invocations_) {
+ ASSIGN_OR_RETURN(auto invocation_offs,
+ SerializeInvocation(*invocation, fbb));
+ invocation_offsets.push_back(invocation_offs);
+ }
+ auto invocations_offs = fbb->CreateVector(invocation_offsets);
+ rpc::ListInvocationsResponseBuilder response(*fbb);
+ response.add_invocations(invocations_offs);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::SuspendInvocationsResponse>>
+DebugService::SuspendInvocations(const rpc::SuspendInvocationsRequest& request,
+ FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: SuspendInvocations(invocation_ids=["
+ << (request.invocation_ids()
+ ? absl::StrJoin(*request.invocation_ids(), ", ")
+ : "")
+ << "])";
+ std::vector<Offset<rpc::InvocationDef>> invocation_offsets;
+ if (request.invocation_ids() && request.invocation_ids()->size() > 0) {
+ // Suspending a list of invocations.
+ std::vector<Invocation*> invocations_to_suspend;
+ for (int invocation_id : *request.invocation_ids()) {
+ ASSIGN_OR_RETURN(auto* invocation, GetInvocation(invocation_id));
+ invocations_to_suspend.push_back(invocation);
+ }
+ RETURN_IF_ERROR(
+ SuspendInvocationsAndWait(absl::MakeSpan(invocations_to_suspend)));
+ for (auto* invocation : invocations_to_suspend) {
+ ASSIGN_OR_RETURN(auto invocation_offs,
+ SerializeInvocation(*invocation, fbb));
+ invocation_offsets.push_back(invocation_offs);
+ }
+ } else {
+ // Suspending all invocations.
+ RETURN_IF_ERROR(SuspendAllInvocations());
+ for (auto* invocation : invocations_) {
+ ASSIGN_OR_RETURN(auto invocation_offs,
+ SerializeInvocation(*invocation, fbb));
+ invocation_offsets.push_back(invocation_offs);
+ }
+ }
+ auto invocations_offs = fbb->CreateVector(invocation_offsets);
+ rpc::SuspendInvocationsResponseBuilder response(*fbb);
+ response.add_invocations(invocations_offs);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::ResumeInvocationsResponse>>
+DebugService::ResumeInvocations(const rpc::ResumeInvocationsRequest& request,
+ FlatBufferBuilder* fbb) {
+ VLOG(1) << "RPC: ResumeInvocations(invocation_ids=["
+ << (request.invocation_ids()
+ ? absl::StrJoin(*request.invocation_ids(), ", ")
+ : "")
+ << "])";
+ absl::MutexLock lock(&mutex_);
+ if (request.invocation_ids() && request.invocation_ids()->size() > 0) {
+ // Resuming a list of invocations.
+ for (int invocation_id : *request.invocation_ids()) {
+ ASSIGN_OR_RETURN(auto* invocation, GetInvocation(invocation_id));
+ RETURN_IF_ERROR(invocation->Resume());
+ }
+ } else {
+ // Resuming all invocations.
+ RETURN_IF_ERROR(ResumeAllInvocations());
+ }
+ rpc::ResumeInvocationsResponseBuilder response(*fbb);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::StepInvocationResponse>> DebugService::StepInvocation(
+ const rpc::StepInvocationRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: StepInvocation(" << request.invocation_id() << ")";
+ ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id()));
+ Invocation::StepTarget step_target;
+ // TODO(benvanik): step settings.
+ RETURN_IF_ERROR(invocation->Step(step_target));
+ rpc::StepInvocationResponseBuilder response(*fbb);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::GetInvocationLocalResponse>>
+DebugService::GetInvocationLocal(const rpc::GetInvocationLocalRequest& request,
+ FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: GetInvocationLocal(" << request.invocation_id() << ", "
+ << request.frame_index() << ", " << request.local_index() << ")";
+ ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id()));
+ ASSIGN_OR_RETURN(auto* local,
+ ResolveInvocationLocal(invocation, request.frame_index(),
+ request.local_index()));
+
+ ASSIGN_OR_RETURN(
+ auto value_offs,
+ SerializeBufferView(*local, /*include_buffer_contents=*/true, fbb));
+ rpc::GetInvocationLocalResponseBuilder response(*fbb);
+ response.add_value(value_offs);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::SetInvocationLocalResponse>>
+DebugService::SetInvocationLocal(const rpc::SetInvocationLocalRequest& request,
+ FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: SetInvocationLocal(" << request.invocation_id() << ", "
+ << request.frame_index() << ", " << request.local_index() << ")";
+ ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id()));
+ ASSIGN_OR_RETURN(auto* local,
+ ResolveInvocationLocal(invocation, request.frame_index(),
+ request.local_index()));
+
+ if (!request.value()) {
+ local->shape.clear();
+ local->element_size = 0;
+ local->buffer.reset();
+ } else {
+ const auto& value = *request.value();
+ local->shape.clear();
+ if (value.shape()) {
+ for (int dim : *value.shape()) {
+ local->shape.push_back(dim);
+ }
+ }
+ local->element_size = value.element_size();
+ // TODO(benvanik): copy buffer data.
+ }
+
+ ASSIGN_OR_RETURN(
+ auto value_offs,
+ SerializeBufferView(*local, /*include_buffer_contents=*/true, fbb));
+ rpc::SetInvocationLocalResponseBuilder response(*fbb);
+ response.add_value(value_offs);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::ListBreakpointsResponse>> DebugService::ListBreakpoints(
+ const rpc::ListBreakpointsRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: ListBreakpoints()";
+ std::vector<Offset<rpc::BreakpointDef>> breakpoint_offs;
+ for (const auto& breakpoint : breakpoints_) {
+ breakpoint_offs.push_back(rpc::BreakpointDef::Pack(*fbb, &breakpoint));
+ }
+ auto breakpoints_offs = fbb->CreateVector(breakpoint_offs);
+ rpc::ListBreakpointsResponseBuilder response(*fbb);
+ response.add_breakpoints(breakpoints_offs);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::AddBreakpointResponse>> DebugService::AddBreakpoint(
+ const rpc::AddBreakpointRequest& request, FlatBufferBuilder* fbb) {
+ if (!request.breakpoint()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "No breakpoint specified";
+ }
+ absl::MutexLock lock(&mutex_);
+ int breakpoint_id = NextUniqueBreakpointId();
+ VLOG(1) << "RPC: AddBreakpoint(" << breakpoint_id << ")";
+
+ RETURN_IF_ERROR(SuspendAllInvocations());
+
+ rpc::BreakpointDefT breakpoint;
+ request.breakpoint()->UnPackTo(&breakpoint);
+ breakpoint.breakpoint_id = breakpoint_id;
+ switch (breakpoint.breakpoint_type) {
+ case rpc::BreakpointType::BYTECODE_FUNCTION:
+ case rpc::BreakpointType::NATIVE_FUNCTION:
+ for (auto* context : contexts_) {
+ auto module_or = context->LookupModule(breakpoint.module_name);
+ if (!module_or.ok()) continue;
+ auto* module = module_or.ValueOrDie();
+ RETURN_IF_ERROR(
+ RegisterFunctionBreakpoint(context, module, &breakpoint));
+ }
+ break;
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unhandled breakpoint type";
+ }
+ breakpoints_.push_back(std::move(breakpoint));
+
+ RETURN_IF_ERROR(ResumeAllInvocations());
+
+ auto breakpoint_offs = rpc::BreakpointDef::Pack(*fbb, &breakpoints_.back());
+ rpc::AddBreakpointResponseBuilder response(*fbb);
+ response.add_breakpoint(breakpoint_offs);
+ return response.Finish();
+}
+
+StatusOr<Offset<rpc::RemoveBreakpointResponse>> DebugService::RemoveBreakpoint(
+ const rpc::RemoveBreakpointRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: RemoveBreakpoint(" << request.breakpoint_id() << ")";
+ RETURN_IF_ERROR(SuspendAllInvocations());
+
+ bool found = false;
+ for (auto it = breakpoints_.begin(); it != breakpoints_.end(); ++it) {
+ if (it->breakpoint_id == request.breakpoint_id()) {
+ auto& breakpoint = *it;
+ found = true;
+ switch (breakpoint.breakpoint_type) {
+ case rpc::BreakpointType::BYTECODE_FUNCTION:
+ case rpc::BreakpointType::NATIVE_FUNCTION:
+ RETURN_IF_ERROR(UnregisterFunctionBreakpoint(breakpoint));
+ break;
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unhandled breakpoint type";
+ }
+ breakpoints_.erase(it);
+ break;
+ }
+ }
+
+ RETURN_IF_ERROR(ResumeAllInvocations());
+ if (!found) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Breakpoint ID " << request.breakpoint_id() << " not found";
+ }
+
+ rpc::RemoveBreakpointResponseBuilder response(*fbb);
+ return response.Finish();
+}
+
+Status DebugService::RegisterModuleBreakpoints(Context* context,
+ Module* module) {
+ for (auto& breakpoint : breakpoints_) {
+ switch (breakpoint.breakpoint_type) {
+ case rpc::BreakpointType::BYTECODE_FUNCTION:
+ if (breakpoint.module_name == module->name()) {
+ RETURN_IF_ERROR(
+ RegisterFunctionBreakpoint(context, module, &breakpoint));
+ }
+ break;
+ default:
+ // Not relevant to modules.
+ break;
+ }
+ }
+ return OkStatus();
+}
+
+Status DebugService::RegisterFunctionBreakpoint(
+ Context* context, Module* module, rpc::BreakpointDefT* breakpoint) {
+ if (!breakpoint->function_name.empty()) {
+ ASSIGN_OR_RETURN(breakpoint->function_ordinal,
+ module->function_table().LookupFunctionOrdinalByName(
+ breakpoint->function_name));
+ }
+ RETURN_IF_ERROR(module->mutable_function_table()->RegisterBreakpoint(
+ breakpoint->function_ordinal, breakpoint->bytecode_offset,
+ std::bind(&DebugService::OnFunctionBreakpointHit, this,
+ breakpoint->breakpoint_id, std::placeholders::_1)));
+ for (auto* session : sessions_) {
+ RETURN_IF_ERROR(session->OnBreakpointResolved(*breakpoint, context));
+ }
+ return OkStatus();
+}
+
+Status DebugService::UnregisterFunctionBreakpoint(
+ const rpc::BreakpointDefT& breakpoint) {
+ for (auto* context : contexts_) {
+ auto module_or = context->LookupModule(breakpoint.module_name);
+ if (!module_or.ok()) continue;
+ auto* module = module_or.ValueOrDie();
+ RETURN_IF_ERROR(module->mutable_function_table()->UnregisterBreakpoint(
+ breakpoint.function_ordinal, breakpoint.bytecode_offset));
+ }
+ return OkStatus();
+}
+
+Status DebugService::OnFunctionBreakpointHit(int breakpoint_id,
+ const Invocation& invocation) {
+ absl::ReleasableMutexLock lock(&mutex_);
+ LOG(INFO) << "Breakpoint hit: " << breakpoint_id;
+ RETURN_IF_ERROR(UnreadyAllSessions());
+ for (auto* session : sessions_) {
+ RETURN_IF_ERROR(session->OnBreakpointHit(breakpoint_id, invocation));
+ }
+ lock.Release();
+
+ // TODO(benvanik): on-demand attach if desired?
+
+ // Wait until all clients are ready.
+ auto wait_status = WaitUntilAllSessionsReady();
+ if (IsAborted(wait_status)) {
+ // This means we lost all sessions. Just continue.
+ VLOG(1) << "No sessions active; ignoring breakpoint and continuing";
+ return OkStatus();
+ }
+ return wait_status;
+}
+
+StatusOr<Offset<rpc::StartProfilingResponse>> DebugService::StartProfiling(
+ const rpc::StartProfilingRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: StartProfiling()";
+ // TODO(benvanik): implement profiling.
+ // ASSIGN_OR_RETURN(auto* context, GetContext(request.context_id()));
+ // rpc::StartProfilingResponseBuilder response(*fbb);
+ // return response.Finish();
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "StartProfiling not yet implemented";
+}
+
+StatusOr<Offset<rpc::StopProfilingResponse>> DebugService::StopProfiling(
+ const rpc::StopProfilingRequest& request, FlatBufferBuilder* fbb) {
+ absl::MutexLock lock(&mutex_);
+ VLOG(1) << "RPC: StopProfiling()";
+ // TODO(benvanik): implement profiling.
+ // ASSIGN_OR_RETURN(auto* context, GetContext(request.context_id()));
+ // rpc::StopProfilingResponseBuilder response(*fbb);
+ // return response.Finish();
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "StopProfiling not yet implemented";
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/rt/debug/debug_service.h b/rt/debug/debug_service.h
new file mode 100644
index 0000000..54a4744
--- /dev/null
+++ b/rt/debug/debug_service.h
@@ -0,0 +1,175 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_DEBUG_DEBUG_SERVICE_H_
+#define IREE_RT_DEBUG_DEBUG_SERVICE_H_
+
+#include <vector>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "base/status.h"
+#include "flatbuffers/flatbuffers.h"
+#include "rt/context.h"
+#include "rt/debug/debug_session.h"
+#include "schemas/debug_service_generated.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+// Debugging service used to implement the DebugService RPC methods in a
+// transport-independent way. Specific DebugServer implementations can compose
+// with a DebugService to avoid needing to maintain state themselves. Multiple
+// DebugServer instances could share the same DebugService instance to ensure
+// all clients - regardless of transport - share the same state.
+//
+// Thread-safe.
+class DebugService {
+ public:
+ // Registers a context with the debug service.
+ // Ownership remains with the caller and UnregisterContext must be called
+ // prior to the context being destroyed.
+ Status RegisterContext(Context* context);
+ Status UnregisterContext(Context* context);
+
+ // Registers a new module linked into an existing Context.
+ Status RegisterContextModule(Context* context, Module* module);
+
+ // Registers a invocation state with the debug service.
+ // Ownership remains with the caller and UnregisterInvocation must be called
+ // prior to the invocation state being destroyed.
+ Status RegisterInvocation(Invocation* invocation);
+ Status UnregisterInvocation(Invocation* invocation);
+
+ // Registers a debug session with the service.
+ Status RegisterDebugSession(DebugSession* session);
+ Status UnregisterDebugSession(DebugSession* session);
+
+ // Blocks the caller until all sessions are ready.
+ // Returns AbortedError if a session connects/is already connected but
+ // disconnects during the wait.
+ Status WaitUntilAllSessionsReady();
+
+ StatusOr<::flatbuffers::Offset<rpc::MakeReadyResponse>> MakeReady(
+ const rpc::MakeReadyRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+
+ StatusOr<::flatbuffers::Offset<rpc::GetStatusResponse>> GetStatus(
+ const rpc::GetStatusRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+
+ StatusOr<::flatbuffers::Offset<rpc::ListContextsResponse>> ListContexts(
+ const rpc::ListContextsRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+
+ StatusOr<::flatbuffers::Offset<rpc::GetModuleResponse>> GetModule(
+ const rpc::GetModuleRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+ StatusOr<::flatbuffers::Offset<rpc::GetFunctionResponse>> GetFunction(
+ const rpc::GetFunctionRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+ StatusOr<::flatbuffers::Offset<rpc::ResolveFunctionResponse>> ResolveFunction(
+ const rpc::ResolveFunctionRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+
+ StatusOr<::flatbuffers::Offset<rpc::ListInvocationsResponse>> ListInvocations(
+ const rpc::ListInvocationsRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+ StatusOr<::flatbuffers::Offset<rpc::SuspendInvocationsResponse>>
+ SuspendInvocations(const rpc::SuspendInvocationsRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+ StatusOr<::flatbuffers::Offset<rpc::ResumeInvocationsResponse>>
+ ResumeInvocations(const rpc::ResumeInvocationsRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+ StatusOr<::flatbuffers::Offset<rpc::StepInvocationResponse>> StepInvocation(
+ const rpc::StepInvocationRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+ StatusOr<::flatbuffers::Offset<rpc::GetInvocationLocalResponse>>
+ GetInvocationLocal(const rpc::GetInvocationLocalRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+ StatusOr<::flatbuffers::Offset<rpc::SetInvocationLocalResponse>>
+ SetInvocationLocal(const rpc::SetInvocationLocalRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+
+ StatusOr<::flatbuffers::Offset<rpc::ListBreakpointsResponse>> ListBreakpoints(
+ const rpc::ListBreakpointsRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+ StatusOr<::flatbuffers::Offset<rpc::AddBreakpointResponse>> AddBreakpoint(
+ const rpc::AddBreakpointRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+ StatusOr<::flatbuffers::Offset<rpc::RemoveBreakpointResponse>>
+ RemoveBreakpoint(const rpc::RemoveBreakpointRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+
+ StatusOr<::flatbuffers::Offset<rpc::StartProfilingResponse>> StartProfiling(
+ const rpc::StartProfilingRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+ StatusOr<::flatbuffers::Offset<rpc::StopProfilingResponse>> StopProfiling(
+ const rpc::StopProfilingRequest& request,
+ ::flatbuffers::FlatBufferBuilder* fbb);
+
+ // Serializes an invocation and its stack frames.
+ StatusOr<::flatbuffers::Offset<rpc::InvocationDef>> SerializeInvocation(
+ const Invocation& invocation, ::flatbuffers::FlatBufferBuilder* fbb);
+
+ private:
+ StatusOr<Context*> GetContext(int context_id) const
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ StatusOr<Module*> GetModule(int context_id,
+ absl::string_view module_name) const
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ StatusOr<Invocation*> GetInvocation(int invocation_id) const
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Suspends all invocations on all contexts. Returns only once all invocations
+ // have been suspended successfully. Fails if any invocation fails to suspend.
+ Status SuspendAllInvocations() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Resumes all invocations on all contexts (the inverse of
+ // SuspendAllInvocations). Returns immediately.
+ Status ResumeAllInvocations() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Marks all sessions as unready.
+ Status UnreadyAllSessions() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Attempts to re-register all breakpoints for a module.
+ Status RegisterModuleBreakpoints(Context* context, Module* module)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ Status RegisterFunctionBreakpoint(Context* context, Module* module,
+ rpc::BreakpointDefT* breakpoint)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ Status UnregisterFunctionBreakpoint(const rpc::BreakpointDefT& breakpoint)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ // Signals that the given breakpoint was hit by the specified invocation.
+ // Called without the debug lock held.
+ Status OnFunctionBreakpointHit(int breakpoint_id,
+ const Invocation& invocation);
+
+ absl::Mutex mutex_;
+ std::vector<Context*> contexts_ ABSL_GUARDED_BY(mutex_);
+ std::vector<Invocation*> invocations_ ABSL_GUARDED_BY(mutex_);
+ std::vector<DebugSession*> sessions_ ABSL_GUARDED_BY(mutex_);
+ int sessions_unready_ ABSL_GUARDED_BY(mutex_) = 0;
+ int sessions_ready_ ABSL_GUARDED_BY(mutex_) = 0;
+
+ std::vector<rpc::BreakpointDefT> breakpoints_ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_DEBUG_DEBUG_SERVICE_H_
diff --git a/rt/debug/debug_session.cc b/rt/debug/debug_session.cc
new file mode 100644
index 0000000..43d8197
--- /dev/null
+++ b/rt/debug/debug_session.cc
@@ -0,0 +1,49 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/debug/debug_session.h"
+
+#include "base/source_location.h"
+#include "base/status.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+bool DebugSession::is_ready() const {
+ absl::MutexLock lock(&mutex_);
+ return ready_ == 0;
+}
+
+Status DebugSession::OnReady() {
+ absl::MutexLock lock(&mutex_);
+ if (ready_ > 0) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Session has already readied up";
+ }
+ ++ready_;
+ VLOG(2) << "Session " << id() << ": ++ready = " << ready_;
+ return OkStatus();
+}
+
+Status DebugSession::OnUnready() {
+ absl::MutexLock lock(&mutex_);
+ --ready_;
+ VLOG(2) << "Session " << id() << ": --ready = " << ready_;
+ return OkStatus();
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/rt/debug/debug_session.h b/rt/debug/debug_session.h
new file mode 100644
index 0000000..3cdf903
--- /dev/null
+++ b/rt/debug/debug_session.h
@@ -0,0 +1,93 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_DEBUG_DEBUG_SESSION_H_
+#define IREE_RT_DEBUG_DEBUG_SESSION_H_
+
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#include "base/status.h"
+#include "rt/context.h"
+#include "rt/invocation.h"
+#include "rt/module.h"
+#include "schemas/debug_service_generated.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+// An active debugging session maintained by the DebugService.
+// Each connected client gets a session and transport-specific implementations
+// use the event methods to receive signals from the service.
+//
+// All methods are called only while the debug lock is held and may be called
+// from any thread.
+class DebugSession {
+ public:
+ virtual ~DebugSession() = default;
+
+ // Session ID used in all RPCs related to this session.
+ // This can be used for attributing RPCs to the originating session when
+ // multiple sessions may be active at a time/over the same transport.
+ int id() const { return session_id_; }
+
+ // Returns true if the session has issued a MakeReady request and is ok if
+ // execution resumes.
+ bool is_ready() const;
+
+ // Signals that the session has readied up and is now active.
+ // Called with the global debug lock held.
+ virtual Status OnReady();
+
+ // Signals that the session has gone unready (from an event/etc) and the
+ // service is now awaiting it to ready up.
+ // Called with the global debug lock held.
+ virtual Status OnUnready();
+
+ // Signals that a context has been registered.
+ // Called with the global debug lock held.
+ virtual Status OnContextRegistered(Context* context) = 0;
+ virtual Status OnContextUnregistered(Context* context) = 0;
+
+ // Signals that a module has been loaded in a context.
+ // Called with the global debug lock held.
+ virtual Status OnModuleLoaded(Context* context, Module* module) = 0;
+
+ // Signals that a invocation has been registered.
+ // Called with the global debug lock held.
+ virtual Status OnInvocationRegistered(Invocation* invocation) = 0;
+ virtual Status OnInvocationUnregistered(Invocation* invocation) = 0;
+
+ // Signals that a breakpoint has been resolved to a particular function in a
+ // context.
+ // Called with the global debug lock held.
+ virtual Status OnBreakpointResolved(const rpc::BreakpointDefT& breakpoint,
+ Context* context) = 0;
+
+ // Signals that the given breakpoint has been hit during execution.
+ // Called with the global debug lock held.
+ virtual Status OnBreakpointHit(int breakpoint_id,
+ const Invocation& invocation) = 0;
+
+ private:
+ mutable absl::Mutex mutex_;
+ int session_id_ = 0;
+ int ready_ ABSL_GUARDED_BY(mutex_) = -1;
+};
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_DEBUG_DEBUG_SESSION_H_
diff --git a/rt/debug/debug_tcp_util.h b/rt/debug/debug_tcp_util.h
new file mode 100644
index 0000000..67b6e31
--- /dev/null
+++ b/rt/debug/debug_tcp_util.h
@@ -0,0 +1,217 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Utilities for working with TCP sockets.
+// These are (mostly) portable to systems implementing BSD sockets.
+
+#ifndef IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_
+#define IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_
+
+#include <fcntl.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include <cstddef>
+
+#include "base/status.h"
+#include "flatbuffers/base.h"
+#include "flatbuffers/flatbuffers.h"
+#include "schemas/debug_service_generated.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+namespace tcp {
+
+// Toggles address reuse on a socket. Call prior to binding.
+// This is useful if a socket is sitting in close_wait from a previous process
+// while a new one is trying to bind to it.
+inline Status ToggleSocketAddressReuse(int fd, bool is_enabled) {
+ int toggle = is_enabled ? 1 : 0;
+ ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &toggle, sizeof(toggle));
+ return OkStatus();
+}
+
+// Toggles the linger option on a socket.
+// Enabling linger will ensure all data on the socket is sent (if it can be
+// sent within N sec) prior to closing. Disabling linger will cause the socket
+// to close gracefully.
+inline Status ToggleSocketLinger(int fd, bool is_enabled) {
+ struct linger linger;
+ linger.l_onoff = is_enabled ? 1 : 0;
+ linger.l_linger = 1;
+ ::setsockopt(fd, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger));
+ return OkStatus();
+}
+
+// Toggles Nagel's algorithm on a socket.
+// Enabled by default, sockets have ~250ms delay for small packets. Disabling
+// the algorithm will make socket flushes actually send data.
+inline Status ToggleSocketNagelsAlgorithm(int fd, bool is_enabled) {
+ int toggle = is_enabled ? 1 : 0;
+ ::setsockopt(fd, SOL_TCP, TCP_NODELAY, &toggle, sizeof(toggle));
+ return OkStatus();
+}
+
+// Toggles TCP keepalive on a socket.
+// Assumes that the remote side is on the local machine/network and that we can
+// spam it with packets.
+//
+// NOTE: we may want to adjust this when real debuggers are attached (to prevent
+// dropping our own connections). Need to figure out how to reliably detect
+// debug suspends vs. actual death.
+inline Status ToggleSocketLocalKeepalive(int fd, bool is_enabled) {
+ // Toggle keepalive.
+ int keepalive_enable = is_enabled ? 1 : 0;
+ ::setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &keepalive_enable,
+ sizeof(keepalive_enable));
+ // Begin sending keepalive probes after N sec.
+ int keepalive_idle_delay = 3;
+ ::setsockopt(fd, SOL_TCP, TCP_KEEPIDLE, &keepalive_idle_delay,
+ sizeof(keepalive_idle_delay));
+ // Try one probe and bail (faster detection).
+ int keepalive_retry_count = 1;
+ ::setsockopt(fd, SOL_TCP, TCP_KEEPINTVL, &keepalive_retry_count,
+ sizeof(keepalive_retry_count));
+ // Send keepalives every N sec.
+ int keepalive_interval = 1;
+ ::setsockopt(fd, SOL_TCP, TCP_KEEPINTVL, &keepalive_interval,
+ sizeof(keepalive_interval));
+ return OkStatus();
+}
+
+// Toggles the blocking state of a socket.
+// If a socket has been set to non-blocking methods like read and write will
+// return EWOULDBLOCK if they would have blocked on the specific operation.
+inline Status ToggleSocketBlocking(int fd, bool is_blocking) {
+ if (is_blocking) {
+ ::fcntl(fd, F_SETFL, ::fcntl(fd, F_GETFL) & ~O_NONBLOCK);
+ } else {
+ ::fcntl(fd, F_SETFL, ::fcntl(fd, F_GETFL) | O_NONBLOCK);
+ }
+ return OkStatus();
+}
+
+// RAII wrapper for messages containing flatbuffer roots of type T.
+template <typename T>
+struct MessageBuffer {
+ public:
+ explicit MessageBuffer(std::vector<uint8_t> buffer)
+ : buffer_(std::move(buffer)) {}
+ MessageBuffer(const MessageBuffer&) = delete;
+ MessageBuffer& operator=(const MessageBuffer&) = delete;
+ MessageBuffer(MessageBuffer&&) = default;
+ MessageBuffer& operator=(MessageBuffer&&) = default;
+
+ const T& GetRoot() const {
+ return *::flatbuffers::GetRoot<T>(buffer_.data());
+ }
+
+ private:
+ std::vector<uint8_t> buffer_;
+};
+
+// Reads a size prefix value from the given fd.
+// If |poll_only| is true then the size prefix is not consumed from the stream
+// and the call will return 0 if there is no size prefix available.
+// Returns CancelledError if a (probably) graceful close is detected.
+inline StatusOr<size_t> ReadSizePrefix(int fd, bool poll_only) {
+ ::flatbuffers::uoffset_t size_prefix = 0;
+ int read_bytes = ::recv(fd, &size_prefix, sizeof(size_prefix),
+ poll_only ? (MSG_PEEK | MSG_DONTWAIT) : 0);
+ if (read_bytes == 0) {
+ // Remote side disconnected.
+ return CancelledErrorBuilder(IREE_LOC) << "Graceful remote close";
+ } else if (read_bytes < 0) {
+ if (errno == ECONNRESET) {
+ return CancelledErrorBuilder(IREE_LOC) << "Ungraceful remote close";
+ }
+ return DataLossErrorBuilder(IREE_LOC)
+ << "Failed to read size prefix from socket: (" << errno << ") "
+ << ::strerror(errno);
+ } else if (read_bytes != sizeof(size_prefix)) {
+ if (poll_only) {
+ // No data available.
+ return 0;
+ } else {
+ return DataLossErrorBuilder(IREE_LOC)
+ << "Failed to read full size prefix (got " << read_bytes << "b of "
+ << sizeof(size_prefix) << "b expected)";
+ }
+ }
+ return size_prefix;
+}
+
+// Returns true if ReadBuffer will (likely) not block when called.
+// Returns CancelledError if a (probably) graceful close is detected.
+inline StatusOr<bool> CanReadBuffer(int fd) {
+ ASSIGN_OR_RETURN(size_t size_prefix, ReadSizePrefix(fd, /*poll_only=*/true));
+ return size_prefix != 0;
+}
+
+// Reads a size-prefixed message from the given fd.
+// This will block until the entire message contents are available.
+// Returns a buffer reference that will deallocate the buffer automatically or
+// CancelledError if a (probably) graceful close is detected.
+template <typename T>
+StatusOr<MessageBuffer<T>> ReadBuffer(int fd) {
+ // Read the size prefix (written as a uoffset_t by the Write* methods).
+ ASSIGN_OR_RETURN(size_t size_prefix, ReadSizePrefix(fd, /*poll_only=*/false));
+
+ // Allocate the buffer for the entire message.
+ // We'll use the BufferRef to free() it when it's no longer required.
+ std::vector<uint8_t> buffer(size_prefix);
+
+ // Read the entire message contents.
+ int full_read_bytes = ::recv(fd, buffer.data(), buffer.size(), 0);
+ if (full_read_bytes < 0) {
+ return DataLossErrorBuilder(IREE_LOC)
+ << "Failed to read full message contents from socket: (" << errno
+ << ") " << ::strerror(errno);
+ } else if (full_read_bytes != buffer.size()) {
+ return DataLossErrorBuilder(IREE_LOC)
+ << "Failed to read full message contents (got " << full_read_bytes
+ << "b of " << buffer.size() << "b expected)";
+ }
+
+ // Verify the contents. Not strictly required (as we won't ever ship this to
+ // prod), but useful in ensuring our socket code isn't corrupting things.
+ ::flatbuffers::Verifier verifier(buffer.data(), buffer.size());
+ if (!verifier.VerifyBuffer<T>()) {
+ return DataLossErrorBuilder(IREE_LOC)
+ << "Verification of input buffer of type " << typeid(T).name()
+ << " (" << buffer.size() << "b) failed";
+ }
+
+ // Wrap the buffer to get some RAII goodness.
+ return MessageBuffer<T>(std::move(buffer));
+}
+
+// Writes a buffer to the given fd.
+inline Status WriteBuffer(int fd, ::flatbuffers::DetachedBuffer buffer) {
+ if (::send(fd, buffer.data(), buffer.size(), 0) < 0) {
+ return UnavailableErrorBuilder(IREE_LOC)
+ << "Write failed: (" << errno << ") " << ::strerror(errno);
+ }
+ return OkStatus();
+}
+
+} // namespace tcp
+} // namespace debug
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_
diff --git a/rt/disassembler.h b/rt/disassembler.h
new file mode 100644
index 0000000..cb3e0de
--- /dev/null
+++ b/rt/disassembler.h
@@ -0,0 +1,64 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_DISASSEMBLER_H_
+#define IREE_RT_DISASSEMBLER_H_
+
+#include <cstdint>
+#include <ostream>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "base/status.h"
+#include "rt/function.h"
+#include "rt/source_location.h"
+
+namespace iree {
+namespace rt {
+
+// A single disassembled instruction.
+struct Instruction {
+ // Offset of the instruction within the function.
+ // The meaning of this is backend-dependent.
+ SourceOffset offset;
+
+ // The first line of |long_text|.
+ absl::string_view short_text;
+
+ // Human-readable text of the instruction. May contain multiple lines.
+ std::string long_text;
+};
+
+// Disassembles functions into instructions.
+//
+// Thread-safe.
+class Disassembler {
+ public:
+ virtual ~Disassembler() = default;
+
+ // Disassembles one or more instructions within the given function based on
+ // source offsets.
+ virtual StatusOr<std::vector<Instruction>> DisassembleInstructions(
+ const Function& function, SourceOffset offset,
+ int32_t instruction_count = INT32_MAX) const = 0;
+
+ protected:
+ Disassembler() = default;
+};
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_DISASSEMBLER_H_
diff --git a/rt/function.cc b/rt/function.cc
new file mode 100644
index 0000000..b3ec339
--- /dev/null
+++ b/rt/function.cc
@@ -0,0 +1,33 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/function.h"
+
+#include "rt/module.h"
+
+namespace iree {
+namespace rt {
+
+absl::string_view Function::name() const {
+ auto result_or = module_->GetFunctionName(linkage_, ordinal_);
+ return result_or.ok() ? result_or.ValueOrDie() : absl::string_view();
+}
+
+const FunctionSignature Function::signature() const {
+ auto result_or = module_->GetFunctionSignature(linkage_, ordinal_);
+ return result_or.ok() ? result_or.ValueOrDie() : FunctionSignature();
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/rt/function.h b/rt/function.h
new file mode 100644
index 0000000..9c3a2ca
--- /dev/null
+++ b/rt/function.h
@@ -0,0 +1,84 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_FUNCTION_H_
+#define IREE_RT_FUNCTION_H_
+
+#include <cstdint>
+#include <ostream>
+
+#include "absl/strings/string_view.h"
+#include "rt/function_signature.h"
+
+namespace iree {
+namespace rt {
+
+class Module;
+
+// Reference to a function within a module.
+// Functions are either visible or hidden from the module interface and may be
+// of one Linkage type. Imports and exports are always visible (as they are
+// required for dynamic linking) however functions with internal linkage may be
+// hidden in optimized builds to reduce the amount of reflection metadata
+// required.
+class Function final {
+ public:
+ enum class Linkage {
+ // Function is internal to the module and may not be reflectable.
+ kInternal = 0,
+ // Function is an import from another module.
+ kImport = 1,
+ // Function is an export from the module.
+ kExport = 2,
+ };
+
+ Function() = default;
+ Function(const Module* module, Linkage linkage, int32_t ordinal)
+ : module_(module), linkage_(linkage), ordinal_(ordinal) {}
+
+ // Module the function is contained within.
+ const Module* module() const { return module_; }
+
+ // Linkage of the function. Note that Linkage::kInternal functions may be
+ // missing reflection information.
+ Linkage linkage() const { return linkage_; }
+
+ // Ordinal within the module in the linkage scope.
+ int32_t ordinal() const { return ordinal_; }
+
+ // Returns the original name of the function.
+ // Internal functions may return empty if debugging info has been stripped.
+ absl::string_view name() const;
+
+ // Returns the signature of the function.
+ // Always present for imports and exports but may be empty for internal
+ // functions if debugging info has been stripped.
+ const FunctionSignature signature() const;
+
+ private:
+ const Module* module_ = nullptr;
+ Linkage linkage_ = Linkage::kInternal;
+ int32_t ordinal_ = -1;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const Function& function) {
+ stream << '@' << function.name() << '#' << function.ordinal();
+ return stream;
+}
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_FUNCTION_H_
diff --git a/iree/rt/function_signature.h b/rt/function_signature.h
similarity index 100%
rename from iree/rt/function_signature.h
rename to rt/function_signature.h
diff --git a/rt/instance.cc b/rt/instance.cc
new file mode 100644
index 0000000..3ad8b35
--- /dev/null
+++ b/rt/instance.cc
@@ -0,0 +1,51 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/instance.h"
+
+#include "base/tracing.h"
+#include "rt/debug/debug_server.h"
+
+namespace iree {
+namespace rt {
+
+Instance::Instance(std::unique_ptr<debug::DebugServer> debug_server)
+ : debug_server_(std::move(debug_server)) {
+ IREE_TRACE_SCOPE0("Instance::ctor");
+}
+
+Instance::~Instance() { IREE_TRACE_SCOPE0("Instance::dtor"); }
+
+void Instance::RegisterContext(Context* context) {
+ {
+ absl::MutexLock lock(&contexts_mutex_);
+ contexts_.push_back(context);
+ }
+ if (debug_server_) {
+ CHECK_OK(debug_server_->RegisterContext(context));
+ }
+}
+
+void Instance::UnregisterContext(Context* context) {
+ if (debug_server_) {
+ CHECK_OK(debug_server_->UnregisterContext(context));
+ }
+ {
+ absl::MutexLock lock(&contexts_mutex_);
+ contexts_.erase(context);
+ }
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/rt/instance.h b/rt/instance.h
new file mode 100644
index 0000000..6578ca4
--- /dev/null
+++ b/rt/instance.h
@@ -0,0 +1,70 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_INSTANCE_H_
+#define IREE_RT_INSTANCE_H_
+
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#include "base/intrusive_list.h"
+#include "base/ref_ptr.h"
+#include "hal/device_manager.h"
+#include "rt/context.h"
+#include "rt/debug/debug_server.h"
+
+namespace iree {
+namespace rt {
+
+// Shared runtime instance responsible for routing Context events, enumerating
+// and creating hardware device interfaces, and managing thread pools.
+//
+// A single runtime instance can service multiple contexts and hosting
+// applications should try to reuse instances as much as possible. This ensures
+// that resource allocation across contexts is handled and extraneous device
+// interaction is avoided. For devices that may have exclusive access
+// restrictions it is mandatory to share instances, so plan accordingly.
+//
+// Thread-safe.
+class Instance final : public RefObject<Instance> {
+ public:
+ // Creates an instance with an optional attached |debug_server|.
+ Instance() : Instance(nullptr) {}
+ explicit Instance(std::unique_ptr<debug::DebugServer> debug_server);
+ ~Instance();
+ Instance(const Instance&) = delete;
+ Instance& operator=(const Instance&) = delete;
+
+ // Optional debug server that has access to contexts in this instance.
+ debug::DebugServer* debug_server() const { return debug_server_.get(); }
+
+ // Device manager used to enumerate available devices.
+ hal::DeviceManager* device_manager() const { return &device_manager_; }
+
+ private:
+ friend class Context;
+ void RegisterContext(Context* context);
+ void UnregisterContext(Context* context);
+
+ std::unique_ptr<debug::DebugServer> debug_server_;
+ mutable hal::DeviceManager device_manager_;
+
+ absl::Mutex contexts_mutex_;
+ IntrusiveList<Context, offsetof(Context, instance_list_link_)> contexts_
+ ABSL_GUARDED_BY(contexts_mutex_);
+};
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_INSTANCE_H_
diff --git a/rt/invocation.cc b/rt/invocation.cc
new file mode 100644
index 0000000..004c7bf
--- /dev/null
+++ b/rt/invocation.cc
@@ -0,0 +1,185 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/invocation.h"
+
+#include <atomic>
+#include <iterator>
+
+#include "absl/strings/str_cat.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "rt/context.h"
+
+namespace iree {
+namespace rt {
+
+namespace {
+
+int32_t NextUniqueInvocationId() {
+ static std::atomic<int32_t> next_id = {0};
+ return ++next_id;
+}
+
+} // namespace
+
+// static
+StatusOr<ref_ptr<Invocation>> Invocation::Create(
+ ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy,
+ absl::InlinedVector<ref_ptr<Invocation>, 4> dependencies,
+ absl::InlinedVector<hal::BufferView, 8> arguments,
+ absl::optional<absl::InlinedVector<hal::BufferView, 8>> results) {
+ IREE_TRACE_SCOPE0("Invocation::Create");
+
+ const auto& signature = function.signature();
+ if (arguments.size() != signature.argument_count()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Argument count mismatch; expected " << signature.argument_count()
+ << " but received " << arguments.size();
+ } else if (results.has_value() &&
+ results.value().size() != signature.result_count()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Result count mismatch; expected " << signature.result_count()
+ << " but received " << results.value().size();
+ }
+
+ absl::InlinedVector<hal::BufferView, 8> results_value;
+ if (results.has_value()) {
+ results_value = std::move(results.value());
+ } else {
+ results_value.resize(signature.result_count());
+ }
+
+ auto invocation = assign_ref(
+ new Invocation(std::move(context), function, std::move(policy)));
+
+ // TODO(benvanik): grab execution state, insert deps, etc.
+ if (!dependencies.empty()) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Dependencies are not yet implemented";
+ }
+
+ // TODO(benvanik): fiber scheduling and such.
+ auto execute_status = function.module()->Execute(
+ &invocation->stack_, function, std::move(arguments), &results_value);
+ if (execute_status.ok()) {
+ invocation->CompleteSuccess(std::move(results_value));
+ } else {
+ invocation->CompleteFailure(std::move(execute_status), nullptr);
+ }
+
+ return invocation;
+}
+
+// static
+StatusOr<ref_ptr<Invocation>> Invocation::Create(
+ ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy,
+ absl::Span<const ref_ptr<Invocation>> dependencies,
+ absl::Span<const hal::BufferView> arguments) {
+ absl::InlinedVector<ref_ptr<Invocation>, 4> dependency_list;
+ dependency_list.reserve(dependencies.size());
+ for (auto& dependency : dependencies) {
+ dependency_list.push_back(add_ref(dependency));
+ }
+ absl::InlinedVector<hal::BufferView, 8> argument_list;
+ argument_list.reserve(arguments.size());
+ for (auto& buffer_view : arguments) {
+ argument_list.push_back(buffer_view);
+ }
+ return Invocation::Create(std::move(context), function, std::move(policy),
+ std::move(dependency_list),
+ std::move(argument_list));
+}
+
+Invocation::Invocation(ref_ptr<Context> context, const Function function,
+ ref_ptr<Policy> policy)
+ : id_(NextUniqueInvocationId()),
+ context_(std::move(context)),
+ function_(function),
+ policy_(std::move(policy)),
+ stack_(context_.get()) {
+ IREE_TRACE_SCOPE0("Invocation::ctor");
+ context_->RegisterInvocation(this);
+}
+
+Invocation::~Invocation() {
+ IREE_TRACE_SCOPE0("Invocation::dtor");
+ context_->UnregisterInvocation(this);
+}
+
+std::string Invocation::DebugStringShort() const {
+ return absl::StrCat("invocation_", id_);
+}
+
+std::string Invocation::DebugString() const { return DebugStringShort(); }
+
+Status Invocation::QueryStatus() {
+ IREE_TRACE_SCOPE0("Invocation::QueryStatus");
+ absl::MutexLock lock(&status_mutex_);
+ return completion_status_;
+}
+
+StatusOr<absl::InlinedVector<hal::BufferView, 8>> Invocation::ConsumeResults() {
+ IREE_TRACE_SCOPE0("Invocation::ConsumeResults");
+ absl::MutexLock lock(&status_mutex_);
+ if (!completion_status_.ok()) {
+ return completion_status_;
+ }
+ return std::move(results_);
+}
+
+Status Invocation::Await(absl::Time deadline) {
+ IREE_TRACE_SCOPE0("Invocation::Await");
+ absl::MutexLock lock(&status_mutex_);
+ // TODO(benvanik): implement async invocation behavior.
+ return completion_status_;
+}
+
+Status Invocation::Abort() {
+ IREE_TRACE_SCOPE0("Invocation::Abort");
+ // TODO(benvanik): implement async invocation behavior.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Async invocations not yet implemented";
+}
+
+void Invocation::CompleteSuccess(
+ absl::InlinedVector<hal::BufferView, 8> results) {
+ IREE_TRACE_SCOPE0("Invocation::CompleteSuccess");
+ absl::MutexLock lock(&status_mutex_);
+ if (IsAborted(completion_status_)) {
+ // Ignore as the invocation was already aborted prior to completion.
+ return;
+ }
+ DCHECK(IsUnavailable(completion_status_));
+ completion_status_ = OkStatus();
+ failure_stack_trace_.reset();
+ results_ = std::move(results);
+}
+
+void Invocation::CompleteFailure(
+ Status completion_status, std::unique_ptr<StackTrace> failure_stack_trace) {
+ IREE_TRACE_SCOPE0("Invocation::CompleteFailure");
+ absl::MutexLock lock(&status_mutex_);
+ if (IsAborted(completion_status_)) {
+ // Ignore as the invocation was already aborted prior to completion.
+ return;
+ }
+ DCHECK(IsUnavailable(completion_status_));
+ completion_status_ = std::move(completion_status);
+ failure_stack_trace_ = std::move(failure_stack_trace);
+ results_.clear();
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/rt/invocation.h b/rt/invocation.h
new file mode 100644
index 0000000..3226ea3
--- /dev/null
+++ b/rt/invocation.h
@@ -0,0 +1,157 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_INVOCATION_H_
+#define IREE_RT_INVOCATION_H_
+
+#include <ostream>
+#include <string>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "base/intrusive_list.h"
+#include "base/ref_ptr.h"
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "rt/function.h"
+#include "rt/policy.h"
+#include "rt/stack.h"
+#include "rt/stack_trace.h"
+
+namespace iree {
+namespace rt {
+
+class Context;
+
+// An asynchronous invocation of a function.
+// Holds the invocation state and allows querying and waiting on completion.
+// Invocations are conceptually fibers and may suspend and resume execution
+// several times before completing.
+//
+// Thread-safe.
+class Invocation final : public RefObject<Invocation> {
+ public:
+ // TODO(benvanik): define error propagation semantics across dependencies.
+ // TODO(benvanik): support more dependency types (semaphores, etc).
+ // Creates a new invocation tracking object for invoking the given |function|
+ // from |context|. |arguments| will be retained until the invocation is made.
+ // If |dependencies| are provided then the invocation will wait until they are
+ // resolved before executing. If a |policy| is provided it will override the
+ // context-level policy.
+ //
+ // Optionally |results| may be provided with preallocated buffers that will
+ // receive the outputs of the invocation. Invocation will fail if they do not
+ // match expected sizes.
+ //
+ // Note that it's possible for the invocation to complete prior to the return
+ // of this function. Any errors that occur will be set on the invocation and
+ // callers should query its state prior to assuming it is in-flight.
+ static StatusOr<ref_ptr<Invocation>> Create(
+ ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy,
+ absl::InlinedVector<ref_ptr<Invocation>, 4> dependencies,
+ absl::InlinedVector<hal::BufferView, 8> arguments,
+ absl::optional<absl::InlinedVector<hal::BufferView, 8>> results =
+ absl::nullopt);
+ static StatusOr<ref_ptr<Invocation>> Create(
+ ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy,
+ absl::Span<const ref_ptr<Invocation>> dependencies,
+ absl::Span<const hal::BufferView> arguments);
+
+ ~Invocation();
+
+ // A process-unique ID for the invocation.
+ int32_t id() const { return id_; }
+
+ // Context this invocation is running within.
+ const ref_ptr<Context>& context() const { return context_; }
+
+ // Function being invoked.
+ const Function& function() const { return function_; }
+
+ // A single-line human-readable debug string for the invocation.
+ std::string DebugStringShort() const;
+
+ // A long-form debug string with stack trace (if available).
+ std::string DebugString() const;
+
+ // Queries the completion status of the invocation.
+ // Returns one of the following:
+ // StatusCode::kOk: the invocation completed successfully.
+ // StatusCode::kUnavailable: the invocation has not yet completed.
+ // StatusCode::kCancelled: the invocation was cancelled internally.
+ // StatusCode::kAborted: the invocation was aborted.
+ // StatusCode::*: an error occurred during invocation.
+ Status QueryStatus();
+
+ // Returns ownership of the results of the operation to the caller.
+ // If the invocation failed then the result will be returned as if Query had
+ // been called.
+ StatusOr<absl::InlinedVector<hal::BufferView, 8>> ConsumeResults();
+
+ // Blocks the caller until the invocation completes (successfully or
+ // otherwise).
+ //
+ // Returns StatusCode::kDeadlineExceeded if |deadline| elapses before the
+ // invocation completes and otherwise returns the result of Query().
+ Status Await(absl::Time deadline);
+
+ // Attempts to abort the invocation if it is in-flight.
+ // A no-op if the invocation has already completed.
+ Status Abort();
+
+ // TODO(benvanik): export a hal::TimelineSemaphore.
+
+ private:
+ friend class Context;
+
+ Invocation(ref_ptr<Context> context, const Function function,
+ ref_ptr<Policy> policy);
+
+ // Completes the invocation with a successful result.
+ void CompleteSuccess(absl::InlinedVector<hal::BufferView, 8> results);
+
+ // Completes the invocation with a failure, including an optional stack trace.
+ void CompleteFailure(Status completion_status,
+ std::unique_ptr<StackTrace> failure_stack_trace);
+
+ int32_t id_;
+ ref_ptr<Context> context_;
+ const Function function_;
+ ref_ptr<Policy> policy_;
+
+ Stack stack_;
+
+ absl::Mutex status_mutex_;
+ Status completion_status_ ABSL_GUARDED_BY(status_mutex_) =
+ UnavailableErrorBuilder(IREE_LOC);
+ std::unique_ptr<StackTrace> failure_stack_trace_
+ ABSL_GUARDED_BY(status_mutex_);
+ absl::InlinedVector<hal::BufferView, 8> results_
+ ABSL_GUARDED_BY(status_mutex_);
+
+ friend class Context;
+ IntrusiveListLink context_list_link_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const Invocation& invocation) {
+ stream << invocation.DebugStringShort();
+ return stream;
+}
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_INVOCATION_H_
diff --git a/rt/module.h b/rt/module.h
new file mode 100644
index 0000000..777d454
--- /dev/null
+++ b/rt/module.h
@@ -0,0 +1,109 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_MODULE_H_
+#define IREE_RT_MODULE_H_
+
+#include <ostream>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/string_view.h"
+#include "base/ref_ptr.h"
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "rt/function.h"
+#include "rt/module_signature.h"
+
+namespace iree {
+namespace rt {
+
+class Disassembler;
+class SourceResolver;
+class Stack;
+
+// Abstract compiled module interface for resolving functions.
+//
+// Modules are (generally) stateless, immutable, and may exist in multiple
+// contexts at the same time.
+class Module : public RefObject<Module> {
+ public:
+ virtual ~Module() = default;
+
+ // Name of the module used to resolve fully-qualified references.
+ // The lifetime of the returned reference is not guaranteed beyond the current
+ // calling scope and callers must clone it if they want to retain it.
+ virtual absl::string_view name() const = 0;
+
+ // A description of the module imports, exports, and other metadata.
+ virtual const ModuleSignature signature() const = 0;
+
+ // Returns a resolver capable of resolving functions to source and performing
+ // basic debugging logic (such as offset calculation).
+ // May be nullptr if debugging info has been stripped.
+ virtual SourceResolver* source_resolver() const = 0;
+
+ // Returns a disassembler that can be used to disassemble functions in the
+ // module. May be nullptr if debugging info has been stripped or disassembly
+ // has been disabled as a compile option.
+ virtual Disassembler* disassembler() const = 0;
+
+ // A short human-readable string that matches the compiler formatting.
+ virtual std::string DebugStringShort() const = 0;
+
+ // Looks up a visible function by ordinal.
+ // Internal functions may not be found if debugging info has been stripped.
+ virtual StatusOr<const Function> LookupFunctionByOrdinal(
+ Function::Linkage linkage, int32_t ordinal) const = 0;
+
+ // Looks up a visible function by name.
+ // Internal functions may not be found if debugging info has been stripped.
+ virtual StatusOr<const Function> LookupFunctionByName(
+ Function::Linkage linkage, absl::string_view name) const = 0;
+
+ // Returns the name of the visible function as a string reference.
+ //
+ // May return empty for functions with internal linkage if debugging info has
+ // been stripped.
+ //
+ // The lifetime of the returned reference is not guaranteed beyond the current
+ // calling scope and callers must clone it if they want to retain it.
+ virtual StatusOr<absl::string_view> GetFunctionName(
+ Function::Linkage linkage, int32_t ordinal) const = 0;
+
+ // Returns the full function signature for the given |ordinal|.
+ //
+ // May return empty for functions with internal linkage if the debugging info
+ // has been stripped.
+ virtual StatusOr<const FunctionSignature> GetFunctionSignature(
+ Function::Linkage linkage, int32_t ordinal) const = 0;
+
+ // Temporary until scheduler is built.
+ virtual Status Execute(
+ Stack* stack, const Function function,
+ absl::InlinedVector<hal::BufferView, 8> arguments,
+ absl::InlinedVector<hal::BufferView, 8>* results) const = 0;
+
+ protected:
+ Module() = default;
+};
+
+inline std::ostream& operator<<(std::ostream& stream, const Module& module) {
+ stream << module.DebugStringShort();
+ return stream;
+}
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_MODULE_H_
diff --git a/rt/module_printer.cc b/rt/module_printer.cc
new file mode 100644
index 0000000..a5ff3e6
--- /dev/null
+++ b/rt/module_printer.cc
@@ -0,0 +1,65 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/module_printer.h"
+
+#include <iomanip>
+
+#include "rt/disassembler.h"
+#include "rt/source_resolver.h"
+
+namespace iree {
+namespace rt {
+
+Status PrintModuleToStream(const Module& module, PrintModuleFlagBitfield flags,
+ std::ostream* stream) {
+ *stream << "Imports:\n";
+ for (int i = 0; i < module.signature().import_function_count(); ++i) {
+ ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal(
+ Function::Linkage::kImport, i));
+ *stream << " " << i << ": " << function << "\n";
+ }
+ *stream << "Exports:\n";
+ for (int i = 0; i < module.signature().export_function_count(); ++i) {
+ ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal(
+ Function::Linkage::kExport, i));
+ *stream << " " << i << ": " << function << "\n";
+ }
+ if (module.signature().internal_function_count()) {
+ *stream << "Internal:\n";
+ auto* disassembler = module.disassembler();
+ for (int i = 0; i < module.signature().internal_function_count(); ++i) {
+ ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal(
+ Function::Linkage::kInternal, i));
+ *stream << " " << i << ": " << function << "\n";
+ if (disassembler && AllBitsSet(flags, PrintModuleFlag::kDisassemble)) {
+ auto instructions_or =
+ disassembler->DisassembleInstructions(function, 0);
+ if (IsUnavailable(instructions_or.status())) continue;
+ for (const auto& instruction : instructions_or.ValueOrDie()) {
+ *stream << " " << std::setw(6) << instruction.offset << ": "
+ << instruction.long_text << "\n";
+ }
+ }
+ }
+ }
+ return OkStatus();
+}
+
+Status PrintModuleToStream(const Module& module, std::ostream* stream) {
+ return PrintModuleToStream(module, PrintModuleFlag::kNone, stream);
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/rt/module_printer.h b/rt/module_printer.h
new file mode 100644
index 0000000..48abcff
--- /dev/null
+++ b/rt/module_printer.h
@@ -0,0 +1,42 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_MODULE_PRINTER_H_
+#define IREE_RT_MODULE_PRINTER_H_
+
+#include <ostream>
+
+#include "base/bitfield.h"
+#include "base/status.h"
+#include "rt/module.h"
+
+namespace iree {
+namespace rt {
+
+enum class PrintModuleFlag {
+ kNone = 0,
+ kDisassemble = 1 << 0,
+};
+IREE_BITFIELD(PrintModuleFlag);
+using PrintModuleFlagBitfield = PrintModuleFlag;
+
+// Prints all functions within the module to the given |stream|.
+Status PrintModuleToStream(const Module& module, std::ostream* stream);
+Status PrintModuleToStream(const Module& module, PrintModuleFlagBitfield flags,
+ std::ostream* stream);
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_MODULE_PRINTER_H_
diff --git a/iree/rt/module_signature.h b/rt/module_signature.h
similarity index 100%
rename from iree/rt/module_signature.h
rename to rt/module_signature.h
diff --git a/rt/policy.h b/rt/policy.h
new file mode 100644
index 0000000..0c3c8f9
--- /dev/null
+++ b/rt/policy.h
@@ -0,0 +1,43 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_POLICY_H_
+#define IREE_RT_POLICY_H_
+
+#include "base/ref_ptr.h"
+
+namespace iree {
+namespace rt {
+
+// Defines how invocation scheduling is to be performed.
+// The policy instance is used by the scheduler to determine when submissions
+// should be flushed to target queues.
+//
+// Thread-safe; the policy may be evaluated from arbitrary threads after an
+// invocation has began processing.
+class Policy : public RefObject<Policy> {
+ public:
+ virtual ~Policy() = default;
+
+ // TODO(benvanik): constraints:
+ // - max memory usage
+ // - max delay
+ // - max in-flight items/etc
+ // - allowed device types
+};
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_POLICY_H_
diff --git a/rt/source_location.cc b/rt/source_location.cc
new file mode 100644
index 0000000..0cf035b
--- /dev/null
+++ b/rt/source_location.cc
@@ -0,0 +1,32 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/source_location.h"
+
+#include <sstream>
+
+#include "rt/source_resolver.h"
+
+namespace iree {
+namespace rt {
+
+std::string SourceLocation::DebugStringShort() const {
+ if (is_unknown()) return "(unknown)";
+ std::ostringstream stream;
+ resolver_->PrintSourceLocation(resolver_args_, &stream);
+ return stream.str();
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/iree/rt/source_location.h b/rt/source_location.h
similarity index 100%
rename from iree/rt/source_location.h
rename to rt/source_location.h
diff --git a/rt/source_resolver.h b/rt/source_resolver.h
new file mode 100644
index 0000000..3b27aaf
--- /dev/null
+++ b/rt/source_resolver.h
@@ -0,0 +1,67 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_SOURCE_RESOLVER_H_
+#define IREE_RT_SOURCE_RESOLVER_H_
+
+#include <cstdint>
+#include <ostream>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "base/status.h"
+#include "rt/function.h"
+#include "rt/source_location.h"
+
+namespace iree {
+namespace rt {
+
+// Resolves offsets within functions to SourceLocations and provides source
+// language services.
+//
+// Thread-safe.
+class SourceResolver {
+ public:
+ virtual ~SourceResolver() = default;
+
+ // Resolves a function-relative offset to a source location.
+ // Not all offsets within a function may have source mapping information.
+ virtual absl::optional<SourceLocation> ResolveFunctionOffset(
+ const Function& function, SourceOffset offset) = 0;
+
+ // Converts a source location to a human-readable string, commonly in a single
+ // line denoting an original source file location (such as path:line:col).
+ virtual void PrintSourceLocation(SourceResolverArgs resolver_args,
+ std::ostream* stream) const = 0;
+
+ // TODO(benvanik): query local variable names.
+
+ // TODO(benvanik): step target calculation (relative mapping).
+ // TODO(benvanik): step target based on SourceLocation delta.
+
+ // TODO(benvanik): expression evaluator? (setting variables)
+
+ protected:
+ friend class SourceLocation;
+
+ SourceResolver() = default;
+
+ // TODO(benvanik): get line mapping information.
+};
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_SOURCE_RESOLVER_H_
diff --git a/rt/stack.cc b/rt/stack.cc
new file mode 100644
index 0000000..6e0cd13
--- /dev/null
+++ b/rt/stack.cc
@@ -0,0 +1,60 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/stack.h"
+
+#include <iterator>
+
+#include "absl/strings/str_join.h"
+#include "base/status.h"
+
+namespace iree {
+namespace rt {
+
+constexpr int Stack::kMaxStackDepth;
+
+Stack::Stack(Context* context) : context_(context) {}
+
+Stack::~Stack() = default;
+
+StatusOr<StackFrame*> Stack::PushFrame(Function function) {
+ if (stack_depth_ + 1 > kMaxStackDepth) {
+ return InternalErrorBuilder(IREE_LOC)
+ << "Max stack depth of " << kMaxStackDepth << " exceeded";
+ }
+ frames_[stack_depth_++] = StackFrame(function);
+
+ // TODO(benvanik): WTF scope enter.
+
+ return current_frame();
+}
+
+Status Stack::PopFrame() {
+ if (stack_depth_ == 0) {
+ return InternalErrorBuilder(IREE_LOC) << "Unbalanced stack pop";
+ }
+
+ // TODO(benvanik): WTF scope leave.
+
+ --stack_depth_;
+ frames_[stack_depth_] = {};
+ return OkStatus();
+}
+
+std::string Stack::DebugString() const {
+ return absl::StrJoin(frames(), "\n", StackFrameFormatter());
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/rt/stack.h b/rt/stack.h
new file mode 100644
index 0000000..34a2c8d
--- /dev/null
+++ b/rt/stack.h
@@ -0,0 +1,85 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_STACK_H_
+#define IREE_RT_STACK_H_
+
+#include <functional>
+
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "rt/stack_frame.h"
+
+namespace iree {
+namespace rt {
+
+class Context;
+
+// A runtime call stack for managing stack frames.
+// The frames within a stack may be from different backends and may provide
+// varying levels of information based on capabilities.
+//
+// Thread-compatible. Do not attempt to investigate a stack while another thread
+// may be mutating it!
+class Stack final {
+ public:
+ static constexpr int kMaxStackDepth = 32;
+
+ explicit Stack(Context* context);
+ Stack(const Stack&) = delete;
+ Stack& operator=(const Stack&) = delete;
+ ~Stack();
+
+ // Context defining the module and global workspaces.
+ Context* context() const { return context_; }
+
+ // All stack frames within the stack.
+ absl::Span<StackFrame> frames() {
+ return absl::MakeSpan(frames_).subspan(0, stack_depth_);
+ }
+ absl::Span<const StackFrame> frames() const {
+ return absl::MakeConstSpan(frames_).subspan(0, stack_depth_);
+ }
+
+ // The current stack frame.
+ StackFrame* current_frame() {
+ return stack_depth_ > 0 ? &frames_[stack_depth_ - 1] : nullptr;
+ }
+
+ // The stack frame of the caller of the current function.
+ StackFrame* caller_frame() {
+ return stack_depth_ > 1 ? &frames_[stack_depth_ - 2] : nullptr;
+ }
+
+ StatusOr<StackFrame*> PushFrame(Function function);
+ Status PopFrame();
+
+ // Returns a full stack frame listing in human-readable form.
+ std::string DebugString() const;
+
+ private:
+ Context* context_ = nullptr;
+ std::array<StackFrame, kMaxStackDepth> frames_;
+ int stack_depth_ = 0;
+};
+
+inline std::ostream& operator<<(std::ostream& stream, const Stack& stack) {
+ stream << stack.DebugString();
+ return stream;
+}
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_STACK_H_
diff --git a/rt/stack_frame.cc b/rt/stack_frame.cc
new file mode 100644
index 0000000..6faf9e7
--- /dev/null
+++ b/rt/stack_frame.cc
@@ -0,0 +1,34 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/stack_frame.h"
+
+#include "absl/strings/str_cat.h"
+#include "rt/source_resolver.h"
+
+namespace iree {
+namespace rt {
+
+absl::optional<SourceLocation> StackFrame::source_location() const {
+ auto* source_resolver = function_.module()->source_resolver();
+ if (!source_resolver) return absl::nullopt;
+ return source_resolver->ResolveFunctionOffset(function_, offset_);
+}
+
+std::string StackFrame::DebugStringShort() const {
+ return absl::StrCat(module().name(), ":", function().name(), "@", offset());
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/rt/stack_frame.h b/rt/stack_frame.h
new file mode 100644
index 0000000..f2014b8
--- /dev/null
+++ b/rt/stack_frame.h
@@ -0,0 +1,106 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_STACK_FRAME_H_
+#define IREE_RT_STACK_FRAME_H_
+
+#include <ostream>
+
+#include "absl/types/span.h"
+#include "rt/function.h"
+#include "rt/module.h"
+#include "rt/source_location.h"
+
+namespace iree {
+namespace rt {
+
+// TODO(benvanik): allocate in-place from an arena.
+// Register table used within a stack frame.
+struct Registers {
+ std::vector<hal::BufferView> buffer_views;
+};
+
+// A single frame on the call stack containing current execution state and
+// register values.
+//
+// As different backends may support different features this interface exposes
+// only the things we want to view in our debugger/stack dumps. This allows us
+// to ignore the actual implementation (bytecode VM, compiled C code, etc) so
+// long as it can respond to queries for register values. This has the benefit
+// of keeping the actual frame very lightweight as we are not storing the values
+// but instead just routing to the real storage via indirection. If the debugger
+// is not attached and no errors are hit then no additional bookkeeping is done.
+//
+// Thread-compatible, as is the owning Stack/StackTrace.
+class StackFrame final {
+ public:
+ StackFrame() = default;
+ explicit StackFrame(Function function) : function_(function) {}
+ StackFrame(Function function, SourceOffset offset, Registers registers)
+ : function_(function),
+ offset_(offset),
+ registers_(std::move(registers)) {}
+ StackFrame(const StackFrame&) = delete;
+ StackFrame& operator=(const StackFrame&) = delete;
+ StackFrame(StackFrame&&) = default;
+ StackFrame& operator=(StackFrame&&) = default;
+
+ // Module that owns the function this stack frame represents.
+ const Module& module() const { return *function_.module(); }
+
+ // Function the stack frame represents.
+ const Function& function() const { return function_; }
+
+ // Current virtual offset within the function.
+ // The exact meaning of the offset is backend dependent and callers should
+ // treat them as opaque and must use the SourceResolver to compute new
+ // offsets (such as 'next offset').
+ SourceOffset offset() const { return offset_; }
+ SourceOffset* mutable_offset() { return &offset_; }
+
+ // Returns a source location, if available, for the current offset within the
+ // target function.
+ absl::optional<SourceLocation> source_location() const;
+
+ // Registers used within the stack frame.
+ // Storage is implementation-defined and is valid only for the lifetime of the
+ // frame.
+ const Registers& registers() const { return registers_; }
+ Registers* mutable_registers() { return ®isters_; }
+
+ // A short human-readable string for the frame; a single line.
+ std::string DebugStringShort() const;
+
+ private:
+ Function function_;
+ SourceOffset offset_ = 0;
+ Registers registers_;
+};
+
+struct StackFrameFormatter {
+ void operator()(std::string* out, const StackFrame& stack_frame) const {
+ out->append(stack_frame.DebugStringShort());
+ }
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const StackFrame& stack_frame) {
+ stream << stack_frame.DebugStringShort();
+ return stream;
+}
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_STACK_FRAME_H_
diff --git a/rt/stack_trace.cc b/rt/stack_trace.cc
new file mode 100644
index 0000000..a3ad17f
--- /dev/null
+++ b/rt/stack_trace.cc
@@ -0,0 +1,28 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "rt/stack_trace.h"
+
+#include "absl/strings/str_join.h"
+#include "base/status.h"
+
+namespace iree {
+namespace rt {
+
+std::string StackTrace::DebugString() const {
+ return absl::StrJoin(frames_, "\n", StackFrameFormatter());
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/rt/stack_trace.h b/rt/stack_trace.h
new file mode 100644
index 0000000..56d0d47
--- /dev/null
+++ b/rt/stack_trace.h
@@ -0,0 +1,65 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_RT_STACK_TRACE_H_
+#define IREE_RT_STACK_TRACE_H_
+
+#include <ostream>
+#include <vector>
+
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "rt/stack_frame.h"
+
+namespace iree {
+namespace rt {
+
+// A snapshot of a stack at a point in time.
+// The frames within a stack may be from different backends and may provide
+// varying levels of information based on capabilities.
+//
+// Depending on the capture options the trace may contain references to register
+// values (such as buffers) from the time of capture. If the buffers were
+// modified after the capture was taken those results will be reflected!
+class StackTrace final {
+ public:
+ StackTrace() = default;
+ explicit StackTrace(std::vector<StackFrame> frames)
+ : frames_(std::move(frames)) {}
+ StackTrace(const StackTrace&) = delete;
+ StackTrace& operator=(const StackTrace&) = delete;
+ ~StackTrace() = default;
+
+ // All stack frames within the stack.
+ absl::Span<const StackFrame> frames() const {
+ return absl::MakeConstSpan(frames_);
+ }
+
+ // Returns a full stack frame listing in human-readable form.
+ std::string DebugString() const;
+
+ private:
+ std::vector<StackFrame> frames_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const StackTrace& stack_trace) {
+ stream << stack_trace.DebugString();
+ return stream;
+}
+
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_RT_STACK_TRACE_H_
diff --git a/iree/samples/CMakeLists.txt b/samples/CMakeLists.txt
similarity index 100%
rename from iree/samples/CMakeLists.txt
rename to samples/CMakeLists.txt
diff --git a/samples/hal/BUILD b/samples/hal/BUILD
new file mode 100644
index 0000000..0976831
--- /dev/null
+++ b/samples/hal/BUILD
@@ -0,0 +1,48 @@
+# Samples demonstrating use of the HAL API.
+# These do not rely on higher layers of the system (such as the VM or runtime).
+
+load("//:build_defs.google.bzl", "PLATFORM_VULKAN_TEST_DEPS")
+load("///tools:compilation.bzl", "iree_bytecode_module")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_bytecode_module(
+ name = "simple_compute_test_module",
+ srcs = ["simple_compute_test.mlir"],
+ cc_namespace = "iree::hal::samples",
+)
+
+cc_test(
+ name = "simple_compute_test",
+ srcs = ["simple_compute_test.cc"],
+ data = [
+ # When building with --config=asan you must specify the following
+ # envvar when using Vulkan + a local Nvidia GPU:
+ # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt
+ "///tools:sanitizer_suppressions.txt",
+ ],
+ deps = [
+ ":simple_compute_test_module_cc",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "///base:flatbuffer_util",
+ "///base:status_matchers",
+ "///hal:command_buffer",
+ "///hal:command_queue",
+ "///hal:driver_registry",
+ "///schemas",
+ "///base:status",
+
+ # These are the drivers we support running with and can produce
+ # executables for from the source MLIR.
+ "///hal/interpreter:interpreter_driver_module", # build-cleaner: keep
+ "///hal/vulkan:vulkan_driver_module", # build-cleaner: keep
+
+ # TODO(b/142004903): enable when Dawn HAL implementation is functional
+ # "///hal/dawn:dawn_driver_module", # build-cleaner: keep
+ ] + PLATFORM_VULKAN_TEST_DEPS,
+)
diff --git a/iree/samples/hal/CMakeLists.txt b/samples/hal/CMakeLists.txt
similarity index 100%
rename from iree/samples/hal/CMakeLists.txt
rename to samples/hal/CMakeLists.txt
diff --git a/samples/hal/simple_compute_test.cc b/samples/hal/simple_compute_test.cc
new file mode 100644
index 0000000..1c34905
--- /dev/null
+++ b/samples/hal/simple_compute_test.cc
@@ -0,0 +1,221 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// A simple backend-agnostic compute test for the HAL API.
+// This will load an IREE module containing one or more executables and attempt
+// to run them against all registered driver backends.
+//
+// The input file, simple_compute_test.mlir, is as generic as possible to ensure
+// we don't need too many variants. This means that it does not use any FFI
+// imports requiring runtime support, uses floats exclusively (as that's assumed
+// available everywhere), etc.
+//
+// The `iree_bytecode_module` build rule is used to translate the MLIR to the
+// module flatbuffer. Additional target support can be defined there.
+
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_replace.h"
+#include "absl/time/time.h"
+#include "base/flatbuffer_util.h"
+#include "base/status.h"
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "hal/command_buffer.h"
+#include "hal/command_queue.h"
+#include "hal/driver_registry.h"
+#include "samples/hal/simple_compute_test_module.h"
+#include "schemas/module_def_generated.h"
+
+namespace iree {
+namespace hal {
+namespace samples {
+namespace {
+
+using ModuleFile = FlatBufferFile<ModuleDef>;
+
+struct TestParams {
+ // HAL driver to use for the test.
+ std::string driver_name;
+ // Ordinal within the module to execute.
+ int executable_ordinal;
+ // Name of the executable (just for prettier logging).
+ std::string executable_name;
+};
+
+std::ostream& operator<<(std::ostream& os, const TestParams& params) {
+ return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}}) << "_ex"
+ << params.executable_ordinal << "_" << params.executable_name;
+}
+
+// Loads the precompiled module file (from simple_compute_test.mlir).
+std::unique_ptr<ModuleFile> LoadModuleFile() {
+ const auto* file_toc = simple_compute_test_module_create();
+ return ModuleFile::WrapBuffer(
+ ModuleDefIdentifier(),
+ absl::MakeSpan(reinterpret_cast<const uint8_t*>(file_toc->data),
+ file_toc->size))
+ .ValueOrDie();
+}
+
+// Builds a list of tests to run for each [driver x available executable].
+std::vector<TestParams> GetAvailableDriverTestParams() {
+ auto module_file = LoadModuleFile();
+ auto& executable_table = *module_file->root()->executable_table();
+ std::vector<TestParams> all_test_params;
+ for (const auto& driver_name :
+ DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) {
+ int executable_ordinal = 0;
+ for (const auto* multi_arch_executable_def :
+ *executable_table.multi_arch_executables()) {
+ TestParams test_params;
+ test_params.driver_name = driver_name;
+ test_params.executable_ordinal = executable_ordinal--;
+ test_params.executable_name =
+ std::string(WrapString(multi_arch_executable_def->name()));
+ all_test_params.push_back(std::move(test_params));
+ }
+ }
+ return all_test_params;
+}
+
+class SimpleComputeTest : public ::testing::Test,
+ public ::testing::WithParamInterface<TestParams> {
+ protected:
+ virtual void SetUp() { module_file_ = LoadModuleFile(); }
+
+ std::unique_ptr<ModuleFile> module_file_;
+};
+
+TEST_P(SimpleComputeTest, RunOnce) {
+ const auto& test_params = GetParam();
+
+ // Create driver for this test (based on params) and then get a default
+ // device.
+ LOG(INFO) << "Creating driver '" << test_params.driver_name << "'...";
+ auto driver_or =
+ DriverRegistry::shared_registry()->Create(test_params.driver_name);
+ if (IsUnavailable(driver_or.status())) {
+ LOG(WARNING) << "Skipping test as driver is unavailable: "
+ << driver_or.status();
+ GTEST_SKIP();
+ return;
+ }
+ ASSERT_OK_AND_ASSIGN(auto driver, driver_or);
+ ASSERT_OK_AND_ASSIGN(auto available_devices,
+ driver->EnumerateAvailableDevices());
+ for (const auto& device_info : available_devices) {
+ LOG(INFO) << " Device: " << device_info.name();
+ }
+ LOG(INFO) << "Creating default device...";
+ ASSERT_OK_AND_ASSIGN(auto device, driver->CreateDefaultDevice());
+ LOG(INFO) << "Successfully created device '" << device->info().name() << "'";
+
+ // Attempt to compile the appropriate executable. This may fail if there's no
+ // executable available in the input file that the driver can load.
+ auto executable_cache = device->CreateExecutableCache();
+ auto& executable_table = *module_file_->root()->executable_table();
+ auto multi_arch_executable_def =
+ executable_table.multi_arch_executables()->Get(
+ test_params.executable_ordinal);
+ ref_ptr<Executable> executable;
+ for (auto executable_def : *multi_arch_executable_def->executables()) {
+ if (!executable_cache->CanPrepareFormat(executable_def->format())) {
+ continue;
+ }
+ ExecutableSpec spec;
+ spec.format = executable_def->format();
+ spec.executable_data = *executable_def->contents();
+ ASSERT_OK_AND_ASSIGN(executable,
+ executable_cache->PrepareExecutable(
+ ExecutableCachingMode::kDefault, spec));
+ break;
+ }
+ ASSERT_NE(executable, nullptr)
+ << "No executable found that has a supported format for driver "
+ << test_params.driver_name;
+
+ // Create I/O buffers.
+ ASSERT_OK_AND_ASSIGN(auto arg0_buffer,
+ device->allocator()->Allocate(
+ MemoryType::kHostLocal | MemoryType::kDeviceVisible,
+ BufferUsage::kAll, 4 * sizeof(float)));
+ ASSERT_OK_AND_ASSIGN(auto arg1_buffer,
+ device->allocator()->Allocate(
+ MemoryType::kHostLocal | MemoryType::kDeviceVisible,
+ BufferUsage::kAll, 4 * sizeof(float)));
+ ASSERT_OK_AND_ASSIGN(auto ret0_buffer,
+ device->allocator()->Allocate(
+ MemoryType::kHostLocal | MemoryType::kDeviceVisible,
+ BufferUsage::kAll, 4 * sizeof(float)));
+
+ // Populate initial values for 4 * 2 = 8.
+ // We scribble into the result buffer so that it's easy to ensure it's
+ // overwritten.
+ ASSERT_OK(arg0_buffer->Fill32(4.0f));
+ ASSERT_OK(arg1_buffer->Fill32(2.0f));
+ ASSERT_OK(ret0_buffer->Fill32(99999.0f));
+
+ // Record the command buffer that dispatches the executable.
+ ASSERT_OK_AND_ASSIGN(
+ auto cmd, device->CreateCommandBuffer(
+ CommandBufferMode::kOneShot,
+ CommandCategory::kTransfer | CommandCategory::kDispatch));
+ ASSERT_OK(cmd->Begin());
+ DispatchRequest dispatch_request;
+ dispatch_request.executable = executable.get();
+ dispatch_request.entry_point = 0;
+ dispatch_request.workload[0] = 4;
+ dispatch_request.workload[1] = 1;
+ dispatch_request.workload[2] = 1;
+ BufferBinding bindings[3];
+ bindings[0].buffer = arg0_buffer.get();
+ bindings[0].access = MemoryAccess::kRead;
+ bindings[0].element_size = sizeof(float);
+ bindings[0].shape = {4};
+ bindings[1].buffer = arg1_buffer.get();
+ bindings[1].access = MemoryAccess::kRead;
+ bindings[1].element_size = sizeof(float);
+ bindings[1].shape = {4};
+ bindings[2].buffer = ret0_buffer.get();
+ bindings[2].access = MemoryAccess::kDiscardWrite;
+ bindings[2].element_size = sizeof(float);
+ bindings[2].shape = {4};
+ dispatch_request.bindings = bindings;
+ ASSERT_OK(cmd->Dispatch(dispatch_request));
+ ASSERT_OK(cmd->End());
+
+ // Schedule and wait for completion.
+ ASSERT_FALSE(device->dispatch_queues().empty());
+ CommandQueue* queue = device->dispatch_queues().front();
+ ASSERT_OK_AND_ASSIGN(auto fence, device->CreateFence(0u));
+ ASSERT_OK(
+ queue->Submit(SubmissionBatch{{}, {cmd.get()}, {}}, {fence.get(), 1u}));
+ ASSERT_OK(device->WaitAllFences({{fence.get(), 1u}}, absl::InfiniteFuture()));
+
+ // Read back the results.
+ ASSERT_OK_AND_ASSIGN(auto ret0_mapping,
+ ret0_buffer->MapMemory<float>(MemoryAccess::kRead));
+ EXPECT_THAT(ret0_mapping.contents(),
+ ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f}));
+}
+
+INSTANTIATE_TEST_SUITE_P(AllDrivers, SimpleComputeTest,
+ ::testing::ValuesIn(GetAvailableDriverTestParams()),
+ ::testing::PrintToStringParamName());
+
+} // namespace
+} // namespace samples
+} // namespace hal
+} // namespace iree
diff --git a/iree/samples/hal/simple_compute_test.mlir b/samples/hal/simple_compute_test.mlir
similarity index 100%
rename from iree/samples/hal/simple_compute_test.mlir
rename to samples/hal/simple_compute_test.mlir
diff --git a/samples/rt/BUILD b/samples/rt/BUILD
new file mode 100644
index 0000000..72b7343
--- /dev/null
+++ b/samples/rt/BUILD
@@ -0,0 +1,72 @@
+# Samples demonstrating use of the RT API.
+
+load("///tools:compilation.bzl", "iree_bytecode_module")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_bytecode_module(
+ name = "simple_module_test_bytecode_module",
+ srcs = ["simple_module_test.mlir"],
+ cc_namespace = "iree::rt::samples",
+)
+
+cc_test(
+ name = "bytecode_module_test",
+ srcs = ["bytecode_module_test.cc"],
+ data = [
+ # When building with --config=asan you must specify the following
+ # envvar when using Vulkan + a local Nvidia GPU:
+ # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt
+ "///tools:sanitizer_suppressions.txt",
+ ],
+ deps = [
+ ":simple_module_test_bytecode_module_cc",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_absl//absl/strings",
+ "///base:flatbuffer_util",
+ "///base:status",
+ "///base:status_matchers",
+ "///hal:buffer_view",
+ "///hal:driver_registry",
+ "///rt",
+ "///schemas",
+ "///vm:bytecode_module",
+ "///vm:sequencer_module",
+
+ # These are the drivers we support running with and can produce
+ # executables for from the source MLIR.
+ "///hal/interpreter:interpreter_driver_module", # build-cleaner: keep
+ # TODO(benvanik): include SPIR-V.
+ # "///hal/vulkan:vulkan_driver_module", # build-cleaner: keep
+ ],
+)
+
+cc_test(
+ name = "bytecode_module_api_test",
+ srcs = ["bytecode_module_api_test.cc"],
+ data = [
+ # When building with --config=asan you must specify the following
+ # envvar when using Vulkan + a local Nvidia GPU:
+ # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt
+ "///tools:sanitizer_suppressions.txt",
+ ],
+ deps = [
+ ":simple_module_test_bytecode_module_cc",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_absl//absl/strings",
+ "///base:api",
+ "///hal:api",
+ "///hal:driver_registry",
+ "///rt:api",
+ "///vm:api",
+
+ # These are the drivers we support running with and can produce
+ # executables for from the source MLIR.
+ "///hal/interpreter:interpreter_driver_module", # build-cleaner: keep
+ # TODO(benvanik): include SPIR-V.
+ # "///hal/vulkan:vulkan_driver_module", # build-cleaner: keep
+ ],
+)
diff --git a/iree/samples/rt/CMakeLists.txt b/samples/rt/CMakeLists.txt
similarity index 100%
rename from iree/samples/rt/CMakeLists.txt
rename to samples/rt/CMakeLists.txt
diff --git a/samples/rt/bytecode_module_api_test.cc b/samples/rt/bytecode_module_api_test.cc
new file mode 100644
index 0000000..fbb980b
--- /dev/null
+++ b/samples/rt/bytecode_module_api_test.cc
@@ -0,0 +1,188 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "absl/strings/str_replace.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+// C API:
+#include "base/api.h"
+#include "hal/api.h"
+#include "rt/api.h"
+#include "vm/api.h"
+
+// Only temporary, used for test device registration:
+#include "hal/driver_registry.h"
+
+// Compiled module embedded here to avoid file IO:
+#include "samples/rt/simple_module_test_bytecode_module.h"
+
+namespace iree {
+namespace rt {
+namespace samples {
+namespace {
+
+#define ASSERT_API_OK(expr) ASSERT_EQ(IREE_STATUS_OK, (expr))
+
+struct TestParams {
+ // HAL driver to use for the test.
+ std::string driver_name;
+};
+
+std::ostream& operator<<(std::ostream& os, const TestParams& params) {
+ return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}});
+}
+
+// Builds a list of tests to run based on the linked in driver modules.
+std::vector<TestParams> GetAvailableDriverTestParams() {
+ std::vector<TestParams> all_test_params;
+ for (const auto& driver_name :
+ hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) {
+ TestParams test_params;
+ test_params.driver_name = driver_name;
+ all_test_params.push_back(std::move(test_params));
+ }
+ return all_test_params;
+}
+
+class BytecodeModuleApiTest : public ::testing::Test,
+ public ::testing::WithParamInterface<TestParams> {
+ protected:
+};
+
+TEST_P(BytecodeModuleApiTest, RunOnce) {
+ iree_rt_instance_t* instance = nullptr;
+ ASSERT_API_OK(iree_rt_instance_create(IREE_ALLOCATOR_DEFAULT, &instance));
+
+ // TEMPORARY: until policies and placement are performed with manually
+ // register drivers via a magic function.
+ const auto& driver_name = GetParam().driver_name;
+ LOG(INFO) << "Creating driver '" << driver_name << "'...";
+ ASSERT_API_OK(iree_rt_instance_register_driver_ex(
+ instance, iree_string_view_t{driver_name.data(), driver_name.size()}));
+
+ // Allocate a context that will hold the module state across invocations.
+ iree_rt_policy_t* dummy_policy = nullptr;
+ ASSERT_API_OK(iree_rt_policy_create(IREE_ALLOCATOR_DEFAULT, &dummy_policy));
+ iree_rt_context_t* context = nullptr;
+ ASSERT_API_OK(iree_rt_context_create(instance, dummy_policy,
+ IREE_ALLOCATOR_DEFAULT, &context));
+ iree_rt_policy_release(dummy_policy);
+
+ // Load bytecode module from the embedded data.
+ LOG(INFO) << "Loading simple_module_test.mlir...";
+ const auto* module_file_toc = simple_module_test_bytecode_module_create();
+ iree_rt_module_t* bytecode_module = nullptr;
+ ASSERT_API_OK(iree_vm_bytecode_module_create_from_buffer(
+ iree_const_byte_span_t{
+ reinterpret_cast<const uint8_t*>(module_file_toc->data),
+ module_file_toc->size},
+ nullptr, nullptr, IREE_ALLOCATOR_DEFAULT, &bytecode_module));
+
+ // Register modules that we want to be able to use in the context.
+ std::vector<iree_rt_module_t*> modules;
+ modules.push_back(bytecode_module);
+ ASSERT_API_OK(
+ iree_rt_context_register_modules(context, &modules[0], modules.size()));
+ iree_rt_module_release(bytecode_module);
+ LOG(INFO) << "Module loaded and context is ready for use";
+
+ // Lookup the entry point function.
+ iree_rt_function_t main_function;
+ const char kMainFunctionName[] = "module.simple_mul";
+ ASSERT_API_OK(iree_rt_context_resolve_function(
+ context,
+ iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1},
+ &main_function));
+
+ // Allocate buffers that can be mapped on the CPU and that can also be used
+ // on the device. Not all devices support this, but the ones we have now do.
+ LOG(INFO) << "Creating I/O buffers...";
+ constexpr int kElementCount = 4;
+ iree_hal_buffer_t* arg0_buffer = nullptr;
+ iree_hal_buffer_t* arg1_buffer = nullptr;
+ ASSERT_API_OK(iree_rt_context_allocate_device_visible_buffer(
+ context, IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount,
+ IREE_ALLOCATOR_DEFAULT, &arg0_buffer));
+ ASSERT_API_OK(iree_rt_context_allocate_device_visible_buffer(
+ context, IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount,
+ IREE_ALLOCATOR_DEFAULT, &arg1_buffer));
+
+ // Populate initial values for 4 * 2 = 8.
+ float kFloat4 = 4.0f;
+ float kFloat2 = 2.0f;
+ ASSERT_API_OK(iree_hal_buffer_fill(arg0_buffer, 0, IREE_WHOLE_BUFFER,
+ &kFloat4, sizeof(float)));
+ ASSERT_API_OK(iree_hal_buffer_fill(arg1_buffer, 0, IREE_WHOLE_BUFFER,
+ &kFloat2, sizeof(float)));
+
+ // Wrap buffers in buffer views to provide shape information.
+ std::array<iree_hal_buffer_view_t*, 2> arg_buffer_views;
+ ASSERT_API_OK(iree_hal_buffer_view_create(
+ arg0_buffer, iree_shape_t{1, {kElementCount}}, sizeof(float),
+ IREE_ALLOCATOR_DEFAULT, &arg_buffer_views[0]));
+ ASSERT_API_OK(iree_hal_buffer_view_create(
+ arg1_buffer, iree_shape_t{1, {kElementCount}}, sizeof(float),
+ IREE_ALLOCATOR_DEFAULT, &arg_buffer_views[1]));
+ iree_hal_buffer_release(arg0_buffer);
+ iree_hal_buffer_release(arg1_buffer);
+
+ // Call into the @simple_mul function.
+ LOG(INFO) << "Calling @simple_mul...";
+ iree_rt_invocation_t* invocation = nullptr;
+ ASSERT_API_OK(iree_rt_invocation_create(
+ context, &main_function, nullptr, nullptr, arg_buffer_views.data(), 2,
+ nullptr, 0, IREE_ALLOCATOR_DEFAULT, &invocation));
+ ASSERT_API_OK(iree_hal_buffer_view_release(arg_buffer_views[0]));
+ ASSERT_API_OK(iree_hal_buffer_view_release(arg_buffer_views[1]));
+ ASSERT_API_OK(
+ iree_rt_invocation_await(invocation, IREE_TIME_INFINITE_FUTURE));
+
+ // Get the result buffers from the invocation.
+ LOG(INFO) << "Retreiving results...";
+ std::array<iree_hal_buffer_view_t*, 2> result_buffer_views;
+ iree_host_size_t result_count;
+ ASSERT_API_OK(iree_rt_invocation_consume_results(
+ invocation, result_buffer_views.size(), IREE_ALLOCATOR_DEFAULT,
+ result_buffer_views.data(), &result_count));
+ iree_rt_invocation_release(invocation);
+
+ // Read back the results and ensure we got the right values.
+ LOG(INFO) << "Reading back results...";
+ iree_hal_buffer_t* result_buffer =
+ iree_hal_buffer_view_buffer(result_buffer_views[0]);
+ iree_hal_mapped_memory_t mapped_memory;
+ ASSERT_API_OK(iree_hal_buffer_map(result_buffer, IREE_HAL_MEMORY_ACCESS_READ,
+ 0, IREE_WHOLE_BUFFER, &mapped_memory));
+ ASSERT_THAT(absl::Span<const float>(
+ reinterpret_cast<const float*>(mapped_memory.contents.data),
+ mapped_memory.contents.data_length / sizeof(float)),
+ ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f}));
+ ASSERT_API_OK(iree_hal_buffer_unmap(result_buffer, &mapped_memory));
+ LOG(INFO) << "Results match!";
+
+ iree_hal_buffer_view_release(result_buffer_views[0]);
+
+ iree_rt_context_release(context);
+ iree_rt_instance_release(instance);
+}
+
+INSTANTIATE_TEST_SUITE_P(AllDrivers, BytecodeModuleApiTest,
+ ::testing::ValuesIn(GetAvailableDriverTestParams()),
+ ::testing::PrintToStringParamName());
+
+} // namespace
+} // namespace samples
+} // namespace rt
+} // namespace iree
diff --git a/samples/rt/bytecode_module_test.cc b/samples/rt/bytecode_module_test.cc
new file mode 100644
index 0000000..67c2dc9
--- /dev/null
+++ b/samples/rt/bytecode_module_test.cc
@@ -0,0 +1,170 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// A simple sample demonstrating simple synchronous module loading and VM use.
+// This will load an IREE module containing a @simple_mul method that performs
+// an element-wise multiplication. It will invoke @simple_mul in the VM, once
+// for each available HAL driver linked into the binary.
+//
+// The synchronous invocation method (Context::Invoke) used here waits until all
+// asynchronous HAL work completes before returning. It's still possible get
+// overlapped execution by invoking methods from other threads with their own
+// FiberState, though it's best to use the asynchronous API instead.
+//
+// The `iree_module` build rule is used to translate the MLIR to the module
+// flatbuffer. Additional HAL backend target support can be defined there.
+
+#include "vm/bytecode_module.h"
+
+#include "absl/strings/str_replace.h"
+#include "base/flatbuffer_util.h"
+#include "base/status.h"
+#include "base/status_matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "hal/buffer_view.h"
+#include "hal/driver_registry.h"
+#include "rt/context.h"
+#include "rt/instance.h"
+#include "samples/rt/simple_module_test_bytecode_module.h"
+#include "schemas/module_def_generated.h"
+#include "vm/sequencer_module.h"
+
+namespace iree {
+namespace rt {
+namespace samples {
+namespace {
+
+using ::iree::hal::BufferView;
+using ::iree::vm::ModuleFile;
+
+struct TestParams {
+ // HAL driver to use for the test.
+ std::string driver_name;
+};
+
+std::ostream& operator<<(std::ostream& os, const TestParams& params) {
+ return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}});
+}
+
+// Builds a list of tests to run based on the linked in driver modules.
+std::vector<TestParams> GetAvailableDriverTestParams() {
+ std::vector<TestParams> all_test_params;
+ for (const auto& driver_name :
+ hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) {
+ TestParams test_params;
+ test_params.driver_name = driver_name;
+ all_test_params.push_back(std::move(test_params));
+ }
+ return all_test_params;
+}
+
+class BytecodeModuleTest : public ::testing::Test,
+ public ::testing::WithParamInterface<TestParams> {
+ protected:
+};
+
+TEST_P(BytecodeModuleTest, RunOnce) {
+ auto instance = make_ref<Instance>();
+
+ // Create driver for this test (based on params) and then get a default
+ // device.
+ const auto& test_params = GetParam();
+ LOG(INFO) << "Creating driver '" << test_params.driver_name << "'...";
+ auto driver_or =
+ hal::DriverRegistry::shared_registry()->Create(test_params.driver_name);
+ if (IsUnavailable(driver_or.status())) {
+ LOG(WARNING) << "Skipping test as driver is unavailable: "
+ << driver_or.status();
+ GTEST_SKIP();
+ return;
+ }
+ ASSERT_OK_AND_ASSIGN(auto driver, driver_or);
+ ASSERT_OK_AND_ASSIGN(auto available_devices,
+ driver->EnumerateAvailableDevices());
+ for (const auto& device_info : available_devices) {
+ LOG(INFO) << " Device: " << device_info.name();
+ }
+ LOG(INFO) << "Creating default device...";
+ ASSERT_OK_AND_ASSIGN(auto device, driver->CreateDefaultDevice());
+ ASSERT_OK(instance->device_manager()->RegisterDevice(device));
+ LOG(INFO) << "Successfully created device '" << device->info().name() << "'";
+
+ // Make a new context and load the precompiled module file (from
+ // simple_module_test.mlir) into it.
+ LOG(INFO) << "Loading simple_module_test.mlir...";
+ auto policy = make_ref<Policy>();
+ Context context(add_ref(instance), add_ref(policy));
+ const auto* module_file_toc = simple_module_test_bytecode_module_create();
+ ASSERT_OK_AND_ASSIGN(auto module_file,
+ vm::ModuleFile::WrapBuffer(
+ ModuleDefIdentifier(),
+ absl::MakeSpan(reinterpret_cast<const uint8_t*>(
+ module_file_toc->data),
+ module_file_toc->size)));
+ ASSERT_OK_AND_ASSIGN(auto main_module,
+ vm::SequencerModule::FromFile(std::move(module_file)));
+ ASSERT_OK(context.RegisterModule(std::move(main_module)));
+ LOG(INFO) << "Module loaded and context is ready for use";
+
+ // Allocate buffers that can be mapped on the CPU and that can also be used
+ // on the device. Not all devices support this, but the ones we have now do.
+ LOG(INFO) << "Creating I/O buffers...";
+ constexpr int kElementCount = 4;
+ ASSERT_OK_AND_ASSIGN(
+ auto arg0_buffer,
+ instance->device_manager()->AllocateDeviceVisibleBuffer(
+ hal::BufferUsage::kAll, sizeof(float) * kElementCount, {{device}}));
+ ASSERT_OK_AND_ASSIGN(
+ auto arg1_buffer,
+ instance->device_manager()->AllocateDeviceVisibleBuffer(
+ hal::BufferUsage::kAll, sizeof(float) * kElementCount, {{device}}));
+
+ // Populate initial values for 4 * 2 = 8.
+ ASSERT_OK(arg0_buffer->Fill32(4.0f));
+ ASSERT_OK(arg1_buffer->Fill32(2.0f));
+
+ // Call into the @simple_mul function.
+ LOG(INFO) << "Calling @simple_mul...";
+ absl::InlinedVector<BufferView, 8> args{
+ BufferView{add_ref(arg0_buffer), {kElementCount}, sizeof(float)},
+ BufferView{add_ref(arg1_buffer), {kElementCount}, sizeof(float)},
+ };
+ ASSERT_OK_AND_ASSIGN(auto simple_mul,
+ context.ResolveFunction("module.simple_mul"));
+ ASSERT_OK_AND_ASSIGN(auto invocation,
+ Invocation::Create(add_ref(&context), simple_mul,
+ nullptr, {}, std::move(args)));
+ ASSERT_OK(invocation->Await(absl::InfiniteFuture()));
+ ASSERT_OK_AND_ASSIGN(auto results, invocation->ConsumeResults());
+
+ // Read back the results and ensure we got the right values.
+ LOG(INFO) << "Reading back results...";
+ auto& ret_buffer_view = results[0];
+ ASSERT_OK_AND_ASSIGN(
+ auto ret_mapping,
+ ret_buffer_view.buffer->MapMemory<float>(hal::MemoryAccess::kRead));
+ ASSERT_THAT(ret_mapping.contents(),
+ ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f}));
+ LOG(INFO) << "Results match!";
+}
+
+INSTANTIATE_TEST_SUITE_P(AllDrivers, BytecodeModuleTest,
+ ::testing::ValuesIn(GetAvailableDriverTestParams()),
+ ::testing::PrintToStringParamName());
+
+} // namespace
+} // namespace samples
+} // namespace rt
+} // namespace iree
diff --git a/iree/samples/rt/simple_module_test.mlir b/samples/rt/simple_module_test.mlir
similarity index 100%
rename from iree/samples/rt/simple_module_test.mlir
rename to samples/rt/simple_module_test.mlir
diff --git a/schemas/BUILD b/schemas/BUILD
new file mode 100644
index 0000000..cbcbfc0
--- /dev/null
+++ b/schemas/BUILD
@@ -0,0 +1,252 @@
+load("//:build_defs.google.bzl", "FLATBUFFER_SUPPORTS_REFLECTIONS", "iree_build_test", "iree_flatbuffer_cc_library")
+load("///build_tools/embed_data:build_defs.bzl", "cc_embed_data")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+FLATC_ARGS = [
+ # Preserve workspace-relative include paths in generated code.
+ "--keep-prefix",
+ # Use C++11 'enum class' for enums.
+ "--scoped-enums",
+ # Include reflection tables used for dumping debug representations.
+ "--reflect-names",
+ # Generate FooT types for unpack/pack support. Note that this should only
+ # be used in tooling as the code size/runtime overhead is non-trivial.
+ "--gen-object-api",
+]
+
+iree_flatbuffer_cc_library(
+ name = "archive_def_cc_fbs",
+ srcs = ["archive_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ":bytecode_def_cc_fbs_includes",
+ ":device_def_cc_fbs_includes",
+ ":device_group_def_cc_fbs_includes",
+ ":device_table_def_cc_fbs_includes",
+ ":executable_def_cc_fbs_includes",
+ ":executable_table_def_cc_fbs_includes",
+ ":function_def_cc_fbs_includes",
+ ":function_table_def_cc_fbs_includes",
+ ":module_def_cc_fbs_includes",
+ ":source_map_def_cc_fbs_includes",
+ ":type_def_cc_fbs_includes",
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "bytecode_def_cc_fbs",
+ srcs = ["bytecode_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "debug_service_cc_fbs",
+ srcs = ["debug_service.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ":bytecode_def_cc_fbs_includes",
+ ":device_def_cc_fbs_includes",
+ ":device_group_def_cc_fbs_includes",
+ ":device_table_def_cc_fbs_includes",
+ ":executable_def_cc_fbs_includes",
+ ":executable_table_def_cc_fbs_includes",
+ ":function_def_cc_fbs_includes",
+ ":function_table_def_cc_fbs_includes",
+ ":module_def_cc_fbs_includes",
+ ":source_map_def_cc_fbs_includes",
+ ":type_def_cc_fbs_includes",
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "device_def_cc_fbs",
+ srcs = ["device_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "device_group_def_cc_fbs",
+ srcs = ["device_group_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "device_table_def_cc_fbs",
+ srcs = ["device_table_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ":device_def_cc_fbs_includes",
+ ":device_group_def_cc_fbs_includes",
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "executable_def_cc_fbs",
+ srcs = ["executable_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "executable_table_def_cc_fbs",
+ srcs = ["executable_table_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ":executable_def_cc_fbs_includes",
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "function_def_cc_fbs",
+ srcs = ["function_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ":bytecode_def_cc_fbs_includes",
+ ":type_def_cc_fbs_includes",
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "function_table_def_cc_fbs",
+ srcs = ["function_table_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ":bytecode_def_cc_fbs_includes",
+ ":function_def_cc_fbs_includes",
+ ":type_def_cc_fbs_includes",
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "module_def_cc_fbs",
+ srcs = ["module_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ":bytecode_def_cc_fbs_includes",
+ ":device_def_cc_fbs_includes",
+ ":device_group_def_cc_fbs_includes",
+ ":device_table_def_cc_fbs_includes",
+ ":executable_def_cc_fbs_includes",
+ ":executable_table_def_cc_fbs_includes",
+ ":function_def_cc_fbs_includes",
+ ":function_table_def_cc_fbs_includes",
+ ":source_map_def_cc_fbs_includes",
+ ":type_def_cc_fbs_includes",
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "source_map_def_cc_fbs",
+ srcs = ["source_map_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "spirv_executable_def_cc_fbs",
+ srcs = ["spirv_executable_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ],
+)
+
+iree_flatbuffer_cc_library(
+ name = "type_def_cc_fbs",
+ srcs = ["type_def.fbs"],
+ flatc_args = FLATC_ARGS,
+ includes = [
+ ],
+)
+
+iree_build_test(
+ name = "schema_build_test",
+ targets = [
+ ":archive_def_cc_fbs",
+ ":bytecode_def_cc_fbs",
+ ":debug_service_cc_fbs",
+ ":device_def_cc_fbs",
+ ":device_group_def_cc_fbs",
+ ":device_table_def_cc_fbs",
+ ":executable_def_cc_fbs",
+ ":executable_table_def_cc_fbs",
+ ":function_def_cc_fbs",
+ ":function_table_def_cc_fbs",
+ ":module_def_cc_fbs",
+ ":source_map_def_cc_fbs",
+ ":spirv_executable_def_cc_fbs",
+ ":type_def_cc_fbs",
+ ],
+)
+
+cc_library(
+ name = "schemas",
+ hdrs = [
+ ":archive_def_generated.h",
+ ":bytecode_def_generated.h",
+ ":debug_service_generated.h",
+ ":device_def_generated.h",
+ ":device_group_def_generated.h",
+ ":device_table_def_generated.h",
+ ":executable_def_generated.h",
+ ":executable_table_def_generated.h",
+ ":function_def_generated.h",
+ ":function_table_def_generated.h",
+ ":module_def_generated.h",
+ ":source_map_def_generated.h",
+ ":type_def_generated.h",
+ ],
+ deps = [
+ ":archive_def_cc_fbs",
+ ":bytecode_def_cc_fbs",
+ ":debug_service_cc_fbs",
+ ":device_def_cc_fbs",
+ ":device_group_def_cc_fbs",
+ ":device_table_def_cc_fbs",
+ ":executable_def_cc_fbs",
+ ":executable_table_def_cc_fbs",
+ ":function_def_cc_fbs",
+ ":function_table_def_cc_fbs",
+ ":module_def_cc_fbs",
+ ":source_map_def_cc_fbs",
+ ":spirv_executable_def_cc_fbs",
+ ":type_def_cc_fbs",
+ "@com_github_google_flatbuffers//:flatbuffers",
+ ],
+)
+
+REFLECTION_SRCS = [] if not FLATBUFFER_SUPPORTS_REFLECTIONS else [
+ "archive_def.bfbs",
+ "bytecode_def.bfbs",
+ "debug_service.bfbs",
+ "executable_def.bfbs",
+ "executable_table_def.bfbs",
+ "function_def.bfbs",
+ "function_table_def.bfbs",
+ "module_def.bfbs",
+ "source_map_def.bfbs",
+ "spirv_executable_def.bfbs",
+ "type_def.bfbs",
+ "device_def.bfbs",
+ "device_group_def.bfbs",
+ "device_table_def.bfbs",
+]
+
+cc_embed_data(
+ name = "reflection_data",
+ srcs = REFLECTION_SRCS,
+ cc_file_output = "reflection_data.cc",
+ cpp_namespace = "iree::schemas",
+ h_file_output = "reflection_data.h",
+)
diff --git a/iree/schemas/CMakeLists.txt b/schemas/CMakeLists.txt
similarity index 100%
rename from iree/schemas/CMakeLists.txt
rename to schemas/CMakeLists.txt
diff --git a/schemas/archive_def.fbs b/schemas/archive_def.fbs
new file mode 100644
index 0000000..114415c
--- /dev/null
+++ b/schemas/archive_def.fbs
@@ -0,0 +1,14 @@
+include "schemas/module_def.fbs";
+
+namespace iree;
+
+// 'Executable ARChive'.
+file_identifier "EARC";
+file_extension "earc";
+
+table ArchiveDef {
+ name:string;
+ modules:[ModuleDef];
+}
+
+root_type ArchiveDef;
diff --git a/schemas/bytecode/BUILD b/schemas/bytecode/BUILD
new file mode 100644
index 0000000..7b7dcdd
--- /dev/null
+++ b/schemas/bytecode/BUILD
@@ -0,0 +1,29 @@
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "bytecode_v0",
+ hdrs = ["bytecode_v0.h"],
+ deps = [
+ "///base:bitfield",
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
+cc_library(
+ name = "interpreter_bytecode_v0",
+ hdrs = ["interpreter_bytecode_v0.h"],
+ deps = [
+ ":bytecode_v0",
+ ],
+)
+
+cc_library(
+ name = "sequencer_bytecode_v0",
+ hdrs = ["sequencer_bytecode_v0.h"],
+ deps = [
+ ":bytecode_v0",
+ ],
+)
diff --git a/iree/schemas/bytecode/CMakeLists.txt b/schemas/bytecode/CMakeLists.txt
similarity index 100%
rename from iree/schemas/bytecode/CMakeLists.txt
rename to schemas/bytecode/CMakeLists.txt
diff --git a/schemas/bytecode/bytecode_v0.h b/schemas/bytecode/bytecode_v0.h
new file mode 100644
index 0000000..c9e6398
--- /dev/null
+++ b/schemas/bytecode/bytecode_v0.h
@@ -0,0 +1,143 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Opcode table for the V0 binary format.
+// Additions are fine but changing the behavior or order of any opcodes will
+// break pasring of existing files.
+//
+// Opcodes have been selected on frequency of use, general applicability, and
+// relative stability. Experimental ops should be implemented via the FFI fisrt
+// before graduating into the core set. Ops that may only be present on certain
+// targets should also be kept as imports via the FFI.
+//
+// Opcodes may be specified for particular types (int32_t), categories of types
+// (all floating-point types), or implicit types (output matches input). Saving
+// opcode space by sharing a single opcode for multiple types is preferred
+// except where hot operations are performed (for example, comparison used in
+// loop iteratosr).
+
+#ifndef IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_
+#define IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_
+
+#include <cstdint>
+
+#include "base/bitfield.h"
+
+namespace iree {
+
+#define IREE_CONSTANT_ENCODING_LIST(ENC) \
+ ENC(0x00, kDense, "dense") \
+ ENC(0x01, kSplat, "splat")
+
+#define IREE_TYPE_LIST(TYP) \
+ TYP(0x00, kI8, "i8", 1) \
+ TYP(0x01, kI16, "i16", 2) \
+ TYP(0x02, kI32, "i32", 4) \
+ TYP(0x03, kI64, "i64", 8) \
+ TYP(0x04, kF16, "f16", 2) \
+ TYP(0x05, kF32, "f32", 4) \
+ TYP(0x06, kF64, "f64", 8) \
+ TYP(0x80, kDevice, "device", 0) \
+ TYP(0x81, kCommandBuffer, "command_buffer", 0) \
+ TYP(0x82, kEvent, "event", 0) \
+ TYP(0x83, kSemaphore, "semaphore", 0) \
+ TYP(0x84, kFence, "fence", 0) \
+ TYP(0xFF, kOpaque, "opaque", 0)
+
+#define IREE_CMPI_PREDICATE_LIST(PRED) \
+ PRED(0, kEq, "eq") \
+ PRED(1, kNe, "ne") \
+ PRED(2, kSlt, "slt") \
+ PRED(3, kSle, "sle") \
+ PRED(4, kSgt, "sgt") \
+ PRED(5, kSge, "sge") \
+ PRED(6, kUlt, "ult") \
+ PRED(7, kUle, "ule") \
+ PRED(8, kUgt, "ugt") \
+ PRED(9, kUge, "uge")
+
+#define IREE_CMPF_PREDICATE_LIST(PRED) \
+ PRED(0, kFalse, "false") \
+ PRED(1, kOeq, "oeq") \
+ PRED(2, kOgt, "ogt") \
+ PRED(3, kOge, "oge") \
+ PRED(4, kOlt, "olt") \
+ PRED(5, kOle, "ole") \
+ PRED(6, kOne, "one") \
+ PRED(7, kOrd, "ord") \
+ PRED(8, kUeq, "ueq") \
+ PRED(9, kUgt, "ugt") \
+ PRED(10, kUge, "uge") \
+ PRED(11, kUlt, "ult") \
+ PRED(12, kUle, "ule") \
+ PRED(13, kUne, "une") \
+ PRED(14, kUno, "uno") \
+ PRED(15, kTrue, "true")
+
+// NOTE: FF is a to-be-defined flag value for encoding/decoding.
+#define FLAG(V) ::iree::OpcodeFlag::V
+
+#define RSV(opcode, RESERVED_OPC) \
+ RESERVED_OPC(opcode, kReserved##opcode, "rsv." #opcode, FLAG(kDefault), "", \
+ FF)
+
+#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal,
+
+enum class ConstantEncoding : uint8_t {
+ IREE_CONSTANT_ENCODING_LIST(DECLARE_ENUM)
+};
+
+enum class BuiltinType : uint8_t { IREE_TYPE_LIST(DECLARE_ENUM) };
+
+enum class CmpIPredicate : uint8_t { IREE_CMPI_PREDICATE_LIST(DECLARE_ENUM) };
+
+enum class CmpFPredicate : uint8_t { IREE_CMPF_PREDICATE_LIST(DECLARE_ENUM) };
+
+#undef DECLARE_ENUM
+
+static constexpr uint8_t kBuiltinTypeCount =
+ static_cast<uint8_t>(BuiltinType::kF64) + 1;
+
+enum class OpcodeFlag : uint8_t {
+ kDefault = 0,
+};
+IREE_BITFIELD(OpcodeFlag);
+using OpcodeFlagBitfield = OpcodeFlag;
+
+enum class OperandEncoding : char {
+ kNone = '\0',
+ kInputSlot = 's',
+ kVariadicInputSlots = 'S',
+ kOutputSlot = 'o',
+ kVariadicOutputSlots = 'O',
+ kResultSlot = 'r',
+ kVariadicResultSlots = 'R',
+ kVariadicTransferSlots = 'T',
+ kConstant = 'c',
+ kFunctionOrdinal = 'f',
+ kImportOrdinal = 'F',
+ kDispatchOrdinal = 'd',
+ kBlockOffset = 'b',
+ kTypeIndex = 't',
+ kIndex = 'i',
+ kIndexList = 'I',
+ kCmpIPredicate = 'p',
+ kCmpFPredicate = 'P',
+};
+IREE_BITFIELD(OperandEncoding);
+using OperandEncodingBitfield = OperandEncoding;
+
+} // namespace iree
+
+#endif // IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_
diff --git a/schemas/bytecode/interpreter_bytecode_v0.h b/schemas/bytecode/interpreter_bytecode_v0.h
new file mode 100644
index 0000000..9d21d32
--- /dev/null
+++ b/schemas/bytecode/interpreter_bytecode_v0.h
@@ -0,0 +1,325 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Opcode table for the V0 binary format.
+// Additions are fine but changing the behavior or order of any opcodes will
+// break parsing of existing files.
+//
+// Opcodes have been selected on frequency of use, general applicability, and
+// relative stability. Experimental ops should be implemented via the Foreign
+// Function Interface (FFI) first before graduating into the core set. Ops that
+// may only be present on certain targets should also be kept as imports via the
+// FFI.
+//
+// Opcodes may be specified for particular types (int32_t), categories of types
+// (all floating-point types), or implicit types (output matches input). Saving
+// opcode space by sharing a single opcode for multiple types is preferred
+// except where hot operations are performed (for example, comparison used in
+// loop iterators).
+
+#ifndef IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_
+#define IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_
+
+#include "schemas/bytecode/bytecode_v0.h"
+
+namespace iree {
+
+#define IREE_INTERPRETER_OPCODE_LIST(OPC, RESERVED_OPC) \
+ OPC(0x00, kConstant, "constant", FLAG(kDefault), "cr", FF) \
+ \
+ OPC(0x01, kCall, "call", FLAG(kDefault), "fSR", FF) \
+ OPC(0x02, kCallImport, "call_import", FLAG(kDefault), "FSR", FF) \
+ OPC(0x03, kCallIndirect, "call_indirect", FLAG(kDefault), "tsSR", FF) \
+ OPC(0x04, kReturn, "return", FLAG(kDefault), "S", FF) \
+ OPC(0x05, kBranch, "br", FLAG(kDefault), "bT", FF) \
+ OPC(0x06, kCondBranch, "cond_br", FLAG(kDefault), "sbTbT", FF) \
+ OPC(0x07, kCmpI, "cmp_i", FLAG(kDefault), "psso", FF) \
+ OPC(0x08, kCmpF, "cmp_f", FLAG(kDefault), "Psso", FF) \
+ \
+ RSV(0x09, RESERVED_OPC) \
+ RSV(0x0A, RESERVED_OPC) \
+ RSV(0x0B, RESERVED_OPC) \
+ RSV(0x0C, RESERVED_OPC) \
+ RSV(0x0D, RESERVED_OPC) \
+ RSV(0x0E, RESERVED_OPC) \
+ RSV(0x0F, RESERVED_OPC) \
+ RSV(0x10, RESERVED_OPC) \
+ RSV(0x11, RESERVED_OPC) \
+ RSV(0x12, RESERVED_OPC) \
+ RSV(0x13, RESERVED_OPC) \
+ RSV(0x14, RESERVED_OPC) \
+ RSV(0x15, RESERVED_OPC) \
+ RSV(0x16, RESERVED_OPC) \
+ RSV(0x17, RESERVED_OPC) \
+ RSV(0x18, RESERVED_OPC) \
+ RSV(0x19, RESERVED_OPC) \
+ RSV(0x1A, RESERVED_OPC) \
+ RSV(0x1B, RESERVED_OPC) \
+ RSV(0x1C, RESERVED_OPC) \
+ RSV(0x1D, RESERVED_OPC) \
+ RSV(0x1E, RESERVED_OPC) \
+ RSV(0x1F, RESERVED_OPC) \
+ \
+ OPC(0x20, kAllocStatic, "alloc_static", FLAG(kDefault), "Icr", FF) \
+ OPC(0x21, kAllocStack, "alloc_stack", FLAG(kDefault), "itISr", FF) \
+ OPC(0x22, kAllocStackInit, "alloc_stack_init", FLAG(kDefault), "tIScr", FF) \
+ OPC(0x23, kAllocHeap, "alloc_heap", FLAG(kDefault), "itISr", FF) \
+ OPC(0x24, kDiscard, "discard", FLAG(kDefault), "s", FF) \
+ \
+ RSV(0x25, RESERVED_OPC) \
+ RSV(0x26, RESERVED_OPC) \
+ RSV(0x27, RESERVED_OPC) \
+ RSV(0x28, RESERVED_OPC) \
+ RSV(0x29, RESERVED_OPC) \
+ RSV(0x2A, RESERVED_OPC) \
+ RSV(0x2B, RESERVED_OPC) \
+ RSV(0x2C, RESERVED_OPC) \
+ RSV(0x2D, RESERVED_OPC) \
+ RSV(0x2E, RESERVED_OPC) \
+ RSV(0x2F, RESERVED_OPC) \
+ \
+ OPC(0x30, kRank, "rank", FLAG(kDefault), "so", FF) \
+ OPC(0x31, kDim, "dim", FLAG(kDefault), "iso", FF) \
+ OPC(0x32, kShape, "shape", FLAG(kDefault), "so", FF) \
+ OPC(0x33, kLength, "length", FLAG(kDefault), "so", FF) \
+ OPC(0x34, kDynamicSlice, "dynamic_slice", FLAG(kDefault), "sssr", FF) \
+ OPC(0x35, kStaticSlice, "static_slice", FLAG(kDefault), "sIIr", FF) \
+ OPC(0x36, kDynamicCopy, "dynamic_copy", FLAG(kDefault), "ssoss", FF) \
+ OPC(0x37, kStaticCopy, "static_copy", FLAG(kDefault), "sIoII", FF) \
+ OPC(0x38, kClone, "clone", FLAG(kDefault), "sr", FF) \
+ RSV(0x39, RESERVED_OPC) \
+ OPC(0x3A, kSplit, "split", FLAG(kDefault), "isR", FF) \
+ OPC(0x3B, kAssign, "assign", FLAG(kDefault), "sr", FF) \
+ OPC(0x3C, kCondAssign, "cond_assign", FLAG(kDefault), "sssr", FF) \
+ OPC(0x3D, kReshape, "reshape", FLAG(kDefault), "ssr", FF) \
+ OPC(0x3E, kSelect, "select", FLAG(kDefault), "ssso", FF) \
+ OPC(0x3F, kTranspose, "transpose", FLAG(kDefault), "sso", FF) \
+ OPC(0x40, kBroadcast, "broadcast", FLAG(kDefault), "sso", FF) \
+ OPC(0x41, kTile, "tile", FLAG(kDefault), "sso", FF) \
+ OPC(0x42, kReverse, "reverse", FLAG(kDefault), "sso", FF) \
+ OPC(0x43, kPad, "pad", FLAG(kDefault), "ssssso", FF) \
+ \
+ RSV(0x44, RESERVED_OPC) \
+ RSV(0x45, RESERVED_OPC) \
+ RSV(0x46, RESERVED_OPC) \
+ RSV(0x47, RESERVED_OPC) \
+ RSV(0x48, RESERVED_OPC) \
+ RSV(0x49, RESERVED_OPC) \
+ RSV(0x4A, RESERVED_OPC) \
+ RSV(0x4B, RESERVED_OPC) \
+ RSV(0x4C, RESERVED_OPC) \
+ RSV(0x4D, RESERVED_OPC) \
+ RSV(0x4E, RESERVED_OPC) \
+ RSV(0x4F, RESERVED_OPC) \
+ \
+ OPC(0x50, kNot, "not", FLAG(kDefault), "so", FF) \
+ OPC(0x51, kAnd, "and", FLAG(kDefault), "sso", FF) \
+ OPC(0x52, kOr, "or", FLAG(kDefault), "sso", FF) \
+ OPC(0x53, kXor, "xor", FLAG(kDefault), "sso", FF) \
+ OPC(0x54, kShiftLeft, "sll", FLAG(kDefault), "sso", FF) \
+ OPC(0x55, kShiftRightLogical, "srl", FLAG(kDefault), "sso", FF) \
+ OPC(0x56, kShiftRightArithmetic, "sra", FLAG(kDefault), "sso", FF) \
+ \
+ RSV(0x57, RESERVED_OPC) \
+ RSV(0x58, RESERVED_OPC) \
+ RSV(0x59, RESERVED_OPC) \
+ RSV(0x5A, RESERVED_OPC) \
+ RSV(0x5B, RESERVED_OPC) \
+ RSV(0x5C, RESERVED_OPC) \
+ RSV(0x5D, RESERVED_OPC) \
+ RSV(0x5E, RESERVED_OPC) \
+ RSV(0x5F, RESERVED_OPC) \
+ RSV(0x60, RESERVED_OPC) \
+ RSV(0x61, RESERVED_OPC) \
+ RSV(0x62, RESERVED_OPC) \
+ RSV(0x63, RESERVED_OPC) \
+ RSV(0x64, RESERVED_OPC) \
+ RSV(0x65, RESERVED_OPC) \
+ RSV(0x66, RESERVED_OPC) \
+ RSV(0x67, RESERVED_OPC) \
+ RSV(0x68, RESERVED_OPC) \
+ RSV(0x69, RESERVED_OPC) \
+ RSV(0x6A, RESERVED_OPC) \
+ RSV(0x6B, RESERVED_OPC) \
+ RSV(0x6C, RESERVED_OPC) \
+ RSV(0x6D, RESERVED_OPC) \
+ RSV(0x6E, RESERVED_OPC) \
+ RSV(0x6F, RESERVED_OPC) \
+ \
+ /* TODO(benvanik): remove ones we don't need/can emulate */ \
+ OPC(0x70, kAddI, "add_i", FLAG(kDefault), "sso", FF) \
+ OPC(0x71, kAddF, "add_f", FLAG(kDefault), "sso", FF) \
+ OPC(0x72, kSubI, "sub_i", FLAG(kDefault), "sso", FF) \
+ OPC(0x73, kSubF, "sub_f", FLAG(kDefault), "sso", FF) \
+ OPC(0x74, kAbsI, "abs_i", FLAG(kDefault), "so", FF) \
+ OPC(0x75, kAbsF, "abs_f", FLAG(kDefault), "so", FF) \
+ OPC(0x76, kMulI, "mul_i", FLAG(kDefault), "sso", FF) \
+ OPC(0x77, kMulF, "mul_f", FLAG(kDefault), "sso", FF) \
+ OPC(0x78, kDivIS, "div_i_s", FLAG(kDefault), "sso", FF) \
+ OPC(0x79, kDivIU, "div_i_u", FLAG(kDefault), "sso", FF) \
+ OPC(0x7A, kDivF, "div_f", FLAG(kDefault), "sso", FF) \
+ OPC(0x7B, kMulAddI, "madd_i", FLAG(kDefault), "ssso", FF) \
+ OPC(0x7C, kMulAddF, "madd_f", FLAG(kDefault), "ssso", FF) \
+ OPC(0x7D, kCosF, "cos_f", FLAG(kDefault), "so", FF) \
+ OPC(0x7E, kSinF, "sin_f", FLAG(kDefault), "so", FF) \
+ OPC(0x7F, kTanhF, "tanh_f", FLAG(kDefault), "so", FF) \
+ OPC(0x80, kAtan2F, "atan2_f", FLAG(kDefault), "sso", FF) \
+ OPC(0x81, kExpF, "exp_f", FLAG(kDefault), "so", FF) \
+ OPC(0x82, kLogF, "log_f", FLAG(kDefault), "so", FF) \
+ OPC(0x83, kRsqrtF, "rsqrt_f", FLAG(kDefault), "so", FF) \
+ \
+ RSV(0x84, RESERVED_OPC) \
+ RSV(0x85, RESERVED_OPC) \
+ RSV(0x86, RESERVED_OPC) \
+ RSV(0x87, RESERVED_OPC) \
+ RSV(0x88, RESERVED_OPC) \
+ RSV(0x89, RESERVED_OPC) \
+ RSV(0x8A, RESERVED_OPC) \
+ RSV(0x8B, RESERVED_OPC) \
+ RSV(0x8C, RESERVED_OPC) \
+ RSV(0x8D, RESERVED_OPC) \
+ RSV(0x8E, RESERVED_OPC) \
+ RSV(0x8F, RESERVED_OPC) \
+ \
+ OPC(0x90, kMinIS, "min_i_s", FLAG(kDefault), "sso", FF) \
+ OPC(0x91, kMinIU, "min_i_u", FLAG(kDefault), "sso", FF) \
+ OPC(0x92, kMinF, "min_f", FLAG(kDefault), "sso", FF) \
+ OPC(0x93, kMaxIS, "max_i_s", FLAG(kDefault), "sso", FF) \
+ OPC(0x94, kMaxIU, "max_i_u", FLAG(kDefault), "sso", FF) \
+ OPC(0x95, kMaxF, "max_f", FLAG(kDefault), "sso", FF) \
+ OPC(0x96, kClampIS, "clamp_i_s", FLAG(kDefault), "ssso", FF) \
+ OPC(0x97, kClampIU, "clamp_i_u", FLAG(kDefault), "ssso", FF) \
+ OPC(0x98, kClampF, "clamp_f", FLAG(kDefault), "ssso", FF) \
+ OPC(0x99, kFloorF, "floor_f", FLAG(kDefault), "so", FF) \
+ OPC(0x9A, kCeilF, "ceil_f", FLAG(kDefault), "so", FF) \
+ \
+ OPC(0x9B, kConvertSS, "convert_s_s", FLAG(kDefault), "tsto", FF) \
+ OPC(0x9C, kConvertUU, "convert_u_u", FLAG(kDefault), "tsto", FF) \
+ OPC(0x9D, kConvertSU, "convert_s_u", FLAG(kDefault), "tsto", FF) \
+ OPC(0x9E, kConvertUS, "convert_u_s", FLAG(kDefault), "tsto", FF) \
+ \
+ RSV(0x9F, RESERVED_OPC) \
+ \
+ /* TODO(benvanik): reduction/sum/etc */ \
+ /* TODO(benvanik): sort */ \
+ \
+ OPC(0xA0, kMatMulI, "matmul_i", FLAG(kDefault), "sssso", FF) \
+ OPC(0xA1, kMatMulF, "matmul_f", FLAG(kDefault), "sso", FF) \
+ /* TODO(benvanik): convolution */ \
+ \
+ OPC(0xA2, kReduceSumI, "reduce_sum_i", FLAG(kDefault), "ssio", FF) \
+ OPC(0xA3, kReduceSumF, "reduce_sum_f", FLAG(kDefault), "ssio", FF) \
+ OPC(0xA4, kReduceMinI, "reduce_min_i", FLAG(kDefault), "ssio", FF) \
+ OPC(0xA5, kReduceMinF, "reduce_min_f", FLAG(kDefault), "ssio", FF) \
+ OPC(0xA6, kReduceMaxI, "reduce_max_i", FLAG(kDefault), "ssio", FF) \
+ OPC(0xA7, kReduceMaxF, "reduce_max_f", FLAG(kDefault), "ssio", FF) \
+ RSV(0xA8, RESERVED_OPC) \
+ RSV(0xA9, RESERVED_OPC) \
+ RSV(0xAA, RESERVED_OPC) \
+ RSV(0xAB, RESERVED_OPC) \
+ RSV(0xAC, RESERVED_OPC) \
+ RSV(0xAD, RESERVED_OPC) \
+ RSV(0xAE, RESERVED_OPC) \
+ RSV(0xAF, RESERVED_OPC) \
+ RSV(0xB0, RESERVED_OPC) \
+ RSV(0xB1, RESERVED_OPC) \
+ RSV(0xB2, RESERVED_OPC) \
+ RSV(0xB3, RESERVED_OPC) \
+ RSV(0xB4, RESERVED_OPC) \
+ RSV(0xB5, RESERVED_OPC) \
+ RSV(0xB6, RESERVED_OPC) \
+ RSV(0xB7, RESERVED_OPC) \
+ RSV(0xB8, RESERVED_OPC) \
+ RSV(0xB9, RESERVED_OPC) \
+ RSV(0xBA, RESERVED_OPC) \
+ RSV(0xBB, RESERVED_OPC) \
+ RSV(0xBC, RESERVED_OPC) \
+ RSV(0xBD, RESERVED_OPC) \
+ RSV(0xBE, RESERVED_OPC) \
+ RSV(0xBF, RESERVED_OPC) \
+ RSV(0xC0, RESERVED_OPC) \
+ RSV(0xC1, RESERVED_OPC) \
+ RSV(0xC2, RESERVED_OPC) \
+ RSV(0xC3, RESERVED_OPC) \
+ RSV(0xC4, RESERVED_OPC) \
+ RSV(0xC5, RESERVED_OPC) \
+ RSV(0xC6, RESERVED_OPC) \
+ RSV(0xC7, RESERVED_OPC) \
+ RSV(0xC8, RESERVED_OPC) \
+ RSV(0xC9, RESERVED_OPC) \
+ RSV(0xCA, RESERVED_OPC) \
+ RSV(0xCB, RESERVED_OPC) \
+ RSV(0xCC, RESERVED_OPC) \
+ RSV(0xCD, RESERVED_OPC) \
+ RSV(0xCE, RESERVED_OPC) \
+ RSV(0xCF, RESERVED_OPC) \
+ RSV(0xD0, RESERVED_OPC) \
+ RSV(0xD1, RESERVED_OPC) \
+ RSV(0xD2, RESERVED_OPC) \
+ RSV(0xD3, RESERVED_OPC) \
+ RSV(0xD4, RESERVED_OPC) \
+ RSV(0xD5, RESERVED_OPC) \
+ RSV(0xD6, RESERVED_OPC) \
+ RSV(0xD7, RESERVED_OPC) \
+ RSV(0xD8, RESERVED_OPC) \
+ RSV(0xD9, RESERVED_OPC) \
+ RSV(0xDA, RESERVED_OPC) \
+ RSV(0xDB, RESERVED_OPC) \
+ RSV(0xDC, RESERVED_OPC) \
+ RSV(0xDD, RESERVED_OPC) \
+ RSV(0xDE, RESERVED_OPC) \
+ RSV(0xDF, RESERVED_OPC) \
+ RSV(0xE0, RESERVED_OPC) \
+ RSV(0xE1, RESERVED_OPC) \
+ RSV(0xE2, RESERVED_OPC) \
+ RSV(0xE3, RESERVED_OPC) \
+ RSV(0xE4, RESERVED_OPC) \
+ RSV(0xE5, RESERVED_OPC) \
+ RSV(0xE6, RESERVED_OPC) \
+ RSV(0xE7, RESERVED_OPC) \
+ RSV(0xE8, RESERVED_OPC) \
+ RSV(0xE9, RESERVED_OPC) \
+ RSV(0xEA, RESERVED_OPC) \
+ RSV(0xEB, RESERVED_OPC) \
+ RSV(0xEC, RESERVED_OPC) \
+ RSV(0xED, RESERVED_OPC) \
+ RSV(0xEE, RESERVED_OPC) \
+ RSV(0xEF, RESERVED_OPC) \
+ RSV(0xF0, RESERVED_OPC) \
+ RSV(0xF1, RESERVED_OPC) \
+ RSV(0xF2, RESERVED_OPC) \
+ RSV(0xF3, RESERVED_OPC) \
+ RSV(0xF4, RESERVED_OPC) \
+ RSV(0xF5, RESERVED_OPC) \
+ RSV(0xF6, RESERVED_OPC) \
+ RSV(0xF7, RESERVED_OPC) \
+ RSV(0xF8, RESERVED_OPC) \
+ RSV(0xF9, RESERVED_OPC) \
+ RSV(0xFA, RESERVED_OPC) \
+ RSV(0xFB, RESERVED_OPC) \
+ RSV(0xFC, RESERVED_OPC) \
+ \
+ OPC(0xFD, kTrace, "trace", FLAG(kDefault), "s", FF) \
+ OPC(0xFE, kCondBreak, "cond_break", FLAG(kDefault), "s", FF) \
+ OPC(0xFF, kBreak, "break", FLAG(kDefault), "", FF)
+
+#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal,
+enum class InterpreterOpcode : uint8_t {
+ IREE_INTERPRETER_OPCODE_LIST(DECLARE_ENUM, DECLARE_ENUM)
+};
+#undef DECLARE_ENUM
+
+} // namespace iree
+
+#endif // IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_
diff --git a/schemas/bytecode/sequencer_bytecode_v0.h b/schemas/bytecode/sequencer_bytecode_v0.h
new file mode 100644
index 0000000..a199011
--- /dev/null
+++ b/schemas/bytecode/sequencer_bytecode_v0.h
@@ -0,0 +1,313 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Opcode table for the V0 binary format.
+// Additions are fine but changing the behavior or order of any opcodes will
+// break parsing of existing files.
+//
+// Opcodes have been selected on frequency of use, general applicability, and
+// relative stability. Experimental ops should be implemented via the Foreign
+// Function Interface (FFI) first before graduating into the core set. Ops that
+// may only be present on certain targets should also be kept as imports via the
+// FFI.
+//
+// Opcodes may be specified for particular types (int32_t), categories of types
+// (all floating-point types), or implicit types (output matches input). Saving
+// opcode space by sharing a single opcode for multiple types is preferred
+// except where hot operations are performed (for example, comparison used in
+// loop iterators).
+
+#ifndef IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_
+#define IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_
+
+#include "schemas/bytecode/bytecode_v0.h"
+
+namespace iree {
+
+#define IREE_SEQUENCER_OPCODE_LIST(OPC, RESERVED_OPC) \
+ OPC(0x00, kConstant, "constant", FLAG(kDefault), "cr", FF) \
+ \
+ OPC(0x01, kCall, "call", FLAG(kDefault), "fSR", FF) \
+ OPC(0x02, kCallImport, "call_import", FLAG(kDefault), "FSR", FF) \
+ OPC(0x03, kCallIndirect, "call_indirect", FLAG(kDefault), "tsSR", FF) \
+ OPC(0x04, kReturn, "return", FLAG(kDefault), "S", FF) \
+ OPC(0x05, kBranch, "br", FLAG(kDefault), "bT", FF) \
+ OPC(0x06, kCondBranch, "cond_br", FLAG(kDefault), "sbTbT", FF) \
+ \
+ RSV(0x07, RESERVED_OPC) \
+ RSV(0x08, RESERVED_OPC) \
+ RSV(0x09, RESERVED_OPC) \
+ RSV(0x0A, RESERVED_OPC) \
+ RSV(0x0B, RESERVED_OPC) \
+ RSV(0x0C, RESERVED_OPC) \
+ RSV(0x0D, RESERVED_OPC) \
+ RSV(0x0E, RESERVED_OPC) \
+ RSV(0x0F, RESERVED_OPC) \
+ \
+ OPC(0x10, kDynamicDispatch, "dynamic_dispatch", FLAG(kDefault), "dsSOR", FF) \
+ OPC(0x11, kStaticDispatch, "static_dispatch", FLAG(kDefault), "diiiSOR", FF) \
+ \
+ RSV(0x12, RESERVED_OPC) \
+ RSV(0x13, RESERVED_OPC) \
+ RSV(0x14, RESERVED_OPC) \
+ RSV(0x15, RESERVED_OPC) \
+ RSV(0x16, RESERVED_OPC) \
+ RSV(0x17, RESERVED_OPC) \
+ RSV(0x18, RESERVED_OPC) \
+ RSV(0x19, RESERVED_OPC) \
+ RSV(0x1A, RESERVED_OPC) \
+ RSV(0x1B, RESERVED_OPC) \
+ RSV(0x1C, RESERVED_OPC) \
+ RSV(0x1D, RESERVED_OPC) \
+ RSV(0x1E, RESERVED_OPC) \
+ RSV(0x1F, RESERVED_OPC) \
+ \
+ OPC(0x20, kAllocStatic, "alloc_static", FLAG(kDefault), "Icr", FF) \
+ OPC(0x21, kAllocStack, "alloc_stack", FLAG(kDefault), "itISr", FF) \
+ OPC(0x22, kAllocStackInit, "alloc_stack_init", FLAG(kDefault), "tIScr", FF) \
+ OPC(0x23, kAllocHeap, "alloc_heap", FLAG(kDefault), "itISr", FF) \
+ OPC(0x24, kDiscard, "discard", FLAG(kDefault), "s", FF) \
+ \
+ RSV(0x25, RESERVED_OPC) \
+ RSV(0x26, RESERVED_OPC) \
+ RSV(0x27, RESERVED_OPC) \
+ RSV(0x28, RESERVED_OPC) \
+ RSV(0x29, RESERVED_OPC) \
+ RSV(0x2A, RESERVED_OPC) \
+ RSV(0x2B, RESERVED_OPC) \
+ RSV(0x2C, RESERVED_OPC) \
+ RSV(0x2D, RESERVED_OPC) \
+ RSV(0x2E, RESERVED_OPC) \
+ RSV(0x2F, RESERVED_OPC) \
+ RSV(0x30, RESERVED_OPC) \
+ \
+ OPC(0x31, kComputeRange, "compute_range", FLAG(kDefault), "sissoo", FF) \
+ OPC(0x32, kShape, "shape", FLAG(kDefault), "so", FF) \
+ OPC(0x33, kLength, "length", FLAG(kDefault), "so", FF) \
+ OPC(0x34, kDynamicSlice, "dynamic_slice", FLAG(kDefault), "ssstsr", FF) \
+ OPC(0x35, kStaticSlice, "static_slice", FLAG(kDefault), "siitIr", FF) \
+ OPC(0x36, kDynamicCopy, "dynamic_copy", FLAG(kDefault), "ssoss", FF) \
+ OPC(0x37, kStaticCopy, "static_copy", FLAG(kDefault), "sioii", FF) \
+ OPC(0x38, kDynamicFill, "dynamic_fill", FLAG(kDefault), "soss", FF) \
+ OPC(0x39, kStaticFill, "static_fill", FLAG(kDefault), "ioii", FF) \
+ OPC(0x3A, kClone, "clone", FLAG(kDefault), "sr", FF) \
+ OPC(0x3B, kAssign, "assign", FLAG(kDefault), "sr", FF) \
+ OPC(0x3C, kCondAssign, "cond_assign", FLAG(kDefault), "sssr", FF) \
+ OPC(0x3D, kReshape, "reshape", FLAG(kDefault), "ssr", FF) \
+ \
+ RSV(0x3E, RESERVED_OPC) \
+ RSV(0x3F, RESERVED_OPC) \
+ RSV(0x40, RESERVED_OPC) \
+ RSV(0x41, RESERVED_OPC) \
+ RSV(0x42, RESERVED_OPC) \
+ RSV(0x43, RESERVED_OPC) \
+ RSV(0x44, RESERVED_OPC) \
+ RSV(0x45, RESERVED_OPC) \
+ RSV(0x46, RESERVED_OPC) \
+ RSV(0x47, RESERVED_OPC) \
+ RSV(0x48, RESERVED_OPC) \
+ RSV(0x49, RESERVED_OPC) \
+ RSV(0x4A, RESERVED_OPC) \
+ RSV(0x4B, RESERVED_OPC) \
+ RSV(0x4C, RESERVED_OPC) \
+ RSV(0x4D, RESERVED_OPC) \
+ RSV(0x4E, RESERVED_OPC) \
+ RSV(0x4F, RESERVED_OPC) \
+ RSV(0x50, RESERVED_OPC) \
+ RSV(0x51, RESERVED_OPC) \
+ RSV(0x52, RESERVED_OPC) \
+ RSV(0x53, RESERVED_OPC) \
+ RSV(0x54, RESERVED_OPC) \
+ RSV(0x55, RESERVED_OPC) \
+ RSV(0x56, RESERVED_OPC) \
+ RSV(0x57, RESERVED_OPC) \
+ RSV(0x58, RESERVED_OPC) \
+ RSV(0x59, RESERVED_OPC) \
+ RSV(0x5A, RESERVED_OPC) \
+ RSV(0x5B, RESERVED_OPC) \
+ RSV(0x5C, RESERVED_OPC) \
+ RSV(0x5D, RESERVED_OPC) \
+ RSV(0x5E, RESERVED_OPC) \
+ RSV(0x5F, RESERVED_OPC) \
+ RSV(0x60, RESERVED_OPC) \
+ RSV(0x61, RESERVED_OPC) \
+ RSV(0x62, RESERVED_OPC) \
+ RSV(0x63, RESERVED_OPC) \
+ RSV(0x64, RESERVED_OPC) \
+ RSV(0x65, RESERVED_OPC) \
+ RSV(0x66, RESERVED_OPC) \
+ RSV(0x67, RESERVED_OPC) \
+ RSV(0x68, RESERVED_OPC) \
+ RSV(0x69, RESERVED_OPC) \
+ RSV(0x6A, RESERVED_OPC) \
+ RSV(0x6B, RESERVED_OPC) \
+ RSV(0x6C, RESERVED_OPC) \
+ RSV(0x6D, RESERVED_OPC) \
+ RSV(0x6E, RESERVED_OPC) \
+ RSV(0x6F, RESERVED_OPC) \
+ RSV(0x70, RESERVED_OPC) \
+ RSV(0x71, RESERVED_OPC) \
+ RSV(0x72, RESERVED_OPC) \
+ RSV(0x73, RESERVED_OPC) \
+ RSV(0x74, RESERVED_OPC) \
+ RSV(0x75, RESERVED_OPC) \
+ RSV(0x76, RESERVED_OPC) \
+ RSV(0x77, RESERVED_OPC) \
+ RSV(0x78, RESERVED_OPC) \
+ RSV(0x79, RESERVED_OPC) \
+ RSV(0x7A, RESERVED_OPC) \
+ RSV(0x7B, RESERVED_OPC) \
+ RSV(0x7C, RESERVED_OPC) \
+ RSV(0x7D, RESERVED_OPC) \
+ RSV(0x7E, RESERVED_OPC) \
+ RSV(0x7F, RESERVED_OPC) \
+ RSV(0x80, RESERVED_OPC) \
+ RSV(0x81, RESERVED_OPC) \
+ RSV(0x82, RESERVED_OPC) \
+ RSV(0x83, RESERVED_OPC) \
+ RSV(0x84, RESERVED_OPC) \
+ RSV(0x85, RESERVED_OPC) \
+ RSV(0x86, RESERVED_OPC) \
+ RSV(0x87, RESERVED_OPC) \
+ RSV(0x88, RESERVED_OPC) \
+ RSV(0x89, RESERVED_OPC) \
+ RSV(0x8A, RESERVED_OPC) \
+ RSV(0x8B, RESERVED_OPC) \
+ RSV(0x8C, RESERVED_OPC) \
+ RSV(0x8D, RESERVED_OPC) \
+ RSV(0x8E, RESERVED_OPC) \
+ RSV(0x8F, RESERVED_OPC) \
+ RSV(0x90, RESERVED_OPC) \
+ RSV(0x91, RESERVED_OPC) \
+ RSV(0x92, RESERVED_OPC) \
+ RSV(0x93, RESERVED_OPC) \
+ RSV(0x94, RESERVED_OPC) \
+ RSV(0x95, RESERVED_OPC) \
+ RSV(0x96, RESERVED_OPC) \
+ RSV(0x97, RESERVED_OPC) \
+ RSV(0x98, RESERVED_OPC) \
+ RSV(0x99, RESERVED_OPC) \
+ RSV(0x9A, RESERVED_OPC) \
+ RSV(0x9B, RESERVED_OPC) \
+ RSV(0x9C, RESERVED_OPC) \
+ RSV(0x9D, RESERVED_OPC) \
+ RSV(0x9E, RESERVED_OPC) \
+ RSV(0x9F, RESERVED_OPC) \
+ RSV(0xA0, RESERVED_OPC) \
+ RSV(0xA1, RESERVED_OPC) \
+ RSV(0xA2, RESERVED_OPC) \
+ RSV(0xA3, RESERVED_OPC) \
+ RSV(0xA4, RESERVED_OPC) \
+ RSV(0xA5, RESERVED_OPC) \
+ RSV(0xA6, RESERVED_OPC) \
+ RSV(0xA7, RESERVED_OPC) \
+ RSV(0xA8, RESERVED_OPC) \
+ RSV(0xA9, RESERVED_OPC) \
+ RSV(0xAA, RESERVED_OPC) \
+ RSV(0xAB, RESERVED_OPC) \
+ RSV(0xAC, RESERVED_OPC) \
+ RSV(0xAD, RESERVED_OPC) \
+ RSV(0xAE, RESERVED_OPC) \
+ RSV(0xAF, RESERVED_OPC) \
+ RSV(0xB0, RESERVED_OPC) \
+ RSV(0xB1, RESERVED_OPC) \
+ RSV(0xB2, RESERVED_OPC) \
+ RSV(0xB3, RESERVED_OPC) \
+ RSV(0xB4, RESERVED_OPC) \
+ RSV(0xB5, RESERVED_OPC) \
+ RSV(0xB6, RESERVED_OPC) \
+ RSV(0xB7, RESERVED_OPC) \
+ RSV(0xB8, RESERVED_OPC) \
+ RSV(0xB9, RESERVED_OPC) \
+ RSV(0xBA, RESERVED_OPC) \
+ RSV(0xBB, RESERVED_OPC) \
+ RSV(0xBC, RESERVED_OPC) \
+ RSV(0xBD, RESERVED_OPC) \
+ RSV(0xBE, RESERVED_OPC) \
+ RSV(0xBF, RESERVED_OPC) \
+ RSV(0xC0, RESERVED_OPC) \
+ RSV(0xC1, RESERVED_OPC) \
+ RSV(0xC2, RESERVED_OPC) \
+ RSV(0xC3, RESERVED_OPC) \
+ RSV(0xC4, RESERVED_OPC) \
+ RSV(0xC5, RESERVED_OPC) \
+ RSV(0xC6, RESERVED_OPC) \
+ RSV(0xC7, RESERVED_OPC) \
+ RSV(0xC8, RESERVED_OPC) \
+ RSV(0xC9, RESERVED_OPC) \
+ RSV(0xCA, RESERVED_OPC) \
+ RSV(0xCB, RESERVED_OPC) \
+ RSV(0xCC, RESERVED_OPC) \
+ RSV(0xCD, RESERVED_OPC) \
+ RSV(0xCE, RESERVED_OPC) \
+ RSV(0xCF, RESERVED_OPC) \
+ RSV(0xD0, RESERVED_OPC) \
+ RSV(0xD1, RESERVED_OPC) \
+ RSV(0xD2, RESERVED_OPC) \
+ RSV(0xD3, RESERVED_OPC) \
+ RSV(0xD4, RESERVED_OPC) \
+ RSV(0xD5, RESERVED_OPC) \
+ RSV(0xD6, RESERVED_OPC) \
+ RSV(0xD7, RESERVED_OPC) \
+ RSV(0xD8, RESERVED_OPC) \
+ RSV(0xD9, RESERVED_OPC) \
+ RSV(0xDA, RESERVED_OPC) \
+ RSV(0xDB, RESERVED_OPC) \
+ RSV(0xDC, RESERVED_OPC) \
+ RSV(0xDD, RESERVED_OPC) \
+ RSV(0xDE, RESERVED_OPC) \
+ RSV(0xDF, RESERVED_OPC) \
+ RSV(0xE0, RESERVED_OPC) \
+ RSV(0xE1, RESERVED_OPC) \
+ RSV(0xE2, RESERVED_OPC) \
+ RSV(0xE3, RESERVED_OPC) \
+ RSV(0xE4, RESERVED_OPC) \
+ RSV(0xE5, RESERVED_OPC) \
+ RSV(0xE6, RESERVED_OPC) \
+ RSV(0xE7, RESERVED_OPC) \
+ RSV(0xE8, RESERVED_OPC) \
+ RSV(0xE9, RESERVED_OPC) \
+ RSV(0xEA, RESERVED_OPC) \
+ RSV(0xEB, RESERVED_OPC) \
+ RSV(0xEC, RESERVED_OPC) \
+ RSV(0xED, RESERVED_OPC) \
+ RSV(0xEE, RESERVED_OPC) \
+ RSV(0xEF, RESERVED_OPC) \
+ RSV(0xF0, RESERVED_OPC) \
+ RSV(0xF1, RESERVED_OPC) \
+ RSV(0xF2, RESERVED_OPC) \
+ RSV(0xF3, RESERVED_OPC) \
+ RSV(0xF4, RESERVED_OPC) \
+ RSV(0xF5, RESERVED_OPC) \
+ RSV(0xF6, RESERVED_OPC) \
+ RSV(0xF7, RESERVED_OPC) \
+ RSV(0xF8, RESERVED_OPC) \
+ RSV(0xF9, RESERVED_OPC) \
+ RSV(0xFA, RESERVED_OPC) \
+ RSV(0xFB, RESERVED_OPC) \
+ RSV(0xFC, RESERVED_OPC) \
+ \
+ OPC(0xFD, kTrace, "trace", FLAG(kDefault), "s", FF) \
+ OPC(0xFE, kCondBreak, "cond_break", FLAG(kDefault), "s", FF) \
+ OPC(0xFF, kBreak, "break", FLAG(kDefault), "", FF)
+
+#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal,
+enum class SequencerOpcode : uint8_t {
+ IREE_SEQUENCER_OPCODE_LIST(DECLARE_ENUM, DECLARE_ENUM)
+};
+#undef DECLARE_ENUM
+
+} // namespace iree
+
+#endif // IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_
diff --git a/iree/schemas/bytecode_def.fbs b/schemas/bytecode_def.fbs
similarity index 100%
rename from iree/schemas/bytecode_def.fbs
rename to schemas/bytecode_def.fbs
diff --git a/schemas/debug_service.fbs b/schemas/debug_service.fbs
new file mode 100644
index 0000000..0f31294
--- /dev/null
+++ b/schemas/debug_service.fbs
@@ -0,0 +1,347 @@
+include "schemas/function_def.fbs";
+include "schemas/module_def.fbs";
+
+namespace iree.rt.debug.rpc;
+
+table Status {
+ code:int;
+ message:string;
+}
+
+table CreateSessionRequest {
+}
+table CreateSessionResponse {
+ session_id:int;
+}
+
+table MakeReadyRequest {
+ session_id:int;
+}
+table MakeReadyResponse {
+}
+
+table GetStatusRequest {
+ session_id:int;
+ // TODO(benvanik): caps debugger supports? version expected?
+}
+table GetStatusResponse {
+ protocol:int;
+ // TODO(benvanik): run state.
+ // TODO(benvanik): profiling state.
+}
+
+table NativeFunctionDef {
+ name:string;
+ // TODO(benvanik): more information about the fns (stack trace of registrant?)
+}
+
+table ContextDef {
+ context_id:int;
+ native_functions:[NativeFunctionDef];
+ module_names:[string];
+}
+
+table ListContextsRequest {
+ session_id:int;
+}
+
+table ListContextsResponse {
+ contexts:[ContextDef];
+}
+
+table GetModuleRequest {
+ session_id:int;
+ context_id:int;
+ module_name:string;
+}
+table GetModuleResponse {
+ module:ModuleDef;
+}
+
+table GetFunctionRequest {
+ session_id:int;
+ context_id:int;
+ module_name:string;
+ function_ordinal:int;
+}
+table GetFunctionResponse {
+ bytecode:BytecodeDef;
+ // TODO(benvanik): import info (linked module, etc).
+}
+
+table ResolveFunctionRequest {
+ session_id:int;
+ module_name:string;
+ function_name:string;
+}
+table ResolveFunctionResponse {
+ context_ids:[int];
+ function_ordinal:int;
+}
+
+table BufferViewDef {
+ is_valid:bool;
+ shape:[int];
+ element_size:int;
+ // TODO(benvanik): buffer attrs (type, access, usage).
+ // TODO(benvanik): buffer size/allocated_size.
+ // TODO(benvanik): buffer data (if accessible).
+}
+
+table StackFrameDef {
+ module_name:string;
+ function_ordinal:int;
+ offset:int;
+ locals:[BufferViewDef];
+}
+
+table InvocationDef {
+ invocation_id:int;
+ frames:[StackFrameDef];
+}
+
+table ListInvocationsRequest {
+ session_id:int;
+}
+table ListInvocationsResponse {
+ invocations:[InvocationDef];
+}
+
+table SuspendInvocationsRequest {
+ session_id:int;
+ invocation_ids:[int];
+}
+table SuspendInvocationsResponse {
+ invocations:[InvocationDef];
+}
+
+table ResumeInvocationsRequest {
+ session_id:int;
+ invocation_ids:[int];
+}
+table ResumeInvocationsResponse {
+}
+
+enum StepMode : uint8 {
+ STEP_ONCE = 0,
+ STEP_TO_OFFSET = 1,
+}
+
+table StepInvocationRequest {
+ session_id:int;
+ step_id:int;
+ invocation_id:int;
+ step_mode:StepMode;
+ bytecode_offset:int;
+}
+table StepInvocationResponse {}
+
+table GetInvocationLocalRequest {
+ session_id:int;
+ invocation_id:int;
+ frame_index:int;
+ local_index:int;
+}
+table GetInvocationLocalResponse {
+ value:BufferViewDef;
+}
+
+table SetInvocationLocalRequest {
+ session_id:int;
+ invocation_id:int;
+ frame_index:int;
+ local_index:int;
+ value:BufferViewDef;
+}
+table SetInvocationLocalResponse {
+ value:BufferViewDef;
+}
+
+enum BreakpointType : uint8 {
+ BYTECODE_FUNCTION = 0,
+ NATIVE_FUNCTION = 1,
+}
+
+table BreakpointDef {
+ breakpoint_id:int;
+ breakpoint_type:BreakpointType;
+
+ module_name:string;
+ function_name:string;
+ function_ordinal:int;
+ bytecode_offset:int;
+}
+
+table ListBreakpointsRequest {
+ session_id:int;
+}
+table ListBreakpointsResponse {
+ breakpoints:[BreakpointDef];
+}
+
+table AddBreakpointRequest {
+ session_id:int;
+ breakpoint:BreakpointDef;
+}
+table AddBreakpointResponse {
+ breakpoint:BreakpointDef;
+}
+
+table RemoveBreakpointRequest {
+ session_id:int;
+ breakpoint_id:int;
+}
+table RemoveBreakpointResponse {
+}
+
+table StartProfilingRequest {
+ session_id:int;
+ context_id:int;
+ // TODO(benvanik): profiling mode.
+ // mode: sampling_timing, instrumented_coverage, instrumented_log,
+ // invoke_log
+}
+table StartProfilingResponse {
+ // TODO(benvanik): current/new mode.
+}
+
+table StopProfilingRequest {
+ session_id:int;
+ context_id:int;
+}
+table StopProfilingResponse {
+ // TODO(benvanik): profiling data.
+}
+
+// TODO(benvanik): streaming profiling data query.
+
+table ServiceShutdownEvent {
+}
+
+table ContextRegisteredEvent {
+ context_id:int;
+}
+table ContextUnregisteredEvent {
+ context_id:int;
+}
+
+table ModuleLoadedEvent {
+ context_id:int;
+ module_name:string;
+}
+
+table InvocationRegisteredEvent {
+ invocation_id:int;
+}
+table InvocationUnregisteredEvent {
+ invocation_id:int;
+}
+
+table BreakpointResolvedEvent {
+ breakpoint:BreakpointDef;
+ context_id:int;
+}
+
+table BreakpointHitEvent {
+ breakpoint_id:int;
+ invocation:InvocationDef;
+}
+
+table StepCompletedEvent {
+ step_id:int;
+ invocations:[InvocationDef];
+}
+
+union RequestUnion {
+ CreateSessionRequest,
+ MakeReadyRequest,
+ GetStatusRequest,
+ ListContextsRequest,
+ GetModuleRequest,
+ GetFunctionRequest,
+ ResolveFunctionRequest,
+ ListInvocationsRequest,
+ SuspendInvocationsRequest,
+ ResumeInvocationsRequest,
+ StepInvocationRequest,
+ GetInvocationLocalRequest,
+ SetInvocationLocalRequest,
+ ListBreakpointsRequest,
+ AddBreakpointRequest,
+ RemoveBreakpointRequest,
+ StartProfilingRequest,
+ StopProfilingRequest,
+}
+
+union ResponseUnion {
+ CreateSessionResponse,
+ MakeReadyResponse,
+ GetStatusResponse,
+ ListContextsResponse,
+ GetModuleResponse,
+ GetFunctionResponse,
+ ResolveFunctionResponse,
+ ListInvocationsResponse,
+ SuspendInvocationsResponse,
+ ResumeInvocationsResponse,
+ StepInvocationResponse,
+ GetInvocationLocalResponse,
+ SetInvocationLocalResponse,
+ ListBreakpointsResponse,
+ AddBreakpointResponse,
+ RemoveBreakpointResponse,
+ StartProfilingResponse,
+ StopProfilingResponse,
+}
+
+union EventUnion {
+ ServiceShutdownEvent,
+ ContextRegisteredEvent,
+ ContextUnregisteredEvent,
+ ModuleLoadedEvent,
+ InvocationRegisteredEvent,
+ InvocationUnregisteredEvent,
+ BreakpointResolvedEvent,
+ BreakpointHitEvent,
+ StepCompletedEvent,
+}
+
+table Request {
+ message:RequestUnion;
+}
+
+table Response {
+ status:Status;
+ message:ResponseUnion;
+}
+
+table ServicePacket {
+ response:Response;
+ event:EventUnion;
+}
+
+// NOTE: we aren't using this yet as the FlatBuffers gRPC code is... suspect.
+rpc_service DebugServiceRpc {
+ MakeReady(MakeReadyRequest):MakeReadyResponse;
+
+ GetStatus(GetStatusRequest):GetStatusResponse;
+
+ ListContexts(ListContextsRequest):ListContextsResponse;
+ GetModule(GetModuleRequest):GetModuleResponse;
+ GetFunction(GetFunctionRequest):GetFunctionResponse;
+ ResolveFunction(ResolveFunctionRequest):ResolveFunctionResponse;
+
+ ListInvocations(ListInvocationsRequest):ListInvocationsResponse;
+ SuspendInvocations(SuspendInvocationsRequest):SuspendInvocationsResponse;
+ ResumeInvocations(ResumeInvocationsRequest):ResumeInvocationsResponse;
+ StepInvocation(StepInvocationRequest):StepInvocationResponse;
+ GetInvocationLocal(GetInvocationLocalRequest):GetInvocationLocalResponse;
+ SetInvocationLocal(SetInvocationLocalRequest):SetInvocationLocalResponse;
+
+ ListBreakpoints(ListBreakpointsRequest):ListBreakpointsResponse;
+ AddBreakpoint(AddBreakpointRequest):AddBreakpointResponse;
+ RemoveBreakpoint(RemoveBreakpointRequest):RemoveBreakpointResponse;
+
+ StartProfiling(StartProfilingRequest):StartProfilingResponse;
+ StopProfiling(StopProfilingRequest):StopProfilingResponse;
+}
diff --git a/iree/schemas/device_def.fbs b/schemas/device_def.fbs
similarity index 100%
rename from iree/schemas/device_def.fbs
rename to schemas/device_def.fbs
diff --git a/iree/schemas/device_group_def.fbs b/schemas/device_group_def.fbs
similarity index 100%
rename from iree/schemas/device_group_def.fbs
rename to schemas/device_group_def.fbs
diff --git a/schemas/device_table_def.fbs b/schemas/device_table_def.fbs
new file mode 100644
index 0000000..597e126
--- /dev/null
+++ b/schemas/device_table_def.fbs
@@ -0,0 +1,13 @@
+include "schemas/device_def.fbs";
+include "schemas/device_group_def.fbs";
+
+namespace iree;
+
+// A table of devices used for runtime device resolution and referencing.
+table DeviceTableDef {
+ // One or more virtual devices referenced by ordinal in the sequencer ops.
+ devices:[DeviceDef];
+
+ // Zero or more device groups that specify which devices must be compatible.
+ device_groups:[DeviceGroupDef];
+}
diff --git a/iree/schemas/executable_def.fbs b/schemas/executable_def.fbs
similarity index 100%
rename from iree/schemas/executable_def.fbs
rename to schemas/executable_def.fbs
diff --git a/schemas/executable_table_def.fbs b/schemas/executable_table_def.fbs
new file mode 100644
index 0000000..6c55f76
--- /dev/null
+++ b/schemas/executable_table_def.fbs
@@ -0,0 +1,28 @@
+include "schemas/executable_def.fbs";
+
+namespace iree;
+
+// A fat executable containing multiple format variants for the same logical
+// entry points.
+table MultiArchExecutableDef {
+ // Friendly name of the executable used for diagnostics.
+ name:string;
+
+ // Number of available entry points.
+ // This is used for bytecode verification even when the executable is not
+ // fully loaded into a device. All executables must have the same entry
+ // points.
+ entry_point_count:uint;
+
+ // A set of executables of various formats and supported feature sets.
+ // The runtime will select the appropriate executable based on the dispatch
+ // requirements.
+ executables:[ExecutableDef];
+}
+
+// A table of executables used for runtime dispatch lookup.
+table ExecutableTableDef {
+ // One or more top level executables referenced by sequencer dispatch ops.
+ // Ordinal is referenced by dispatch ops to index into the table.
+ multi_arch_executables:[MultiArchExecutableDef];
+}
diff --git a/schemas/function_def.fbs b/schemas/function_def.fbs
new file mode 100644
index 0000000..b6eed64
--- /dev/null
+++ b/schemas/function_def.fbs
@@ -0,0 +1,18 @@
+include "schemas/bytecode_def.fbs";
+include "schemas/type_def.fbs";
+
+namespace iree;
+
+table FunctionAttributeDef {
+ key:string;
+ value:string;
+}
+
+table FunctionDef {
+ name:string;
+ type:FunctionTypeDef;
+
+ attrs:[FunctionAttributeDef];
+
+ bytecode:BytecodeDef;
+}
diff --git a/schemas/function_table_def.fbs b/schemas/function_table_def.fbs
new file mode 100644
index 0000000..3ad9c0a
--- /dev/null
+++ b/schemas/function_table_def.fbs
@@ -0,0 +1,9 @@
+include "schemas/function_def.fbs";
+
+namespace iree;
+
+table FunctionTableDef {
+ functions:[FunctionDef];
+ imports:[int];
+ exports:[int];
+}
diff --git a/schemas/module_def.fbs b/schemas/module_def.fbs
new file mode 100644
index 0000000..cebcbcb
--- /dev/null
+++ b/schemas/module_def.fbs
@@ -0,0 +1,20 @@
+include "schemas/executable_table_def.fbs";
+include "schemas/device_table_def.fbs";
+include "schemas/function_table_def.fbs";
+include "schemas/source_map_def.fbs";
+
+namespace iree;
+
+// 'Executable MODule'.
+file_identifier "EMOD";
+file_extension "emod";
+
+table ModuleDef {
+ name:string;
+ device_table:DeviceTableDef;
+ function_table:FunctionTableDef;
+ executable_table:ExecutableTableDef;
+ source_map:SourceMapDef;
+}
+
+root_type ModuleDef;
diff --git a/iree/schemas/source_map_def.fbs b/schemas/source_map_def.fbs
similarity index 100%
rename from iree/schemas/source_map_def.fbs
rename to schemas/source_map_def.fbs
diff --git a/iree/schemas/spirv_executable_def.fbs b/schemas/spirv_executable_def.fbs
similarity index 100%
rename from iree/schemas/spirv_executable_def.fbs
rename to schemas/spirv_executable_def.fbs
diff --git a/iree/schemas/type_def.fbs b/schemas/type_def.fbs
similarity index 100%
rename from iree/schemas/type_def.fbs
rename to schemas/type_def.fbs
diff --git a/test/e2e/BUILD b/test/e2e/BUILD
index bba2d68..c83e6fc 100644
--- a/test/e2e/BUILD
+++ b/test/e2e/BUILD
@@ -1,6 +1,6 @@
# Tests for end-to-end IREE support.
-load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
package(
default_visibility = ["//visibility:public"],
@@ -9,7 +9,7 @@
iree_setup_lit_package(
data = [
- "//iree/tools:iree-run-mlir",
+ "///tools:iree-run-mlir",
],
)
diff --git a/test/e2e/xla/BUILD b/test/e2e/xla/BUILD
index f3eda03..0066c08 100644
--- a/test/e2e/xla/BUILD
+++ b/test/e2e/xla/BUILD
@@ -1,6 +1,6 @@
# Tests for end-to-end IREE support starting from the XLA HLO dialect.
-load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
package(
default_visibility = ["//visibility:public"],
@@ -9,7 +9,7 @@
iree_setup_lit_package(
data = [
- "//iree/tools:iree-run-mlir",
+ "///tools:iree-run-mlir",
],
)
diff --git a/tools/BUILD b/tools/BUILD
new file mode 100644
index 0000000..c35fa96
--- /dev/null
+++ b/tools/BUILD
@@ -0,0 +1,124 @@
+# Misc tools used to optimize, translate, and evaluate IREE.
+# Most of these are not designed to run on-device.
+
+load("//:build_defs.google.bzl", "PLATFORM_VULKAN_DEPS")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "run_lit.sh",
+ "sanitizer_suppressions.txt",
+])
+
+cc_binary(
+ name = "iree-opt",
+ deps = [
+ "///compiler/Transforms",
+ "///compiler/Transforms/Interpreter",
+ "///compiler/Transforms/Sequencer",
+ "///compiler/Translation/SPIRV",
+ "@llvm//:support",
+ "@local_config_mlir//:AffineDialectRegistration",
+ "@local_config_mlir//:MlirOptLib",
+ "@local_config_mlir//:MlirOptMain",
+ "@local_config_mlir//:StandardDialectRegistration",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration",
+ ],
+)
+
+cc_binary(
+ name = "iree-run-mlir",
+ srcs = ["run_mlir_main.cc"],
+ deps = [
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ "///base:source_location",
+ "///rt",
+ "///vm:sequencer_module",
+ "@llvm//:support",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Parser",
+ "@local_config_mlir//:Support",
+ "///base:init",
+ "///base:status",
+ "///compiler/Translation/Sequencer",
+ "///compiler/Translation/Interpreter",
+ "///compiler/Translation/SPIRV",
+ "///hal:buffer_view_string_util",
+ "///hal:driver_registry",
+ "///schemas",
+ "///rt/debug:debug_server_flags",
+ ] + PLATFORM_VULKAN_DEPS + [
+ "///hal/interpreter:interpreter_driver_module",
+ # TODO(b/142004903): enable when Dawn HAL implementation is functional
+ # "///hal/dawn:dawn_driver_module",
+ "///hal/vulkan:vulkan_driver_module",
+ ],
+)
+
+cc_binary(
+ name = "iree-translate",
+ srcs = ["iree_translate_main.cc"],
+ deps = [
+ "///compiler/Translation/Interpreter",
+ "///compiler/Translation/SPIRV",
+ "///compiler/Translation/Sequencer",
+ "@llvm//:support",
+ "@local_config_mlir//:AffineDialectRegistration",
+ "@local_config_mlir//:IR",
+ "@local_config_mlir//:Pass",
+ "@local_config_mlir//:StandardDialectRegistration",
+ "@local_config_mlir//:Support",
+ "@local_config_mlir//:TranslateClParser",
+ "@local_config_mlir//:Translation",
+ "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration",
+ ],
+)
+
+cc_binary(
+ name = "run_module",
+ srcs = ["run_module_main.cc"],
+ deps = [
+ "///base:file_io",
+ "///base:file_path",
+ "///base:init",
+ "///base:source_location",
+ "///base:status",
+ "///hal:buffer_view_string_util",
+ "///hal:driver_registry",
+ "///hal/interpreter:interpreter_driver_module",
+ "///rt",
+ "///rt/debug:debug_server_flags",
+ "///schemas",
+ "///vm:sequencer_module",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "benchmark_module",
+ testonly = 1,
+ srcs = ["benchmark_module.cc"],
+ deps = [
+ "///base:file_io",
+ "///base:file_path",
+ "///base:init",
+ "///base:source_location",
+ "///base:status",
+ "///hal:buffer_view_string_util",
+ "///hal:driver_registry",
+ "///hal/interpreter:interpreter_driver_module",
+ "///rt",
+ "///rt/debug:debug_server_flags",
+ "///schemas",
+ "///vm:sequencer_module",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ "@com_google_benchmark//:benchmark",
+ ],
+)
diff --git a/iree/tools/CMakeLists.txt b/tools/CMakeLists.txt
similarity index 100%
rename from iree/tools/CMakeLists.txt
rename to tools/CMakeLists.txt
diff --git a/tools/benchmark_module.cc b/tools/benchmark_module.cc
new file mode 100644
index 0000000..b5f45b9
--- /dev/null
+++ b/tools/benchmark_module.cc
@@ -0,0 +1,157 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <iostream>
+#include <vector>
+
+#include "absl/flags/flag.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "base/file_io.h"
+#include "base/file_path.h"
+#include "base/init.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "benchmark/benchmark.h"
+#include "hal/buffer_view_string_util.h"
+#include "hal/driver_registry.h"
+#include "rt/context.h"
+#include "rt/debug/debug_server_flags.h"
+#include "rt/instance.h"
+#include "rt/module_printer.h"
+#include "schemas/module_def_generated.h"
+#include "vm/sequencer_module.h"
+
+ABSL_FLAG(std::string, main_module, "", "Main module with entry point.");
+ABSL_FLAG(std::string, main_function, "",
+ "Function within the main module to execute.");
+
+ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values.");
+ABSL_FLAG(std::string, input_file, "",
+ "Input shapes and optional values serialized in a file.");
+
+namespace iree {
+namespace {
+
+// Parses a list of input shapes and values from a string of newline-separated
+// inputs. Expects the contents to have one value per line with each value
+// listed as
+// [shape]xtype=[value]
+// Example:
+// 4x4xi8=0,1,2,3
+StatusOr<std::vector<hal::BufferView>> ParseInputsFromFlags(
+ hal::Allocator* allocator) {
+ std::string file_contents;
+ if (!absl::GetFlag(FLAGS_input_values).empty()) {
+ file_contents =
+ absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}});
+ } else if (!absl::GetFlag(FLAGS_input_file).empty()) {
+ ASSIGN_OR_RETURN(file_contents,
+ file_io::GetFileContents(absl::GetFlag(FLAGS_input_file)));
+ }
+ std::vector<hal::BufferView> inputs;
+ for (const auto& line :
+ absl::StrSplit(file_contents, '\n', absl::SkipWhitespace())) {
+ ASSIGN_OR_RETURN(auto input,
+ hal::ParseBufferViewFromString(line, allocator));
+ inputs.push_back(input);
+ }
+ return inputs;
+}
+
+Status Run(benchmark::State& state) {
+ ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags());
+ auto instance = make_ref<rt::Instance>(std::move(debug_server));
+ ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
+ "interpreter"));
+ ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
+ RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
+ auto policy = make_ref<rt::Policy>();
+ auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy));
+
+ // Load main module.
+ ASSIGN_OR_RETURN(
+ auto main_module_file,
+ vm::ModuleFile::LoadFile(ModuleDefIdentifier(),
+ absl::GetFlag(FLAGS_main_module)),
+ _ << "while loading module file " << absl::GetFlag(FLAGS_main_module));
+ ASSIGN_OR_RETURN(auto main_module,
+ vm::SequencerModule::FromFile(std::move(main_module_file)));
+
+ // Register the main module with the context.
+ // We could add additional modules (specializations, shared libraries, etc).
+ // ModuleFiles are stateless so we could have the same module_file used by
+ // multiple contexts simultaneously.
+ RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module)));
+
+ rt::Function main_function;
+ if (!absl::GetFlag(FLAGS_main_function).empty()) {
+ // User-specified main function.
+ ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByName(
+ rt::Function::Linkage::kExport,
+ absl::GetFlag(FLAGS_main_function)));
+ } else {
+ // No main function specified; to prevent non-deterministic behavior we
+ // require one unless there's exactly one exported function in the module.
+ if (main_module->signature().export_function_count() == 1) {
+ ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByOrdinal(
+ rt::Function::Linkage::kExport, 0));
+ } else {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "--main_function= must be specified to disambiguate the "
+ "function to run";
+ }
+ }
+
+ // Call into the main function.
+ ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(device->allocator()));
+
+ for (auto _ : state) {
+ ASSIGN_OR_RETURN(auto invocation,
+ rt::Invocation::Create(add_ref(context), main_function,
+ make_ref<rt::Policy>(), {},
+ absl::MakeConstSpan(arguments)));
+ RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture()));
+ }
+
+ return OkStatus();
+}
+
+void BM_RunModule(benchmark::State& state) {
+ // Delegate to a status-returning function so we can use the status macros.
+ CHECK_OK(Run(state));
+}
+
+// By default only the main thread is included in CPU time. Include all the
+// threads instead. To make single and multi-threaded benchmarks more
+// comparable, use the wall time to determine how many iterations to run.
+// See https://github.com/google/benchmark#cpu-timers,
+BENCHMARK(BM_RunModule)->MeasureProcessCPUTime()->UseRealTime();
+
+} // namespace
+
+extern "C" int main(int argc, char** argv) {
+ // The benchmark library uses a different mechanism for its flags. This
+ // consumes any arguments it understands from argv. It must come before
+ // InitializeEnvironment to avoid failures on unknown flags.
+ ::benchmark::Initialize(&argc, argv);
+ InitializeEnvironment(&argc, &argv);
+ size_t run_benchmark_count = ::benchmark::RunSpecifiedBenchmarks();
+ CHECK_GT(run_benchmark_count, 0) << "No benchmarks were run";
+ return 0;
+}
+
+} // namespace iree
diff --git a/tools/compilation.bzl b/tools/compilation.bzl
new file mode 100644
index 0000000..65b2f6b
--- /dev/null
+++ b/tools/compilation.bzl
@@ -0,0 +1,43 @@
+"""Rules for compiling IREE executables, modules, and archives."""
+
+load("///build_tools/embed_data:build_defs.bzl", "cc_embed_data")
+
+# TODO(benvanik): port to a full starlark rule, document, etc.
+def iree_bytecode_module(
+ name,
+ srcs,
+ cc_namespace = None,
+ visibility = None):
+ native.genrule(
+ name = name,
+ srcs = srcs,
+ outs = [
+ "%s.emod" % (name),
+ ],
+ cmd = " && ".join([
+ " ".join([
+ "$(location ///tools:iree-translate)",
+ "-mlir-to-iree-module",
+ "-o $(location %s.emod)" % (name),
+ ] + ["$(locations %s)" % (src) for src in srcs]),
+ ]),
+ tools = [
+ "///tools:iree-translate",
+ ],
+ message = "Compiling IREE module %s..." % (name),
+ output_to_bindir = 1,
+ )
+
+ # Embed the module for use in C++. This avoids the need for file IO in
+ # tests and samples that would otherwise complicate execution/porting.
+ if cc_namespace:
+ cc_embed_data(
+ name = "%s_cc" % (name),
+ identifier = name,
+ srcs = ["%s.emod" % (name)],
+ cc_file_output = "%s.cc" % (name),
+ h_file_output = "%s.h" % (name),
+ cpp_namespace = cc_namespace,
+ visibility = visibility,
+ flatten = True,
+ )
diff --git a/tools/debugger/BUILD b/tools/debugger/BUILD
new file mode 100644
index 0000000..1146b8e
--- /dev/null
+++ b/tools/debugger/BUILD
@@ -0,0 +1,173 @@
+# IREE Debugger UIs.
+#
+# The main debugger UI can be used in standalone mode connected to a remote
+# host (via :debugger) or can be directly embedded into the IREE runtime to
+# allow for attaching (--iree_attach_debugger).
+#
+# By default the IREE runtime does not compile in debug support. To link it in
+# pass --define=IREE_DEBUG=1 to bazel builds of the runtime.
+
+# TODO(benvanik): re-enable debugger after refactoring.
+# load("//third_party/emscripten:split_transition_defs.bzl", "auto_wasm_binary")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# TODO(benvanik): re-enable debugger after refactoring.
+# alias(
+# name = "debugger",
+# actual = select({
+# "//tools/cc_target_os:emscripten": ":debug_app_emscripten_files",
+# "//conditions:default": ":debug_app_native",
+# }),
+# )
+#
+# cc_library(
+# name = "debug_app_library",
+# srcs = ["debug_app.cc"],
+# hdrs = ["debug_app.h"],
+# deps = [
+# "//third_party/GL:GLES2_headers",
+# "//third_party/SDL2",
+# "@com_google_absl//absl/flags:flag",
+# "@com_google_absl//absl/memory",
+# "@com_google_absl//absl/strings",
+# "@com_google_absl//absl/types:optional",
+# "//third_party/dear_imgui",
+# "//third_party/dear_imgui:imgui_sdl_opengl3",
+# "///base:memory",
+# "///base:source_location",
+# "///base:status",
+# "///rt/debug:debug_client",
+# "///schemas",
+# ],
+# )
+#
+# # NOTE: users must also link in a GL implementation, like:
+# # "//third_party/GL/native:GLESv2", # build-cleaner: keep
+# cc_library(
+# name = "debug_app_embedded",
+# srcs = ["debug_app_embedded.cc"],
+# hdrs = ["debug_app_embedded.h"],
+# deps = [
+# ":debug_app_library",
+# "//third_party/SDL2",
+# "@com_google_absl//absl/base:core_headers",
+# "@com_google_absl//absl/memory",
+# "@com_google_absl//absl/strings",
+# "@com_google_absl//absl/synchronization",
+# "//third_party/dear_imgui",
+# "///base:memory",
+# "///base:status",
+# ],
+# )
+#
+# EMSCRIPTEN_LINKOPTS_COMMON = [
+# # Error at compile time on unresolved symbols.
+# "-s ERROR_ON_UNDEFINED_SYMBOLS=1",
+#
+# # Required by SDL.
+# "-s EXTRA_EXPORTED_RUNTIME_METHODS=Pointer_stringify",
+#
+# # TODO(benvanik): tweak to enable support when needed.
+# "-s ALLOW_MEMORY_GROWTH=1",
+# # "-s WASM_MEM_MAX=268435456", # 256MB
+# # "-s TOTAL_MEMORY=268435456", # 256MB
+# ]
+#
+# EMSCRIPTEN_LINKOPTS_DBG = [
+# # Show WASM stack trace in Chrome debugger.
+# "-g2",
+# "-s DEMANGLE_SUPPORT=1",
+#
+# # Enable verbose assertions.
+# "-s ASSERTIONS=2",
+# "-s SAFE_HEAP=1",
+# "-s STACK_OVERFLOW_CHECK=2",
+# ]
+#
+# EMSCRIPTEN_LINKOPTS_OPT = []
+#
+# cc_binary(
+# name = "debug_app_emscripten",
+# srcs = ["debug_app_main_emscripten.cc"],
+# linkopts = EMSCRIPTEN_LINKOPTS_COMMON + select({
+# "//tools/compilation_mode:dbg": EMSCRIPTEN_LINKOPTS_DBG,
+# "//tools/compilation_mode:opt": EMSCRIPTEN_LINKOPTS_OPT,
+# "//conditions:default": EMSCRIPTEN_LINKOPTS_OPT,
+# }),
+# tags = [
+# "manual",
+# "notap", # TODO(b/137088911): Build/test on TAP
+# "wasm",
+# ],
+# deps = [
+# ":debug_app_library",
+# "//third_party/SDL2",
+# "@com_google_absl//absl/memory",
+# "//third_party/dear_imgui",
+# "//third_party/dear_imgui:imgui_sdl_opengl3",
+# "///base:init",
+# "///base:source_location",
+# "///base:status",
+# ],
+# )
+#
+# auto_wasm_binary(
+# name = "debug_app_emscripten_binary",
+# cc_target = ":debug_app_emscripten",
+# tags = ["manual"],
+# )
+#
+# Fileset(
+# name = "debug_app_emscripten_files",
+# out = "wasm_files",
+# entries = [
+# FilesetEntry(
+# files = [":debug_app_emscripten_binary"],
+# strip_prefix = "debug_app_emscripten_binary",
+# destdir = "wasm",
+# ),
+# FilesetEntry(
+# files = ["debug_app.html"],
+# destdir = "wasm",
+# ),
+# ],
+# tags = ["manual"],
+# )
+#
+# cc_binary(
+# name = "debug_app_native",
+# srcs = ["debug_app_main_native.cc"],
+# deps = [
+# ":debug_app_embedded",
+# "//third_party/GL/native:EGL", # build-cleaner: keep
+# "//third_party/GL/native:GLESv2", # build-cleaner: keep
+# "///base:init",
+# "///base:status",
+# ],
+# )
+#
+# cc_binary(
+# name = "debug_cli",
+# srcs = ["debug_cli_main.cc"],
+# deps = [
+# ":debug_prompt",
+# "@com_google_absl//absl/flags:flag",
+# "///base:init",
+# "///base:status",
+# ],
+# )
+#
+# cc_library(
+# name = "debug_prompt",
+# srcs = ["debug_prompt.cc"],
+# hdrs = ["debug_prompt.h"],
+# deps = [
+# "@com_google_absl//absl/strings",
+# "///base:status",
+# "///rt/debug:debug_client",
+# ],
+# )
diff --git a/tools/debugger/debug_app.cc b/tools/debugger/debug_app.cc
new file mode 100644
index 0000000..5136b04
--- /dev/null
+++ b/tools/debugger/debug_app.cc
@@ -0,0 +1,1422 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tools/debugger/debug_app.h"
+
+#include <GLES2/gl2.h>
+
+#include <algorithm>
+#include <cstdio>
+
+#include "absl/flags/flag.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "absl/types/optional.h"
+#include "base/memory.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "rt/debug/debug_client.h"
+#include "schemas/debug_service_generated.h"
+#include "third_party/dear_imgui/imgui.h"
+#include "third_party/dear_imgui/imgui_internal.h"
+#include "vm/bytecode_module.h"
+#include "vm/bytecode_tables_sequencer.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+namespace {
+
+void PushButtonHue(float hue) {
+ ImGui::PushStyleColor(ImGuiCol_Button,
+ (ImVec4)ImColor::HSV(hue / 7.0f, 0.6f, 0.6f));
+ ImGui::PushStyleColor(ImGuiCol_ButtonHovered,
+ (ImVec4)ImColor::HSV(hue / 7.0f, 0.7f, 0.7f));
+ ImGui::PushStyleColor(ImGuiCol_ButtonActive,
+ (ImVec4)ImColor::HSV(hue / 7.0f, 0.8f, 0.8f));
+}
+
+void PushButtonColor(const ImVec4& color) {
+ ImGui::PushStyleColor(ImGuiCol_Button, color);
+ ImGui::PushStyleColor(ImGuiCol_ButtonHovered, color);
+ ImGui::PushStyleColor(ImGuiCol_ButtonActive, color);
+}
+
+void PopButtonStyle() { ImGui::PopStyleColor(3); }
+
+bool AreBreakpointsEqual(const RemoteBreakpoint& breakpoint,
+ const DebugApp::UserBreakpoint& user_breakpoint) {
+ if (user_breakpoint.active_breakpoint == &breakpoint) {
+ return true;
+ } else if (user_breakpoint.type != breakpoint.type()) {
+ return false;
+ }
+ switch (breakpoint.type()) {
+ case RemoteBreakpoint::Type::kBytecodeFunction:
+ if (user_breakpoint.function_ordinal != -1 &&
+ user_breakpoint.function_ordinal != breakpoint.function_ordinal()) {
+ return false;
+ }
+ return breakpoint.module_name() == user_breakpoint.module_name &&
+ breakpoint.function_name() == user_breakpoint.function_name &&
+ breakpoint.bytecode_offset() == user_breakpoint.bytecode_offset;
+ case RemoteBreakpoint::Type::kNativeFunction:
+ return breakpoint.function_name() == user_breakpoint.native_function;
+ default:
+ return false;
+ }
+}
+
+} // namespace
+
+// static
+void DebugApp::PumpMainLoopThunk(void* arg) {
+ auto status = reinterpret_cast<DebugApp*>(arg)->PumpMainLoop();
+ if (IsCancelled(status)) {
+ return;
+ } else if (!status.ok()) {
+ CHECK_OK(status);
+ }
+}
+
+DebugApp::DebugApp(SDL_Window* window, SDL_GLContext gl_context,
+ const char* glsl_version)
+ : window_(window), gl_context_(gl_context) {
+ VLOG(1) << "DebugApp initializing...";
+ IMGUI_CHECKVERSION();
+ ImGui::CreateContext();
+ ImGuiIO& io = ImGui::GetIO();
+ io.ConfigFlags |= ImGuiConfigFlags_NavEnableKeyboard;
+ io.ConfigFlags |= ImGuiConfigFlags_DockingEnable;
+
+ // TODO(benvanik): ini file for settings.
+ io.IniFilename = nullptr;
+ // ImGui::LoadIniSettingsFromMemory()
+ // ImGui::SaveIniSettingsToMemory()
+
+ // TODO(benvanik): theming.
+ ImGui::StyleColorsDark();
+
+ // Setup Platform/Renderer bindings
+ ImGui_ImplSDL2_InitForOpenGL(window_, gl_context_);
+ ImGui_ImplOpenGL3_Init(glsl_version);
+ SDL_GL_MakeCurrent(nullptr, nullptr);
+ VLOG(1) << "DebugApp initialized";
+}
+
+DebugApp::~DebugApp() {
+ VLOG(1) << "DebugApp shutting down...";
+ ImGui_ImplOpenGL3_Shutdown();
+ ImGui_ImplSDL2_Shutdown();
+ ImGui::DestroyContext();
+
+ SDL_GL_DeleteContext(gl_context_);
+ SDL_GL_MakeCurrent(nullptr, nullptr);
+ SDL_DestroyWindow(window_);
+ SDL_Quit();
+ VLOG(1) << "DebugApp shut down (SDL_Quit)";
+}
+
+Status DebugApp::Connect(absl::string_view service_address) {
+ VLOG(1) << "Connecting to debug service at " << service_address << "...";
+ ASSIGN_OR_RETURN(debug_client_, DebugClient::Connect(service_address, this));
+
+ // TODO(benvanik): load breakpoints from file.
+ UserBreakpoint user_breakpoint;
+ user_breakpoint.module_name = "module";
+ user_breakpoint.function_name = "main";
+ user_breakpoint.bytecode_offset = 0;
+ user_breakpoint.wants_enabled = true;
+ user_breakpoint_list_.push_back(std::move(user_breakpoint));
+ RETURN_IF_ERROR(RefreshActiveBreakpoints());
+
+ // Set paused so that we need to resume to continue execution.
+ is_paused_ = true;
+ return OkStatus();
+}
+
+Status DebugApp::Disconnect() {
+ VLOG(1) << "Disconnecting from debug service";
+ debug_client_.reset();
+ return OkStatus();
+}
+
+bool DebugApp::is_paused() const {
+ if (!debug_client_) {
+ return false;
+ }
+ if (!hit_breakpoints_.empty()) {
+ return true; // One or more breakpoints hit.
+ }
+ return is_paused_ || !is_stepping_;
+}
+
+RemoteInvocation* DebugApp::GetSelectedInvocation() const {
+ if (!debug_client_ || !selected_invocation_id_.has_value()) {
+ return nullptr;
+ }
+ for (auto* invocation : debug_client_->invocations()) {
+ if (invocation->id() == selected_invocation_id_.value()) {
+ return invocation;
+ }
+ }
+ return nullptr;
+}
+
+Status DebugApp::RefreshActiveBreakpoints() {
+ // Set all breakpoints to disabled. We'll re-enable them as we find them
+ // below.
+ for (auto& user_breakpoint : user_breakpoint_list_) {
+ user_breakpoint.active_breakpoint = nullptr;
+ }
+
+ // If not connected then no breakpoints are active.
+ if (!debug_client_) {
+ return OkStatus();
+ }
+
+ // Reconcile the user breakpoint list with the breakpoints available on the
+ // server.
+ for (auto* breakpoint : debug_client_->breakpoints()) {
+ auto it =
+ std::find_if(user_breakpoint_list_.begin(), user_breakpoint_list_.end(),
+ [breakpoint](const UserBreakpoint& user_breakpoint) {
+ return AreBreakpointsEqual(*breakpoint, user_breakpoint);
+ });
+ if (it == user_breakpoint_list_.end()) {
+ // Breakpoint not found - add to user list.
+ UserBreakpoint user_breakpoint;
+ user_breakpoint.type = breakpoint->type();
+ user_breakpoint.active_breakpoint = breakpoint;
+ user_breakpoint.module_name = breakpoint->module_name();
+ user_breakpoint.function_name = breakpoint->function_name();
+ user_breakpoint.function_ordinal = breakpoint->function_ordinal();
+ user_breakpoint.bytecode_offset = breakpoint->bytecode_offset();
+ user_breakpoint_list_.push_back(std::move(user_breakpoint));
+ } else {
+ // Breakpoint found - set the active pointer.
+ UserBreakpoint& user_breakpoint = *it;
+ user_breakpoint.active_breakpoint = breakpoint;
+ user_breakpoint.is_enabling = false;
+ user_breakpoint.module_name = breakpoint->module_name();
+ user_breakpoint.function_name = breakpoint->function_name();
+ user_breakpoint.function_ordinal = breakpoint->function_ordinal();
+ user_breakpoint.bytecode_offset = breakpoint->bytecode_offset();
+ }
+ }
+
+ // Ensure any breakpoint the user wants enabled is active/otherwise.
+ for (auto& user_breakpoint : user_breakpoint_list_) {
+ if (user_breakpoint.wants_enabled && !user_breakpoint.is_enabling &&
+ !user_breakpoint.active_breakpoint) {
+ // Add breakpoint on server.
+ switch (user_breakpoint.type) {
+ case RemoteBreakpoint::Type::kBytecodeFunction:
+ RETURN_IF_ERROR(debug_client_->AddFunctionBreakpoint(
+ user_breakpoint.module_name, user_breakpoint.function_name,
+ user_breakpoint.bytecode_offset,
+ [&user_breakpoint](const RemoteBreakpoint& breakpoint) {
+ user_breakpoint.function_ordinal =
+ breakpoint.function_ordinal();
+ }));
+ break;
+ case RemoteBreakpoint::Type::kNativeFunction:
+ // TODO(benvanik): native breakpoint support.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Native function breakpoints are TODO";
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented breakpoint type";
+ }
+ user_breakpoint.is_enabling = true;
+ } else if (!user_breakpoint.wants_enabled &&
+ user_breakpoint.active_breakpoint) {
+ // Remove breakpoint from server.
+ RETURN_IF_ERROR(
+ debug_client_->RemoveBreakpoint(*user_breakpoint.active_breakpoint));
+
+ user_breakpoint.active_breakpoint = nullptr;
+ }
+ }
+
+ return OkStatus();
+}
+
+bool DebugApp::IsStoppedAtBreakpoint(
+ const UserBreakpoint& user_breakpoint) const {
+ return std::find(hit_breakpoints_.begin(), hit_breakpoints_.end(),
+ user_breakpoint.active_breakpoint) != hit_breakpoints_.end();
+}
+
+int DebugApp::FindMatchingUserBreakpointIndex(absl::string_view module_name,
+ int function_ordinal,
+ int offset) {
+ for (int i = 0; i < user_breakpoint_list_.size(); ++i) {
+ auto& user_breakpoint = user_breakpoint_list_[i];
+ if (user_breakpoint.module_name == module_name &&
+ user_breakpoint.function_ordinal == function_ordinal &&
+ user_breakpoint.bytecode_offset == offset) {
+ return i;
+ }
+ }
+ return -1;
+}
+
+int DebugApp::FindMatchingUserBreakpointIndex(absl::string_view module_name,
+ absl::string_view function_name,
+ int offset) {
+ for (int i = 0; i < user_breakpoint_list_.size(); ++i) {
+ auto& user_breakpoint = user_breakpoint_list_[i];
+ if (user_breakpoint.module_name == module_name &&
+ user_breakpoint.function_name == function_name &&
+ user_breakpoint.bytecode_offset == offset) {
+ return i;
+ }
+ }
+ return -1;
+}
+
+Status DebugApp::ResumeFromBreakpoint(UserBreakpoint* user_breakpoint) {
+ if (!user_breakpoint->active_breakpoint) {
+ return FailedPreconditionErrorBuilder(IREE_LOC) << "Breakpoint not active";
+ }
+ VLOG(1) << "Resuming from breakpoint "
+ << user_breakpoint->active_breakpoint->id() << "...";
+ auto it = std::find(hit_breakpoints_.begin(), hit_breakpoints_.end(),
+ user_breakpoint->active_breakpoint);
+ if (it == hit_breakpoints_.end()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "Breakpoint not found";
+ }
+ hit_breakpoints_.erase(it);
+ return debug_client_->MakeReady();
+}
+
+Status DebugApp::OnContextRegistered(const RemoteContext& context) {
+ // Ack event.
+ return debug_client_->MakeReady();
+}
+
+Status DebugApp::OnContextUnregistered(const RemoteContext& context) {
+ // Close documents that may reference modules in the context.
+ std::vector<CodeViewDocument*> closing_documents;
+ for (auto& document : documents_) {
+ auto* module = document->function->module();
+ if (module->context_id() != context.id()) {
+ // Document is not from this context so it's fine.
+ continue;
+ }
+
+ // See if any other live context still has the module loaded. We can change
+ // the document over to that.
+ RemoteModule* replacement_module = nullptr;
+ for (auto* context : debug_client_->contexts()) {
+ for (auto* other_module : context->modules()) {
+ if (other_module->name() == module->name()) {
+ replacement_module = other_module;
+ break;
+ }
+ }
+ if (replacement_module) break;
+ }
+ if (replacement_module && replacement_module->is_loaded()) {
+ // Replace document module reference.
+ int function_ordinal = document->function->ordinal();
+ auto functions = replacement_module->functions();
+ if (function_ordinal < functions.size()) {
+ document->function = functions[function_ordinal];
+ } else {
+ document->function = nullptr;
+ }
+ } else {
+ document->function = nullptr;
+ }
+
+ if (!document->function) {
+ // Close the document if we don't have a valid function for it.
+ VLOG(1)
+ << "Closing document " << document->title
+ << " because the last context using the module is being unregistered";
+ closing_documents.push_back(document.get());
+ }
+ }
+ for (auto* document : closing_documents) {
+ auto it = std::find_if(
+ documents_.begin(), documents_.end(),
+ [document](const std::unique_ptr<CodeViewDocument>& open_document) {
+ return document == open_document.get();
+ });
+ documents_.erase(it);
+ }
+
+ // Ack event.
+ return debug_client_->MakeReady();
+}
+
+Status DebugApp::OnModuleLoaded(const RemoteContext& context,
+ const RemoteModule& module) {
+ // Ack event.
+ return debug_client_->MakeReady();
+}
+
+Status DebugApp::OnInvocationRegistered(const RemoteInvocation& invocation) {
+ if (!selected_invocation_id_.has_value()) {
+ selected_invocation_id_ = invocation.id();
+ selected_stack_frame_index_ = {};
+ }
+
+ // Ack event.
+ return debug_client_->MakeReady();
+}
+
+Status DebugApp::OnInvocationUnregistered(const RemoteInvocation& invocation) {
+ if (selected_invocation_id_.has_value() &&
+ selected_invocation_id_.value() == invocation.id()) {
+ selected_invocation_id_ = {};
+ selected_stack_frame_index_ = {};
+ }
+
+ // Ack event.
+ return debug_client_->MakeReady();
+}
+
+Status DebugApp::OnBreakpointHit(const RemoteBreakpoint& breakpoint,
+ const RemoteInvocation& invocation) {
+ // Keep track of where we are stopped.
+ hit_breakpoints_.push_back(&breakpoint);
+ return NavigateToCodeView(invocation, -1, NavigationMode::kMatchDocument);
+}
+
+Status DebugApp::PumpMainLoop() {
+ ImGuiIO& io = ImGui::GetIO();
+
+ if (debug_client_) {
+ RETURN_IF_ERROR(debug_client_->Poll());
+ }
+ RETURN_IF_ERROR(RefreshActiveBreakpoints());
+
+ SDL_GL_MakeCurrent(window_, gl_context_);
+
+ SDL_Event event;
+ while (SDL_PollEvent(&event)) {
+ ImGui_ImplSDL2_ProcessEvent(&event);
+ if (event.type == SDL_QUIT) {
+ return CancelledErrorBuilder(IREE_LOC) << "Quit hotkey";
+ } else if (event.type == SDL_WINDOWEVENT &&
+ event.window.event == SDL_WINDOWEVENT_CLOSE &&
+ event.window.windowID == SDL_GetWindowID(window_)) {
+ return CancelledErrorBuilder(IREE_LOC) << "Window closed";
+ }
+ }
+ ImGui_ImplOpenGL3_NewFrame();
+ ImGui_ImplSDL2_NewFrame(window_);
+ ImGui::NewFrame();
+
+ auto draw_status = DrawUI();
+ if (!draw_status.ok()) {
+ // TODO(benvanik): show on screen? Probably all messed up.
+ LOG(ERROR) << draw_status;
+ }
+
+ // Blit the entire ImGui UI.
+ ImGui::Render();
+ SDL_GL_MakeCurrent(window_, gl_context_);
+ glViewport(0, 0, (int)io.DisplaySize.x, (int)io.DisplaySize.y);
+ glClearColor(0.45f, 0.55f, 0.60f, 1.0f);
+ glClear(GL_COLOR_BUFFER_BIT);
+ // Workaround for terrible bad SDL/graphics driver leaks.
+ IREE_DISABLE_LEAK_CHECKS();
+ ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData());
+ IREE_ENABLE_LEAK_CHECKS();
+
+ // Render additional viewport windows (desktop only).
+ if (io.ConfigFlags & ImGuiConfigFlags_ViewportsEnable) {
+ SDL_Window* backup_current_window = SDL_GL_GetCurrentWindow();
+ SDL_GLContext backup_current_context = SDL_GL_GetCurrentContext();
+ ImGui::UpdatePlatformWindows();
+ ImGui::RenderPlatformWindowsDefault();
+ SDL_GL_MakeCurrent(backup_current_window, backup_current_context);
+ }
+
+ SDL_GL_SwapWindow(window_);
+ return OkStatus();
+}
+
+Status DebugApp::LayoutInitialDockSpace() {
+ dockspace_id_ = ImGui::GetID("MainDockSpace");
+ if (ImGui::DockBuilderGetNode(dockspace_id_)) {
+ // Already configured.
+ return OkStatus();
+ }
+ ImGui::DockBuilderAddNode(dockspace_id_, ImGuiDockNodeFlags_DockSpace);
+
+ dock_content_id_ = dockspace_id_;
+ dock_top_id_ = ImGui::DockBuilderSplitNode(dock_content_id_, ImGuiDir_Up,
+ 0.05f, nullptr, &dock_content_id_);
+ dock_left_id_ = ImGui::DockBuilderSplitNode(
+ dock_content_id_, ImGuiDir_Left, 0.20f, nullptr, &dock_content_id_);
+ dock_bottom_id_ = ImGui::DockBuilderSplitNode(
+ dock_content_id_, ImGuiDir_Down, 0.20f, nullptr, &dock_content_id_);
+ dock_right_id_ = ImGui::DockBuilderSplitNode(
+ dock_content_id_, ImGuiDir_Right, 0.20f, nullptr, &dock_content_id_);
+ dock_bottom_left_id_ = ImGui::DockBuilderSplitNode(
+ dock_bottom_id_, ImGuiDir_Left, 0.50f, nullptr, &dock_bottom_right_id_);
+
+ ImGui::DockBuilderDockWindow("Toolbar", dock_top_id_);
+ auto* dock_top_node = ImGui::DockBuilderGetNode(dock_top_id_);
+ dock_top_node->LocalFlags = ImGuiDockNodeFlags_NoSplit |
+ ImGuiDockNodeFlags_NoResize |
+ ImGuiDockNodeFlags_AutoHideTabBar;
+
+ ImGui::DockBuilderDockWindow("Modules", dock_left_id_);
+ ImGui::DockBuilderDockWindow("Locals", dock_bottom_left_id_);
+ ImGui::DockBuilderDockWindow("Invocations", dock_bottom_right_id_);
+ ImGui::DockBuilderDockWindow("Breakpoints", dock_bottom_right_id_);
+
+ ImGui::DockBuilderFinish(dockspace_id_);
+ return OkStatus();
+}
+
+Status DebugApp::DrawUI() {
+ ImGuiWindowFlags window_flags =
+ ImGuiWindowFlags_MenuBar | ImGuiWindowFlags_NoDocking;
+ window_flags |= ImGuiWindowFlags_NoTitleBar | ImGuiWindowFlags_NoCollapse |
+ ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoMove |
+ ImGuiWindowFlags_NoNavFocus;
+
+ ImGuiViewport* viewport = ImGui::GetMainViewport();
+ ImGui::SetNextWindowPos(viewport->Pos);
+ ImGui::SetNextWindowSize(viewport->Size);
+ ImGui::SetNextWindowViewport(viewport->ID);
+ ImGui::PushStyleVar(ImGuiStyleVar_WindowRounding, 0.0f);
+ ImGui::PushStyleVar(ImGuiStyleVar_WindowBorderSize, 0.0f);
+ ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(0.0f, 0.0f));
+ ImGui::Begin("IREEDebugRoot", nullptr, window_flags);
+ ImGui::PopStyleVar(3);
+
+ RETURN_IF_ERROR(LayoutInitialDockSpace());
+ ImGui::DockSpace(dockspace_id_, ImVec2(0.0f, 0.0f), ImGuiDockNodeFlags_None);
+
+ RETURN_IF_ERROR(DrawMainMenu());
+ RETURN_IF_ERROR(DrawToolbar());
+
+ ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(2, 2));
+ RETURN_IF_ERROR(DrawBreakpointListPanel());
+ RETURN_IF_ERROR(DrawModuleListPanel());
+ RETURN_IF_ERROR(DrawLocalListPanel());
+ RETURN_IF_ERROR(DrawInvocationListPanel());
+ ImGui::PopStyleVar();
+
+ RETURN_IF_ERROR(DrawCodeViewPanels());
+
+ ImGui::End();
+ return OkStatus();
+}
+
+Status DebugApp::DrawMainMenu() {
+ if (!ImGui::BeginMenuBar()) return OkStatus();
+
+ // TODO(benvanik): main menu.
+ if (ImGui::BeginMenu("File")) {
+ ImGui::EndMenu();
+ }
+
+ ImGui::EndMenuBar();
+ return OkStatus();
+}
+
+Status DebugApp::DrawToolbar() {
+ // TODO(benvanik): figure out how to make this not grow.
+ ImGui::Begin("Toolbar", nullptr,
+ ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoTitleBar |
+ ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoCollapse |
+ ImGuiWindowFlags_NoScrollbar);
+ ImGui::BeginGroup();
+
+#if !defined(IMGUI_DISABLE_DEMO_WINDOWS)
+ static bool show_demo_window = false;
+ if (ImGui::Button("Demo")) {
+ show_demo_window = !show_demo_window;
+ }
+ if (show_demo_window) {
+ ImGui::SetNextWindowDockID(dock_content_id_);
+ ImGui::ShowDemoWindow(&show_demo_window);
+ }
+#endif // !IMGUI_DISABLE_DEMO_WINDOWS
+
+ ImGui::SameLine();
+ if (!debug_client_) {
+ if (ImGui::Button("Connect")) {
+ // TODO(benvanik): connection dialog and/or autoconnect.
+ }
+ } else {
+ if (ImGui::Button("Disconnect")) {
+ debug_client_.reset();
+ }
+ }
+
+ ImGui::SameLine();
+ if (debug_client_) {
+ ImGui::Text("<status>");
+ } else {
+ ImGui::TextDisabled("disconnected");
+ }
+
+ ImGui::SameLine();
+ ImGui::Spacing();
+ ImGui::SameLine();
+ ImGui::Spacing();
+
+ ImGui::SameLine();
+ ImGui::BeginGroup();
+ ImGui::Text("Invocation: ");
+ ImGui::SameLine();
+ ImGui::SetNextItemWidth(300);
+ auto* selected_invocation = GetSelectedInvocation();
+ const std::string& active_invocation_name =
+ selected_invocation ? selected_invocation->name() : "";
+ if (ImGui::BeginCombo("##active_invocation", active_invocation_name.c_str(),
+ ImGuiComboFlags_PopupAlignLeft)) {
+ if (debug_client_) {
+ for (auto* invocation : debug_client_->invocations()) {
+ ImGui::PushID(invocation->id());
+ bool is_selected = invocation == selected_invocation;
+ if (ImGui::Selectable(invocation->name().c_str(), is_selected)) {
+ RETURN_IF_ERROR(NavigateToCodeView(*invocation, -1,
+ NavigationMode::kMatchDocument));
+ }
+ if (is_selected) {
+ ImGui::SetItemDefaultFocus();
+ }
+ ImGui::PopID();
+ }
+ }
+ ImGui::EndCombo();
+ }
+ ImGui::EndGroup();
+
+ ImGui::SameLine();
+ ImGui::BeginGroup();
+ static const float kPauseButtonHue = 0.0f;
+ static const float kResumeButtonHue = 2.0f;
+ static const float kStepButtonHue = 1.0f;
+ if (debug_client_ && !is_paused()) {
+ PushButtonHue(kPauseButtonHue);
+ if (ImGui::Button("Pause")) {
+ RETURN_IF_ERROR(debug_client_->SuspendAllInvocations());
+ }
+ PopButtonStyle();
+ } else if (debug_client_ && is_paused()) {
+ ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666);
+ ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA);
+ ImGui::ButtonEx("Pause", {}, ImGuiButtonFlags_Disabled);
+ ImGui::PopStyleColor(2);
+ }
+ if (debug_client_ && is_paused()) {
+ ImGui::SameLine();
+ PushButtonHue(kResumeButtonHue);
+ if (ImGui::Button("Resume")) {
+ if (is_paused_) {
+ is_paused_ = false;
+ RETURN_IF_ERROR(debug_client_->MakeReady());
+ }
+ while (!hit_breakpoints_.empty()) {
+ hit_breakpoints_.pop_back();
+ RETURN_IF_ERROR(debug_client_->MakeReady());
+ }
+ }
+ PopButtonStyle();
+ } else {
+ ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666);
+ ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA);
+ ImGui::SameLine();
+ ImGui::ButtonEx("Resume", {}, ImGuiButtonFlags_Disabled);
+ ImGui::PopStyleColor(2);
+ }
+
+ if (debug_client_ && is_paused() && selected_invocation) {
+ ImGui::SameLine();
+ PushButtonHue(kStepButtonHue);
+ if (ImGui::Button("Step Into")) {
+ RETURN_IF_ERROR(
+ debug_client_->StepInvocation(*selected_invocation, [this]() {
+ is_paused_ = true;
+ is_stepping_ = false;
+ }));
+ is_stepping_ = true;
+ }
+ PopButtonStyle();
+ ImGui::SameLine();
+ if (ImGui::Button("Step Over")) {
+ RETURN_IF_ERROR(
+ debug_client_->StepInvocationOver(*selected_invocation, [this]() {
+ is_paused_ = true;
+ is_stepping_ = false;
+ }));
+ is_stepping_ = true;
+ }
+ ImGui::SameLine();
+ if (ImGui::Button("Step Out")) {
+ RETURN_IF_ERROR(
+ debug_client_->StepInvocationOut(*selected_invocation, [this]() {
+ is_paused_ = true;
+ is_stepping_ = false;
+ }));
+ is_stepping_ = true;
+ }
+ if (ImGui::BeginPopup("Step to...")) {
+ // TODO(benvanik): step to Invoke exit, next FFI call, etc
+ ImGui::MenuItem("(stuff)");
+ ImGui::EndPopup();
+ }
+ ImGui::SameLine();
+ if (ImGui::Button("Step to...")) {
+ ImGui::OpenPopup("Step to...");
+ }
+ } else {
+ ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666);
+ ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA);
+ ImGui::SameLine();
+ ImGui::ButtonEx("Step Into", {}, ImGuiButtonFlags_Disabled);
+ ImGui::SameLine();
+ ImGui::ButtonEx("Step Over", {}, ImGuiButtonFlags_Disabled);
+ ImGui::SameLine();
+ ImGui::ButtonEx("Step Out", {}, ImGuiButtonFlags_Disabled);
+ ImGui::SameLine();
+ ImGui::ButtonEx("Step to...", {}, ImGuiButtonFlags_Disabled);
+ ImGui::PopStyleColor(2);
+ }
+ ImGui::EndGroup();
+
+ ImGui::EndGroup();
+ ImGui::End();
+ return OkStatus();
+}
+
+Status DebugApp::DrawBreakpointListPanel() {
+ static bool is_panel_visible = true;
+ if (!ImGui::Begin("Breakpoints", &is_panel_visible, ImGuiWindowFlags_None)) {
+ ImGui::End();
+ return OkStatus();
+ }
+
+ ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(8, 8));
+ absl::optional<RemoteBreakpoint::Type> add_breakpoint_type;
+ if (ImGui::BeginPopup("+ Function")) {
+ if (ImGui::MenuItem("Bytecode Function")) {
+ add_breakpoint_type = RemoteBreakpoint::Type::kBytecodeFunction;
+ }
+ if (ImGui::MenuItem("Native Function")) {
+ add_breakpoint_type = RemoteBreakpoint::Type::kNativeFunction;
+ }
+ ImGui::EndPopup();
+ }
+ ImGui::PopStyleVar();
+ if (ImGui::Button("+ Function")) {
+ ImGui::OpenPopup("+ Function");
+ }
+ RETURN_IF_ERROR(DrawAddBreakpointDialogs(add_breakpoint_type));
+
+ ImGui::SameLine();
+ if (ImGui::Button("Remove All")) {
+ // TODO(benvanik): removal all is broken - need removebreakpoints or a
+ // 'want_removal' flag so that RefreshActiveBreakpoints handles things.
+ // Right now if you have 2 breakpoints and hit remove all the second will
+ // come back during the next refresh (as the server hasn't removed it yet).
+ for (auto& user_breakpoint : user_breakpoint_list_) {
+ if (user_breakpoint.active_breakpoint) {
+ RETURN_IF_ERROR(debug_client_->RemoveBreakpoint(
+ *user_breakpoint.active_breakpoint));
+ user_breakpoint.active_breakpoint = nullptr;
+ }
+ }
+ user_breakpoint_list_.clear();
+ }
+ ImGui::Separator();
+
+ ImGui::BeginChild("BreakpointList", ImVec2(-1, -1), false,
+ ImGuiWindowFlags_AlwaysVerticalScrollbar);
+ std::vector<UserBreakpoint*> dead_breakpoints;
+ for (auto& user_breakpoint : user_breakpoint_list_) {
+ ASSIGN_OR_RETURN(bool should_keep, DrawBreakpoint(&user_breakpoint));
+ if (!should_keep) {
+ dead_breakpoints.push_back(&user_breakpoint);
+ }
+ }
+ for (auto* user_breakpoint : dead_breakpoints) {
+ for (auto it = user_breakpoint_list_.begin();
+ it != user_breakpoint_list_.end(); ++it) {
+ if (&*it == user_breakpoint) {
+ if (user_breakpoint->active_breakpoint) {
+ RETURN_IF_ERROR(debug_client_->RemoveBreakpoint(
+ *user_breakpoint->active_breakpoint));
+ }
+ user_breakpoint_list_.erase(it);
+ break;
+ }
+ }
+ }
+ ImGui::EndChild();
+
+ ImGui::End();
+ return OkStatus();
+}
+
+StatusOr<bool> DebugApp::DrawBreakpoint(UserBreakpoint* user_breakpoint) {
+ std::string breakpoint_name;
+ switch (user_breakpoint->type) {
+ case RemoteBreakpoint::Type::kBytecodeFunction:
+ breakpoint_name =
+ absl::StrCat("[bytecode] ", user_breakpoint->module_name, ":",
+ user_breakpoint->function_name, ":",
+ user_breakpoint->bytecode_offset);
+ if (user_breakpoint->function_ordinal != -1) {
+ absl::StrAppend(&breakpoint_name, " @",
+ user_breakpoint->function_ordinal);
+ }
+ break;
+ case RemoteBreakpoint::Type::kNativeFunction:
+ breakpoint_name =
+ absl::StrCat("[native ] ", user_breakpoint->native_function);
+ break;
+ }
+ ImGui::BeginGroup();
+ bool is_closing = true;
+ bool is_expanded = ImGui::CollapsingHeader(
+ ("##" + breakpoint_name).c_str(), &is_closing,
+ ImGuiTreeNodeFlags_Framed | ImGuiTreeNodeFlags_NoTreePushOnOpen |
+ ImGuiTreeNodeFlags_NoAutoOpenOnLog | ImGuiTreeNodeFlags_OpenOnArrow |
+ ImGuiTreeNodeFlags_OpenOnDoubleClick);
+ ImGui::SameLine();
+ ImGui::Checkbox(breakpoint_name.c_str(), &user_breakpoint->wants_enabled);
+ ImGui::EndGroup();
+ if (!is_expanded) {
+ return is_closing;
+ }
+ ImGui::PushID(breakpoint_name.c_str());
+
+ ImGui::Text("(breakpoint stats/etc)");
+
+ ImGui::PopID();
+ return is_closing;
+}
+
+Status DebugApp::DrawAddBreakpointDialogs(
+ absl::optional<RemoteBreakpoint::Type> add_breakpoint_type) {
+ if (add_breakpoint_type.has_value()) {
+ switch (add_breakpoint_type.value()) {
+ case RemoteBreakpoint::Type::kBytecodeFunction:
+ ImGui::OpenPopup("Add Bytecode Function Breakpoint");
+ break;
+ case RemoteBreakpoint::Type::kNativeFunction:
+ ImGui::OpenPopup("Add Native Function Breakpoint");
+ break;
+ }
+ }
+ ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(8, 8));
+ RETURN_IF_ERROR(DrawAddBytecodeFunctionBreakpointDialog());
+ RETURN_IF_ERROR(DrawAddNativeFunctionBreakpointDialog());
+ ImGui::PopStyleVar();
+ return OkStatus();
+}
+
+Status DebugApp::DrawAddBytecodeFunctionBreakpointDialog() {
+ ImGui::SetNextWindowSize(ImVec2(400, 400), ImGuiCond_FirstUseEver);
+ bool close_popup = true;
+ if (!ImGui::BeginPopupModal("Add Bytecode Function Breakpoint", &close_popup,
+ ImGuiWindowFlags_None)) {
+ return OkStatus();
+ }
+ ImGui::BeginGroup();
+ ImGui::BeginChild("##data_entry",
+ ImVec2(0, -ImGui::GetFrameHeightWithSpacing()));
+
+ ImGui::TextWrapped(
+ "Adds a breakpoint set on the entry of the function (offset=0).");
+ ImGui::Separator();
+
+ // TODO(benvanik): fancy list, filtering, etc.
+
+ static char module_name[256] = {0};
+ ImGui::InputText("Module", module_name, sizeof(module_name));
+ ImGui::SetItemDefaultFocus();
+
+ static char function_name[256] = {0};
+ ImGui::InputText("Function", function_name, sizeof(function_name));
+
+ ImGui::EndChild();
+ ImGui::Separator();
+
+ if (ImGui::Button("Add")) {
+ int offset = 0;
+ if (FindMatchingUserBreakpointIndex(module_name, function_name, offset) ==
+ -1) {
+ UserBreakpoint user_breakpoint;
+ user_breakpoint.type = RemoteBreakpoint::Type::kBytecodeFunction;
+ user_breakpoint.module_name = module_name;
+ user_breakpoint.function_name = function_name;
+ user_breakpoint.bytecode_offset = offset;
+ user_breakpoint.wants_enabled = true;
+ user_breakpoint_list_.push_back(std::move(user_breakpoint));
+ }
+ ImGui::CloseCurrentPopup();
+ }
+ ImGui::SameLine();
+ if (ImGui::Button("Cancel")) {
+ ImGui::CloseCurrentPopup();
+ }
+
+ ImGui::EndGroup();
+ ImGui::EndPopup();
+ return OkStatus();
+}
+
+Status DebugApp::DrawAddNativeFunctionBreakpointDialog() {
+ ImGui::SetNextWindowSize(ImVec2(400, 400), ImGuiCond_FirstUseEver);
+ bool close_popup = true;
+ if (!ImGui::BeginPopupModal("Add Native Function Breakpoint", &close_popup,
+ ImGuiWindowFlags_None)) {
+ return OkStatus();
+ }
+ ImGui::BeginGroup();
+ ImGui::BeginChild("##data_entry",
+ ImVec2(0, -ImGui::GetFrameHeightWithSpacing()));
+
+ ImGui::TextWrapped(
+ "Adds a breakpoint set on any call to the given FFI imported "
+ "function.");
+ ImGui::Separator();
+
+ static char function_name[256] = {0};
+ ImGui::InputText("Function", function_name, sizeof(function_name));
+ ImGui::SetItemDefaultFocus();
+
+ ImGui::EndChild();
+ ImGui::Separator();
+
+ if (ImGui::Button("Add")) {
+ UserBreakpoint user_breakpoint;
+ user_breakpoint.type = RemoteBreakpoint::Type::kNativeFunction;
+ user_breakpoint.native_function = function_name;
+ user_breakpoint.wants_enabled = true;
+ user_breakpoint_list_.push_back(std::move(user_breakpoint));
+ ImGui::CloseCurrentPopup();
+ }
+ ImGui::SameLine();
+ if (ImGui::Button("Cancel")) {
+ ImGui::CloseCurrentPopup();
+ }
+
+ ImGui::EndGroup();
+ ImGui::EndPopup();
+ return OkStatus();
+}
+
+Status DebugApp::DrawModuleListPanel() {
+ static bool is_panel_visible = true;
+ if (!ImGui::Begin("Modules", &is_panel_visible, ImGuiWindowFlags_None)) {
+ ImGui::End();
+ return OkStatus();
+ } else if (!debug_client_) {
+ ImGui::TextDisabled("disconnected");
+ ImGui::End();
+ return OkStatus();
+ }
+ ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(4, 4));
+
+ ImGui::BeginGroup();
+ ImGui::SetNextItemWidth(ImGui::GetContentRegionAvailWidth());
+ static char function_name_filter_text[256] = {0};
+ ImGui::InputTextWithHint(
+ "##function_name_filter", "Filter functions", function_name_filter_text,
+ sizeof(function_name_filter_text), ImGuiInputTextFlags_AutoSelectAll);
+ ImGuiTextFilter function_name_filter(function_name_filter_text);
+ ImGui::EndGroup();
+
+ ImGui::Separator();
+
+ ImGui::BeginGroup();
+ ImGui::BeginChild("##context_list", ImVec2(0, -ImGui::GetFrameHeight()));
+ for (auto* context : debug_client_->contexts()) {
+ RETURN_IF_ERROR(DrawContext(*context, function_name_filter));
+ }
+ ImGui::EndChild();
+ ImGui::EndGroup();
+
+ ImGui::PopStyleVar();
+ ImGui::End();
+ return OkStatus();
+}
+
+Status DebugApp::DrawContext(const RemoteContext& context,
+ const ImGuiTextFilter& filter) {
+ std::string context_name = absl::StrCat("Context ", context.id());
+ if (!ImGui::CollapsingHeader(context_name.c_str(), nullptr,
+ ImGuiTreeNodeFlags_DefaultOpen |
+ ImGuiTreeNodeFlags_Framed |
+ ImGuiTreeNodeFlags_NoTreePushOnOpen |
+ ImGuiTreeNodeFlags_NoAutoOpenOnLog |
+ ImGuiTreeNodeFlags_OpenOnArrow |
+ ImGuiTreeNodeFlags_OpenOnDoubleClick)) {
+ return OkStatus();
+ }
+ ImGui::PushID(context.id());
+ for (auto* module : context.modules()) {
+ RETURN_IF_ERROR(DrawModule(module, filter));
+ }
+ ImGui::PopID();
+ return OkStatus();
+}
+
+Status DebugApp::DrawModule(RemoteModule* module,
+ const ImGuiTextFilter& filter) {
+ ImGui::PushID(module->name().c_str());
+ if (ImGui::TreeNodeEx(module->name().c_str(),
+ ImGuiTreeNodeFlags_Framed |
+ ImGuiTreeNodeFlags_DefaultOpen |
+ ImGuiTreeNodeFlags_OpenOnDoubleClick |
+ ImGuiTreeNodeFlags_OpenOnArrow)) {
+ if (module->CheckLoadedOrRequest()) {
+ for (auto* function : module->functions()) {
+ char function_name[128];
+ if (function->name().empty()) {
+ std::snprintf(function_name, sizeof(function_name), "@%d",
+ function->ordinal());
+ } else {
+ std::snprintf(function_name, sizeof(function_name), "@%d %s",
+ function->ordinal(), function->name().c_str());
+ }
+ if (filter.IsActive() && !filter.PassFilter(function_name)) {
+ continue;
+ }
+ ImGui::PushID(function->ordinal());
+ bool is_selected = false;
+ if (ImGui::Selectable("##selectable", &is_selected,
+ ImGuiSelectableFlags_AllowDoubleClick |
+ ImGuiSelectableFlags_DrawFillAvailWidth)) {
+ if (is_selected) {
+ RETURN_IF_ERROR(NavigateToCodeView(module->name(),
+ function->ordinal(), 0,
+ NavigationMode::kMatchDocument));
+ }
+ }
+ ImGui::SameLine();
+ // TODO(benvanik): detect if breakpoint active at offset 0.
+ ImGui::BulletText("%s", function_name);
+ ImGui::PopID();
+ }
+ } else {
+ ImGui::TextDisabled("Loading...");
+ }
+ ImGui::TreePop();
+ }
+ ImGui::PopID();
+ return OkStatus();
+}
+
+Status DebugApp::DrawLocalListPanel() {
+ static bool is_panel_visible = true;
+ if (!ImGui::Begin("Locals", &is_panel_visible, ImGuiWindowFlags_None)) {
+ ImGui::End();
+ return OkStatus();
+ } else if (!debug_client_) {
+ ImGui::TextDisabled("disconnected");
+ ImGui::End();
+ return OkStatus();
+ }
+ auto* invocation = GetSelectedInvocation();
+ if (!invocation) {
+ ImGui::TextDisabled("select a invocation to view locals");
+ ImGui::End();
+ return OkStatus();
+ } else if (invocation->def().frames.empty()) {
+ ImGui::TextDisabled("(invocation has no frames)");
+ ImGui::End();
+ return OkStatus();
+ }
+ int stack_frame_index = selected_stack_frame_index_.value_or(-1);
+ if (stack_frame_index == -1) {
+ stack_frame_index = invocation->def().frames.size() - 1;
+ }
+ auto& stack_frame = invocation->def().frames[stack_frame_index];
+
+ // TODO(benvanik): toggle for IREE VM locals vs. source locals.
+ for (int i = 0; i < stack_frame->locals.size(); ++i) {
+ auto& local = stack_frame->locals[i];
+ RETURN_IF_ERROR(DrawLocal(invocation, stack_frame_index, i, *local));
+ }
+
+ ImGui::End();
+ return OkStatus();
+}
+
+Status DebugApp::DrawLocal(RemoteInvocation* invocation, int stack_frame_index,
+ int local_index, const rpc::BufferViewDefT& local) {
+ // TODO(benvanik): columns and such in fancy table.
+ ImGui::Text("l%d", local_index);
+ ImGui::SameLine(50);
+ if (local.is_valid) {
+ auto shape_str =
+ absl::StrCat(absl::StrJoin(local.shape, "x"), "x", local.element_size);
+ ImGui::Text("%s", shape_str.c_str());
+ } else {
+ ImGui::TextDisabled("∅");
+ }
+ // TODO(benvanik): editing options (change shape, change contents, upload).
+ // TODO(benvanik): save/download/log options.
+ return OkStatus();
+}
+
+Status DebugApp::DrawInvocationListPanel() {
+ static bool is_panel_visible = true;
+ if (!ImGui::Begin("Invocations", &is_panel_visible, ImGuiWindowFlags_None)) {
+ ImGui::End();
+ return OkStatus();
+ } else if (!debug_client_) {
+ ImGui::TextDisabled("disconnected");
+ ImGui::End();
+ return OkStatus();
+ }
+ for (auto* invocation : debug_client_->invocations()) {
+ RETURN_IF_ERROR(DrawInvocation(*invocation));
+ }
+ ImGui::End();
+ return OkStatus();
+}
+
+Status DebugApp::DrawInvocation(const RemoteInvocation& invocation) {
+ // TODO(benvanik): expand if any breakpoints are stopped in invocation.
+ if (selected_invocation_id_.has_value() &&
+ selected_invocation_id_.value() == invocation.id()) {
+ ImGui::SetNextTreeNodeOpen(true);
+ }
+ if (!ImGui::CollapsingHeader(invocation.name().c_str())) {
+ return OkStatus();
+ }
+ ImGui::PushID(invocation.id());
+
+ for (int i = 0; i < invocation.def().frames.size(); ++i) {
+ const auto& stack_frame = invocation.def().frames[i];
+ ImGui::PushID(i);
+ // TODO(benvanik): highlight frames with breakpoints in them.
+ bool is_selected = selected_invocation_id_.has_value() &&
+ selected_invocation_id_.value() == invocation.id() &&
+ selected_stack_frame_index_.has_value() &&
+ selected_stack_frame_index_.value() == i;
+ if (ImGui::Selectable("##selectable", &is_selected,
+ ImGuiSelectableFlags_AllowDoubleClick |
+ ImGuiSelectableFlags_DrawFillAvailWidth)) {
+ // TODO(benvanik): detect when clicking but already selected.
+ if (is_selected) {
+ RETURN_IF_ERROR(
+ NavigateToCodeView(invocation, i, NavigationMode::kMatchDocument));
+ }
+ }
+ ImGui::SameLine();
+ ImGui::Bullet();
+ ImGui::SameLine();
+ // TODO(benvanik): better naming/etc (resolve function).
+ ImGui::Text("%s:%d:%d", stack_frame->module_name.c_str(),
+ stack_frame->function_ordinal, stack_frame->offset);
+
+ ImGui::PopID();
+ }
+
+ ImGui::PopID();
+ return OkStatus();
+}
+
+DebugApp::CodeViewDocument* DebugApp::FindMatchingDocument(
+ absl::string_view module_name, int function_ordinal) {
+ for (auto& document : documents_) {
+ if (document->function->module()->name() == module_name &&
+ document->function->ordinal() == function_ordinal) {
+ return document.get();
+ }
+ }
+ return nullptr;
+}
+
+Status DebugApp::NavigateToCodeView(absl::string_view module_name,
+ int function_ordinal, int offset,
+ NavigationMode navigation_mode) {
+ if (!debug_client_) {
+ return UnavailableErrorBuilder(IREE_LOC) << "No connection established";
+ }
+ VLOG(1) << "NavigateToCodeView(" << module_name << ", " << function_ordinal
+ << ", " << offset << ")";
+ CodeViewDocument* existing_document = nullptr;
+ switch (navigation_mode) {
+ case NavigationMode::kNewDocument:
+ // Fall through and create below.
+ break;
+ case NavigationMode::kCurrentDocument:
+ // Not yet done - treat as a new document.
+ break;
+ case NavigationMode::kMatchDocument:
+ existing_document = FindMatchingDocument(module_name, function_ordinal);
+ break;
+ }
+ if (existing_document) {
+ ImGui::SetWindowFocus(existing_document->title.c_str());
+ return OkStatus();
+ }
+
+ // TODO(benvanik): make this common code.
+ RETURN_IF_ERROR(debug_client_->GetFunction(
+ std::string(module_name), function_ordinal,
+ [this, offset](StatusOr<RemoteFunction*> function_or) {
+ if (!function_or.ok()) {
+ // TODO(benvanik): error dialog.
+ CHECK_OK(function_or.status());
+ }
+ auto* function = function_or.ValueOrDie();
+ auto document = absl::make_unique<CodeViewDocument>();
+ document->title =
+ absl::StrCat(function->module()->name(), ":", function->name());
+ document->function = function;
+ document->focus_offset = offset;
+ ImGui::SetWindowFocus(document->title.c_str());
+ documents_.push_back(std::move(document));
+ }));
+ return OkStatus();
+}
+
+Status DebugApp::NavigateToCodeView(absl::string_view module_name,
+ absl::string_view function_name, int offset,
+ NavigationMode navigation_mode) {
+ if (!debug_client_) {
+ return UnavailableErrorBuilder(IREE_LOC) << "No connection established";
+ }
+ return debug_client_->ResolveFunction(
+ std::string(module_name), std::string(function_name),
+ [this, navigation_mode, module_name,
+ offset](StatusOr<int> function_ordinal) {
+ CHECK_OK(function_ordinal.status());
+ CHECK_OK(NavigateToCodeView(module_name, function_ordinal.ValueOrDie(),
+ offset, navigation_mode));
+ });
+}
+
+Status DebugApp::NavigateToCodeView(const RemoteInvocation& invocation,
+ int stack_frame_index,
+ NavigationMode navigation_mode) {
+ if (!debug_client_) {
+ return UnavailableErrorBuilder(IREE_LOC) << "No connection established";
+ }
+ const auto& stack_frame = stack_frame_index == -1
+ ? *invocation.def().frames.back()
+ : *invocation.def().frames[stack_frame_index];
+ selected_invocation_id_ = invocation.id();
+ selected_stack_frame_index_ = stack_frame_index;
+ return NavigateToCodeView(stack_frame.module_name,
+ stack_frame.function_ordinal, stack_frame.offset,
+ NavigationMode::kMatchDocument);
+}
+
+Status DebugApp::NavigateToCodeView(const UserBreakpoint& user_breakpoint,
+ NavigationMode navigation_mode) {
+ if (!debug_client_) {
+ return UnavailableErrorBuilder(IREE_LOC) << "No connection established";
+ }
+ switch (user_breakpoint.type) {
+ case RemoteBreakpoint::Type::kBytecodeFunction:
+ if (user_breakpoint.function_ordinal != -1) {
+ return NavigateToCodeView(
+ user_breakpoint.module_name, user_breakpoint.function_ordinal,
+ user_breakpoint.bytecode_offset, navigation_mode);
+ } else {
+ return NavigateToCodeView(
+ user_breakpoint.module_name, user_breakpoint.function_name,
+ user_breakpoint.bytecode_offset, navigation_mode);
+ }
+ case RemoteBreakpoint::Type::kNativeFunction:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Navigation to non-bytecode functions unimplemented";
+ }
+}
+
+Status DebugApp::DrawCodeViewPanels() {
+ // If we've disconnected then we need to clear bodies.
+ // TODO(benvanik): allow documents to persist by caching all required info.
+ if (!debug_client_) {
+ documents_.clear();
+ return OkStatus();
+ }
+
+ std::vector<CodeViewDocument*> closing_documents;
+ for (auto& document : documents_) {
+ ASSIGN_OR_RETURN(bool is_open, DrawCodeViewDocument(document.get()));
+ if (!is_open) {
+ closing_documents.push_back(document.get());
+ }
+ }
+ for (auto* closing_document : closing_documents) {
+ auto it = std::find_if(
+ documents_.begin(), documents_.end(),
+ [closing_document](const std::unique_ptr<CodeViewDocument>& document) {
+ return document.get() == closing_document;
+ });
+ documents_.erase(it);
+ }
+ return OkStatus();
+}
+
+StatusOr<bool> DebugApp::DrawCodeViewDocument(CodeViewDocument* document) {
+ ImGui::SetNextWindowDockID(dockspace_id_, ImGuiCond_FirstUseEver);
+ ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(0, 0));
+ bool is_open = true;
+ bool is_visible =
+ ImGui::Begin(document->title.c_str(), &is_open, ImGuiWindowFlags_None);
+ if (!is_open || !is_visible) {
+ ImGui::End();
+ ImGui::PopStyleVar();
+ return is_open;
+ }
+ ImGui::PopStyleVar();
+
+ auto* remote_module = document->function->module();
+ auto* remote_function = document->function;
+ if (remote_module->CheckLoadedOrRequest() &&
+ remote_function->CheckLoadedOrRequest()) {
+ // TODO(benvanik): draw function signature.
+ if (remote_function->bytecode()) {
+ RETURN_IF_ERROR(DrawBytecodeCodeView(document));
+ } else {
+ // TODO(benvanik): display native registration info.
+ ImGui::TextDisabled("(native)");
+ }
+ } else {
+ ImGui::TextDisabled("loading...");
+ }
+
+ ImGui::End();
+ return true;
+}
+
+Status DebugApp::PrepareBytecodeCodeView(CodeViewDocument* document) {
+ auto* remote_module = document->function->module();
+ auto* remote_function = document->function;
+
+ document->bytecode_info.lines = remote_function->name();
+
+ return OkStatus();
+}
+
+Status DebugApp::DrawBytecodeCodeView(CodeViewDocument* document) {
+ // Ensure we have cached our line information.
+ RETURN_IF_ERROR(PrepareBytecodeCodeView(document));
+
+ auto* remote_module = document->function->module();
+ auto* remote_function = document->function;
+
+ ImGui::BeginGroup();
+ ImGui::BeginChild("##bytecode_view", ImVec2(0, 0), false,
+ ImGuiWindowFlags_AlwaysVerticalScrollbar);
+ ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(0, 0));
+
+ // TODO(benvanik): cache breakpoints for this function for faster lookup.
+
+ auto& bytecode_info = document->bytecode_info;
+ ImGuiListClipper clipper(bytecode_info.lines.size(),
+ ImGui::GetTextLineHeightWithSpacing());
+ while (clipper.Step()) {
+ for (int i = clipper.DisplayStart; i < clipper.DisplayEnd; ++i) {
+ ImGui::PushID(i);
+
+ // TODO(benvanik): lookup line info.
+ int bytecode_offset = 0;
+ int breakpoint_index = FindMatchingUserBreakpointIndex(
+ remote_module->name(), remote_function->ordinal(), bytecode_offset);
+ bool has_breakpoint = breakpoint_index != -1;
+ bool active_on_any_invocation = false;
+ bool active_on_selected_invocation = false;
+
+ ImGui::Dummy(ImVec2(4, 0));
+
+ // Gutter breakpoint button.
+ ImGui::SameLine();
+ if (has_breakpoint) {
+ PushButtonHue(0.0f); // Red
+ if (ImGui::Button(" ##toggle_breakpoint")) {
+ CHECK_GE(breakpoint_index, 0);
+ auto& user_breakpoint = user_breakpoint_list_[breakpoint_index];
+ if (user_breakpoint.active_breakpoint) {
+ RETURN_IF_ERROR(debug_client_->RemoveBreakpoint(
+ *user_breakpoint.active_breakpoint));
+ }
+ user_breakpoint_list_.erase(user_breakpoint_list_.begin() +
+ breakpoint_index);
+ }
+ PopButtonStyle();
+ if (ImGui::IsItemHovered()) {
+ ImGui::SetTooltip("Remove the breakpoint at this offset.");
+ }
+ } else {
+ PushButtonColor(ImGui::GetStyleColorVec4(ImGuiCol_ChildBg));
+ if (ImGui::Button(" ##toggle_breakpoint")) {
+ UserBreakpoint user_breakpoint;
+ user_breakpoint.type = RemoteBreakpoint::Type::kBytecodeFunction;
+ user_breakpoint.module_name = remote_module->name();
+ user_breakpoint.function_name = remote_function->name();
+ user_breakpoint.bytecode_offset = bytecode_offset;
+ user_breakpoint.wants_enabled = true;
+ user_breakpoint_list_.push_back(std::move(user_breakpoint));
+ }
+ PopButtonStyle();
+ if (ImGui::IsItemHovered()) {
+ ImGui::SetTooltip("Add a breakpoint at this offset.");
+ }
+ }
+
+ // Active execution chevron (shows when active or any invocation is
+ // executing this region).
+ ImGui::SameLine();
+ if (active_on_selected_invocation) {
+ // The selected invocation is active here.
+ ImGui::TextColored(ImGui::GetStyleColorVec4(ImGuiCol_SeparatorActive),
+ " > ");
+ } else if (active_on_any_invocation) {
+ // At least one other invocation is active here.
+ ImGui::TextColored(ImGui::GetStyleColorVec4(ImGuiCol_Separator), " > ");
+ } else {
+ // Not active.
+ ImGui::Text(" ");
+ }
+
+ // Line contents.
+ ImGui::SameLine();
+ ImGui::Text("%s", bytecode_info.lines[i].c_str());
+
+ if (document->focus_offset.has_value() &&
+ bytecode_offset == document->focus_offset.value()) {
+ document->bytecode_offset = document->focus_offset.value();
+ document->focus_offset = {};
+ ImGui::SetScrollHereY();
+ }
+
+ ImGui::PopID();
+ }
+ }
+
+ ImGui::PopStyleVar();
+ ImGui::EndChild();
+ ImGui::EndGroup();
+
+ return OkStatus();
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/tools/debugger/debug_app.h b/tools/debugger/debug_app.h
new file mode 100644
index 0000000..a7954cf
--- /dev/null
+++ b/tools/debugger/debug_app.h
@@ -0,0 +1,200 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_TOOLS_DEBUGGER_DEBUG_APP_H_
+#define IREE_TOOLS_DEBUGGER_DEBUG_APP_H_
+
+#include <SDL.h>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "base/status.h"
+#include "rt/debug/debug_client.h"
+
+// NOTE: order matters here, imgui must come first:
+#include "third_party/dear_imgui/imgui.h"
+// NOTE: must follow imgui.h:
+#include "third_party/dear_imgui/examples/imgui_impl_opengl3.h"
+#include "third_party/dear_imgui/examples/imgui_impl_sdl.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+// Debug client app UI.
+// Uses a DebugClient to communicate with a remote DebugServer and ImGui to
+// display a nifty UI.
+//
+// See the ImGui site for more info: https://github.com/ocornut/imgui
+// The most useful thing is the imgui_demo.cpp file that contains example usage
+// of most features.
+class DebugApp : private DebugClient::Listener {
+ public:
+ struct UserBreakpoint {
+ RemoteBreakpoint::Type type = RemoteBreakpoint::Type::kBytecodeFunction;
+ const RemoteBreakpoint* active_breakpoint = nullptr;
+ bool wants_enabled = true;
+ bool is_enabling = false;
+ // TODO(benvanik): reuse BreakpointDef here?
+ std::string module_name;
+ std::string function_name;
+ int function_ordinal = -1;
+ int bytecode_offset = 0;
+ std::string native_function;
+ };
+
+ static void PumpMainLoopThunk(void* arg);
+
+ DebugApp(SDL_Window* window, SDL_GLContext gl_context,
+ const char* glsl_version);
+ ~DebugApp();
+
+ // Connects to the service at the specified address.
+ Status Connect(absl::string_view service_address);
+ // Disconnects from the currently connected service, if any.
+ Status Disconnect();
+
+ // Returns true if the remote service is paused at our request.
+ bool is_paused() const;
+
+ // Pumps the main UI loop once.
+ // This polls the DebugClient, SDL input, and renders the UI.
+ // It should be called as frequently as possible to ensure snappy UI updates.
+ // Returns CancelledError if the app is being closed by the user.
+ Status PumpMainLoop();
+
+ // Defines how NavigationToCodeView methods behave.
+ enum class NavigationMode {
+ // The target will be opened in a new document tab.
+ kNewDocument,
+ // The target will be opened in the current document tab, replacing the
+ // current contents.
+ kCurrentDocument,
+ // The target will be opened in a document tab that mostly matches (like
+ // the same function in a module at a different offset), otherwise a new
+ // document will be opened.
+ kMatchDocument,
+ };
+
+ // Navigates to a particular function offset based on resolution of the given
+ // arguments. Navigation may happen asynchronously if targets need to be
+ // resolved or contents fetched.
+ Status NavigateToCodeView(absl::string_view module_name, int function_ordinal,
+ int offset, NavigationMode navigation_mode);
+ Status NavigateToCodeView(absl::string_view module_name,
+ absl::string_view function_name, int offset,
+ NavigationMode navigation_mode);
+ Status NavigateToCodeView(const RemoteInvocation& invocation,
+ int stack_frame_index,
+ NavigationMode navigation_mode);
+ Status NavigateToCodeView(const UserBreakpoint& user_breakpoint,
+ NavigationMode navigation_mode);
+
+ private:
+ struct CodeViewDocument {
+ // Document display title (and ID).
+ std::string title;
+ // Function (and offset within the function) being displayed.
+ RemoteFunction* function = nullptr;
+ int bytecode_offset = 0;
+ // Set to a bytecode offset to have the document focus there.
+ absl::optional<int> focus_offset;
+ // Cached info for bytecode display.
+ struct {
+ std::vector<std::string> lines;
+ } bytecode_info;
+ };
+
+ CodeViewDocument* FindMatchingDocument(absl::string_view module_name,
+ int function_ordinal);
+ RemoteInvocation* GetSelectedInvocation() const;
+
+ Status RefreshActiveBreakpoints();
+ bool IsStoppedAtBreakpoint(const UserBreakpoint& user_breakpoint) const;
+ int FindMatchingUserBreakpointIndex(absl::string_view module_name,
+ int function_ordinal, int offset);
+ int FindMatchingUserBreakpointIndex(absl::string_view module_name,
+ absl::string_view function_name,
+ int offset);
+ Status ResumeFromBreakpoint(UserBreakpoint* user_breakpoint);
+
+ Status OnContextRegistered(const RemoteContext& context) override;
+ Status OnContextUnregistered(const RemoteContext& context) override;
+ Status OnModuleLoaded(const RemoteContext& context,
+ const RemoteModule& module) override;
+ Status OnInvocationRegistered(const RemoteInvocation& invocation) override;
+ Status OnInvocationUnregistered(const RemoteInvocation& invocation) override;
+ Status OnBreakpointHit(const RemoteBreakpoint& breakpoint,
+ const RemoteInvocation& invocation) override;
+
+ Status LayoutInitialDockSpace();
+
+ Status DrawUI();
+ Status DrawMainMenu();
+ Status DrawToolbar();
+
+ Status DrawBreakpointListPanel();
+ StatusOr<bool> DrawBreakpoint(UserBreakpoint* user_breakpoint);
+ Status DrawAddBreakpointDialogs(
+ absl::optional<RemoteBreakpoint::Type> add_breakpoint_type);
+ Status DrawAddBytecodeFunctionBreakpointDialog();
+ Status DrawAddNativeFunctionBreakpointDialog();
+
+ Status DrawModuleListPanel();
+ Status DrawContext(const RemoteContext& context,
+ const ImGuiTextFilter& filter);
+ Status DrawModule(RemoteModule* module, const ImGuiTextFilter& filter);
+
+ Status DrawLocalListPanel();
+ Status DrawLocal(RemoteInvocation* invocation, int stack_frame_index,
+ int local_index, const rpc::BufferViewDefT& local);
+
+ Status DrawInvocationListPanel();
+ Status DrawInvocation(const RemoteInvocation& invocation);
+
+ Status DrawCodeViewPanels();
+ StatusOr<bool> DrawCodeViewDocument(CodeViewDocument* document);
+ Status PrepareBytecodeCodeView(CodeViewDocument* document);
+ Status DrawBytecodeCodeView(CodeViewDocument* document);
+
+ SDL_Window* window_ = nullptr;
+ SDL_GLContext gl_context_ = nullptr;
+
+ ImGuiID dockspace_id_;
+ ImGuiID dock_top_id_;
+ ImGuiID dock_left_id_;
+ ImGuiID dock_bottom_id_;
+ ImGuiID dock_bottom_left_id_;
+ ImGuiID dock_bottom_right_id_;
+ ImGuiID dock_right_id_;
+ ImGuiID dock_content_id_;
+
+ std::unique_ptr<DebugClient> debug_client_;
+ std::vector<UserBreakpoint> user_breakpoint_list_;
+
+ bool is_paused_ = false;
+ std::vector<const RemoteBreakpoint*> hit_breakpoints_;
+ bool is_stepping_ = false;
+
+ absl::optional<int> selected_invocation_id_;
+ absl::optional<int> selected_stack_frame_index_;
+
+ std::vector<std::unique_ptr<CodeViewDocument>> documents_;
+};
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_TOOLS_DEBUGGER_DEBUG_APP_H_
diff --git a/iree/tools/debugger/debug_app.html b/tools/debugger/debug_app.html
similarity index 100%
rename from iree/tools/debugger/debug_app.html
rename to tools/debugger/debug_app.html
diff --git a/tools/debugger/debug_app_embedded.cc b/tools/debugger/debug_app_embedded.cc
new file mode 100644
index 0000000..ba70792
--- /dev/null
+++ b/tools/debugger/debug_app_embedded.cc
@@ -0,0 +1,153 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tools/debugger/debug_app_embedded.h"
+
+#include <SDL.h>
+
+#include <thread> // NOLINT
+
+#include "absl/base/thread_annotations.h"
+#include "absl/memory/memory.h"
+#include "absl/synchronization/mutex.h"
+#include "base/memory.h"
+#include "base/status.h"
+#include "third_party/SDL2/include/SDL_thread.h"
+#include "tools/debugger/debug_app.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+class InProcessEmbeddedDebugger : public EmbeddedDebugger {
+ public:
+ explicit InProcessEmbeddedDebugger(std::unique_ptr<DebugApp> app)
+ : app_(std::move(app)) {
+ thread_ =
+ SDL_CreateThread(&ThreadMainThunk, "InProcessEmbeddedDebugger", this);
+ }
+
+ ~InProcessEmbeddedDebugger() override {
+ VLOG(1) << "Setting shutdown flag and waiting on thread...";
+ shutdown_flag_ = true;
+ int status = 0;
+ SDL_WaitThread(thread_, &status);
+ VLOG(1) << "Thread shutdown, killing app...";
+ app_.reset();
+ }
+
+ Status AwaitClose() override {
+ await_mutex_.LockWhen(absl::Condition(
+ +[](bool* is_shutdown) { return *is_shutdown; }, &is_shutdown_));
+ auto status = std::move(shutdown_status_);
+ await_mutex_.Unlock();
+ return status;
+ }
+
+ private:
+ static int ThreadMainThunk(void* arg) {
+ return reinterpret_cast<InProcessEmbeddedDebugger*>(arg)->ThreadMain();
+ }
+
+ int ThreadMain() {
+ VLOG(1) << "Thread entry";
+ while (!shutdown_flag_) {
+ auto status = app_->PumpMainLoop();
+ if (IsCancelled(status)) {
+ shutdown_flag_ = true;
+ break;
+ } else if (!shutdown_flag_ && !status.ok()) {
+ absl::MutexLock lock(&await_mutex_);
+ shutdown_status_ = std::move(status);
+ // TODO(benvanik): don't check unless no one is watching.
+ CHECK_OK(shutdown_status_);
+ }
+ }
+ app_.reset();
+ {
+ absl::MutexLock lock(&await_mutex_);
+ is_shutdown_ = true;
+ }
+ VLOG(1) << "Thread exit";
+ return 0;
+ }
+
+ std::unique_ptr<DebugApp> app_;
+ SDL_Thread* thread_;
+ std::atomic<bool> shutdown_flag_ = {false};
+ absl::Mutex await_mutex_;
+ bool is_shutdown_ ABSL_GUARDED_BY(await_mutex_) = false;
+ Status shutdown_status_ ABSL_GUARDED_BY(await_mutex_);
+};
+
+StatusOr<std::unique_ptr<EmbeddedDebugger>> LaunchDebugger() {
+ return AttachDebugger("");
+}
+
+StatusOr<std::unique_ptr<EmbeddedDebugger>> AttachDebugger(
+ absl::string_view service_address) {
+ LOG(INFO) << "Launching embedded debugger; service=" << service_address;
+ // Workaround for terrible bad SDL/graphics driver leaks.
+ IREE_DISABLE_LEAK_CHECKS();
+
+ if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
+ return InternalErrorBuilder(IREE_LOC)
+ << "Unable to init SDL: " << SDL_GetError();
+ }
+
+#if __APPLE__
+ // GL 3.2 Core + GLSL 150
+ const char* glsl_version = "#version 150";
+ SDL_GL_SetAttribute(
+ SDL_GL_CONTEXT_FLAGS,
+ SDL_GL_CONTEXT_FORWARD_COMPATIBLE_FLAG); // Always required on Mac
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 3);
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 2);
+#else
+ // GL 3.0 + GLSL 130
+ const char* glsl_version = "#version 130";
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_FLAGS, 0);
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 3);
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 0);
+#endif
+
+ SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
+ SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 24);
+ SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 8);
+ SDL_DisplayMode current;
+ SDL_GetCurrentDisplayMode(0, ¤t);
+ SDL_WindowFlags window_flags = (SDL_WindowFlags)(
+ SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
+ SDL_Window* window =
+ SDL_CreateWindow("IREE Debugger (embedded)", SDL_WINDOWPOS_CENTERED,
+ SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
+ SDL_GLContext gl_context = SDL_GL_CreateContext(window);
+ SDL_GL_MakeCurrent(nullptr, nullptr);
+
+ IREE_ENABLE_LEAK_CHECKS();
+
+ auto app = absl::make_unique<DebugApp>(window, gl_context, glsl_version);
+ if (!service_address.empty()) {
+ RETURN_IF_ERROR(app->Connect(service_address));
+ }
+
+ auto handle = absl::make_unique<InProcessEmbeddedDebugger>(std::move(app));
+ return handle;
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/tools/debugger/debug_app_embedded.h b/tools/debugger/debug_app_embedded.h
new file mode 100644
index 0000000..e597ad2
--- /dev/null
+++ b/tools/debugger/debug_app_embedded.h
@@ -0,0 +1,52 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_
+#define IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_
+
+#include <memory>
+
+#include "absl/strings/string_view.h"
+#include "base/status.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+// RAII handle for keeping the debugger alive.
+// When the instance is destroyed the debugger app will be closed.
+class EmbeddedDebugger {
+ public:
+ virtual ~EmbeddedDebugger() = default;
+
+ // Blocks the caller until the debugger is closed by the user.
+ virtual Status AwaitClose() = 0;
+};
+
+// Launches the debugger app.
+// Returns a handle that can be used to wait for the debugger to close or
+// force it to close.
+StatusOr<std::unique_ptr<EmbeddedDebugger>> LaunchDebugger();
+
+// Launches the debugger app and attaches to the given server address.
+// Returns a handle that can be used to wait for the debugger to close or
+// force it to close.
+StatusOr<std::unique_ptr<EmbeddedDebugger>> AttachDebugger(
+ absl::string_view service_address);
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_
diff --git a/tools/debugger/debug_app_main_emscripten.cc b/tools/debugger/debug_app_main_emscripten.cc
new file mode 100644
index 0000000..2746681
--- /dev/null
+++ b/tools/debugger/debug_app_main_emscripten.cc
@@ -0,0 +1,69 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Emscripten debug_app entry point.
+// Though we are using SDL here we need to do some emscripten-specific magic to
+// handle the different main looping mode (as we can't block in main() like on
+// other platforms) as well as support some emscripten-specific features for
+// file upload/download/etc.
+
+#include <SDL.h>
+#include <emscripten.h>
+
+#include "base/init.h"
+#include "tools/debugger/debug_app.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+extern "C" int main(int argc, char** argv) {
+ InitializeEnvironment(&argc, &argv);
+
+ if (SDL_Init(SDL_INIT_VIDEO) != 0) {
+ printf("Error: %s\n", SDL_GetError());
+ return -1;
+ }
+
+ const char* glsl_version = "#version 100";
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_FLAGS, 0);
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_ES);
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 2);
+ SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 0);
+
+ SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
+ SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 24);
+ SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 8);
+ SDL_DisplayMode current;
+ SDL_GetCurrentDisplayMode(0, ¤t);
+ SDL_WindowFlags window_flags = (SDL_WindowFlags)(
+ SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
+ SDL_Window* window =
+ SDL_CreateWindow("IREE Debugger", SDL_WINDOWPOS_CENTERED,
+ SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
+ SDL_GLContext gl_context = SDL_GL_CreateContext(window);
+ if (!gl_context) {
+ printf("Failed to initialize WebGL context!\n");
+ return 1;
+ }
+
+ auto app = absl::make_unique<DebugApp>(window, gl_context, glsl_version);
+ ::emscripten_set_main_loop_arg(DebugApp::PumpMainLoopThunk, app.release(), 0,
+ false);
+ return 0;
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/tools/debugger/debug_app_main_native.cc b/tools/debugger/debug_app_main_native.cc
new file mode 100644
index 0000000..1753e9f
--- /dev/null
+++ b/tools/debugger/debug_app_main_native.cc
@@ -0,0 +1,45 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Native (linux/etc) debug_app entry point.
+// This should work on any platform with pthreads and SDL support.
+
+#include "base/init.h"
+#include "base/status.h"
+#include "tools/debugger/debug_app_embedded.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+Status Run() {
+ ASSIGN_OR_RETURN(auto handle, LaunchDebugger());
+ RETURN_IF_ERROR(handle->AwaitClose());
+ handle.reset();
+ return OkStatus();
+}
+
+extern "C" int main(int argc, char** argv) {
+ InitializeEnvironment(&argc, &argv);
+ auto status = Run();
+ if (!status.ok()) {
+ LOG(ERROR) << "Debugger error: " << status;
+ return 1;
+ }
+ return 0;
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/tools/debugger/debug_cli_main.cc b/tools/debugger/debug_cli_main.cc
new file mode 100644
index 0000000..4ff33e4
--- /dev/null
+++ b/tools/debugger/debug_cli_main.cc
@@ -0,0 +1,40 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "absl/flags/flag.h"
+#include "base/init.h"
+#include "base/status.h"
+#include "tools/debugger/debug_prompt.h"
+
+ABSL_FLAG(std::string, debug_service_uri, "0.0.0.0:6000",
+ "IP/port of debug service to connect to.");
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+Status Run() {
+ // TODO(benvanik): retry until connected? would allow auto-build reconnects.
+ return AttachDebugPrompt(absl::GetFlag(FLAGS_debug_service_uri));
+}
+
+extern "C" int main(int argc, char** argv) {
+ InitializeEnvironment(&argc, &argv);
+ CHECK_OK(Run());
+ return 0;
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/tools/debugger/debug_prompt.cc b/tools/debugger/debug_prompt.cc
new file mode 100644
index 0000000..729d432
--- /dev/null
+++ b/tools/debugger/debug_prompt.cc
@@ -0,0 +1,90 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tools/debugger/debug_prompt.h"
+
+#include "base/status.h"
+#include "rt/debug/debug_client.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+namespace {
+
+class DebugPrompt : private DebugClient::Listener {
+ public:
+ Status Connect(absl::string_view debug_service_uri) {
+ // Connect to the debug service.
+ ASSIGN_OR_RETURN(debug_client_,
+ DebugClient::Connect(debug_service_uri, this));
+ return OkStatus();
+ }
+
+ Status Run() {
+ // Query commands, transmit requests, and dispatch responses.
+ while (true) {
+ RETURN_IF_ERROR(debug_client_->Poll());
+
+ // TODO(benvanik): ask for a command.
+ // TODO(benvanik): display stuff.
+ }
+ }
+
+ private:
+ Status OnContextRegistered(const RemoteContext& context) override {
+ // Ack.
+ return debug_client_->MakeReady();
+ }
+
+ Status OnContextUnregistered(const RemoteContext& context) override {
+ // Ack.
+ return debug_client_->MakeReady();
+ }
+
+ Status OnModuleLoaded(const RemoteContext& context,
+ const RemoteModule& module) override {
+ // Ack.
+ return debug_client_->MakeReady();
+ }
+
+ Status OnInvocationRegistered(const RemoteInvocation& invocation) override {
+ // Ack.
+ return debug_client_->MakeReady();
+ }
+
+ Status OnInvocationUnregistered(const RemoteInvocation& invocation) override {
+ // Ack.
+ return debug_client_->MakeReady();
+ }
+
+ Status OnBreakpointHit(const RemoteBreakpoint& breakpoint,
+ const RemoteInvocation& invocation) override {
+ // Ack.
+ return debug_client_->MakeReady();
+ }
+
+ std::unique_ptr<DebugClient> debug_client_;
+};
+
+} // namespace
+
+Status AttachDebugPrompt(absl::string_view debug_service_uri) {
+ DebugPrompt debug_prompt;
+ RETURN_IF_ERROR(debug_prompt.Connect(debug_service_uri));
+ return debug_prompt.Run();
+}
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
diff --git a/tools/debugger/debug_prompt.h b/tools/debugger/debug_prompt.h
new file mode 100644
index 0000000..1eb2183
--- /dev/null
+++ b/tools/debugger/debug_prompt.h
@@ -0,0 +1,35 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_
+#define IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_
+
+#include "absl/strings/string_view.h"
+#include "base/status.h"
+
+namespace iree {
+namespace rt {
+namespace debug {
+
+// TODO(benvanik): take stdin/stdout as arguments.
+// Attaches a debug prompt reading stdin for commands and printing results to
+// stdout. The calling thread will block until the debugger is exited or the
+// debug service closes.
+Status AttachDebugPrompt(absl::string_view debug_service_uri);
+
+} // namespace debug
+} // namespace rt
+} // namespace iree
+
+#endif // IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_
diff --git a/iree/tools/iree_translate_main.cc b/tools/iree_translate_main.cc
similarity index 100%
rename from iree/tools/iree_translate_main.cc
rename to tools/iree_translate_main.cc
diff --git a/iree/tools/run_lit.sh b/tools/run_lit.oss.sh
similarity index 100%
rename from iree/tools/run_lit.sh
rename to tools/run_lit.oss.sh
diff --git a/tools/run_mlir_main.cc b/tools/run_mlir_main.cc
new file mode 100644
index 0000000..5c832fb
--- /dev/null
+++ b/tools/run_mlir_main.cc
@@ -0,0 +1,360 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// IREE source.mlir -> execution output test runner.
+// This is meant to be called from LIT for FileCheck tests, and tries to match
+// the interface of mlir-opt (featuring -split-input-file, etc) so it's easier
+// to work with there. If you want a more generalized runner for standalone
+// precompiled IREE modules use //third_party/iree/tools:run_module.
+//
+// By default all exported functions in the module will be run in order.
+// All input values, provided via -input-values, will be passed to the
+// functions (this means all input signatures must match). Results from the
+// executed functions will be printed to stdout for checking.
+// Use -output_types to set the function output data types, which like args will
+// be used for all functions executed.
+//
+// Example input:
+// // RUN: iree-run %s | FileCheck %s
+// // CHECK-LABEL: @foo
+// // CHECK: 1xf32: 2
+// func @foo() -> memref<f32> attributes {iree.module.export} {
+// %0 = "iree.constant"() {value: dense<tensor<f32>, 2.0>} : () -> memref<f32>
+// return %0 : memref<f32>
+// }
+
+#include <iostream>
+
+#include "absl/flags/flag.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "base/init.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "compiler/Translation/Sequencer/SequencerModuleTranslation.h"
+#include "hal/buffer_view_string_util.h"
+#include "hal/driver_registry.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/SourceMgr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "rt/context.h"
+#include "rt/debug/debug_server_flags.h"
+#include "rt/instance.h"
+#include "rt/invocation.h"
+#include "rt/module.h"
+#include "rt/module_printer.h"
+#include "schemas/module_def_generated.h"
+#include "vm/sequencer_module.h"
+
+ABSL_FLAG(bool, split_input_file, true,
+ "Split the input file into multiple modules.");
+
+ABSL_FLAG(std::string, target_backends, "",
+ "Comma-separated list of target backends to translate executables "
+ "into. Omit to translate using all linked-in backend translators.");
+ABSL_FLAG(
+ bool, export_all, true,
+ "Automatically add the iree.module.export attribute to all functions.");
+
+ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values.");
+ABSL_FLAG(std::string, output_types, "",
+ "Output data types (comma delimited list of b/i/u/f for "
+ "binary/signed int/unsigned int/float).");
+
+// TODO(benvanik): is there a more canonical flag we can use?
+ABSL_FLAG(bool, print_mlir, true, "Prints MLIR IR during translation.");
+
+ABSL_FLAG(bool, print_bytecode, false,
+ "Prints IREE bytecode after translation.");
+
+ABSL_FLAG(bool, run, true,
+ "Option to run the file. Setting it to false just compiles it.");
+
+namespace iree {
+namespace {
+
+using ::iree::hal::BufferView;
+using ::iree::rt::Function;
+using ::iree::rt::Module;
+
+// Returns a driver name capable of handling input from the given backend.
+std::string BackendToDriverName(std::string backend) {
+ size_t dash = backend.find('-');
+ if (dash == std::string::npos) {
+ return backend;
+ } else {
+ return backend.substr(0, dash);
+ }
+}
+
+// Prepares a module for evaluation by running MLIR import and IREE translation.
+StatusOr<ref_ptr<Module>> PrepareModule(
+ std::string target_backend,
+ std::unique_ptr<llvm::MemoryBuffer> file_buffer) {
+ mlir::MLIRContext context;
+
+ // Parse input MLIR module.
+ llvm::SourceMgr source_mgr;
+ source_mgr.AddNewSourceBuffer(std::move(file_buffer), llvm::SMLoc());
+ mlir::OwningModuleRef mlir_module =
+ mlir::parseSourceFile(source_mgr, &context);
+
+ if (absl::GetFlag(FLAGS_export_all)) {
+ for (auto function : mlir_module->getOps<mlir::FuncOp>()) {
+ function.setAttr("iree.module.export", mlir::UnitAttr::get(&context));
+ }
+ }
+
+ // Translate from MLIR to IREE bytecode.
+ mlir::iree_compiler::ModuleTranslationOptions options;
+ options.print_mlir = absl::GetFlag(FLAGS_print_mlir);
+ options.target_backends = {target_backend};
+ auto iree_module_bytes =
+ mlir::iree_compiler::translateMlirToIreeSequencerModule(mlir_module.get(),
+ options);
+ if (iree_module_bytes.empty()) {
+ return iree::InternalErrorBuilder(IREE_LOC)
+ << "Error translating MLIR to an IREE sequencer module";
+ }
+
+ if (absl::GetFlag(FLAGS_print_mlir)) {
+ mlir_module->dump();
+ }
+
+ // Wrap module in a file handle.
+ ASSIGN_OR_RETURN(auto iree_module_file,
+ vm::ModuleFile::FromBuffer(ModuleDefIdentifier(),
+ std::move(iree_module_bytes)));
+ return vm::SequencerModule::FromFile(std::move(iree_module_file));
+}
+
+// Parses a list of input shapes and values from a string of newline-separated
+// inputs. Expects the contents to have one value per line with each value
+// listed as
+// [shape]xtype=[value]
+// Example:
+// 4x4xi8=0,1,2,3
+StatusOr<std::vector<BufferView>> ParseInputsFromFlags(
+ hal::Allocator *allocator) {
+ std::string file_contents =
+ absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}});
+ std::vector<BufferView> inputs;
+ std::vector<std::string> lines = absl::StrSplit(
+ file_contents, absl::ByAnyChar("\n;"), absl::SkipWhitespace());
+ for (const auto &line : lines) {
+ ASSIGN_OR_RETURN(auto input,
+ hal::ParseBufferViewFromString(line, allocator));
+ inputs.push_back(input);
+ }
+ return inputs;
+}
+
+// Outputs all results from the function to stdout in IREE BufferView format.
+Status OutputFunctionResults(const Function &function,
+ absl::Span<BufferView> results) {
+ std::vector<std::string> output_types =
+ absl::StrSplit(absl::GetFlag(FLAGS_output_types), absl::ByAnyChar(", "),
+ absl::SkipWhitespace());
+ if (!output_types.empty() && output_types.size() != results.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "--output_types= specified but has " << output_types.size()
+ << " types when the function returns " << results.size();
+ }
+
+ for (int i = 0; i < results.size(); ++i) {
+ const auto &result = results[i];
+ auto print_mode = hal::BufferViewPrintMode::kFloatingPoint;
+ if (!output_types.empty()) {
+ ASSIGN_OR_RETURN(print_mode,
+ hal::ParseBufferViewPrintMode(output_types[i]));
+ }
+ ASSIGN_OR_RETURN(auto result_str,
+ hal::PrintBufferViewToString(result, print_mode, 1024));
+ LOG(INFO) << "result[" << i << "]: " << result.buffer->DebugString();
+ std::cout << result_str << "\n";
+ }
+
+ return OkStatus();
+}
+
+// Evaluates a single function in its own fiber, printing the results to stdout.
+Status EvaluateFunction(const ref_ptr<rt::Context> &context,
+ hal::Allocator *allocator, const Function &function) {
+ std::cout << "EXEC @" << function.name() << std::endl;
+
+ // Create invocation that will perform the execution.
+ ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(allocator));
+ ASSIGN_OR_RETURN(
+ auto invocation,
+ rt::Invocation::Create(add_ref(context), function, make_ref<rt::Policy>(),
+ {}, absl::MakeConstSpan(arguments)));
+
+ // Wait until invocation completes.
+ RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture()));
+
+ // Print outputs.
+ ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults());
+ RETURN_IF_ERROR(OutputFunctionResults(function, absl::MakeSpan(results)));
+
+ return OkStatus();
+}
+
+// Evaluates all exported functions within given module.
+Status EvaluateFunctions(absl::string_view target_backend,
+ ref_ptr<Module> module) {
+ // Create the context we'll use for this (ensuring that we can't interfere
+ // with other running evaluations, such as when in a multithreaded test
+ // runner).
+ ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags());
+ auto instance = make_ref<rt::Instance>(std::move(debug_server));
+ ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
+ target_backend));
+ ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
+ RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
+
+ if (absl::GetFlag(FLAGS_print_bytecode)) {
+ RETURN_IF_ERROR(rt::PrintModuleToStream(
+ *module, rt::PrintModuleFlag::kDisassemble, &std::cout));
+ }
+
+ // Evaluate all exported functions.
+ auto policy = make_ref<rt::Policy>();
+ auto run_function = [&](int ordinal) -> Status {
+ // Setup a new context for this invocation.
+ auto context = make_ref<rt::Context>(add_ref(instance), add_ref(policy));
+ RETURN_IF_ERROR(context->RegisterModule(add_ref(module)));
+
+ // Invoke the function and print results.
+ ASSIGN_OR_RETURN(auto function,
+ module->LookupFunctionByOrdinal(
+ rt::Function::Linkage::kExport, ordinal));
+ RETURN_IF_ERROR(EvaluateFunction(context, device->allocator(), function));
+ return OkStatus();
+ };
+
+ Status evaluate_status = OkStatus();
+ for (int i = 0; i < module->signature().export_function_count(); ++i) {
+ evaluate_status = run_function(i);
+ if (!evaluate_status.ok()) {
+ break;
+ }
+ }
+
+ RETURN_IF_ERROR(instance->device_manager()->UnregisterDevice(device.get()));
+ device.reset();
+ driver.reset();
+
+ return evaluate_status;
+}
+
+// Translates and runs a single LLVM file buffer.
+Status EvaluateFile(std::unique_ptr<llvm::MemoryBuffer> file_buffer) {
+ std::vector<std::string> target_backends;
+ if (absl::GetFlag(FLAGS_target_backends).empty()) {
+ target_backends =
+ hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers();
+ } else {
+ // We need to map specific backends names to drivers (like 'vulkan-spirv' to
+ // the driver 'vulkan').
+ target_backends = absl::StrSplit(absl::GetFlag(FLAGS_target_backends), ',');
+ }
+
+ for (auto target_backend : target_backends) {
+ // Prepare the module for execution and evaluate it.
+ auto cloned_file_buffer = llvm::MemoryBuffer::getMemBufferCopy(
+ file_buffer->getBuffer(), file_buffer->getBufferIdentifier());
+ ASSIGN_OR_RETURN(auto module, PrepareModule(target_backend + '*',
+ std::move(cloned_file_buffer)));
+ if (!absl::GetFlag(FLAGS_run)) {
+ continue;
+ }
+ RETURN_IF_ERROR(EvaluateFunctions(BackendToDriverName(target_backend),
+ std::move(module)));
+ }
+
+ return OkStatus();
+}
+
+// Runs the given .mlir file based on the current flags.
+Status RunFile(std::string mlir_filename) {
+ // Load input file/from stdin.
+ std::string error_message;
+ auto file = mlir::openInputFile(mlir_filename, &error_message);
+ if (!file) {
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Unable to open input file " << mlir_filename << ": "
+ << error_message;
+ }
+
+ if (!absl::GetFlag(FLAGS_split_input_file)) {
+ // Use entire buffer as a single module.
+ return EvaluateFile(std::move(file));
+ }
+
+ // Split the buffer into separate modules and evaluate independently.
+ // This matches the -split-input-file arg to mlir-opt.
+ const char kSplitMarker[] = "// -----\n";
+ auto *full_buffer = file.get();
+ llvm::SmallVector<llvm::StringRef, 8> source_buffers;
+ full_buffer->getBuffer().split(source_buffers, kSplitMarker);
+
+ // Add the original buffer to the source manager.
+ llvm::SourceMgr fileSourceMgr;
+ fileSourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
+
+ // Process each chunk in turn. Only return the first error (but log all).
+ Status any_failure;
+ for (auto &sub_source_buffer : source_buffers) {
+ auto split_loc = llvm::SMLoc::getFromPointer(sub_source_buffer.data());
+ unsigned split_line = fileSourceMgr.getLineAndColumn(split_loc).first;
+ auto sub_buffer = llvm::MemoryBuffer::getMemBufferCopy(
+ sub_source_buffer, full_buffer->getBufferIdentifier() +
+ llvm::Twine(" split at line #") +
+ llvm::Twine(split_line));
+ auto sub_failure = EvaluateFile(std::move(sub_buffer));
+ if (!sub_failure.ok()) {
+ LOG(ERROR) << sub_failure;
+ if (any_failure.ok()) {
+ any_failure = std::move(sub_failure);
+ }
+ }
+ }
+
+ return any_failure;
+}
+
+} // namespace
+
+extern "C" int main(int argc, char **argv) {
+ InitializeEnvironment(&argc, &argv);
+ if (argc < 2) {
+ LOG(ERROR) << "Must supply an input .mlir file.";
+ return 1;
+ }
+ auto status = RunFile(argv[1]);
+ if (!status.ok()) {
+ std::cerr << "ERROR running file (" << argv[1] << "): " << status << "\n";
+ return 1;
+ }
+ return 0;
+}
+
+} // namespace iree
diff --git a/tools/run_module_main.cc b/tools/run_module_main.cc
new file mode 100644
index 0000000..93b1553
--- /dev/null
+++ b/tools/run_module_main.cc
@@ -0,0 +1,183 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <iostream>
+#include <vector>
+
+#include "absl/flags/flag.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "base/file_io.h"
+#include "base/file_path.h"
+#include "base/init.h"
+#include "base/source_location.h"
+#include "base/status.h"
+#include "hal/buffer_view_string_util.h"
+#include "hal/driver_registry.h"
+#include "rt/context.h"
+#include "rt/debug/debug_server_flags.h"
+#include "rt/instance.h"
+#include "rt/module_printer.h"
+#include "schemas/module_def_generated.h"
+#include "vm/sequencer_module.h"
+
+ABSL_FLAG(std::string, main_module, "", "Main module with entry point.");
+ABSL_FLAG(std::string, main_function, "",
+ "Function within the main module to execute.");
+
+ABSL_FLAG(bool, print_disassembly, true,
+ "Prints bytecode disassembly for the module.");
+
+ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values.");
+ABSL_FLAG(std::string, input_file, "",
+ "Input shapes and optional values serialized in a file.");
+
+ABSL_FLAG(std::string, output_types, "",
+ "Output data types (comma delimited list of b/i/u/f for "
+ "binary/signed int/unsigned int/float).");
+
+namespace iree {
+namespace {
+
+// Parses a list of input shapes and values from a string of newline-separated
+// inputs. Expects the contents to have one value per line with each value
+// listed as
+// [shape]xtype=[value]
+// Example:
+// 4x4xi8=0,1,2,3
+StatusOr<std::vector<hal::BufferView>> ParseInputsFromFlags(
+ hal::Allocator* allocator) {
+ std::string file_contents;
+ if (!absl::GetFlag(FLAGS_input_values).empty()) {
+ file_contents =
+ absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}});
+ } else if (!absl::GetFlag(FLAGS_input_file).empty()) {
+ ASSIGN_OR_RETURN(file_contents,
+ file_io::GetFileContents(absl::GetFlag(FLAGS_input_file)));
+ }
+ std::vector<hal::BufferView> inputs;
+ for (const auto& line :
+ absl::StrSplit(file_contents, '\n', absl::SkipWhitespace())) {
+ ASSIGN_OR_RETURN(auto input,
+ hal::ParseBufferViewFromString(line, allocator));
+ inputs.push_back(input);
+ }
+ return inputs;
+}
+
+} // namespace
+
+Status Run() {
+ ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags());
+ auto instance = make_ref<rt::Instance>(std::move(debug_server));
+ ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
+ "interpreter"));
+ ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
+ RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
+ auto policy = make_ref<rt::Policy>();
+ auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy));
+
+ // Load main module.
+ ASSIGN_OR_RETURN(
+ auto main_module_file,
+ vm::ModuleFile::LoadFile(ModuleDefIdentifier(),
+ absl::GetFlag(FLAGS_main_module)),
+ _ << "while loading module file " << absl::GetFlag(FLAGS_main_module));
+ ASSIGN_OR_RETURN(auto main_module,
+ vm::SequencerModule::FromFile(std::move(main_module_file)));
+
+ // Register the main module with the context.
+ // We could add additional modules (specializations, shared libraries, etc).
+ // ModuleFioles are stateless so we could have the same module_file used by
+ // multiple contexts simultaneously.
+ RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module)));
+
+ // Dump the registered modules.
+ rt::PrintModuleFlagBitfield print_flags = rt::PrintModuleFlag::kNone;
+ if (absl::GetFlag(FLAGS_print_disassembly)) {
+ print_flags |= rt::PrintModuleFlag::kDisassemble;
+ }
+ for (const auto& module : context->modules()) {
+ RETURN_IF_ERROR(PrintModuleToStream(*module, print_flags, &std::cout));
+ }
+
+ rt::Function main_function;
+ if (!absl::GetFlag(FLAGS_main_function).empty()) {
+ // User-specified main function.
+ ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByName(
+ rt::Function::Linkage::kExport,
+ absl::GetFlag(FLAGS_main_function)));
+ } else {
+ // No main function specified; to prevent non-deterministic behavior we
+ // require one unless there's exactly one exported function in the module.
+ if (main_module->signature().export_function_count() == 1) {
+ ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByOrdinal(
+ rt::Function::Linkage::kExport, 0));
+ } else {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "--main_function= must be specified to disambiguate the "
+ "function to run";
+ }
+ }
+
+ // Call into the main function.
+ ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(device->allocator()));
+ ASSIGN_OR_RETURN(auto invocation,
+ rt::Invocation::Create(add_ref(context), main_function,
+ make_ref<rt::Policy>(), {},
+ absl::MakeConstSpan(arguments)));
+
+ // Wait until invocation completes.
+ RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture()));
+ ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults());
+
+ // Dump all results to stdout.
+ std::vector<std::string> output_types =
+ absl::StrSplit(absl::GetFlag(FLAGS_output_types), absl::ByAnyChar(", "),
+ absl::SkipWhitespace());
+ if (!output_types.empty() && output_types.size() != results.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "--output_types= specified but has " << output_types.size()
+ << " types when the function returns " << results.size();
+ }
+ for (int i = 0; i < results.size(); ++i) {
+ const auto& result = results[i];
+ auto print_mode = hal::BufferViewPrintMode::kFloatingPoint;
+ if (!output_types.empty()) {
+ ASSIGN_OR_RETURN(print_mode,
+ hal::ParseBufferViewPrintMode(output_types[i]));
+ }
+ ASSIGN_OR_RETURN(auto result_str,
+ PrintBufferViewToString(result, print_mode, 1024));
+ const auto& buffer = result.buffer;
+ if (!buffer) {
+ return InternalErrorBuilder(IREE_LOC)
+ << "result[" << i << "] unexpectedly has no buffer";
+ }
+ LOG(INFO) << "result[" << i << "]: " << buffer->DebugString();
+ std::cout << result_str << "\n";
+ }
+
+ return OkStatus();
+}
+
+extern "C" int main(int argc, char** argv) {
+ InitializeEnvironment(&argc, &argv);
+ CHECK_OK(Run());
+ return 0;
+}
+
+} // namespace iree
diff --git a/iree/tools/sanitizer_suppressions.txt b/tools/sanitizer_suppressions.txt
similarity index 100%
rename from iree/tools/sanitizer_suppressions.txt
rename to tools/sanitizer_suppressions.txt
diff --git a/tools/web/BUILD b/tools/web/BUILD
new file mode 100644
index 0000000..c24b4b2
--- /dev/null
+++ b/tools/web/BUILD
@@ -0,0 +1,89 @@
+# IREE web tools.
+
+load("//third_party/emscripten:split_transition_defs.bzl", "auto_wasm_binary")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+EMSCRIPTEN_LINKOPTS_COMMON = [
+ # Error at compile time on unresolved symbols.
+ "-s ERROR_ON_UNDEFINED_SYMBOLS=1",
+
+ # Note: If pthreads and memory growth are enabled, WASM_MEM_MAX must be set.
+ # Also, USE_PTHREADS + ALLOW_MEMORY_GROWTH may run non-wasm code slowly.
+ # "-s ALLOW_MEMORY_GROWTH=1",
+ # "-s WASM_MEM_MAX=268435456", # 256MB
+ # "-s TOTAL_MEMORY=268435456", # 256MB
+
+ # Request a prepopulated pool of web workers for pthreads to use.
+ # Without this, threads may not start until the javascript thread yields.
+ # See considerations at https://emscripten.org/docs/porting/pthreads.html.
+ "-s PTHREAD_POOL_SIZE=1",
+]
+
+EMSCRIPTEN_LINKOPTS_DBG = [
+ # Show WASM stack trace in Chrome debugger.
+ "-g2",
+ "-s DEMANGLE_SUPPORT=1",
+
+ # Enable verbose assertions.
+ "-s ASSERTIONS=2",
+ "-s SAFE_HEAP=1",
+ "-s STACK_OVERFLOW_CHECK=2",
+]
+
+EMSCRIPTEN_LINKOPTS_OPT = []
+
+# To use run_module_emscripten:
+# > bazel build third_party/iree/tools/web:run_module_emscripten_files
+
+cc_binary(
+ name = "run_module_emscripten",
+ srcs = ["run_module_emscripten.cc"],
+ linkopts = EMSCRIPTEN_LINKOPTS_COMMON + select({
+ "//tools/compilation_mode:dbg": EMSCRIPTEN_LINKOPTS_DBG,
+ "//tools/compilation_mode:opt": EMSCRIPTEN_LINKOPTS_OPT,
+ "//conditions:default": EMSCRIPTEN_LINKOPTS_OPT,
+ }),
+ tags = [
+ "manual",
+ "notap", # TODO(b/137088911): Build/test on TAP
+ "wasm",
+ ],
+ deps = [
+ "///base:init",
+ "///base:status",
+ "///hal:buffer_view_string_util",
+ "///hal:driver_registry",
+ "///hal/interpreter:interpreter_driver_module",
+ "///rt",
+ "///vm:sequencer_module",
+ "//third_party/emscripten:embind",
+ ],
+)
+
+auto_wasm_binary(
+ name = "run_module_emscripten_binary",
+ cc_target = ":run_module_emscripten",
+ tags = ["manual"],
+ threads = "emscripten",
+)
+
+Fileset(
+ name = "run_module_emscripten_files",
+ out = "wasm_files",
+ entries = [
+ FilesetEntry(
+ files = [":run_module_emscripten_binary"],
+ strip_prefix = "run_module_emscripten_binary",
+ destdir = "wasm",
+ ),
+ FilesetEntry(
+ files = ["run_module.html"],
+ destdir = "wasm",
+ ),
+ ],
+ tags = ["manual"],
+)
diff --git a/iree/tools/web/run_module.html b/tools/web/run_module.html
similarity index 100%
rename from iree/tools/web/run_module.html
rename to tools/web/run_module.html
diff --git a/tools/web/run_module_emscripten.cc b/tools/web/run_module_emscripten.cc
new file mode 100644
index 0000000..d433f1f
--- /dev/null
+++ b/tools/web/run_module_emscripten.cc
@@ -0,0 +1,140 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <emscripten.h>
+#include <emscripten/bind.h>
+
+#include <vector>
+
+#include "absl/strings/str_replace.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "base/flatbuffer_util.h"
+#include "base/init.h"
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "hal/buffer_view_string_util.h"
+#include "hal/driver_registry.h"
+#include "rt/context.h"
+#include "rt/instance.h"
+#include "schemas/module_def_generated.h"
+#include "vm/sequencer_module.h"
+
+namespace iree {
+
+// Parses a list of input shapes and values from a string of newline-separated
+// inputs. Expects the contents to have one value per line with each value
+// listed as
+// [shape]xtype=[value]
+// Example:
+// 4x4xi8=0,1,2,3
+StatusOr<std::vector<hal::BufferView>> ParseInputs(
+ absl::string_view inputs_string, hal::Allocator* allocator) {
+ std::string input_lines = absl::StrReplaceAll(inputs_string, {{"\\n", "\n"}});
+ std::vector<hal::BufferView> input_buffer_views;
+ for (const auto& input_line :
+ absl::StrSplit(input_lines, '\n', absl::SkipWhitespace())) {
+ ASSIGN_OR_RETURN(auto input_buffer_view,
+ hal::ParseBufferViewFromString(input_line, allocator));
+ input_buffer_views.push_back(input_buffer_view);
+ }
+ return input_buffer_views;
+}
+
+// Runs an IREE module with the provided inputs and returns its outputs.
+StatusOr<std::string> RunIreeModule(std::string module_file_data,
+ absl::string_view inputs_string) {
+ auto instance = make_ref<rt::Instance>();
+
+ // Create driver and device.
+ ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
+ "interpreter"));
+ ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
+ RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
+
+ auto policy = make_ref<rt::Policy>();
+ auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy));
+
+ // Load main module FlatBuffer.
+ ASSIGN_OR_RETURN(auto main_module_file,
+ FlatBufferFile<ModuleDef>::FromString(ModuleDefIdentifier(),
+ module_file_data));
+ ASSIGN_OR_RETURN(auto main_module,
+ vm::SequencerModule::FromFile(std::move(main_module_file)));
+
+ // Register the main module with the context.
+ RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module)));
+
+ // Setup arguments and storage for results.
+ // TODO(scotttodd): Receive main function name from JS.
+ ASSIGN_OR_RETURN(auto main_function,
+ main_module->LookupFunctionByName(
+ rt::Function::Linkage::kExport, "main"));
+
+ ASSIGN_OR_RETURN(auto arguments,
+ ParseInputs(inputs_string, device->allocator()));
+
+ // Call into the main function.
+ ASSIGN_OR_RETURN(auto invocation,
+ rt::Invocation::Create(add_ref(context), main_function,
+ make_ref<rt::Policy>(), {},
+ absl::MakeConstSpan(arguments)));
+
+ // Wait until invocation completes.
+ // TODO(scotttodd): make this an async callback.
+ RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture()));
+ ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults());
+
+ // Dump all results to stdout.
+ // TODO(scotttodd): Receive output types / print mode from JS.
+ // TODO(scotttodd): Return list of outputs instead of just the first (proto?)
+ for (int i = 0; i < results.size(); ++i) {
+ const auto& result = results[i];
+ auto print_mode = hal::BufferViewPrintMode::kFloatingPoint;
+ ASSIGN_OR_RETURN(auto result_str,
+ PrintBufferViewToString(result, print_mode, 1024));
+ const auto& buffer = result.buffer;
+ if (!buffer) {
+ return InternalErrorBuilder(IREE_LOC)
+ << "result[" << i << "] unexpectedly has no buffer";
+ }
+
+ return result_str;
+ }
+
+ return InternalErrorBuilder(IREE_LOC) << "Received no results";
+}
+
+std::string RunIreeModuleEntry(std::string module_file_data,
+ std::string inputs_string) {
+ // TODO(scotttodd): optimize, minimize copies
+ // https://groups.google.com/d/msg/emscripten-discuss/CMfYljLWMvY/Di52WB2QAgAJ
+ auto result_or = RunIreeModule(std::move(module_file_data), inputs_string);
+ if (!result_or.ok()) {
+ return "Error: " + result_or.status().ToString();
+ } else {
+ return result_or.ValueOrDie();
+ }
+}
+
+EMSCRIPTEN_BINDINGS(iree) {
+ emscripten::function("runIreeModule", &RunIreeModuleEntry);
+}
+
+extern "C" int main(int argc, char** argv) {
+ InitializeEnvironment(&argc, &argv);
+ return 0;
+}
+
+} // namespace iree
diff --git a/vm/BUILD b/vm/BUILD
new file mode 100644
index 0000000..c7fe7a2
--- /dev/null
+++ b/vm/BUILD
@@ -0,0 +1,200 @@
+# Bytecode VM used by the IREE sequencer and interpreter.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "api",
+ srcs = ["api.cc"],
+ hdrs = ["api.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":api_hdrs",
+ ":sequencer_module",
+ "///base:api",
+ "///base:api_util",
+ "///base:flatbuffer_util",
+ "///base:tracing",
+ "///rt:api",
+ ],
+)
+
+cc_library(
+ name = "api_hdrs",
+ hdrs = ["api.h"],
+ deps = [
+ "///base:api_hdrs",
+ "///rt:api_hdrs",
+ ],
+)
+
+cc_library(
+ name = "bytecode_module",
+ srcs = [
+ "bytecode_disassembler.cc",
+ "bytecode_module.cc",
+ ],
+ hdrs = [
+ "bytecode_disassembler.h",
+ "bytecode_module.h",
+ ],
+ deps = [
+ ":bytecode_util",
+ ":opcode_info",
+ ":source_map_resolver",
+ ":type",
+ "///base:flatbuffer_util",
+ "///base:status",
+ "///base:tracing",
+ "///hal:buffer_view",
+ "///rt",
+ "///schemas",
+ "///schemas/bytecode:bytecode_v0",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "bytecode_reader",
+ srcs = ["bytecode_reader.cc"],
+ hdrs = ["bytecode_reader.h"],
+ deps = [
+ ":bytecode_module",
+ ":type",
+ "///base:shape",
+ "///base:status",
+ "///hal:buffer_view",
+ "///hal:heap_buffer",
+ "///rt",
+ "///schemas/bytecode:bytecode_v0",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ ],
+)
+
+cc_library(
+ name = "bytecode_tables_interpreter",
+ srcs = ["bytecode_tables_interpreter.cc"],
+ hdrs = ["bytecode_tables_interpreter.h"],
+ deps = [
+ ":opcode_info",
+ "///schemas/bytecode:interpreter_bytecode_v0",
+ ],
+)
+
+cc_library(
+ name = "bytecode_tables_sequencer",
+ srcs = ["bytecode_tables_sequencer.cc"],
+ hdrs = ["bytecode_tables_sequencer.h"],
+ deps = [
+ ":opcode_info",
+ "///schemas/bytecode:sequencer_bytecode_v0",
+ ],
+)
+
+cc_library(
+ name = "bytecode_util",
+ srcs = ["bytecode_util.cc"],
+ hdrs = ["bytecode_util.h"],
+ deps = [
+ "///schemas/bytecode:bytecode_v0",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "bytecode_validator",
+ srcs = ["bytecode_validator.cc"],
+ hdrs = ["bytecode_validator.h"],
+ deps = [
+ ":bytecode_module",
+ "///base:status",
+ "///schemas",
+ ],
+)
+
+cc_library(
+ name = "opcode_info",
+ hdrs = ["opcode_info.h"],
+ deps = [
+ "///schemas/bytecode:bytecode_v0",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "sequencer_dispatch",
+ srcs = ["sequencer_dispatch.cc"],
+ hdrs = ["sequencer_dispatch.h"],
+ deps = [
+ ":bytecode_module",
+ ":bytecode_reader",
+ ":bytecode_tables_sequencer",
+ ":bytecode_util",
+ ":opcode_info",
+ "///base:logging",
+ "///base:memory",
+ "///base:status",
+ "///hal:buffer_view",
+ "///hal:command_queue",
+ "///hal:device",
+ "///hal:device_placement",
+ "///hal:heap_buffer",
+ "///rt",
+ "///schemas/bytecode:sequencer_bytecode_v0",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "sequencer_module",
+ srcs = ["sequencer_module.cc"],
+ hdrs = ["sequencer_module.h"],
+ deps = [
+ ":bytecode_module",
+ ":bytecode_tables_sequencer",
+ ":sequencer_dispatch",
+ "///base:status",
+ "///base:tracing",
+ "///hal:buffer_view",
+ "///rt",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+cc_library(
+ name = "source_map_resolver",
+ srcs = ["source_map_resolver.cc"],
+ hdrs = ["source_map_resolver.h"],
+ deps = [
+ "///base:flatbuffer_util",
+ "///base:status",
+ "///rt",
+ "///schemas",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
+ name = "type",
+ srcs = ["type.cc"],
+ hdrs = ["type.h"],
+ deps = [
+ "///base:status",
+ "///schemas",
+ "///schemas/bytecode:bytecode_v0",
+ ],
+)
diff --git a/iree/vm/CMakeLists.txt b/vm/CMakeLists.txt
similarity index 100%
rename from iree/vm/CMakeLists.txt
rename to vm/CMakeLists.txt
diff --git a/vm/api.cc b/vm/api.cc
new file mode 100644
index 0000000..08bd6e0
--- /dev/null
+++ b/vm/api.cc
@@ -0,0 +1,99 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/api.h"
+
+#include "base/api.h"
+#include "base/api_util.h"
+#include "base/flatbuffer_util.h"
+#include "base/tracing.h"
+#include "vm/sequencer_module.h"
+
+namespace iree {
+namespace vm {
+
+//===----------------------------------------------------------------------===//
+// iree::vm::BytecodeModule
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_vm_bytecode_module_create_from_buffer(
+ iree_const_byte_span_t buffer_data,
+ void (*buffer_free_fn)(void* self, iree_byte_span_t buffer_data),
+ void* buffer_free_self, iree_allocator_t allocator,
+ iree_rt_module_t** out_module) {
+ IREE_TRACE_SCOPE0("iree_vm_bytecode_module_create_from_buffer");
+
+ if (!out_module) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_module = nullptr;
+
+ if (!buffer_data.data || !buffer_data.data_length) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ IREE_API_ASSIGN_OR_RETURN(
+ auto module_file,
+ FlatBufferFile<ModuleDef>::FromBuffer(
+ ModuleDefIdentifier(), {buffer_data.data, buffer_data.data_length},
+ [buffer_free_fn, buffer_free_self, buffer_data]() {
+ if (buffer_free_fn != nullptr) {
+ buffer_free_fn(buffer_free_self,
+ {const_cast<uint8_t*>(buffer_data.data),
+ buffer_data.data_length});
+ }
+ }));
+
+ IREE_API_ASSIGN_OR_RETURN(auto module,
+ SequencerModule::FromFile(std::move(module_file)));
+
+ *out_module = reinterpret_cast<iree_rt_module_t*>(module.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_vm_bytecode_module_create_from_file_mapping(
+ iree_file_mapping_t* file_mapping, iree_allocator_t allocator,
+ iree_rt_module_t** out_module) {
+ IREE_TRACE_SCOPE0("iree_vm_bytecode_module_create_from_file_mapping");
+
+ if (!out_module) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_module = nullptr;
+
+ if (!file_mapping) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto buffer_data = iree_file_mapping_data(file_mapping);
+ IREE_API_ASSIGN_OR_RETURN(
+ auto module_file,
+ FlatBufferFile<ModuleDef>::FromBuffer(
+ ModuleDefIdentifier(), {buffer_data.data, buffer_data.data_length},
+ [file_mapping]() { iree_file_mapping_release(file_mapping); }));
+ iree_file_mapping_retain(file_mapping);
+
+ IREE_API_ASSIGN_OR_RETURN(auto module,
+ SequencerModule::FromFile(std::move(module_file)));
+
+ *out_module = reinterpret_cast<iree_rt_module_t*>(module.release());
+
+ return IREE_STATUS_OK;
+}
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/api.h b/vm/api.h
new file mode 100644
index 0000000..adaf774
--- /dev/null
+++ b/vm/api.h
@@ -0,0 +1,60 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// See iree/base/api.h for documentation on the API conventions used.
+
+#ifndef IREE_VM_API_H_
+#define IREE_VM_API_H_
+
+#include <stdint.h>
+
+#include "base/api.h"
+#include "rt/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree::vm::BytecodeModule
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Creates a VM module from an in-memory ModuleDef FlatBuffer.
+// The provided |buffer_free_fn| will be called when the module is destroyed
+// and only if this creation function succeeds. If ownership remains with the
+// caller then pass nullptr for |buffer_free_fn|.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_vm_bytecode_module_create_from_buffer(
+ iree_const_byte_span_t buffer_data,
+ void (*buffer_free_fn)(void* self, iree_byte_span_t buffer_data),
+ void* buffer_free_self, iree_allocator_t allocator,
+ iree_rt_module_t** out_module);
+
+// Creates a VM module from a mapped ModuleDef FlatBuffer.
+// The provided |file_mapping| will be retained for the life of the module and
+// the contents will be accessed by reference.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_vm_bytecode_module_create_from_file_mapping(
+ iree_file_mapping_t* file_mapping, iree_allocator_t allocator,
+ iree_rt_module_t** out_module);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_VM_API_H_
diff --git a/vm/bytecode_disassembler.cc b/vm/bytecode_disassembler.cc
new file mode 100644
index 0000000..532abc7
--- /dev/null
+++ b/vm/bytecode_disassembler.cc
@@ -0,0 +1,482 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/bytecode_disassembler.h"
+
+#include <iomanip>
+#include <sstream>
+
+#include "absl/base/macros.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "base/status.h"
+#include "schemas/bytecode/bytecode_v0.h"
+#include "schemas/source_map_def_generated.h"
+#include "vm/bytecode_module.h"
+#include "vm/bytecode_util.h"
+#include "vm/type.h"
+
+namespace iree {
+namespace vm {
+
+namespace {
+
+using ::iree::rt::SourceOffset;
+
+template <typename T>
+StatusOr<T> ReadValue(absl::Span<const uint8_t> data, SourceOffset* offset) {
+ if (*offset + sizeof(T) > data.size()) {
+ return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode data underrun";
+ }
+ auto value = *reinterpret_cast<const T*>(&data[*offset]);
+ *offset = *offset + sizeof(T);
+ return value;
+}
+
+StatusOr<const Type> ReadType(absl::Span<const uint8_t> data,
+ SourceOffset* offset) {
+ ASSIGN_OR_RETURN(uint8_t type_index, ReadValue<uint8_t>(data, offset));
+ return Type::FromTypeIndex(type_index);
+}
+
+StatusOr<uint8_t> ReadCount(absl::Span<const uint8_t> data,
+ SourceOffset* offset) {
+ return ReadValue<uint8_t>(data, offset);
+}
+
+StatusOr<uint16_t> ReadValueSlot(absl::Span<const uint8_t> data,
+ SourceOffset* offset) {
+ return ReadValue<uint16_t>(data, offset);
+}
+
+absl::string_view ConstantEncodingToString(ConstantEncoding encoding) {
+ switch (encoding) {
+#define GET_NAME(ordinal, enum_name, str, ...) \
+ case ConstantEncoding::enum_name: \
+ return str;
+ IREE_CONSTANT_ENCODING_LIST(GET_NAME)
+#undef GET_NAME
+ default:
+ return "unknown";
+ }
+}
+
+template <typename T>
+std::string TypedDataToString(absl::Span<const uint8_t> bytes) {
+ auto typed_data = absl::Span<const T>{
+ reinterpret_cast<const T*>(bytes.data()), bytes.size() / sizeof(T)};
+ return absl::StrJoin(typed_data, ",");
+}
+
+std::string ConstantToString(const Type& type,
+ absl::Span<const uint8_t> bytes) {
+ if (!type.is_builtin()) {
+ return absl::StrJoin(bytes, ",");
+ }
+ switch (type.builtin_type()) {
+ case BuiltinType::kI8:
+ return TypedDataToString<uint8_t>(bytes);
+ case BuiltinType::kI16:
+ return TypedDataToString<uint16_t>(bytes);
+ case BuiltinType::kI32:
+ return TypedDataToString<uint32_t>(bytes);
+ case BuiltinType::kI64:
+ return TypedDataToString<uint64_t>(bytes);
+ case BuiltinType::kF16:
+ return TypedDataToString<uint16_t>(bytes);
+ case BuiltinType::kF32:
+ return TypedDataToString<float>(bytes);
+ case BuiltinType::kF64:
+ return TypedDataToString<double>(bytes);
+ default:
+ return "<unsupported>";
+ }
+}
+
+} // namespace
+
+StatusOr<std::vector<rt::Instruction>>
+BytecodeDisassembler::DisassembleInstructions(const rt::Function& function,
+ SourceOffset offset,
+ int32_t instruction_count) const {
+ std::vector<rt::Instruction> instructions;
+
+ ASSIGN_OR_RETURN(
+ auto* function_def,
+ static_cast<const BytecodeModule*>(function.module())
+ ->GetFunctionDef(function.linkage(), function.ordinal()));
+ auto* bytecode_def = function_def->bytecode();
+ if (!bytecode_def) {
+ return UnavailableErrorBuilder(IREE_LOC) << "Function contains no body";
+ }
+ auto data = absl::MakeSpan(
+ reinterpret_cast<const uint8_t*>(bytecode_def->contents()->data()),
+ bytecode_def->contents()->size());
+
+ // TODO(benvanik): scan and find all branch offsets to insert labels
+
+ while (offset < data.length() && instructions.size() < instruction_count) {
+ instructions.push_back({});
+ auto& instruction = instructions.back();
+ instruction.offset = offset;
+
+ uint8_t opcode = data[offset++];
+ const auto& opcode_info = GetOpcodeInfo(opcode_table_, opcode);
+ if (!opcode_info.mnemonic) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unhandled opcode " << opcode << " at offset " << (offset - 1);
+ }
+ int payload_offset = offset;
+
+ std::ostringstream stream;
+
+ // Print out return values, if any.
+ int base_result_index = 0;
+ int printed_result_count = 0;
+ for (int i = base_result_index; i < ABSL_ARRAYSIZE(opcode_info.operands);
+ ++i) {
+ if (opcode_info.operands[i] == OperandEncoding::kNone) break;
+ if (printed_result_count > 0) {
+ stream << ", ";
+ }
+ switch (opcode_info.operands[i]) {
+ default:
+ case OperandEncoding::kNone:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unhandled op encoding "
+ << static_cast<int>(opcode_info.operands[i]) << " at offset "
+ << (offset - 1);
+ case OperandEncoding::kInputSlot:
+ case OperandEncoding::kOutputSlot: {
+ // Printing handled below.
+ offset += sizeof(uint16_t);
+ break;
+ }
+ case OperandEncoding::kVariadicInputSlots:
+ case OperandEncoding::kVariadicOutputSlots: {
+ // Printing handled below.
+ ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
+ offset += count * sizeof(uint16_t);
+ break;
+ }
+ case OperandEncoding::kResultSlot: {
+ ++printed_result_count;
+ ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset));
+ stream << "%" << slot_ordinal;
+ break;
+ }
+ case OperandEncoding::kVariadicResultSlots: {
+ ++printed_result_count;
+ stream << "[";
+ ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
+ for (int j = 0; j < count; ++j) {
+ ASSIGN_OR_RETURN(uint16_t slot_ordinal,
+ ReadValueSlot(data, &offset));
+ if (j > 0) stream << ", ";
+ stream << "%" << slot_ordinal;
+ }
+ stream << "]";
+ break;
+ }
+ case OperandEncoding::kVariadicTransferSlots: {
+ // Printing handled below.
+ ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
+ offset += count * 2 * sizeof(uint16_t);
+ break;
+ }
+ case OperandEncoding::kConstant: {
+ // Printing handled below.
+ ASSIGN_OR_RETURN(auto type, ReadType(data, &offset));
+ ASSIGN_OR_RETURN(int rank, ReadCount(data, &offset));
+ int element_count = 1;
+ for (int j = 0; j < rank; ++j) {
+ ASSIGN_OR_RETURN(int dim, ReadValue<int32_t>(data, &offset));
+ element_count *= dim;
+ }
+ offset += sizeof(ConstantEncoding);
+ offset += element_count * type.element_size();
+ break;
+ }
+ case OperandEncoding::kFunctionOrdinal: {
+ // Printing handled below.
+ offset += sizeof(uint32_t);
+ break;
+ }
+ case OperandEncoding::kDispatchOrdinal: {
+ // Printing handled below.
+ offset += sizeof(uint32_t) + sizeof(uint16_t);
+ break;
+ }
+ case OperandEncoding::kBlockOffset: {
+ // Printing handled below.
+ offset += sizeof(uint32_t);
+ break;
+ }
+ case OperandEncoding::kTypeIndex: {
+ // Printing handled below.
+ offset += sizeof(uint8_t);
+ break;
+ }
+ case OperandEncoding::kIndex: {
+ // Printing handled below.
+ offset += sizeof(int32_t);
+ break;
+ }
+ case OperandEncoding::kIndexList: {
+ // Printing handled below.
+ ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
+ offset += count * sizeof(int32_t);
+ break;
+ }
+ case OperandEncoding::kCmpIPredicate:
+ case OperandEncoding::kCmpFPredicate: {
+ // Printing handled below.
+ offset += sizeof(uint8_t);
+ break;
+ }
+ }
+ }
+ if (printed_result_count > 0) {
+ stream << " = ";
+ }
+ offset = payload_offset;
+
+ stream << opcode_info.mnemonic;
+
+ // Print out operands.
+ int base_operand_index = 0;
+ int printed_operand_count = 0;
+ for (int i = base_operand_index; i < ABSL_ARRAYSIZE(opcode_info.operands);
+ ++i) {
+ if (opcode_info.operands[i] == OperandEncoding::kNone) break;
+ if (opcode_info.operands[i] != OperandEncoding::kResultSlot &&
+ opcode_info.operands[i] != OperandEncoding::kVariadicResultSlots) {
+ if (i == base_operand_index) {
+ stream << " ";
+ } else if (printed_operand_count > 0) {
+ stream << ", ";
+ }
+ }
+ switch (opcode_info.operands[i]) {
+ default:
+ case OperandEncoding::kNone:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unhandled op encoding "
+ << static_cast<int>(opcode_info.operands[i]) << " at offset "
+ << (offset - 1);
+ case OperandEncoding::kInputSlot: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset));
+ stream << "%" << slot_ordinal;
+ break;
+ }
+ case OperandEncoding::kVariadicInputSlots: {
+ ++printed_operand_count;
+ stream << "[";
+ ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
+ for (int j = 0; j < count; ++j) {
+ ASSIGN_OR_RETURN(uint16_t slot_ordinal,
+ ReadValueSlot(data, &offset));
+ if (j > 0) stream << ", ";
+ stream << "%" << slot_ordinal;
+ }
+ stream << "]";
+ break;
+ }
+ case OperandEncoding::kOutputSlot: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset));
+ stream << "&"
+ << "%" << slot_ordinal;
+ break;
+ }
+ case OperandEncoding::kVariadicOutputSlots: {
+ ++printed_operand_count;
+ stream << "[";
+ ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
+ for (int j = 0; j < count; ++j) {
+ ASSIGN_OR_RETURN(uint16_t slot_ordinal,
+ ReadValueSlot(data, &offset));
+ if (j > 0) stream << ", ";
+ stream << "&"
+ << "%" << slot_ordinal;
+ }
+ stream << "]";
+ break;
+ }
+ case OperandEncoding::kResultSlot: {
+ // Printing handled above.
+ offset += sizeof(uint16_t);
+ break;
+ }
+ case OperandEncoding::kVariadicResultSlots: {
+ // Printing handled above.
+ ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
+ offset += count * sizeof(uint16_t);
+ break;
+ }
+ case OperandEncoding::kVariadicTransferSlots: {
+ ++printed_operand_count;
+ stream << "[";
+ ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
+ for (int j = 0; j < count; ++j) {
+ ASSIGN_OR_RETURN(uint16_t src_slot_ordinal,
+ ReadValueSlot(data, &offset));
+ ASSIGN_OR_RETURN(uint16_t dst_slot_ordinal,
+ ReadValueSlot(data, &offset));
+ if (j > 0) stream << ", ";
+ stream << "%" << src_slot_ordinal << "=>%" << dst_slot_ordinal;
+ }
+ stream << "]";
+ break;
+ }
+ case OperandEncoding::kConstant: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(auto type, ReadType(data, &offset));
+ ASSIGN_OR_RETURN(int rank, ReadCount(data, &offset));
+ absl::InlinedVector<int32_t, 4> shape(rank);
+ int element_count = 1;
+ for (int j = 0; j < rank; ++j) {
+ ASSIGN_OR_RETURN(int dim, ReadValue<int32_t>(data, &offset));
+ shape[j] = dim;
+ element_count *= dim;
+ }
+ ASSIGN_OR_RETURN(auto encoding,
+ ReadValue<ConstantEncoding>(data, &offset));
+ stream << ConstantEncodingToString(encoding);
+ int serialized_element_count = 1;
+ switch (encoding) {
+ case ConstantEncoding::kDense:
+ serialized_element_count = element_count;
+ break;
+ case ConstantEncoding::kSplat:
+ serialized_element_count = 1;
+ break;
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented constant encoding "
+ << static_cast<int>(encoding);
+ }
+ stream << " buffer_view<";
+ if (!shape.empty()) {
+ stream << absl::StrJoin(shape, "x") << "x";
+ }
+ stream << type << ">{";
+ size_t element_size = type.element_size();
+ auto bytes = data.subspan(
+ offset, std::min(serialized_element_count, 1024) * element_size);
+ stream << ConstantToString(type, bytes);
+ if (serialized_element_count > 1024) stream << "...";
+ offset += serialized_element_count * element_size;
+ stream << "}";
+ break;
+ }
+ case OperandEncoding::kFunctionOrdinal: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(auto function_ordinal,
+ ReadValue<uint32_t>(data, &offset));
+ ASSIGN_OR_RETURN(
+ auto target_function,
+ function.module()->LookupFunctionByOrdinal(
+ rt::Function::Linkage::kInternal, function_ordinal));
+ stream << "@" << function_ordinal << " " << target_function.name();
+ break;
+ }
+ case OperandEncoding::kDispatchOrdinal: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(auto dispatch_ordinal,
+ ReadValue<uint32_t>(data, &offset));
+ ASSIGN_OR_RETURN(auto export_ordinal,
+ ReadValue<uint16_t>(data, &offset));
+ // TODO(benvanik): lookup in executable table.
+ stream << "@" << dispatch_ordinal << ":" << export_ordinal;
+ break;
+ }
+ case OperandEncoding::kImportOrdinal: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(auto import_ordinal,
+ ReadValue<uint32_t>(data, &offset));
+ ASSIGN_OR_RETURN(auto target_function,
+ function.module()->LookupFunctionByOrdinal(
+ rt::Function::Linkage::kImport, import_ordinal));
+ stream << "@i" << import_ordinal << " " << target_function.name();
+ break;
+ }
+ case OperandEncoding::kBlockOffset: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(uint32_t block_offset,
+ ReadValue<uint32_t>(data, &offset));
+ stream << ":" << block_offset;
+ break;
+ }
+ case OperandEncoding::kTypeIndex: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(auto type, ReadType(data, &offset));
+ stream << type;
+ break;
+ }
+ case OperandEncoding::kIndex: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(auto index, ReadValue<int32_t>(data, &offset));
+ stream << "#" << index;
+ break;
+ }
+ case OperandEncoding::kIndexList: {
+ ++printed_operand_count;
+ stream << "{";
+ ASSIGN_OR_RETURN(int count, ReadCount(data, &offset));
+ for (int j = 0; j < count; ++j) {
+ ASSIGN_OR_RETURN(auto dim, ReadValue<int32_t>(data, &offset));
+ if (j > 0) stream << ",";
+ stream << dim;
+ }
+ stream << "}";
+ break;
+ }
+ case OperandEncoding::kCmpIPredicate: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(auto predicate_value,
+ ReadValue<uint8_t>(data, &offset));
+ stream << "<"
+ << PredicateToString(
+ static_cast<CmpIPredicate>(predicate_value))
+ << ">";
+ break;
+ }
+ case OperandEncoding::kCmpFPredicate: {
+ ++printed_operand_count;
+ ASSIGN_OR_RETURN(auto predicate_value,
+ ReadValue<uint8_t>(data, &offset));
+ stream << "<"
+ << PredicateToString(
+ static_cast<CmpFPredicate>(predicate_value))
+ << ">";
+ break;
+ }
+ }
+ }
+
+ stream << "\n";
+
+ instruction.long_text = stream.str();
+ instruction.short_text = instruction.long_text;
+ }
+
+ return instructions;
+}
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/bytecode_disassembler.h b/vm/bytecode_disassembler.h
new file mode 100644
index 0000000..4639777
--- /dev/null
+++ b/vm/bytecode_disassembler.h
@@ -0,0 +1,46 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_BYTECODE_DISASSEMBLER_H_
+#define IREE_VM_BYTECODE_DISASSEMBLER_H_
+
+#include <ostream>
+
+#include "base/status.h"
+#include "rt/disassembler.h"
+#include "schemas/bytecode_def_generated.h"
+#include "schemas/source_map_def_generated.h"
+#include "vm/opcode_info.h"
+
+namespace iree {
+namespace vm {
+
+// Disassembles bytecode with a specific op set to text.
+class BytecodeDisassembler final : public rt::Disassembler {
+ public:
+ explicit BytecodeDisassembler(OpcodeTable opcode_table)
+ : opcode_table_(opcode_table) {}
+
+ StatusOr<std::vector<rt::Instruction>> DisassembleInstructions(
+ const rt::Function& function, rt::SourceOffset offset,
+ int32_t instruction_count) const override;
+
+ private:
+ OpcodeTable opcode_table_;
+};
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_BYTECODE_DISASSEMBLER_H_
diff --git a/vm/bytecode_module.cc b/vm/bytecode_module.cc
new file mode 100644
index 0000000..9ba5e84
--- /dev/null
+++ b/vm/bytecode_module.cc
@@ -0,0 +1,310 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/bytecode_module.h"
+
+#include "absl/memory/memory.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/buffer_view.h"
+#include "vm/bytecode_disassembler.h"
+
+namespace iree {
+namespace vm {
+
+namespace {
+
+using ::iree::hal::BufferView;
+using ::iree::rt::Function;
+using ::iree::rt::FunctionSignature;
+using ::iree::rt::Module;
+using ::iree::rt::ModuleSignature;
+
+Status ValidateElementSize(int element_bit_width,
+ const ElementTypeDef& expected_element_type) {
+ switch (expected_element_type.type_union_type()) {
+ case ElementTypeDefUnion::FloatTypeDef: {
+ auto expected_bit_width =
+ expected_element_type.type_union_as_FloatTypeDef()->width();
+ if (element_bit_width != expected_bit_width) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Has element bit width " << element_bit_width
+ << " but expected " << expected_bit_width;
+ }
+ return OkStatus();
+ }
+ case ElementTypeDefUnion::IntegerTypeDef: {
+ auto expected_bit_width =
+ expected_element_type.type_union_as_IntegerTypeDef()->width();
+ if (element_bit_width != expected_bit_width) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Has element bit width " << element_bit_width
+ << " but expected " << expected_bit_width;
+ }
+ return OkStatus();
+ }
+ case ElementTypeDefUnion::UnknownTypeDef:
+ case ElementTypeDefUnion::NONE: {
+ }
+ }
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Defined type has unsupported element type "
+ << EnumNameElementTypeDefUnion(
+ expected_element_type.type_union_type());
+}
+
+Status ValidateTypeStructure(const FunctionTypeDef& type_def) {
+ // Ensure all fields are populated.
+ return OkStatus();
+}
+
+Status ValidateFunctionTableStructure(
+ const FunctionTableDef& function_table_def) {
+ if (!function_table_def.functions()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Function table is missing the function listing";
+ }
+
+ // All functions must contain a valid type.
+ const auto& functions = *function_table_def.functions();
+ for (int i = 0; i < functions.size(); ++i) {
+ const auto* function = functions[i];
+ if (!function) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Function ordinal " << i << " is missing its contents";
+ }
+ if (!function->type()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Function ordinal " << i << " is missing its type";
+ }
+ RETURN_IF_ERROR(ValidateTypeStructure(*function->type()));
+ }
+
+ // Imports must also have a name (that we can use to resolve it).
+ if (function_table_def.imports()) {
+ const auto& imports = *function_table_def.imports();
+ for (int i = 0; i < imports.size(); ++i) {
+ int function_index = imports[i];
+ if (!functions[function_index]->name()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Import ordinal " << i << " is missing its contents";
+ }
+ }
+ }
+
+ // Exports must also have a name (that others will use to look it up).
+ if (function_table_def.exports()) {
+ const auto& exports = *function_table_def.exports();
+ for (int i = 0; i < exports.size(); ++i) {
+ int function_index = exports[i];
+ if (!functions[function_index]->name()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Export ordinal " << i << " is missing its contents";
+ }
+ }
+ }
+
+ return OkStatus();
+}
+
+Status ValidateExecutableTableStructure(
+ const ExecutableTableDef& executable_table_def) {
+ if (!executable_table_def.multi_arch_executables()) {
+ // May have sequencer only fns. Fine to not have dispatchable executables.
+ return OkStatus();
+ }
+
+ // All fat executables need at least one device-specific executable.
+ const auto& multi_arch_executables =
+ *executable_table_def.multi_arch_executables();
+ for (int i = 0; i < multi_arch_executables.size(); ++i) {
+ const auto* multi_arch_executable = multi_arch_executables[i];
+ if (!multi_arch_executable || !multi_arch_executable->executables() ||
+ multi_arch_executable->executables()->size() == 0) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Multi-arch executable ordinal " << i
+ << " is missing its contents";
+ }
+ }
+
+ return OkStatus();
+}
+
+} // namespace
+
+// static
+Status BytecodeModule::ValidateStructure(const ModuleDef& module_def) {
+ IREE_TRACE_SCOPE0("BytecodeModule::ValidateStructure");
+
+ // Must have a function table.
+ if (module_def.function_table()) {
+ RETURN_IF_ERROR(
+ ValidateFunctionTableStructure(*module_def.function_table()));
+ } else {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "ModuleDef is missing a function table";
+ }
+
+ // Must have an executable table.
+ if (module_def.executable_table()) {
+ RETURN_IF_ERROR(
+ ValidateExecutableTableStructure(*module_def.executable_table()));
+ } else {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "ModuleDef is missing an executable table";
+ }
+
+ return OkStatus();
+}
+
+BytecodeModule::BytecodeModule(std::unique_ptr<ModuleFile> module_file,
+ OpcodeTable opcode_table)
+ : module_file_(std::move(module_file)),
+ module_def_(*module_file_->root()),
+ source_resolver_(SourceMapResolver::FromModule(module_def_)),
+ disassembler_(absl::make_unique<BytecodeDisassembler>(opcode_table)) {}
+
+BytecodeModule::~BytecodeModule() = default;
+
+const ModuleSignature BytecodeModule::signature() const {
+ return ModuleSignature(function_table_def().imports()->size(),
+ function_table_def().exports()->size(),
+ function_table_def().functions()->size(), 0);
+}
+
+std::string BytecodeModule::DebugStringShort() const {
+ return std::string(name());
+}
+
+StatusOr<int32_t> BytecodeModule::MapFunctionOrdinal(Function::Linkage linkage,
+ int32_t ordinal) const {
+ const auto& function_table = function_table_def();
+ switch (linkage) {
+ case Function::Linkage::kImport:
+ if (ordinal < 0 || ordinal >= function_table.imports()->size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Import ordinal " << ordinal
+ << " is outside the valid range [0, "
+ << function_table.imports()->size() << ")";
+ }
+ ordinal = function_table.imports()->Get(ordinal);
+ break;
+ case Function::Linkage::kExport:
+ if (ordinal < 0 || ordinal >= function_table.exports()->size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Export ordinal " << ordinal
+ << " is outside the valid range [0, "
+ << function_table.exports()->size() << ")";
+ }
+ ordinal = function_table.exports()->Get(ordinal);
+ break;
+ default:
+ break;
+ }
+ if (ordinal < 0 || ordinal >= function_table.functions()->size()) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Function ordinal " << ordinal
+ << " is outside the valid range [0, "
+ << function_table.functions()->size() << ")";
+ }
+ return ordinal;
+}
+
+StatusOr<const Function> BytecodeModule::LookupFunctionByOrdinal(
+ Function::Linkage linkage, int32_t ordinal) const {
+ ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
+ return Function(this, Function::Linkage::kInternal, ordinal);
+}
+
+StatusOr<const Function> BytecodeModule::LookupFunctionByName(
+ Function::Linkage linkage, absl::string_view name) const {
+ const auto& functions = *function_table_def().functions();
+ for (int i = 0; i < functions.size(); ++i) {
+ const auto* function_def = functions.Get(i);
+ if (WrapString(function_def->name()) == name) {
+ return LookupFunctionByOrdinal(Function::Linkage::kInternal, i);
+ }
+ }
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Function '" << name
+ << "' not found in function table (or names have been stripped)";
+}
+
+StatusOr<absl::string_view> BytecodeModule::GetFunctionName(
+ Function::Linkage linkage, int32_t ordinal) const {
+ ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
+ const auto* function_def = function_table_def().functions()->Get(ordinal);
+ return WrapString(function_def->name());
+}
+
+StatusOr<const FunctionSignature> BytecodeModule::GetFunctionSignature(
+ Function::Linkage linkage, int32_t ordinal) const {
+ ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
+ const auto* function_def = function_table_def().functions()->Get(ordinal);
+ const auto* type_def = function_def->type();
+ return FunctionSignature(
+ type_def->inputs() ? type_def->inputs()->size() : 0,
+ type_def->results() ? type_def->results()->size() : 0);
+}
+
+StatusOr<const FunctionDef*> BytecodeModule::GetFunctionDef(
+ rt::Function::Linkage linkage, int32_t ordinal) const {
+ ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
+ const auto& function_defs = *function_table_def().functions();
+ if (ordinal >= function_defs.size()) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Internal function ordinal " << ordinal
+ << " out of range of table (" << function_defs.size() << ")";
+ }
+ return function_defs.Get(ordinal);
+}
+
+StatusOr<const MultiArchExecutableDef*>
+BytecodeModule::LookupMultiArchExecutable(int executable_ordinal) const {
+ if (executable_ordinal < 0 ||
+ executable_ordinal >=
+ executable_table_def().multi_arch_executables()->size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Invalid multi-arch executable ordinal " << executable_ordinal;
+ }
+ return executable_table_def().multi_arch_executables()->Get(
+ executable_ordinal);
+}
+
+// static
+Status BytecodeModule::ValidateArgType(const BufferView& arg,
+ const MemRefTypeDef& expected_type) {
+ RETURN_IF_ERROR(
+ ValidateElementSize(arg.element_size * 8, *expected_type.element_type()));
+
+ auto expected_shape = expected_type.shape();
+ if (arg.shape.size() != expected_shape->size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Argument should have rank " << expected_shape->size()
+ << " but has rank " << arg.shape.size();
+ }
+ for (int i = 0; i < expected_shape->size(); ++i) {
+ auto dim_size = arg.shape[i];
+ auto expected_dim_size = expected_shape->Get(i);
+ if (dim_size != expected_dim_size) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Argument dimension " << i << " should have size "
+ << expected_dim_size << " but has size " << dim_size;
+ }
+ }
+ return OkStatus();
+}
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/bytecode_module.h b/vm/bytecode_module.h
new file mode 100644
index 0000000..b2df6d9
--- /dev/null
+++ b/vm/bytecode_module.h
@@ -0,0 +1,103 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_BYTECODE_MODULE_H_
+#define IREE_VM_BYTECODE_MODULE_H_
+
+#include <memory>
+
+#include "base/flatbuffer_util.h"
+#include "rt/function.h"
+#include "rt/module.h"
+#include "schemas/executable_table_def_generated.h"
+#include "schemas/function_table_def_generated.h"
+#include "schemas/module_def_generated.h"
+#include "vm/opcode_info.h"
+#include "vm/source_map_resolver.h"
+
+namespace iree {
+namespace vm {
+
+using ModuleFile = FlatBufferFile<ModuleDef>;
+
+// A loaded bytecode module backed by a FlatBuffer.
+class BytecodeModule : public rt::Module {
+ public:
+ static Status ValidateStructure(const ModuleDef& module_def);
+
+ ~BytecodeModule() override;
+
+ const ModuleDef& def() const { return module_def_; }
+ const FunctionTableDef& function_table_def() const {
+ return *module_def_.function_table();
+ }
+ const ExecutableTableDef& executable_table_def() const {
+ return *module_def_.executable_table();
+ }
+
+ absl::string_view name() const override {
+ return WrapString(module_def_.name());
+ }
+
+ const rt::ModuleSignature signature() const override;
+
+ rt::SourceResolver* source_resolver() const override {
+ return &source_resolver_;
+ }
+
+ rt::Disassembler* disassembler() const override {
+ return disassembler_.get();
+ }
+
+ std::string DebugStringShort() const override;
+
+ StatusOr<const rt::Function> LookupFunctionByOrdinal(
+ rt::Function::Linkage linkage, int32_t ordinal) const override;
+
+ StatusOr<const rt::Function> LookupFunctionByName(
+ rt::Function::Linkage linkage, absl::string_view name) const override;
+
+ StatusOr<absl::string_view> GetFunctionName(rt::Function::Linkage linkage,
+ int32_t ordinal) const override;
+
+ StatusOr<const rt::FunctionSignature> GetFunctionSignature(
+ rt::Function::Linkage linkage, int32_t ordinal) const override;
+
+ StatusOr<const FunctionDef*> GetFunctionDef(rt::Function::Linkage linkage,
+ int32_t ordinal) const;
+
+ StatusOr<const MultiArchExecutableDef*> LookupMultiArchExecutable(
+ int executable_ordinal) const;
+
+ protected:
+ BytecodeModule(std::unique_ptr<ModuleFile> module_file,
+ OpcodeTable opcode_table);
+
+ static Status ValidateArgType(const hal::BufferView& arg,
+ const MemRefTypeDef& expected_type);
+
+ private:
+ StatusOr<int32_t> MapFunctionOrdinal(rt::Function::Linkage linkage,
+ int32_t ordinal) const;
+
+ std::unique_ptr<ModuleFile> module_file_;
+ const ModuleDef& module_def_;
+ mutable SourceMapResolver source_resolver_;
+ mutable std::unique_ptr<rt::Disassembler> disassembler_;
+};
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_BYTECODE_MODULE_H_
diff --git a/vm/bytecode_reader.cc b/vm/bytecode_reader.cc
new file mode 100644
index 0000000..7730857
--- /dev/null
+++ b/vm/bytecode_reader.cc
@@ -0,0 +1,289 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/bytecode_reader.h"
+
+#include "base/shape.h"
+#include "base/status.h"
+#include "hal/heap_buffer.h"
+#include "vm/bytecode_module.h"
+
+namespace iree {
+namespace vm {
+
+namespace {
+
+using ::iree::hal::BufferView;
+using ::iree::rt::StackFrame;
+
+} // namespace
+
+StatusOr<const uint8_t*> BytecodeReader::AdvanceOffset() {
+ *stack_frame_->mutable_offset() = offset();
+ // TODO(benvanik): make a flag and/or remove.
+ DVLOG(1) << "dispatch(" << stack_frame_->function().name() << "@" << offset()
+ << "): " << int(*bytecode_pc_);
+ for (int i = 0; i < registers_->buffer_views.size(); ++i) {
+ DVLOG(1) << "local[" << i << "] "
+ << registers_->buffer_views[i].DebugStringShort();
+ }
+ return bytecode_pc_++;
+}
+
+Status BytecodeReader::SkipLocals(int count) {
+ size_t stride = sizeof(uint16_t) * count;
+ if (bytecode_pc_ + stride >= bytecode_limit_) {
+ return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode underflow";
+ }
+ bytecode_pc_ += stride;
+ return OkStatus();
+}
+
+Status BytecodeReader::ReadShape(Shape* out_shape) {
+ ASSIGN_OR_RETURN(auto shape_dims, ReadIndexList());
+ *out_shape = Shape(shape_dims);
+ return OkStatus();
+}
+
+StatusOr<Shape> BytecodeReader::ReadShapePieces() {
+ // TODO(benvanik): rewrite to be faster (multiple offsets to walk both lists).
+ ASSIGN_OR_RETURN(auto shape_dims, ReadIndexList());
+ if (shape_dims.size() >= kMaxRank) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Shapes limited to rank " << kMaxRank << " right now";
+ }
+ int expected_dynamic_dims = 0;
+ for (int i = 0; i < shape_dims.size(); ++i) {
+ if (shape_dims[i] == -1) {
+ ++expected_dynamic_dims;
+ }
+ }
+
+ Shape shape(shape_dims);
+ ASSIGN_OR_RETURN(int dynamic_dims, ReadCount());
+ if (dynamic_dims != expected_dynamic_dims) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Expected " << expected_dynamic_dims << " dynamic dims but only "
+ << dynamic_dims << " provided";
+ } else if (dynamic_dims) {
+ for (int i = 0; i < shape_dims.size(); ++i) {
+ if (shape_dims[i] != -1) {
+ continue;
+ }
+ // TODO(benvanik): kill this embarrassment.
+ ASSIGN_OR_RETURN(auto dims_piece, ReadSlotElements<int32_t>());
+ if (dims_piece.size() != 1) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Dims piece has rank " << dims_piece.size() << "; must be 1";
+ }
+ shape[i] = dims_piece[0];
+ }
+ }
+ return shape;
+}
+
+StatusOr<Shape> BytecodeReader::ReadShapePieces(size_t* out_element_count) {
+ ASSIGN_OR_RETURN(auto shape, ReadShapePieces());
+ *out_element_count = shape.element_count();
+ return shape;
+}
+
+StatusOr<absl::Span<const int32_t>> BytecodeReader::ReadIndexList() {
+ ASSIGN_OR_RETURN(int count, ReadCount());
+ int stride = count * sizeof(int32_t);
+ if (bytecode_pc_ + stride >= bytecode_limit_) {
+ return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode underflow";
+ }
+ auto list = absl::Span<const int32_t>(
+ reinterpret_cast<const int32_t*>(bytecode_pc_), count);
+ bytecode_pc_ += stride;
+ return list;
+}
+
+Status BytecodeReader::SwitchStackFrame(StackFrame* new_stack_frame) {
+ // Flush old state.
+ auto* old_stack_frame = stack_frame_;
+ if (old_stack_frame) {
+ *old_stack_frame->mutable_offset() = offset();
+ }
+
+ // Switch the frame. The FiberState holds the full stack, this is just the
+ // current one for easy access.
+ stack_frame_ = new_stack_frame;
+
+ // Setup state pointers for faster dereferencing.
+ const auto& function = new_stack_frame->function();
+ ASSIGN_OR_RETURN(
+ const auto* function_def,
+ static_cast<const BytecodeModule*>(function.module())
+ ->GetFunctionDef(function.linkage(), function.ordinal()));
+ const auto& bytecode = *function_def->bytecode();
+ bytecode_base_ = bytecode.contents()->Data();
+ bytecode_limit_ = bytecode_base_ + bytecode.contents()->size();
+ bytecode_pc_ = bytecode_base_ + new_stack_frame->offset();
+ registers_ = new_stack_frame->mutable_registers();
+ return OkStatus();
+}
+
+Status BytecodeReader::CopyInputsAndSwitchStackFrame(
+ StackFrame* src_stack_frame, StackFrame* dst_stack_frame) {
+ ASSIGN_OR_RETURN(size_t src_count, ReadCount());
+ auto& dst_buffer_views = dst_stack_frame->mutable_registers()->buffer_views;
+ for (int i = 0; i < std::min(src_count, dst_buffer_views.size()); ++i) {
+ ASSIGN_OR_RETURN(auto* src_local,
+ ReadLocal(src_stack_frame->mutable_registers()));
+ dst_buffer_views[i] = *src_local;
+ }
+ return SwitchStackFrame(dst_stack_frame);
+}
+
+Status BytecodeReader::CopyResultsAndSwitchStackFrame(
+ StackFrame* src_stack_frame, StackFrame* dst_stack_frame) {
+ ASSIGN_OR_RETURN(int32_t src_count, ReadCount());
+ // TODO(benvanik): avoid vector.
+ absl::InlinedVector<BufferView*, 8> src_locals(src_count);
+ for (int i = 0; i < src_count; ++i) {
+ ASSIGN_OR_RETURN(src_locals[i],
+ ReadLocal(src_stack_frame->mutable_registers()));
+ }
+ RETURN_IF_ERROR(SwitchStackFrame(dst_stack_frame));
+ ASSIGN_OR_RETURN(int32_t dst_count, ReadCount());
+ if (src_count != dst_count) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Src and dst value counts differ: " << src_count << " vs "
+ << dst_count;
+ }
+ for (int i = 0; i < dst_count; ++i) {
+ ASSIGN_OR_RETURN(auto* dst_local,
+ ReadLocal(dst_stack_frame->mutable_registers()));
+ *dst_local = *src_locals[i];
+ }
+ return OkStatus();
+}
+
+Status BytecodeReader::CopySlots() {
+ ASSIGN_OR_RETURN(int32_t count, ReadCount());
+ for (int i = 0; i < count; ++i) {
+ ASSIGN_OR_RETURN(auto* src_local,
+ ReadLocal(stack_frame_->mutable_registers()));
+ ASSIGN_OR_RETURN(auto* dst_local,
+ ReadLocal(stack_frame_->mutable_registers()));
+ *dst_local = *src_local;
+ }
+ return OkStatus();
+}
+
+Status BytecodeReader::BranchToOffset(int32_t offset) {
+ const uint8_t* new_bytecode_pc = bytecode_base_ + offset;
+ if (new_bytecode_pc < bytecode_base_ || new_bytecode_pc > bytecode_limit_) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Branch target " << offset
+ << " is out of bounds of the function bytecode ("
+ << static_cast<size_t>(bytecode_limit_ - bytecode_base_)
+ << "b total)";
+ }
+ bytecode_pc_ = new_bytecode_pc;
+ return OkStatus();
+}
+
+StatusOr<BufferView> BytecodeReader::ReadConstant() {
+ BufferView buffer_view;
+
+ // Element type defines the buffer_view size (but we don't really care about
+ // the data format).
+ ASSIGN_OR_RETURN(auto element_type, ReadType());
+ buffer_view.element_size = element_type.element_size();
+
+ // Parse shape - constants always define a full shape.
+ RETURN_IF_ERROR(ReadShape(&buffer_view.shape));
+
+ // Read encoding to determine how the constant data is stored in the file.
+ ASSIGN_OR_RETURN(auto encoding, ReadValue<ConstantEncoding>());
+
+ // Get buffer for the constant data.
+ switch (encoding) {
+ case ConstantEncoding::kDense: {
+ // Validate we have all constant data present.
+ device_size_t serialized_length = buffer_view.byte_length();
+ if (bytecode_pc_ + serialized_length >= bytecode_limit_) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Constant data out of bounds";
+ }
+
+ buffer_view.buffer = hal::HeapBuffer::Wrap(
+ hal::MemoryType::kHostLocal, hal::BufferUsage::kAll, bytecode_pc_,
+ serialized_length);
+ bytecode_pc_ += serialized_length;
+ break;
+ }
+ case ConstantEncoding::kSplat: {
+ // Validate we have at least one element worth of data in the buffer.
+ if (bytecode_pc_ + buffer_view.element_size >= bytecode_limit_) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Constant data out of bounds";
+ }
+
+ // TODO(benvanik): replace with fancy constant pool and such.
+ // NOTE: this is not much different than if a alloc_heap+broadcast pair
+ // had been in the IR.
+ buffer_view.buffer = hal::HeapBuffer::Allocate(
+ hal::MemoryType::kHostLocal, hal::BufferUsage::kAll,
+ buffer_view.byte_length());
+ switch (buffer_view.element_size) {
+ case 1: {
+ uint8_t value = *reinterpret_cast<const uint8_t*>(bytecode_pc_);
+ RETURN_IF_ERROR(buffer_view.buffer->Fill8(value));
+ break;
+ }
+ case 2: {
+ uint16_t value = *reinterpret_cast<const uint16_t*>(bytecode_pc_);
+ RETURN_IF_ERROR(buffer_view.buffer->Fill16(value));
+ break;
+ }
+ case 4: {
+ uint32_t value = *reinterpret_cast<const uint32_t*>(bytecode_pc_);
+ RETURN_IF_ERROR(buffer_view.buffer->Fill32(value));
+ break;
+ }
+ case 8: {
+ // TODO(benvanik): add Fill64.
+ uint64_t value = *reinterpret_cast<const uint64_t*>(bytecode_pc_);
+ ASSIGN_OR_RETURN(auto mapping,
+ buffer_view.buffer->MapMemory<uint64_t>(
+ hal::MemoryAccess::kDiscardWrite));
+ auto mapped_data = mapping.mutable_contents();
+ for (int i = 0; i < mapping.size(); ++i) {
+ mapped_data[i] = value;
+ }
+ break;
+ }
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented splat element stride "
+ << buffer_view.element_size;
+ }
+ bytecode_pc_ += buffer_view.element_size;
+ break;
+ }
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented constant encoding "
+ << static_cast<int>(encoding);
+ }
+
+ return buffer_view;
+}
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/bytecode_reader.h b/vm/bytecode_reader.h
new file mode 100644
index 0000000..499d679
--- /dev/null
+++ b/vm/bytecode_reader.h
@@ -0,0 +1,169 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_BYTECODE_READER_H_
+#define IREE_VM_BYTECODE_READER_H_
+
+#include "absl/base/attributes.h"
+#include "absl/container/inlined_vector.h"
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "rt/context.h"
+#include "rt/stack.h"
+#include "rt/stack_frame.h"
+#include "schemas/bytecode/bytecode_v0.h"
+#include "vm/type.h"
+
+namespace iree {
+namespace vm {
+
+class BytecodeReader {
+ public:
+ explicit BytecodeReader(rt::Stack* stack) : stack_(stack) {}
+
+ int offset() const { return static_cast<int>(bytecode_pc_ - bytecode_base_); }
+
+ StatusOr<const uint8_t*> AdvanceOffset();
+
+ Status SwitchStackFrame(rt::StackFrame* new_stack_frame);
+ Status BranchToOffset(int32_t offset);
+
+ Status CopyInputsAndSwitchStackFrame(rt::StackFrame* src_stack_frame,
+ rt::StackFrame* dst_stack_frame);
+ Status CopyResultsAndSwitchStackFrame(rt::StackFrame* src_stack_frame,
+ rt::StackFrame* dst_stack_frame);
+ Status CopySlots();
+
+ StatusOr<hal::BufferView> ReadConstant();
+
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<int> ReadCount() {
+ return ReadValue<uint8_t>();
+ }
+
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const Type> ReadType() {
+ ASSIGN_OR_RETURN(uint8_t type_index, ReadValue<uint8_t>());
+ return Type::FromTypeIndex(type_index);
+ }
+
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const rt::Function> ReadFunction() {
+ ASSIGN_OR_RETURN(auto value, ReadValue<uint32_t>());
+ const auto& module = stack_frame_->module();
+ return module.LookupFunctionByOrdinal(rt::Function::Linkage::kInternal,
+ value);
+ }
+
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const rt::Function>
+ ReadImportFunction() {
+ ASSIGN_OR_RETURN(auto value, ReadValue<uint32_t>());
+ const auto& module = stack_frame_->module();
+ return stack_->context()->ResolveImport(&module, value);
+ }
+
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<hal::BufferView*> ReadLocal(
+ rt::Registers* registers) {
+ ASSIGN_OR_RETURN(auto value, ReadValue<uint16_t>());
+ if (value > registers->buffer_views.size()) {
+ return OutOfRangeErrorBuilder(IREE_LOC)
+ << "Out of bounds local access " << value << " of "
+ << registers->buffer_views.size();
+ }
+ return ®isters->buffer_views[value];
+ }
+
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<hal::BufferView*> ReadLocal() {
+ return ReadLocal(registers_);
+ }
+
+ Status SkipLocals(int count);
+
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint8_t> ReadUint8_t() {
+ return ReadValue<uint8_t>();
+ }
+
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint16_t> ReadUint16_t() {
+ return ReadValue<uint16_t>();
+ }
+
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<int32_t> ReadInt32() {
+ return ReadValue<int32_t>();
+ }
+
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint32_t> ReadBlockOffset() {
+ return ReadValue<uint32_t>();
+ }
+
+ template <typename T, size_t N = 8>
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<absl::InlinedVector<T, N>>
+ ReadSlotElements() {
+ ASSIGN_OR_RETURN(auto* local, ReadLocal(registers_));
+ absl::InlinedVector<T, N> result(local->shape.element_count());
+ if (sizeof(T) == local->element_size) {
+ // Fast(ish) path: requested element size matches the actual element size.
+ RETURN_IF_ERROR(
+ local->buffer->ReadData(0, result.data(), result.size() * sizeof(T)));
+ } else {
+ // Slow path: need to convert the data.
+ switch (local->element_size) {
+ case 4: {
+ ASSIGN_OR_RETURN(auto mapping, local->buffer->MapMemory<int32_t>(
+ hal::MemoryAccess::kRead));
+ for (size_t i = 0; i < result.size(); ++i) {
+ result[i] = static_cast<T>(mapping[i]);
+ }
+ break;
+ }
+ case 8: {
+ ASSIGN_OR_RETURN(auto mapping, local->buffer->MapMemory<int64_t>(
+ hal::MemoryAccess::kRead));
+ for (size_t i = 0; i < result.size(); ++i) {
+ result[i] = static_cast<T>(mapping[i]);
+ }
+ break;
+ }
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unsupported local element size: " << local->element_size;
+ }
+ }
+ return result;
+ }
+
+ Status ReadShape(Shape* out_shape);
+
+ StatusOr<Shape> ReadShapePieces();
+ StatusOr<Shape> ReadShapePieces(size_t* out_element_count);
+
+ StatusOr<absl::Span<const int32_t>> ReadIndexList();
+
+ private:
+ template <typename T>
+ ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<T> ReadValue() {
+ // TODO(benvanik): validate bounds.
+ T value = *reinterpret_cast<const T*>(bytecode_pc_);
+ bytecode_pc_ += sizeof(T);
+ return value;
+ }
+
+ rt::Stack* stack_ = nullptr;
+ rt::StackFrame* stack_frame_ = nullptr;
+ const uint8_t* bytecode_base_ = nullptr;
+ const uint8_t* bytecode_limit_ = nullptr;
+ const uint8_t* bytecode_pc_ = nullptr;
+ rt::Registers* registers_ = nullptr;
+};
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_BYTECODE_READER_H_
diff --git a/vm/bytecode_tables_interpreter.cc b/vm/bytecode_tables_interpreter.cc
new file mode 100644
index 0000000..5744bf9
--- /dev/null
+++ b/vm/bytecode_tables_interpreter.cc
@@ -0,0 +1,44 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/bytecode_tables_interpreter.h"
+
+#include "schemas/bytecode/interpreter_bytecode_v0.h"
+
+namespace iree {
+namespace vm {
+
+namespace {
+
+// Info table mapping 1:1 with bytecode ops.
+//
+// Note that we ensure the table is 256 elements long exactly to make sure
+// that unused opcodes are handled gracefully.
+static const OpcodeInfo kInfoTable[256] = {
+#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \
+ OpcodeInfo{ \
+ name, \
+ flags, \
+ {operand_encodings}, \
+ },
+ IREE_INTERPRETER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)
+#undef DECLARE_INFO
+};
+
+} // namespace
+
+OpcodeTable interpreter_opcode_table() { return kInfoTable; }
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/bytecode_tables_interpreter.h b/vm/bytecode_tables_interpreter.h
new file mode 100644
index 0000000..6407e84
--- /dev/null
+++ b/vm/bytecode_tables_interpreter.h
@@ -0,0 +1,28 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_BYTECODE_TABLES_INTERPRETER_H_
+#define IREE_VM_BYTECODE_TABLES_INTERPRETER_H_
+
+#include "vm/opcode_info.h"
+
+namespace iree {
+namespace vm {
+
+OpcodeTable interpreter_opcode_table();
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_BYTECODE_TABLES_INTERPRETER_H_
diff --git a/vm/bytecode_tables_sequencer.cc b/vm/bytecode_tables_sequencer.cc
new file mode 100644
index 0000000..02880dc
--- /dev/null
+++ b/vm/bytecode_tables_sequencer.cc
@@ -0,0 +1,44 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/bytecode_tables_sequencer.h"
+
+#include "schemas/bytecode/sequencer_bytecode_v0.h"
+
+namespace iree {
+namespace vm {
+
+namespace {
+
+// Info table mapping 1:1 with bytecode ops.
+//
+// Note that we ensure the table is 256 elements long exactly to make sure
+// that unused opcodes are handled gracefully.
+static const OpcodeInfo kInfoTable[256] = {
+#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \
+ OpcodeInfo{ \
+ name, \
+ flags, \
+ {operand_encodings}, \
+ },
+ IREE_SEQUENCER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)
+#undef DECLARE_INFO
+};
+
+} // namespace
+
+OpcodeTable sequencer_opcode_table() { return kInfoTable; }
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/bytecode_tables_sequencer.h b/vm/bytecode_tables_sequencer.h
new file mode 100644
index 0000000..0ead8f6
--- /dev/null
+++ b/vm/bytecode_tables_sequencer.h
@@ -0,0 +1,28 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_BYTECODE_TABLES_SEQUENCER_H_
+#define IREE_VM_BYTECODE_TABLES_SEQUENCER_H_
+
+#include "vm/opcode_info.h"
+
+namespace iree {
+namespace vm {
+
+OpcodeTable sequencer_opcode_table();
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_BYTECODE_TABLES_SEQUENCER_H_
diff --git a/vm/bytecode_util.cc b/vm/bytecode_util.cc
new file mode 100644
index 0000000..48deb32
--- /dev/null
+++ b/vm/bytecode_util.cc
@@ -0,0 +1,43 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/bytecode_util.h"
+
+namespace iree {
+namespace vm {
+
+absl::string_view PredicateToString(CmpIPredicate p) {
+#define PRED(index, name, str, ...) \
+ case CmpIPredicate::name: \
+ return str;
+ switch (p) {
+ IREE_CMPI_PREDICATE_LIST(PRED)
+#undef PRED
+ }
+ return "<unknown>";
+}
+
+absl::string_view PredicateToString(CmpFPredicate p) {
+#define PRED(index, name, str, ...) \
+ case CmpFPredicate::name: \
+ return str;
+ switch (p) {
+ IREE_CMPF_PREDICATE_LIST(PRED)
+#undef PRED
+ }
+ return "<unknown>";
+}
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/bytecode_util.h b/vm/bytecode_util.h
new file mode 100644
index 0000000..560ab36
--- /dev/null
+++ b/vm/bytecode_util.h
@@ -0,0 +1,31 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_BYTECODE_UTIL_H_
+#define IREE_VM_BYTECODE_UTIL_H_
+
+#include "absl/strings/string_view.h"
+#include "schemas/bytecode/bytecode_v0.h"
+
+namespace iree {
+namespace vm {
+
+absl::string_view PredicateToString(CmpIPredicate predicate);
+
+absl::string_view PredicateToString(CmpFPredicate predicate);
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_BYTECODE_UTIL_H_
diff --git a/vm/bytecode_validator.cc b/vm/bytecode_validator.cc
new file mode 100644
index 0000000..54e8a82
--- /dev/null
+++ b/vm/bytecode_validator.cc
@@ -0,0 +1,28 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/bytecode_validator.h"
+
+namespace iree {
+namespace vm {
+
+// static
+Status BytecodeValidator::Validate(const BytecodeModule& module,
+ const BytecodeDef& bytecode_def) {
+ // TODO(benvanik): validate bytecode.
+ return OkStatus();
+}
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/bytecode_validator.h b/vm/bytecode_validator.h
new file mode 100644
index 0000000..1ebb580
--- /dev/null
+++ b/vm/bytecode_validator.h
@@ -0,0 +1,37 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_BYTECODE_VALIDATOR_H_
+#define IREE_VM_BYTECODE_VALIDATOR_H_
+
+#include "base/status.h"
+#include "schemas/bytecode_def_generated.h"
+#include "vm/bytecode_module.h"
+
+namespace iree {
+namespace vm {
+
+// Validates bytecode such that success indicates the bytecode does not
+// reference undefined types, functions, or required imports and all imports can
+// be resolved with matching signatures.
+class BytecodeValidator {
+ public:
+ static Status Validate(const BytecodeModule& module,
+ const BytecodeDef& bytecode_def);
+};
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_BYTECODE_VALIDATOR_H_
diff --git a/vm/opcode_info.h b/vm/opcode_info.h
new file mode 100644
index 0000000..1b61366
--- /dev/null
+++ b/vm/opcode_info.h
@@ -0,0 +1,45 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_OPCODE_INFO_H_
+#define IREE_VM_OPCODE_INFO_H_
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "schemas/bytecode/bytecode_v0.h"
+
+namespace iree {
+namespace vm {
+
+struct OpcodeInfo {
+ const char* mnemonic;
+ OpcodeFlagBitfield flag;
+ union {
+ const char operands_value[8];
+ const OperandEncoding operands[8];
+ };
+};
+
+using OpcodeTable = absl::Span<const OpcodeInfo>;
+
+template <typename T>
+inline const OpcodeInfo& GetOpcodeInfo(OpcodeTable opcode_table, T opcode) {
+ return opcode_table[static_cast<uint8_t>(opcode)];
+}
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_OPCODE_INFO_H_
diff --git a/vm/sequencer_dispatch.cc b/vm/sequencer_dispatch.cc
new file mode 100644
index 0000000..d8e8fa0
--- /dev/null
+++ b/vm/sequencer_dispatch.cc
@@ -0,0 +1,561 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Implements a full bytecode dispatch system for sequencer ops.
+// TODO(benvanik): rework to be async against CommandBuffers.
+
+#include "vm/sequencer_dispatch.h"
+
+#include <algorithm>
+
+#include "absl/base/attributes.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_join.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "base/logging.h"
+#include "base/memory.h"
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "hal/command_queue.h"
+#include "hal/device.h"
+#include "hal/heap_buffer.h"
+#include "schemas/bytecode/sequencer_bytecode_v0.h"
+#include "vm/bytecode_module.h"
+#include "vm/bytecode_reader.h"
+#include "vm/bytecode_tables_sequencer.h"
+#include "vm/bytecode_util.h"
+#include "vm/opcode_info.h"
+
+namespace iree {
+namespace vm {
+
+namespace {
+
+using ::iree::hal::Buffer;
+using ::iree::hal::BufferView;
+
+// TODO(benvanik): remove (this should happen via predication).
+bool BufferViewIsTrue(const BufferView& buffer_view) {
+ if (buffer_view.element_size == 0 || !buffer_view.buffer ||
+ buffer_view.byte_length() == 0) {
+ return false;
+ }
+ // TODO(benvanik): map more efficiently (based on element size?).
+ auto mapping =
+ buffer_view.buffer->MapMemory<uint8_t>(hal::MemoryAccess::kRead);
+ if (!mapping.ok()) {
+ return false;
+ }
+ for (uint8_t value : mapping.ValueOrDie().contents()) {
+ if (value) return true;
+ }
+ return false;
+}
+
+// TODO(benvanik): insert fence callbacks and wait on fence.
+Status CallExternalFunction(rt::Stack* stack, const rt::Function& function) {
+ // Marshal inputs and outputs.
+ const auto* stack_frame = stack->current_frame();
+ auto buffer_views = absl::MakeSpan(stack_frame->registers().buffer_views);
+ absl::InlinedVector<hal::BufferView, 8> arguments(
+ buffer_views.begin(),
+ buffer_views.begin() + function.signature().argument_count());
+ absl::InlinedVector<hal::BufferView, 8> results(
+ buffer_views.begin() + arguments.size(), buffer_views.end());
+ return function.module()->Execute(stack, function, std::move(arguments),
+ &results);
+}
+
+// Pretty prints an array, e.g. [1, 2, 3, 4]
+inline std::string PrettyPrint(absl::Span<const int32_t> arr) {
+ return "[" + absl::StrJoin(arr, ",") + "]";
+}
+
+// Calculates the byte offset into a buffer corresponding to the indices in the
+// given shape.
+StatusOr<device_size_t> CalculateOffset(absl::Span<const int32_t> indices,
+ Shape shape, uint8_t element_size) {
+ if (shape.empty() || indices.size() > shape.size()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Indices " << PrettyPrint(indices) << " out of bounds of shape "
+ << PrettyPrint(shape.subspan());
+ }
+ device_size_t offset = 0;
+ for (int i = 0; i < indices.size(); ++i) {
+ if (indices[i] >= shape[i]) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Indices[" << i << "]=" << indices[i]
+ << " out of bounds of shape " << PrettyPrint(shape.subspan());
+ }
+ device_size_t axis_offset = indices[i];
+ for (int j = i + 1; j < shape.size(); ++j) {
+ axis_offset *= shape[j];
+ }
+ offset += axis_offset;
+ }
+ offset *= element_size;
+ return offset;
+}
+
+} // namespace
+
+Status DispatchSequence(const hal::DevicePlacement& placement, rt::Stack* stack,
+ rt::StackFrame* entry_stack_frame,
+ absl::Span<BufferView> entry_results) {
+ // Dispatch table mapping 1:1 with bytecode ops.
+ // Each entry is a label within this function that can be used for computed
+ // goto. You can find more information on computed goto here:
+ // https://eli.thegreenplace.net/2012/07/12/computed-goto-for-efficient-dispatch-tables
+ //
+ // Note that we ensure the table is 256 elements long exactly to make sure
+ // that unused opcodes are handled gracefully.
+ static const void* kDispatchTable[256] = {
+#define DECLARE_DISPATCH(ordinal, name, ...) &&_dispatch_##name,
+#define DECLARE_DISPATCH_RESERVED(ordinal, name, ...) &&_dispatch_unhandled,
+ IREE_SEQUENCER_OPCODE_LIST(DECLARE_DISPATCH, DECLARE_DISPATCH_RESERVED)
+#undef DECLARE_DISPATCH
+#undef DECLARE_DISPATCH_RESERVED
+ };
+
+ // Primary dispatch state. This is our 'native stack frame' and really just
+ // enough to make dereferencing common addresses (like the current offset)
+ // faster. You can think of this like CPU state (like PC).
+ //
+ // We hope that LLVM decides to keep these in registers (as they are touched
+ // for every instruction executed). The stack_frame will change as we call
+ // into different functions.
+ BytecodeReader reader(stack);
+ RETURN_IF_ERROR(reader.SwitchStackFrame(entry_stack_frame));
+
+#define DISPATCH_NEXT() \
+ { \
+ uint8_t opcode = *reader.AdvanceOffset().ValueOrDie(); \
+ DVLOG(1) << "Sequencer dispatching op code: " \
+ << GetOpcodeInfo(sequencer_opcode_table(), opcode).mnemonic; \
+ goto* kDispatchTable[opcode]; \
+ }
+
+#define DISPATCH_CORE_OPCODE(opcode, body) \
+ _dispatch_##opcode : {body} DISPATCH_NEXT()
+
+ DISPATCH_NEXT();
+
+ DISPATCH_CORE_OPCODE(kConstant, {
+ ASSIGN_OR_RETURN(auto value, reader.ReadConstant());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ // TODO(b/139121143): until we have full command buffers we need to do this.
+ ASSIGN_OR_RETURN(value.buffer,
+ placement.device->allocator()->AllocateConstant(
+ hal::BufferUsage::kConstant | hal::BufferUsage::kAll,
+ std::move(value.buffer)));
+ *dst_local = std::move(value);
+ });
+
+ DISPATCH_CORE_OPCODE(kCall, {
+ auto* old_stack_frame = stack->current_frame();
+ ASSIGN_OR_RETURN(const auto& target_function, reader.ReadFunction());
+ // TODO(benvanik): rework register storage interface.
+ ASSIGN_OR_RETURN(
+ const auto* function_def,
+ static_cast<const BytecodeModule*>(target_function.module())
+ ->GetFunctionDef(target_function.linkage(),
+ target_function.ordinal()));
+ ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function));
+ new_stack_frame->mutable_registers()->buffer_views.resize(
+ function_def->bytecode()->local_count());
+ RETURN_IF_ERROR(
+ reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame));
+ DVLOG(1) << "Call; stack now: " << stack->DebugString();
+ });
+
+ DISPATCH_CORE_OPCODE(kCallImport, {
+ auto* old_stack_frame = stack->current_frame();
+ ASSIGN_OR_RETURN(const auto& target_function, reader.ReadImportFunction());
+ ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function));
+ // TODO(benvanik): rework register storage interface.
+ const auto& signature = target_function.signature();
+ new_stack_frame->mutable_registers()->buffer_views.resize(
+ signature.argument_count() + signature.result_count());
+ RETURN_IF_ERROR(
+ reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame));
+ DVLOG(1) << "Call native import; stack now: " << stack->DebugString();
+ RETURN_IF_ERROR(CallExternalFunction(stack, target_function));
+ RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame,
+ new_stack_frame));
+ RETURN_IF_ERROR(stack->PopFrame());
+ DVLOG(1) << "Return from native; stack now: " << stack->DebugString();
+ });
+
+ DISPATCH_CORE_OPCODE(kCallIndirect, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented call_indirect";
+ });
+
+ DISPATCH_CORE_OPCODE(kReturn, {
+ auto* old_stack_frame = stack->current_frame();
+ auto* new_stack_frame = stack->caller_frame();
+ if (old_stack_frame == entry_stack_frame) {
+ // Returning from entry function. Marshal results from the return stmt.
+ ASSIGN_OR_RETURN(int32_t src_count, reader.ReadCount());
+ for (int i = 0; i < src_count; ++i) {
+ ASSIGN_OR_RETURN(
+ auto* src_local,
+ reader.ReadLocal(old_stack_frame->mutable_registers()));
+ entry_results[i] = std::move(*src_local);
+ }
+ DVLOG(1) << "Returning to entry";
+ return OkStatus();
+ } else if (!new_stack_frame) {
+ return FailedPreconditionErrorBuilder(IREE_LOC) << "Stack underflow";
+ }
+ RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame,
+ new_stack_frame));
+ RETURN_IF_ERROR(stack->PopFrame());
+ DVLOG(1) << "Return; stack now: " << stack->DebugString();
+ });
+
+ DISPATCH_CORE_OPCODE(kBranch, {
+ ASSIGN_OR_RETURN(int32_t offset, reader.ReadBlockOffset());
+ RETURN_IF_ERROR(reader.CopySlots());
+ RETURN_IF_ERROR(reader.BranchToOffset(offset));
+ });
+
+ DISPATCH_CORE_OPCODE(kCondBranch, {
+ // Evaluate condition first so we can do the copies as we read them for
+ // which side of the branch we take.
+ ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal());
+ bool cond_value = BufferViewIsTrue(*cond_local);
+ ASSIGN_OR_RETURN(int32_t true_offset, reader.ReadBlockOffset());
+
+ if (cond_value) {
+ RETURN_IF_ERROR(reader.CopySlots());
+ RETURN_IF_ERROR(reader.BranchToOffset(true_offset));
+ } else {
+ ASSIGN_OR_RETURN(int32_t true_op_count, reader.ReadCount());
+ RETURN_IF_ERROR(reader.SkipLocals(2 * true_op_count));
+ ASSIGN_OR_RETURN(int32_t false_offset, reader.ReadBlockOffset());
+
+ RETURN_IF_ERROR(reader.CopySlots());
+ RETURN_IF_ERROR(reader.BranchToOffset(false_offset));
+ }
+ });
+
+ DISPATCH_CORE_OPCODE(kDynamicDispatch, {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented dynamic_dispatch";
+ });
+
+ DISPATCH_CORE_OPCODE(kStaticDispatch, {
+ // TODO(benvanik): the real sequencer :)
+ ASSIGN_OR_RETURN(auto dispatch_ordinal, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto export_ordinal, reader.ReadUint16_t());
+ ASSIGN_OR_RETURN(
+ const auto* multi_arch_executable_def,
+ static_cast<const BytecodeModule&>(stack->current_frame()->module())
+ .LookupMultiArchExecutable(dispatch_ordinal));
+ if (export_ordinal >= multi_arch_executable_def->entry_point_count()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Invalid executable export ordinal " << export_ordinal;
+ }
+ auto* executable_def = multi_arch_executable_def->executables()->Get(0);
+ hal::ExecutableSpec executable_spec;
+ executable_spec.format = executable_def->format();
+ executable_spec.executable_data = absl::Span<const uint8_t>(
+ executable_def->contents()->data(), executable_def->contents()->size());
+ auto executable_cache = placement.device->CreateExecutableCache();
+ ref_ptr<hal::Executable> executable;
+ for (auto* executable_def : *multi_arch_executable_def->executables()) {
+ if (!executable_cache->CanPrepareFormat(executable_def->format())) {
+ continue;
+ }
+ hal::ExecutableSpec executable_spec;
+ executable_spec.format = executable_def->format();
+ executable_spec.executable_data =
+ absl::Span<const uint8_t>(executable_def->contents()->data(),
+ executable_def->contents()->size());
+ ASSIGN_OR_RETURN(executable,
+ executable_cache->PrepareExecutable(
+ hal::ExecutableCachingMode::kDefault |
+ hal::ExecutableCachingMode::kAliasProvidedData,
+ executable_spec),
+ _.LogError());
+ break;
+ }
+ if (!executable) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "No executable found for the current driver";
+ }
+
+ ASSIGN_OR_RETURN(int workload_x, reader.ReadInt32());
+ ASSIGN_OR_RETURN(int workload_y, reader.ReadInt32());
+ ASSIGN_OR_RETURN(int workload_z, reader.ReadInt32());
+
+ std::vector<hal::BufferBinding> bindings;
+ ASSIGN_OR_RETURN(int input_count, reader.ReadCount());
+ for (int i = 0; i < input_count; ++i) {
+ ASSIGN_OR_RETURN(auto* input_local, reader.ReadLocal());
+ bindings.push_back(hal::BufferBinding(
+ input_local->buffer->allowed_access() & hal::MemoryAccess::kAll,
+ *input_local));
+ }
+ ASSIGN_OR_RETURN(int output_count, reader.ReadCount());
+ for (int i = 0; i < output_count; ++i) {
+ ASSIGN_OR_RETURN(auto* output_local, reader.ReadLocal());
+ bindings.push_back(
+ hal::BufferBinding(hal::MemoryAccess::kWrite, *output_local));
+ }
+ ASSIGN_OR_RETURN(int result_count, reader.ReadCount());
+ CHECK_EQ(0, result_count) << "Results not yet implemented";
+
+ ASSIGN_OR_RETURN(
+ auto cmd,
+ placement.device->CreateCommandBuffer(
+ hal::CommandBufferMode::kOneShot,
+ hal::CommandCategory::kTransfer | hal::CommandCategory::kDispatch),
+ _.LogError());
+ RETURN_IF_ERROR(cmd->Begin());
+ hal::DispatchRequest dispatch_request;
+ dispatch_request.executable = executable.get();
+ dispatch_request.entry_point = export_ordinal;
+ dispatch_request.workload[0] = workload_x;
+ dispatch_request.workload[1] = workload_y;
+ dispatch_request.workload[2] = workload_z;
+ dispatch_request.bindings = bindings;
+ RETURN_IF_ERROR(cmd->Dispatch(dispatch_request));
+ RETURN_IF_ERROR(cmd->End());
+ auto* cmd_ptr = cmd.get();
+
+ auto* queue = placement.device->dispatch_queues().front();
+ hal::SubmissionBatch batch;
+ batch.command_buffers = absl::MakeConstSpan(&cmd_ptr, 1);
+ ASSIGN_OR_RETURN(auto fence, placement.device->CreateFence(0u));
+ RETURN_IF_ERROR(queue->Submit(batch, {fence.get(), 1u}));
+ RETURN_IF_ERROR(placement.device->WaitAllFences({{fence.get(), 1u}},
+ absl::InfiniteFuture()));
+ });
+
+ DISPATCH_CORE_OPCODE(kAllocStatic, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_static";
+ });
+
+ DISPATCH_CORE_OPCODE(kAllocStack, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_stack";
+ });
+
+ DISPATCH_CORE_OPCODE(kAllocStackInit, {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unimplemented alloc_stack_init";
+ });
+
+ DISPATCH_CORE_OPCODE(kAllocHeap, {
+ ASSIGN_OR_RETURN(auto heap_type, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto type, reader.ReadType());
+ size_t element_size = type.element_size();
+
+ // TODO(benvanik): more efficient reading and storage.
+ size_t element_count = 0;
+ ASSIGN_OR_RETURN(auto shape, reader.ReadShapePieces(&element_count));
+ size_t allocation_size = element_size * element_count;
+
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ dst_local->element_size = element_size;
+ dst_local->shape = shape;
+
+ // TODO(benvanik): pick an allocator and use that instead.
+ CHECK_EQ(heap_type, 0);
+ auto* allocator = placement.device->allocator();
+ ASSIGN_OR_RETURN(
+ dst_local->buffer,
+ allocator->Allocate(
+ hal::MemoryType::kHostLocal | hal::MemoryType::kDeviceVisible,
+ hal::BufferUsage::kAll, allocation_size));
+ });
+
+ DISPATCH_CORE_OPCODE(kDiscard, {
+ // NOTE: if we were an encoder we would actually discard the buffer.
+ ASSIGN_OR_RETURN(auto* local, reader.ReadLocal());
+ *local = {};
+ });
+
+ DISPATCH_CORE_OPCODE(kComputeRange, {
+ ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto element_size, reader.ReadUint8_t());
+ ASSIGN_OR_RETURN(auto indices, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_offset_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_length_local, reader.ReadLocal());
+
+ Shape shape(shape_data);
+ ASSIGN_OR_RETURN(device_size_t dst_offset,
+ CalculateOffset(indices, shape, element_size));
+ RETURN_IF_ERROR(
+ dst_offset_local->buffer->WriteData(0, &dst_offset, sizeof(int32_t)));
+
+ // A buffer range can only be computed for contiguous memory. To ensure that
+ // this only requests such, we validate that the offset in the buffer
+ // between the start and end indices is the same as the requested size.
+ device_size_t dst_length = element_size;
+ for (int i = 0; i < lengths.size(); ++i) {
+ dst_length *= lengths[i];
+ indices[i] += lengths[i] - 1;
+ }
+ ASSIGN_OR_RETURN(auto end_offset,
+ CalculateOffset(indices, shape, element_size));
+ auto offset_based_length = end_offset - dst_offset + element_size;
+ if (dst_length != offset_based_length) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Cannot compute range for non-contiguous region of memory;"
+ << " shape: " << PrettyPrint(shape.subspan())
+ << " indices: " << PrettyPrint(indices)
+ << " lengths: " << PrettyPrint(lengths);
+ }
+ RETURN_IF_ERROR(
+ dst_length_local->buffer->WriteData(0, &dst_length, sizeof(int32_t)));
+ });
+
+ DISPATCH_CORE_OPCODE(kShape, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ RETURN_IF_ERROR(dst_local->buffer->WriteData(
+ 0, src_local->shape.subspan().data(),
+ src_local->shape.subspan().size() * sizeof(int32_t)));
+ });
+
+ DISPATCH_CORE_OPCODE(kLength, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ int32_t length = src_local->shape.element_count();
+ RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &length, sizeof(int32_t)));
+ });
+
+ DISPATCH_CORE_OPCODE(kDynamicSlice, {
+ // TODO(b/139299169): implement indirect copies to avoid CPU readback.
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented dynamic_slice";
+ });
+
+ DISPATCH_CORE_OPCODE(kStaticSlice, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto offset, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto length, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto type, reader.ReadType());
+ ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ Shape new_shape = Shape{shape_data};
+ if (new_shape.element_count() * type.element_size() != length) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "New element count " << new_shape.element_count()
+ << " != length slice " << length;
+ }
+ ASSIGN_OR_RETURN(dst_local->buffer,
+ Buffer::Subspan(src_local->buffer, offset, length));
+ dst_local->shape = new_shape;
+ dst_local->element_size = type.element_size();
+ });
+
+ DISPATCH_CORE_OPCODE(kDynamicCopy, {
+ // TODO(b/139299169): implement indirect copies to avoid CPU readback.
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto src_offset_span, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dst_offset_span, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto length_span, reader.ReadSlotElements<int32_t>());
+ RETURN_IF_ERROR(dst_local->buffer->CopyData(
+ dst_offset_span.front(), src_local->buffer.get(),
+ src_offset_span.front(), length_span.front()));
+ });
+
+ DISPATCH_CORE_OPCODE(kStaticCopy, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto src_offset, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dst_offset, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto length, reader.ReadInt32());
+ RETURN_IF_ERROR(dst_local->buffer->CopyData(
+ dst_offset, src_local->buffer.get(), src_offset, length));
+ });
+
+ DISPATCH_CORE_OPCODE(kDynamicFill, {
+ // TODO(b/139299169): implement indirect fills to avoid CPU readback.
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented dynamic_fill";
+ });
+
+ DISPATCH_CORE_OPCODE(kStaticFill, {
+ ASSIGN_OR_RETURN(auto value, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto dst_offset, reader.ReadInt32());
+ ASSIGN_OR_RETURN(auto length, reader.ReadInt32());
+ RETURN_IF_ERROR(dst_local->buffer->Fill32(dst_offset, length, value));
+ });
+
+ DISPATCH_CORE_OPCODE(kClone, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ dst_local->element_size = src_local->element_size;
+ dst_local->shape = src_local->shape;
+ ASSIGN_OR_RETURN(dst_local->buffer, placement.device->allocator()->Allocate(
+ src_local->buffer->memory_type(),
+ src_local->buffer->usage(),
+ src_local->buffer->byte_length()));
+ RETURN_IF_ERROR(dst_local->buffer->CopyData(0, src_local->buffer.get()));
+ });
+
+ DISPATCH_CORE_OPCODE(kAssign, {
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ *dst_local = *src_local;
+ });
+
+ DISPATCH_CORE_OPCODE(kCondAssign, {
+ ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ *dst_local = BufferViewIsTrue(*cond_local) ? *lhs_local : *rhs_local;
+ });
+
+ DISPATCH_CORE_OPCODE(kReshape, {
+ // TODO(benvanik): more logic required if strides differ.
+ ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal());
+ ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>());
+ ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal());
+ Shape new_shape = Shape{shape_data};
+ if (src_local->shape.element_count() != new_shape.element_count()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "New element count " << new_shape.element_count()
+ << " != source element count " << src_local->shape.element_count();
+ }
+ dst_local->shape = new_shape;
+ dst_local->buffer = add_ref(src_local->buffer);
+ dst_local->element_size = src_local->element_size;
+ });
+
+ DISPATCH_CORE_OPCODE(kTrace, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented trace";
+ });
+
+ DISPATCH_CORE_OPCODE(kBreak, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented break";
+ });
+
+ DISPATCH_CORE_OPCODE(kCondBreak, {
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented cond_break";
+ });
+
+_dispatch_unhandled:
+ // TODO(benvanik): better tracing.
+ return UnimplementedErrorBuilder(IREE_LOC) << "Unknown dispatch opcode";
+}
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/sequencer_dispatch.h b/vm/sequencer_dispatch.h
new file mode 100644
index 0000000..814f1d5
--- /dev/null
+++ b/vm/sequencer_dispatch.h
@@ -0,0 +1,35 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_SEQUENCER_DISPATCH_H_
+#define IREE_VM_SEQUENCER_DISPATCH_H_
+
+#include "base/status.h"
+#include "hal/buffer_view.h"
+#include "hal/device_placement.h"
+#include "rt/stack.h"
+#include "rt/stack_frame.h"
+
+namespace iree {
+namespace vm {
+
+// TODO(benvanik): API that supports yielding.
+Status DispatchSequence(const hal::DevicePlacement& placement, rt::Stack* stack,
+ rt::StackFrame* entry_stack_frame,
+ absl::Span<hal::BufferView> entry_results);
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_SEQUENCER_DISPATCH_H_
diff --git a/vm/sequencer_module.cc b/vm/sequencer_module.cc
new file mode 100644
index 0000000..52b1ca7
--- /dev/null
+++ b/vm/sequencer_module.cc
@@ -0,0 +1,112 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/sequencer_module.h"
+
+#include "absl/memory/memory.h"
+#include "base/status.h"
+#include "base/tracing.h"
+#include "hal/buffer_view.h"
+#include "rt/context.h"
+#include "rt/instance.h"
+#include "vm/bytecode_tables_sequencer.h"
+#include "vm/sequencer_dispatch.h"
+
+namespace iree {
+namespace vm {
+
+namespace {
+
+using ::iree::hal::BufferView;
+using ::iree::rt::Function;
+using ::iree::rt::Module;
+
+} // namespace
+
+// static
+StatusOr<ref_ptr<rt::Module>> SequencerModule::FromDef(
+ const ModuleDef& module_def) {
+ ASSIGN_OR_RETURN(auto module_file, ModuleFile::Create(&module_def, []() {}));
+ return FromFile(std::move(module_file));
+}
+
+// static
+StatusOr<ref_ptr<rt::Module>> SequencerModule::FromFile(
+ std::unique_ptr<ModuleFile> module_file) {
+ if (module_file->root() == nullptr) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "No root ModuleDef present";
+ }
+ const auto& module_def = *module_file->root();
+
+ // Validates the structure of the module (but not bytecode).
+ // This ensures we don't have flatbuffer vectors will null entries, etc.
+ RETURN_IF_ERROR(BytecodeModule::ValidateStructure(module_def));
+
+ auto module = assign_ref(new SequencerModule(std::move(module_file)));
+
+ // TODO(benvanik): validate internals here? or make explicit?
+
+ return {std::move(module)};
+}
+
+SequencerModule::SequencerModule(std::unique_ptr<ModuleFile> module_file)
+ : BytecodeModule(std::move(module_file), sequencer_opcode_table()) {}
+
+SequencerModule::~SequencerModule() = default;
+
+Status SequencerModule::Execute(
+ rt::Stack* stack, const Function function,
+ absl::InlinedVector<hal::BufferView, 8> arguments,
+ absl::InlinedVector<hal::BufferView, 8>* results) const {
+ IREE_TRACE_SCOPE0("SequencerModule::Execute");
+
+ // Push stack frame for the function we are calling.
+ ASSIGN_OR_RETURN(auto* callee_stack_frame, stack->PushFrame(function));
+
+ // TODO(benvanik): rework register storage interface.
+ ASSIGN_OR_RETURN(const auto* function_def,
+ GetFunctionDef(function.linkage(), function.ordinal()));
+ auto* registers = callee_stack_frame->mutable_registers();
+ registers->buffer_views.resize(function_def->bytecode()->local_count());
+
+ // Marshal input arguments.
+ for (int i = 0; i < arguments.size(); ++i) {
+ auto arg = arguments[i];
+ auto expected_arg_type = function_def->type()->inputs()->Get(i);
+ RETURN_IF_ERROR(BytecodeModule::ValidateArgType(
+ arg, *expected_arg_type->type_union_as_MemRefTypeDef()))
+ << "Function " << function.name() << " argument " << i;
+ registers->buffer_views[i] = std::move(arg);
+ }
+
+ // TODO(benvanik): change to:
+ // get command queue (any command queue)
+ // make command buffer
+ // record dispatch
+ // submit
+ // wait on fence
+ ASSIGN_OR_RETURN(
+ auto placement,
+ stack->context()->instance()->device_manager()->ResolvePlacement({}));
+ RETURN_IF_ERROR(DispatchSequence(placement, stack, callee_stack_frame,
+ absl::MakeSpan(*results)));
+
+ // Pop the callee frame to balance out the stack.
+ RETURN_IF_ERROR(stack->PopFrame());
+
+ return OkStatus();
+}
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/sequencer_module.h b/vm/sequencer_module.h
new file mode 100644
index 0000000..6be5464
--- /dev/null
+++ b/vm/sequencer_module.h
@@ -0,0 +1,46 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_SEQUENCER_MODULE_H_
+#define IREE_VM_SEQUENCER_MODULE_H_
+
+#include <memory>
+
+#include "vm/bytecode_module.h"
+
+namespace iree {
+namespace vm {
+
+// A module using the sequencer bytecode ops.
+class SequencerModule final : public BytecodeModule {
+ public:
+ static StatusOr<ref_ptr<rt::Module>> FromDef(const ModuleDef& module_def);
+ static StatusOr<ref_ptr<rt::Module>> FromFile(
+ std::unique_ptr<ModuleFile> module_file);
+
+ ~SequencerModule() override;
+
+ Status Execute(
+ rt::Stack* stack, const rt::Function function,
+ absl::InlinedVector<hal::BufferView, 8> arguments,
+ absl::InlinedVector<hal::BufferView, 8>* results) const override;
+
+ private:
+ explicit SequencerModule(std::unique_ptr<ModuleFile> module_file);
+};
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_SEQUENCER_MODULE_H_
diff --git a/vm/source_map_resolver.cc b/vm/source_map_resolver.cc
new file mode 100644
index 0000000..c5da2f7
--- /dev/null
+++ b/vm/source_map_resolver.cc
@@ -0,0 +1,194 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/source_map_resolver.h"
+
+#include "base/flatbuffer_util.h"
+#include "base/status.h"
+#include "schemas/source_map_def_generated.h"
+
+namespace iree {
+namespace vm {
+
+namespace {
+
+Status PrintLocation(const SourceMapResolver& source_map,
+ const FunctionSourceMapDef& function_source_map,
+ const LocationDef& location, std::ostream* stream);
+
+Status PrintFileLocation(const SourceMapResolver& source_map,
+ const FunctionSourceMapDef& function_source_map,
+ const FileLocationDef& location,
+ std::ostream* stream) {
+ ASSIGN_OR_RETURN(auto filename,
+ source_map.GetUniqueString(location.filename()));
+ *stream << filename << ":" << location.line() << ":" << location.column();
+ return OkStatus();
+}
+
+Status PrintNameLocation(const SourceMapResolver& source_map,
+ const FunctionSourceMapDef& function_source_map,
+ const NameLocationDef& location,
+ std::ostream* stream) {
+ ASSIGN_OR_RETURN(auto name, source_map.GetUniqueString(location.name()));
+ *stream << "\"" << name << "\"";
+ return OkStatus();
+}
+
+Status PrintCallSiteLocation(const SourceMapResolver& source_map,
+ const FunctionSourceMapDef& function_source_map,
+ const CallSiteLocationDef& location,
+ std::ostream* stream) {
+ *stream << "(callsites todo)";
+ return OkStatus();
+}
+
+Status PrintFusedLocation(const SourceMapResolver& source_map,
+ const FunctionSourceMapDef& function_source_map,
+ const FusedLocationDef& location,
+ std::ostream* stream) {
+ *stream << "fused[";
+ if (location.locations()) {
+ for (int i = 0; i < location.locations()->size(); ++i) {
+ if (i > 0) *stream << ", ";
+ int location_ordinal = location.locations()->Get(i);
+ const auto& child_location =
+ *function_source_map.location_table()->Get(location_ordinal);
+ RETURN_IF_ERROR(PrintLocation(source_map, function_source_map,
+ child_location, stream));
+ }
+ }
+ *stream << "]";
+ return OkStatus();
+}
+
+Status PrintLocation(const SourceMapResolver& source_map,
+ const FunctionSourceMapDef& function_source_map,
+ const LocationDef& location, std::ostream* stream) {
+ switch (location.location_union_type()) {
+ case LocationDefUnion::FileLocationDef:
+ return PrintFileLocation(source_map, function_source_map,
+ *location.location_union_as_FileLocationDef(),
+ stream);
+ case LocationDefUnion::NameLocationDef:
+ return PrintNameLocation(source_map, function_source_map,
+ *location.location_union_as_NameLocationDef(),
+ stream);
+ case LocationDefUnion::CallSiteLocationDef:
+ return PrintCallSiteLocation(
+ source_map, function_source_map,
+ *location.location_union_as_CallSiteLocationDef(), stream);
+ case LocationDefUnion::FusedLocationDef:
+ return PrintFusedLocation(source_map, function_source_map,
+ *location.location_union_as_FusedLocationDef(),
+ stream);
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unhandled location type "
+ << static_cast<int>(location.location_union_type());
+ }
+}
+
+} // namespace
+
+// static
+SourceMapResolver SourceMapResolver::FromModule(const ModuleDef& module_def) {
+ if (module_def.source_map()) {
+ return SourceMapResolver{*module_def.source_map()};
+ }
+ return {};
+}
+
+StatusOr<absl::string_view> SourceMapResolver::GetUniqueString(
+ int string_index) const {
+ if (empty()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "No source map present";
+ }
+ const auto* string_table = source_map_def_->string_table();
+ if (string_table && string_table->size() > string_index) {
+ return WrapString(string_table->Get(string_index));
+ }
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "String index " << string_index << " not present in string table";
+}
+
+StatusOr<const FunctionSourceMapDef*> SourceMapResolver::GetFunctionSourceMap(
+ int function_ordinal) const {
+ if (empty()) {
+ return NotFoundErrorBuilder(IREE_LOC) << "No source map present";
+ }
+ const auto* function_table = source_map_def_->function_table();
+ if (function_table && function_table->size() > function_ordinal) {
+ const auto* function_source_map = function_table->Get(function_ordinal);
+ if (function_source_map && function_source_map->location_table() &&
+ function_source_map->bytecode_map()) {
+ return function_source_map;
+ }
+ }
+ return NotFoundErrorBuilder(IREE_LOC)
+ << "Function ordinal " << function_ordinal
+ << " source map not present in function table";
+}
+
+absl::optional<rt::SourceLocation> SourceMapResolver::ResolveFunctionOffset(
+ const rt::Function& function, rt::SourceOffset offset) {
+ if (empty()) return absl::nullopt;
+ auto function_source_map_or = GetFunctionSourceMap(function.ordinal());
+ if (!function_source_map_or.ok()) {
+ return absl::nullopt;
+ }
+ const auto* function_source_map = function_source_map_or.ValueOrDie();
+ const auto* bytecode_map = function_source_map->bytecode_map();
+ if (!bytecode_map) return absl::nullopt;
+
+ // TODO(benvanik): allow fuzzy offset matching/table sparsity.
+ int location_ordinal = -1;
+ for (const auto* map_loc : *bytecode_map) {
+ if (map_loc->offset() == offset) {
+ location_ordinal = map_loc->location();
+ break;
+ }
+ }
+ if (location_ordinal == -1) {
+ return absl::nullopt;
+ }
+
+ return rt::SourceLocation(this,
+ {
+ reinterpret_cast<uint64_t>(function_source_map),
+ static_cast<uint64_t>(location_ordinal),
+ });
+}
+
+void SourceMapResolver::PrintSourceLocation(
+ rt::SourceResolverArgs resolver_args, std::ostream* stream) const {
+ if (empty()) {
+ *stream << "<unknown>";
+ return;
+ }
+
+ auto* function_source_map =
+ reinterpret_cast<FunctionSourceMapDef*>(resolver_args[0]);
+ int location_ordinal = static_cast<int>(resolver_args[1]);
+
+ const auto& location =
+ *function_source_map->location_table()->Get(location_ordinal);
+ auto status = PrintLocation(*this, *function_source_map, location, stream);
+ if (!status.ok()) {
+ *stream << status;
+ }
+}
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/source_map_resolver.h b/vm/source_map_resolver.h
new file mode 100644
index 0000000..63a9f01
--- /dev/null
+++ b/vm/source_map_resolver.h
@@ -0,0 +1,57 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_SOURCE_MAP_RESOLVER_H_
+#define IREE_VM_SOURCE_MAP_RESOLVER_H_
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "base/status.h"
+#include "rt/source_resolver.h"
+#include "schemas/module_def_generated.h"
+#include "schemas/source_map_def_generated.h"
+
+namespace iree {
+namespace vm {
+
+class SourceMapResolver final : public rt::SourceResolver {
+ public:
+ static SourceMapResolver FromModule(const ModuleDef& module_def);
+
+ SourceMapResolver() = default;
+ explicit SourceMapResolver(const SourceMapDef& source_map_def)
+ : source_map_def_(&source_map_def) {}
+
+ bool empty() const { return source_map_def_ == nullptr; }
+ const SourceMapDef* def() const { return source_map_def_; }
+
+ StatusOr<absl::string_view> GetUniqueString(int string_index) const;
+
+ StatusOr<const FunctionSourceMapDef*> GetFunctionSourceMap(
+ int function_ordinal) const;
+
+ absl::optional<rt::SourceLocation> ResolveFunctionOffset(
+ const rt::Function& function, rt::SourceOffset offset) override;
+
+ void PrintSourceLocation(rt::SourceResolverArgs resolver_args,
+ std::ostream* stream) const override;
+
+ private:
+ const SourceMapDef* source_map_def_ = nullptr;
+};
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_SOURCE_MAP_RESOLVER_H_
diff --git a/vm/type.cc b/vm/type.cc
new file mode 100644
index 0000000..8906371
--- /dev/null
+++ b/vm/type.cc
@@ -0,0 +1,64 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "vm/type.h"
+
+#include "base/status.h"
+
+namespace iree {
+namespace vm {
+
+// static
+StatusOr<const Type> Type::FromTypeIndex(uint8_t type_index) {
+ // Currently we only support the builtin types.
+ if (type_index == static_cast<uint8_t>(BuiltinType::kOpaque)) {
+ return Type(type_index);
+ } else if (type_index < kBuiltinTypeCount) {
+ return Type(type_index);
+ }
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Type index " << static_cast<int>(type_index) << " not supported";
+}
+
+// static
+const Type Type::FromBuiltin(BuiltinType type) {
+ return Type(static_cast<uint8_t>(type));
+}
+
+std::string Type::DebugString() const {
+ switch (type_index_) {
+#define TYPE_NAME(index, name, str, size) \
+ case index: \
+ return str;
+ IREE_TYPE_LIST(TYPE_NAME)
+#undef TYPE_NAME
+ default:
+ return "<invalid>";
+ }
+}
+
+size_t Type::element_size() const {
+ switch (type_index_) {
+#define TYPE_SIZE(index, name, str, size) \
+ case index: \
+ return size;
+ IREE_TYPE_LIST(TYPE_SIZE)
+#undef TYPE_SIZE
+ default:
+ return 0;
+ }
+}
+
+} // namespace vm
+} // namespace iree
diff --git a/vm/type.h b/vm/type.h
new file mode 100644
index 0000000..0de5357
--- /dev/null
+++ b/vm/type.h
@@ -0,0 +1,65 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_VM_TYPE_H_
+#define IREE_VM_TYPE_H_
+
+#include "base/status.h"
+#include "schemas/bytecode/bytecode_v0.h"
+#include "schemas/type_def_generated.h"
+
+namespace iree {
+namespace vm {
+
+class Type {
+ public:
+ static StatusOr<const Type> FromTypeIndex(uint8_t type_index);
+ static const Type FromBuiltin(BuiltinType type);
+
+ std::string DebugString() const;
+
+ uint8_t type_index() const { return type_index_; }
+
+ bool is_opaque() const {
+ return type_index_ == static_cast<uint8_t>(BuiltinType::kOpaque);
+ }
+ bool is_builtin() const { return !is_opaque(); }
+ BuiltinType builtin_type() const {
+ DCHECK(is_builtin());
+ return static_cast<BuiltinType>(type_index_);
+ }
+
+ size_t element_size() const;
+
+ private:
+ explicit Type(uint8_t type_index) : type_index_(type_index) {}
+
+ uint8_t type_index_;
+};
+
+inline bool operator==(const Type& a, const Type& b) {
+ return a.type_index() == b.type_index();
+}
+
+inline bool operator!=(const Type& a, const Type& b) { return !(a == b); }
+
+inline std::ostream& operator<<(std::ostream& stream, const Type& type) {
+ stream << type.DebugString();
+ return stream;
+}
+
+} // namespace vm
+} // namespace iree
+
+#endif // IREE_VM_TYPE_H_